In [None]:
%pip install torch torchvision torchaudio transformers datasets nltk numpy pandas


In [28]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import nltk
from nltk.tokenize import word_tokenize
from transformers import BertTokenizer, BertForMaskedLM
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split


In [35]:
# Load dataset
with open("missp.dat", "r") as file:
    raw_data = file.readlines()

# Process data
corrections = {}
current_correct_word = None

for line in raw_data:
    word = line.strip()
    if word.startswith("$"):  # Correct words start with $
        current_correct_word = word[1:]
        corrections[current_correct_word] = []
    else:
        if current_correct_word:
            corrections[current_correct_word].append(word)

# Convert into a DataFrame
df = pd.DataFrame([(k, v) for k, values in corrections.items() for v in values],
                  columns=["Correct_Word", "Misspelled_Word"])

# Split into train and test sets
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

print("Training samples:", len(train_df))
print("Testing samples:", len(test_df))
df.head()


Training samples: 28906
Testing samples: 7227


Unnamed: 0,Correct_Word,Misspelled_Word
0,Albert,Ab
1,America,Ameraca
2,America,Amercia
3,American,Ameracan
4,April,Apirl


In [36]:
class SpellingDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.vocab = set("".join(data["Misspelled_Word"]) + "".join(data["Correct_Word"]))
        self.char2idx = {ch: i+1 for i, ch in enumerate(self.vocab)}  # +1 to reserve 0 for padding
        self.idx2char = {i: ch for ch, i in self.char2idx.items()}
        self.max_len = max(data["Misspelled_Word"].apply(len))  # Max word length
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        misspelled = [self.char2idx[ch] for ch in row["Misspelled_Word"]]
        correct = [self.char2idx[ch] for ch in row["Correct_Word"]]
        
        # Pad sequences
        misspelled += [0] * (self.max_len - len(misspelled))
        correct += [0] * (self.max_len - len(correct))
        
        return torch.tensor(misspelled), torch.tensor(correct)

# Load dataset
train_dataset = SpellingDataset(train_df)
test_dataset = SpellingDataset(test_df)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [37]:
class Seq2SeqModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=32, hidden_dim=64):
        super(Seq2SeqModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size + 1, embedding_dim)  # +1 for padding
        self.encoder = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.decoder = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size + 1)

    def forward(self, x):
        embedded = self.embedding(x)
        _, (hidden, _) = self.encoder(embedded)
        out, _ = self.decoder(embedded, (hidden, torch.zeros_like(hidden)))
        return self.fc(out)

# Initialize model
vocab_size = len(train_dataset.vocab)
model = Seq2SeqModel(vocab_size)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [38]:
def collate_fn(batch):
    misspelled, correct = zip(*batch)
    misspelled = pad_sequence(misspelled, batch_first=True, padding_value=0)
    correct = pad_sequence(correct, batch_first=True, padding_value=0)
    return misspelled, correct

train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn)

In [39]:
for misspelled, correct in train_loader:
    optimizer.zero_grad()
    
    outputs = model(misspelled)  # Forward pass
    
    # Debugging shape mismatch issue
    print(f"Output shape: {outputs.shape}")  # Expected: (batch_size, seq_len, vocab_size)
    print(f"Target shape: {correct.shape}")  # Expected: (batch_size, seq_len)

    batch_size, seq_len, vocab_size = outputs.shape  # Unpack output dimensions

    # Ensure the correct tensor is the same shape as the output tensor
    correct = correct[:, :seq_len]

    # Reshape tensors correctly for loss computation
    loss = criterion(
        outputs.reshape(batch_size * seq_len, vocab_size),  # Flatten output for cross-entropy
        correct.reshape(-1)  # Flatten target labels
    )

    loss.backward()
    optimizer.step()
    
    print(f"Loss: {loss.item()}")  # Monitor loss


Output shape: torch.Size([8, 18, 55])
Target shape: torch.Size([8, 18])
Loss: 3.985180377960205
Output shape: torch.Size([8, 18, 55])
Target shape: torch.Size([8, 18])
Loss: 3.954068899154663
Output shape: torch.Size([8, 18, 55])
Target shape: torch.Size([8, 18])
Loss: 3.8797013759613037
Output shape: torch.Size([8, 18, 55])
Target shape: torch.Size([8, 18])
Loss: 3.8206214904785156
Output shape: torch.Size([8, 18, 55])
Target shape: torch.Size([8, 18])
Loss: 3.761223793029785
Output shape: torch.Size([8, 18, 55])
Target shape: torch.Size([8, 18])
Loss: 3.758967161178589
Output shape: torch.Size([8, 18, 55])
Target shape: torch.Size([8, 18])
Loss: 3.694514036178589
Output shape: torch.Size([8, 18, 55])
Target shape: torch.Size([8, 18])
Loss: 3.6091701984405518
Output shape: torch.Size([8, 18, 55])
Target shape: torch.Size([8, 18])
Loss: 3.536564826965332
Output shape: torch.Size([8, 18, 55])
Target shape: torch.Size([8, 18])
Loss: 3.5269298553466797
Output shape: torch.Size([8, 18, 55]

In [40]:
model.eval()
with torch.no_grad():
    for misspelled, correct in test_loader:
        outputs = model(misspelled)
        predictions = torch.argmax(outputs, dim=-1)
        print("Misspelled:", "".join([train_dataset.idx2char[idx.item()] for idx in misspelled[0] if idx.item() != 0]))
        print("Predicted:", "".join([train_dataset.idx2char[idx.item()] for idx in predictions[0] if idx.item() != 0]))
        print("Correct:", "".join([train_dataset.idx2char[idx.item()] for idx in correct[0] if idx.item() != 0]))
        break


Misspelled: rHVuEpEglO
Predicted: rnnuepgl
Correct: rHVupEglO
