In [12]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
from peft import PeftModel
import torch

In [13]:
# ✅ Load Base Model
base_model_name = "xlm-roberta-base"
model = AutoModelForMaskedLM.from_pretrained(base_model_name)

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaForMaskedLM: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing XLMRobertaForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
# ✅ Load LoRA Weights
lora_checkpoint_path = "model-variants/models/XLM-R_BPE"
model = PeftModel.from_pretrained(model, lora_checkpoint_path)

In [15]:
# ✅ Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(lora_checkpoint_path)

In [16]:
# ✅ Define Inference Function Using Masked Token Prediction
def generate_predictions(model, tokenizer, sentences, top_k=5):
    model.eval()  # Set model to evaluation mode

    # Detect correct mask token
    mask_token = tokenizer.mask_token  # Detects if it's [MASK] or <mask>

    predictions = {}

    for text in sentences:
        if mask_token not in text:
            print(f"⚠️ Warning: No `{mask_token}` token found in '{text}'. Ensure the sentence has a masked token.")
            continue

        # Tokenize input
        inputs = tokenizer(text, return_tensors="pt")
        mask_index = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]

        if mask_index.numel() == 0:
            print(f"⚠️ Warning: No `{mask_token}` token found in '{text}'")
            continue

        # Forward pass through the model
        with torch.no_grad():
            outputs = model(**inputs)

        logits = outputs.logits  # (batch_size, seq_len, vocab_size)
        
        for idx in mask_index:
            token_logits = logits[0, idx, :]  # Get logits for masked token position
            top_k_tokens = torch.topk(token_logits, top_k, dim=-1).indices
            top_k_words = [tokenizer.decode([token]) for token in top_k_tokens]

            # Store results
            predictions[text] = top_k_words

    return predictions


In [17]:
sentences = [
    "This is a <mask> example.",  # Correct mask token for XLM-R
    "ဒီစာသားက <mask> ဖြစ်ပါတယ်။"  # Correct mask token for XLM-R
]

In [18]:
xlmr_predictions = generate_predictions(model, tokenizer, sentences)

In [19]:
xlmr_predictions

{'This is a <mask> example.': ['', 'နည်း', 'အ', 'ဥပမာ', 'ပုံ'],
 'ဒီစာသားက <mask> ဖြစ်ပါတယ်။': ['', 'က', 'စာ', 'က', 'အ']}