In [1]:
%load_ext autoreload
%autoreload 2
import os
import torch
import torch.utils as utils
import lightning as L
from torchmetrics import Precision, Recall, F1Score, AUROC
from text_clf_base.data import TextClfDataset
from text_clf_base.model import TextClf

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
DATASET = "sample"

## Preparation

Load data:

In [3]:
train_dataset = TextClfDataset(f"../data/{DATASET}/train_text.txt", f"../data/{DATASET}/train_label.txt")
train_loader = utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
_ = iter(train_loader)

Build model:

In [4]:
model = TextClf(train_dataset.tokenizer.get_vocab_size(), d_model=128, num_layers=4, nhead=8)

## Training

In [5]:
model.train()
trainer = L.Trainer(max_epochs=2)
trainer.fit(model=model, train_dataloaders=train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name   | Type               | Params
----------------------------------------------
0 | embed  | Embedding          | 2.7 M 
1 | model  | TransformerEncoder | 2.4 M 
2 | output | Linear             | 129   
----------------------------------------------
5.1 M     Trainable params
0         Non-trainable params
5.1 M     Total params
20.306    Total estimated model params size (MB)
/Users/yxonic/miniconda3/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


## Evaluation

Precision, recall and F1 score:

In [9]:
model.eval()
precision = Precision('binary')
recall = Recall('binary')
f1 = F1Score('binary')
auc = AUROC('binary')

test_dataset = TextClfDataset(f"../data/{DATASET}/test_text.txt", f"../data/{DATASET}/test_label.txt")
test_loader = utils.data.DataLoader(test_dataset, batch_size=64)
for x, y in test_loader:
    y_score = model(x)
    y_pred = torch.sigmoid(y_score) > 0.25
    precision.update(y_pred, y)
    recall.update(y_pred, y)
    f1.update(y_pred, y)
    auc.update(y_score, y)

print(f"Precision: {precision.compute():.4f}, Recall: {recall.compute():.4f}, F1 score: {f1.compute():.4f}, AUC: {auc.compute():.4f}")


Precision: 0.8333, Recall: 0.8333, F1 score: 0.8333, AUC: 0.9877
