In [1]:
from sklearn.datasets import fetch_20newsgroups
from src.cleaning.cleaners import BasicTextCleaner
from src.tokenisation.tokenisers import BasicTokeniser
from src.word_encoding.word_encoders import BasicEncoder

from itertools import chain
import numpy as np

In [2]:
data = fetch_20newsgroups(subset='train')

X = data.data
y = data.target

In [3]:
# Optional: Get the category names
label_names = data.target_names

cleaner = BasicTextCleaner()
X_clean = [cleaner.clean_text(text) for text in X]

In [11]:
tokeniser = BasicTokeniser()
X_tokens = [tokeniser.tokenise(text) for text in X_clean]
print(list(map(lambda x: len(x), X_tokens))[:20])

[129, 136, 200, 123, 170, 200, 93, 200, 53, 200, 144, 200, 52, 200, 200, 180, 131, 200, 123, 200]


In [5]:
max_doc_len = max(map(lambda x: len(x), X_tokens))

flat = list(chain.from_iterable(X_tokens))
encoder = BasicEncoder(flat)

X_encodings = np.array([encoder.encode(doc, max_len=max_doc_len) for doc in X_tokens])
print(X_encodings[0])

[ 2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19  4  5  6 20 21 22
 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 15 13 27 38 39 40 41 42 28
 43 44 45 13 46 47 48  2 39 49  8 50  8 42 28 51 43 52 39 53 54 55 56 57
 58 39 59 60 28 61  2 39 62 22 39 63 15 14 64 27 65 30 31 66 67 43 68 69
 70 71 72 22 73  7 15 13 14 74 75 76 77 78 79 80 37 15 81 82 13 83 84 85
 86 87 88 47 79 89 90 91  3  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0]


In [6]:
from sklearn.model_selection import train_test_split
import torch

X_train, X_val, y_train, y_val = train_test_split(X_encodings, y, test_size=0.2, random_state=42)

# Convert to tensors
X_train = torch.tensor(X_train, dtype=torch.long)
X_val = torch.tensor(X_val, dtype=torch.long)
y_train = torch.tensor(y_train, dtype=torch.long)
y_val = torch.tensor(y_val, dtype=torch.long)

In [7]:
print(X_train.shape, y_train.shape)
print(X_val.shape, y_val.shape)

torch.Size([9051, 200]) torch.Size([9051])
torch.Size([2263, 200]) torch.Size([2263])


In [8]:
from torch.utils.data import TensorDataset, DataLoader

BATCH_SIZE = 64

train_data = TensorDataset(X_train, y_train)
val_data = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_data, shuffle=True, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE)

In [9]:
from src.models.SimpleTextClassifier import SimpleTextClassifier

vocab_size = encoder.get_vocab_size()

model = SimpleTextClassifier(vocab_size, 100, 20)

In [10]:
optimizer = model.configure_optimizers()

for epoch in range(100):
    model.train()
    for batch in train_loader:
        loss = model.training_step(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        results = [model.validation_step(batch) for batch in val_loader]
    val_loss = sum(r["val_loss"] for r in results) / len(results)
    val_acc = sum(r["val_acc"] for r in results) / len(results)
    print(f"Epoch {epoch+1}: val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")

Epoch 1: val_loss=2.9114, val_acc=0.2022
Epoch 2: val_loss=2.7669, val_acc=0.3467
Epoch 3: val_loss=2.5301, val_acc=0.4648
Epoch 4: val_loss=2.2392, val_acc=0.5625
Epoch 5: val_loss=1.9432, val_acc=0.6503
Epoch 6: val_loss=1.6736, val_acc=0.7111
Epoch 7: val_loss=1.4440, val_acc=0.7536
Epoch 8: val_loss=1.2586, val_acc=0.7782
Epoch 9: val_loss=1.1092, val_acc=0.7999
Epoch 10: val_loss=0.9901, val_acc=0.8155
Epoch 11: val_loss=0.8936, val_acc=0.8285
Epoch 12: val_loss=0.8157, val_acc=0.8398
Epoch 13: val_loss=0.7520, val_acc=0.8480
Epoch 14: val_loss=0.6996, val_acc=0.8563
Epoch 15: val_loss=0.6554, val_acc=0.8598
Epoch 16: val_loss=0.6184, val_acc=0.8641
Epoch 17: val_loss=0.5868, val_acc=0.8683
Epoch 18: val_loss=0.5593, val_acc=0.8731
Epoch 19: val_loss=0.5359, val_acc=0.8770
Epoch 20: val_loss=0.5152, val_acc=0.8827
Epoch 21: val_loss=0.4976, val_acc=0.8827
Epoch 22: val_loss=0.4814, val_acc=0.8848
Epoch 23: val_loss=0.4680, val_acc=0.8860
Epoch 24: val_loss=0.4555, val_acc=0.8891
E