In [11]:
!pip install torch transformers peft datasets



In [12]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForMaskedLM
from peft import LoraConfig, get_peft_model
import torch.nn.functional as F

In [13]:
class SpellCorrectionDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_len=128):
        with open(file_path, "r", encoding="utf-8") as f:
            self.data = json.load(f)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        encoding = self.tokenizer(
            item["sentence"],
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )

        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()

        mask_index = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)[0].item()

        candidate_ids = self.tokenizer.convert_tokens_to_ids(item["candidates"])
        label_index = item["candidates"].index(item["label"])

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "mask_index": mask_index,
            "candidate_ids": torch.tensor(candidate_ids),
            "label_index": torch.tensor(label_index)
        }


In [14]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
base_model = AutoModelForMaskedLM.from_pretrained("microsoft/MiniLM-L12-H384-uncased")


Some weights of BertForMaskedLM were not initialized from the model checkpoint at microsoft/MiniLM-L12-H384-uncased and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    bias="none",
    task_type="TOKEN_CLS"
)

model = get_peft_model(base_model, lora_config)
model.train()


PeftModelForTokenClassification(
  (base_model): LoraModel(
    (model): BertForMaskedLM(
      (bert): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 384, padding_idx=0)
          (position_embeddings): Embedding(512, 384)
          (token_type_embeddings): Embedding(2, 384)
          (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-11): 12 x BertLayer(
              (attention): BertAttention(
                (self): BertSdpaSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=384, out_features=384, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_feat

In [16]:
dataset = SpellCorrectionDataset("dataset.json", tokenizer)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)


In [18]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

for epoch in range(3):
    total_loss = 0

    for batch in dataloader:
        optimizer.zero_grad()

        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"]
        )

        logits = outputs.logits

        mask_index = batch["mask_index"].item()
        candidate_ids = batch["candidate_ids"].squeeze()
        label_index = batch["label_index"].item()

        mask_logits = logits[0, mask_index]            # [vocab_size]
        restricted_logits = mask_logits[candidate_ids] # [num_candidates]

        loss = F.cross_entropy(
            restricted_logits.unsqueeze(0),
            torch.tensor([label_index])
        )

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} | Loss: {total_loss / len(dataloader):.4f}")


Epoch 1 | Loss: 0.7122
Epoch 2 | Loss: 0.7287
Epoch 3 | Loss: 0.5696


In [19]:
def predict(sentence, candidates):
    inputs = tokenizer(sentence, return_tensors="pt")
    mask_index = (inputs["input_ids"][0] == tokenizer.mask_token_id).nonzero(as_tuple=True)[0].item()

    with torch.no_grad():
        logits = model(**inputs).logits

    candidate_ids = tokenizer.convert_tokens_to_ids(candidates)
    restricted_logits = logits[0, mask_index, candidate_ids]

    return candidates[torch.argmax(restricted_logits).item()]


In [20]:
sentence = "I went to the [MASK] to buy groceries."
candidates = ["market", "marcet", "markit"]

print("Predicted word:", predict(sentence, candidates))




Predicted word: market


In [21]:
model.eval()

PeftModelForTokenClassification(
  (base_model): LoraModel(
    (model): BertForMaskedLM(
      (bert): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 384, padding_idx=0)
          (position_embeddings): Embedding(512, 384)
          (token_type_embeddings): Embedding(2, 384)
          (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-11): 12 x BertLayer(
              (attention): BertAttention(
                (self): BertSdpaSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=384, out_features=384, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_feat

In [22]:
def test_spell_correction(sentence, candidates):
    """
    sentence  : string containing [MASK]
    candidates: list of candidate words
    """

    inputs = tokenizer(sentence, return_tensors="pt")

    # Find [MASK] index
    mask_index = (inputs["input_ids"][0] == tokenizer.mask_token_id).nonzero(as_tuple=True)[0].item()

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Convert candidates to token IDs
    candidate_ids = tokenizer.convert_tokens_to_ids(candidates)

    # Get logits only for [MASK] token
    mask_logits = logits[0, mask_index]              # [vocab_size]
    restricted_logits = mask_logits[candidate_ids]  # [num_candidates]

    # Pick best candidate
    predicted_index = torch.argmax(restricted_logits).item()
    predicted_word = candidates[predicted_index]

    return predicted_word


In [23]:
tests_en = [
    ("She loves to drink [MASK] in the morning.", ["coffee", "cofee", "coffi"]),
    ("Please close the [MASK] before leaving.", ["door", "doar", "dor"]),
    ("He is going to the [MASK] for higher studies.", ["college", "collage", "colage"]),
    ("I forgot my [MASK] at home.", ["wallet", "walet", "vallet"]),
]

for sent, cands in tests_en:
    print("Sentence :", sent)
    print("Prediction:", test_spell_correction(sent, cands))
    print("-" * 50)


Sentence : She loves to drink [MASK] in the morning.
Prediction: coffee
--------------------------------------------------
Sentence : Please close the [MASK] before leaving.
Prediction: door
--------------------------------------------------
Sentence : He is going to the [MASK] for higher studies.
Prediction: college
--------------------------------------------------
Sentence : I forgot my [MASK] at home.
Prediction: wallet
--------------------------------------------------


In [24]:
tests_ko = [
    ("나는 매일 아침 [MASK]를 마신다.", ["커피", "코피", "커피이"]),
    ("그는 학교에 [MASK] 갔다.", ["갔다", "같다", "갓다"]),
    ("오늘 날씨가 정말 [MASK].", ["좋다", "조타", "좃다"]),
]

for sent, cands in tests_ko:
    print("Sentence :", sent)
    print("Prediction:", test_spell_correction(sent, cands))
    print("-" * 50)


Sentence : 나는 매일 아침 [MASK]를 마신다.
Prediction: 커피
--------------------------------------------------
Sentence : 그는 학교에 [MASK] 갔다.
Prediction: 갔다
--------------------------------------------------
Sentence : 오늘 날씨가 정말 [MASK].
Prediction: 좋다
--------------------------------------------------
