In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import io
import re
import zipfile
from collections import Counter

import requests
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split

from lstm import LSTM

In [3]:
DATA_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip"
DATA_DIR = "./data"
MAX_VOCAB_SIZE = 20000
MIN_FREQ = 2
BATCH_SIZE = 32

In [4]:
def download_sms_dataset() -> list[tuple[int, str]]:
    os.makedirs(DATA_DIR, exist_ok=True)
    zip_path = os.path.join(DATA_DIR, "smsspam.zip")
    txt_path = os.path.join(DATA_DIR, "SMSSpamCollection")

    if not os.path.exists(txt_path):
        print("Downloading dataset...")
        r = requests.get(DATA_URL, timeout=30)
        r.raise_for_status()
        with open(zip_path, "wb") as f:
            f.write(r.content)
        with zipfile.ZipFile(zip_path) as zf:
            zf.extractall(DATA_DIR)
    else:
        print("Dataset already present.")

    data = []
    with io.open(txt_path, encoding="utf-8") as f:
        for line in f:
            label, text = line.strip().split("\t", 1)
            y = 1 if label == "spam" else 0
            data.append((y, text))
    return data

# ----------------------------
# Tokenization & vocab
# ----------------------------
TOKEN_RE = re.compile(r"\b\w+\b", flags=re.UNICODE)


def tokenize(s: str) -> list[str]:
    return TOKEN_RE.findall(s.lower())


PAD, UNK = "<pad>", "<unk>"


def build_vocab(texts: list[list[str]]):
    counter = Counter()
    for toks in texts:
        counter.update(toks)
    # keep by freq, cap by size
    vocab = [PAD, UNK]
    for w, c in counter.most_common():
        if c < MIN_FREQ: break
        vocab.append(w)
        if len(vocab) >= MAX_VOCAB_SIZE: break
    stoi = {w:i for i,w in enumerate(vocab)}
    itos = vocab
    return stoi, itos


def encode(tokens: list[str], stoi: dict) -> list[int]:
    unk = stoi.get(UNK)
    return [stoi.get(t, unk) for t in tokens]


# ----------------------------
# Dataset with dynamic padding
# ----------------------------
class SMSDataset(Dataset):
    def __init__(self, samples, stoi):
        self.labels = [y for y, _ in samples]
        self.texts = [encode(tokenize(x), stoi) for _, x in samples]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.tensor(self.texts[idx], dtype=torch.long), torch.tensor(self.labels[idx], dtype=torch.long)


def collate_batch(batch, pad_idx):
    # batch: list of (tensor_ids, label)
    seqs, labels = zip(*batch)
    lengths = torch.tensor([len(s) for s in seqs], dtype=torch.long)
    max_len = lengths.max().item()
    padded = torch.full((len(seqs), max_len), pad_idx, dtype=torch.long)
    for i, s in enumerate(seqs):
        padded[i, :len(s)] = s
    # sort by length desc for pack_padded_sequence
    lengths, sort_idx = lengths.sort(descending=True)
    padded = padded[sort_idx]
    labels = torch.stack(labels)[sort_idx]
    return padded, lengths, labels


class LSTMTextClassifier(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int, num_classes: int, pad_idx: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, input_ids: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        embedded = self.embedding(input_ids)
        outputs, _ = self.lstm(embedded)
        batch_range = torch.arange(outputs.size(0), device=outputs.device)
        last_indices = lengths - 1
        last_hidden = outputs[batch_range, last_indices]
        return self.fc(last_hidden)

    def to_qat(self, bits: int, qat_linear_class, **qat_kwargs) -> "LSTMTextClassifier":
        new_model = LSTMTextClassifier(
            vocab_size=self.embedding.num_embeddings,
            embed_dim=self.embedding.embedding_dim,
            hidden_dim=self.fc.in_features,
            num_classes=self.fc.out_features,
            pad_idx=self.embedding.padding_idx,
        )
        new_model.lstm = self.lstm.to_qat(bits, qat_linear_class, **qat_kwargs)
        return new_model.to(self.fc.weight.device)

    def quantize(self, bits: int, linear_int_class) -> "LSTMTextClassifier":
        new_model = LSTMTextClassifier(
            vocab_size=self.embedding.num_embeddings,
            embed_dim=self.embedding.embedding_dim,
            hidden_dim=self.fc.in_features,
            num_classes=self.fc.out_features,
            pad_idx=self.embedding.padding_idx,
        )
        new_model.lstm = self.lstm.quantize(bits, linear_int_class)
        return new_model.to(self.fc.weight.device)


def accuracy(logits: torch.Tensor, targets: torch.Tensor) -> float:
    predictions = logits.argmax(dim=1)
    correct = (predictions == targets).sum().item()
    return correct / targets.size(0)


def train_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
) -> float:
    model.train()
    epoch_loss = 0.0
    for inputs, lengths, labels in dataloader:
        inputs, lengths, labels = inputs.to(device), lengths.to(device), labels.to(device)

        optimizer.zero_grad()
        logits = model(inputs, lengths)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * inputs.size(0)

    return epoch_loss / len(dataloader.dataset)


