In [1]:
import torch

from torchtext.data.utils import ngrams_iterator
from torchtext.data.utils import get_tokenizer

In [2]:
from ir_classification import datasets, models
from ir_classification import vocab as ir_vocab
from ir_classification import train

In [8]:
vocab = ir_vocab.create_vocab_from_tsv("../datasets/systematic_review/phase1.train.shuf.tsv", [2], ngrams=1)
data_columns = [2]
train_dataset = datasets.TSVRawTextMapDataset("../datasets/systematic_review/phase1.train.shuf.tsv", data_columns)
val_dataset = datasets.TSVRawTextMapDataset("../datasets/systematic_review/phase1.dev.shuf.tsv", data_columns)
label_transform = lambda x: x if x > 0 else 0
tokenizer = get_tokenizer("basic_english")
text_transform = lambda x: list(ngrams_iterator(tokenizer(x), 1))
dataloader = datasets.create_torch_dataloader(train_dataset, vocab,  label_transform, text_transform, weighted=True, batch_size=8)
val_dataloader = datasets.create_torch_dataloader(val_dataset, vocab,  label_transform, text_transform, weighted=False, batch_size=1)

In [4]:
num_classes = 2
vocab_size = len(vocab)
embedding_size = 64
model = models.EmbeddingBagLinearModel(vocab_size, embedding_size, num_classes)

In [5]:
from torch.utils.tensorboard import SummaryWriter
EPOCHS = 20 # epoch
LR = 5  # learning rate
BATCH_SIZE = 8 # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
writer = SummaryWriter()

for i in range(EPOCHS):
    start_iter = len(dataloader) * i
    train.train_epoch(i, model, optimizer, criterion, dataloader, start_iter=start_iter, writer=writer)
    validation_results = train.evaluate_epoch(i, model, criterion, val_dataloader, writer)
    scheduler.step(validation_results["precision"])


Epoch 0: 100%|██████████| 2708/2708 [00:14<00:00, 191.55 batch/s, accurracy=1, loss=0.124]
Validation: 0: 100%|██████████| 607/607 [00:02<00:00, 254.06 batch/s, accurracy=1, loss=0.0102]
Epoch 1: 100%|██████████| 2708/2708 [00:13<00:00, 201.70 batch/s, accurracy=1, loss=0.0309]
Validation: 1: 100%|██████████| 607/607 [00:02<00:00, 258.60 batch/s, accurracy=1, loss=0.0159]
Epoch 2: 100%|██████████| 2708/2708 [00:13<00:00, 194.60 batch/s, accurracy=1, loss=0.118]
Validation: 2: 100%|██████████| 607/607 [00:02<00:00, 274.62 batch/s, accurracy=1, loss=0.136]
Epoch 3: 100%|██████████| 2708/2708 [00:14<00:00, 192.02 batch/s, accurracy=1, loss=0.0201]
Validation: 3: 100%|██████████| 607/607 [00:02<00:00, 253.14 batch/s, accurracy=1, loss=0.00377]
Epoch 4: 100%|██████████| 2708/2708 [00:13<00:00, 204.40 batch/s, accurracy=1, loss=0.000683]
Validation: 4: 100%|██████████| 607/607 [00:02<00:00, 252.47 batch/s, accurracy=0.5, loss=2.86]
Epoch 5: 100%|██████████| 2708/2708 [00:13<00:00, 196.45 bat

In [7]:
with open("state_dict.pth", mode="wb") as f:
    torch.save(model.state_dict, f)

In [None]:
model.eval()
preds = []
labels = []
for batch in val_dataloader:
    label, text, offset = batch
    pred_label = train.predict(model, text)
    preds.append(pred_label)
    labels.append(label)

average_results = {key: aggregate_results[key] / tepoch.total for key in aggregate_results}

writer.add_scalars("validation", average_results, epoch_num)
return average_results