In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

dataset_dir = '/Users/subhojit/datasets/sms_spam_collection'
df = pd.read_csv(dataset_dir + "/SMSSpamCollection", sep='\t', header=None, names=['label', 'text'])

df['label'] = df['label'].map({'ham': 0, 'spam': 1})
texts = df['text'].tolist()
labels = df['label'].tolist()

chars = sorted(set(''.join(texts)))
stoi = {ch: i+1 for i, ch in enumerate(chars)}
stoi['<PAD>'] = 0
vocab_size = len(stoi)
encode = lambda s: [stoi[c] for c in s if c in stoi]

xtrain, xval, ytrain, yval = train_test_split(texts, labels, test_size=0.2, random_state=42)

def pad_sequences(sequences, max_len=256):
    padded = torch.zeros(len(sequences), max_len, dtype=torch.long)
    lengths = torch.zeros(len(sequences), dtype=torch.long)
    for i, seq in enumerate(sequences):
        seq = seq[:max_len]
        padded[i, :len(seq)] = torch.tensor(seq, dtype=torch.long)
        lengths[i] = len(seq)
    return padded, lengths

def get_batch(batch_size, split='train'):
    x = xtrain if split == 'train' else xval
    y = ytrain if split == 'train' else yval
    idx = torch.randint(0, len(x), (batch_size,))
    xb = [encode(x[i]) for i in idx]
    yb = [y[i] for i in idx]
    xb, lengths = pad_sequences(xb)
    return xb, torch.tensor(yb, dtype=torch.long), lengths

class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, lengths):
        x = self.embedding(x)
        packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, (h_n, _) = self.lstm(packed)
        return self.fc(h_n[-1])
        # return self.fc(packed_out[:, -1, :])

device = 'mps'
model = LSTMClassifier(vocab_size=vocab_size, embed_dim=32, hidden_dim=64, output_dim=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

batch_size = 64

for step in range(1000):
    xb, yb, lengths = get_batch(batch_size)
    xb, yb, lengths = xb.to(device), yb.to(device), lengths.to(device)

    model.train()
    logits = model(xb, lengths)
    loss = F.cross_entropy(logits, yb)

    optimizer.zero_grad()
    loss.backward()

    for name, param in model.named_parameters():
        if param.grad is not None:
            if torch.isnan(param.grad).any():
                print(f"🔥 NaN in gradient of {name}")
                param.grad = torch.nan_to_num(param.grad, nan=0.0, posinf=1.0, neginf=-1.0)

    # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    optimizer.step()
    print(f"Step {step}, loss = {loss.item():.4f}")

    if step % 100 == 0:
        preds = torch.argmax(logits, dim=1)
        confidences = torch.softmax(logits, dim=1).max(dim=1).values
        # print(f"Step {step}, loss = {loss.item():.4f}")
        # for name, param in model.named_parameters():
        #     if param.grad is not None:
        #         print(f"{name}: grad norm = {param.grad.norm():.4f}")
        # print("Preds:", preds.tolist())
        # print("Targets:", yb.tolist())
        # print("Confidences:", confidences[:10])
        print()


Step 0, loss = 0.7318

Step 1, loss = 0.5770
Step 2, loss = 0.4510
Step 3, loss = 0.3131
Step 4, loss = 0.5025
Step 5, loss = 0.2335
Step 6, loss = 0.4168
Step 7, loss = 0.2461
Step 8, loss = 0.4098
Step 9, loss = 0.3293
Step 10, loss = 0.2666
Step 11, loss = 0.3041
Step 12, loss = 0.3543
Step 13, loss = 0.2961
Step 14, loss = 0.2607
Step 15, loss = 0.4103
Step 16, loss = 0.3710
Step 17, loss = 0.2879
Step 18, loss = 0.2469
Step 19, loss = 0.2143
Step 20, loss = 0.4174
Step 21, loss = 0.3537
Step 22, loss = 0.2369
Step 23, loss = 0.2018
Step 24, loss = 0.3297
Step 25, loss = 0.1878
Step 26, loss = 0.1536
Step 27, loss = 0.1385
Step 28, loss = 0.1006
Step 29, loss = 0.1741
Step 30, loss = 0.1868
Step 31, loss = 0.1876
Step 32, loss = 0.1029
Step 33, loss = 0.0267
Step 34, loss = 0.2399
Step 35, loss = 0.3145
Step 36, loss = 0.1659
Step 37, loss = 0.0729
Step 38, loss = 0.3802
Step 39, loss = 0.1271
Step 40, loss = 0.0284
Step 41, loss = 0.1076
Step 42, loss = 0.1930
Step 43, loss = 0.18

KeyboardInterrupt: 