@torch.no_grad()
def evaluate(model: nn.Module, dataloader: DataLoader, criterion: nn.Module, device: torch.device) -> tuple[float, float]:
    model.eval()
    epoch_loss = 0.0
    epoch_acc = 0.0
    for inputs, lengths, labels in dataloader:
        inputs, lengths, labels = inputs.to(device), lengths.to(device), labels.to(device)
        logits = model(inputs, lengths)
        loss = criterion(logits, labels)
        epoch_loss += loss.item() * inputs.size(0)
        epoch_acc += accuracy(logits, labels) * inputs.size(0)

    dataset_size = len(dataloader.dataset)
    return epoch_loss / dataset_size, epoch_acc / dataset_size

In [5]:
data = download_sms_dataset()
# Tokenize once to build vocab on train only
train_p = 0.7
train_set, test_set = random_split(data, lengths=(train_p, 1 - train_p))

all_tokens = [tokenize(txt) for _, txt in data]
stoi, itos = build_vocab(all_tokens)
pad_idx = stoi[PAD]
vocab_size = len(itos)
print(f"Vocab size: {vocab_size}")

# Datasets
ds_train = SMSDataset(train_set, stoi)
ds_test  = SMSDataset(test_set, stoi)

# Loaders
collate = lambda b: collate_batch(b, pad_idx)
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)
dl_test  = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = LSTMTextClassifier(
    vocab_size=vocab_size,
    embed_dim=64,
    hidden_dim=64,
    num_classes=2,
    pad_idx=pad_idx,
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

Dataset already present.
Vocab size: 4349


In [6]:
train_loss = train_epoch(model, dl_train, criterion, optimizer, device)
val_loss, val_acc = evaluate(model, dl_test, criterion, device)
print(f"Epoch {1:02d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc*100:.1f}%")

Epoch 01 | train_loss=0.1447 | val_loss=0.1172 | val_acc=97.1%


In [8]:
from lsq.utils import QALinear, LinearInt

In [9]:
model_qat = model.to_qat(bits=8, qat_linear_class=QALinear)
optimizer_qa = torch.optim.Adam(model_qat.parameters(), lr=1e-2)

In [10]:
for i in range(15):
    train_loss = train_epoch(model_qat, dl_train, criterion, optimizer_qa, device)
    val_loss, val_acc = evaluate(model_qat, dl_test, criterion, device)
    print(f"Epoch {i:02d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc*100:.1f}%")

    model_qantized = model_qat.quantize(bits=8, linear_int_class=LinearInt)
    model_qantized.to('cpu')
    val_loss, val_acc = evaluate(model_qantized, dl_test, criterion, 'cpu')
    print(f"Epoch {i:02d} | val_loss={val_loss:.4f} | val_acc={val_acc*100:.1f}%")

Epoch 00 | train_loss=0.1684 | val_loss=0.0551 | val_acc=98.4%
Epoch 00 | val_loss=0.6582 | val_acc=64.4%
Epoch 01 | train_loss=0.0361 | val_loss=0.0657 | val_acc=98.0%
Epoch 01 | val_loss=0.7103 | val_acc=45.8%
Epoch 02 | train_loss=0.0175 | val_loss=0.0700 | val_acc=98.3%
Epoch 02 | val_loss=0.7854 | val_acc=29.5%
Epoch 03 | train_loss=0.0025 | val_loss=0.0619 | val_acc=98.8%
Epoch 03 | val_loss=0.9196 | val_acc=16.6%
Epoch 04 | train_loss=0.0006 | val_loss=0.0703 | val_acc=98.7%
Epoch 04 | val_loss=0.6319 | val_acc=73.3%
Epoch 05 | train_loss=0.0002 | val_loss=0.0758 | val_acc=98.6%
Epoch 05 | val_loss=0.6595 | val_acc=64.2%
Epoch 06 | train_loss=0.0002 | val_loss=0.0804 | val_acc=98.6%
Epoch 06 | val_loss=0.7313 | val_acc=42.1%
Epoch 07 | train_loss=0.0001 | val_loss=0.0818 | val_acc=98.6%
Epoch 07 | val_loss=0.6017 | val_acc=79.8%
Epoch 08 | train_loss=0.0001 | val_loss=0.0868 | val_acc=98.5%
Epoch 08 | val_loss=0.6380 | val_acc=74.0%
Epoch 09 | train_loss=0.0001 | val_loss=0.0899