In [8]:
import torch
import torch.nn as nn
from models.transformerTagger import TransformerTagger
from models.transformerTagger import PositionalEncoder
from dataset.dataset import NERDataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch
import datetime as dt
from torch.utils.tensorboard import SummaryWriter
from torchmetrics import F1Score, Accuracy, Precision, Recall
from typing import Sequence, Optional, Union, Tuple
import utils

In [9]:
train_data = NERDataset(tokenizer="spacy", cased=False, mode='train')
test_data = NERDataset(tokenizer="spacy", cased=False, mode='test')
val_data = NERDataset(tokenizer="spacy", cased=False, mode='valid')

TODAY = dt.datetime.today().strftime("%Y-%m-%d")
nhead = 16
no_dense_layers = 5
d_model = 512
layer_norm_eps = 0.01
batch_size = 256


torch.autograd.set_detect_anomaly(True)
torch.cuda.amp.autocast(enabled=True)

f1 = F1Score(num_classes=train_data.ntargets, threshold=0.5)
accu = Accuracy(threshold=0.5, num_classes=train_data.ntargets)
precision = Precision(num_classes=train_data.ntargets)
recall = Recall(num_classes=train_data.ntargets, threshold=0.5)

feature_padding_value = train_data._tokenidx.get(train_data.pad_token)
tag_padding_value = train_data._targetidx.get(train_data.pad_token)

def collate_fn(data: Sequence[Tuple], 
    n_classes: int=train_data.ntargets,
    feature_padding_value=feature_padding_value,
    tag_padding_value=tag_padding_value,):
    """:return: features, target_prob, target, mask (save dims as target_prob)"""
    features, target_prob, targets, idx = zip(*data)
    features = pad_sequence(features, batch_first=True, padding_value=feature_padding_value)
    targets = pad_sequence(targets, batch_first=True, padding_value=tag_padding_value)
    max_len = targets.shape[1]
    batch_size=targets.shape[0]
    target_prob, target_mask = utils.pad_target_prob(target_prob, n_classes - 1, max_len, n_classes, batch_size)

    return (idx,
        features.long(), target_prob.to(torch.float64), 
        targets.long(), target_mask.bool()
        )

test_dataloader = DataLoader(train_data, 
    shuffle=True, 
    batch_size=batch_size, 
    collate_fn=collate_fn
        )


model = TransformerTagger(d_model=d_model, 
    n_tags=train_data.ntargets, 
    vocab_size=train_data.vocab_size + 1,
    layer_norm_eps=layer_norm_eps,
    activation=torch.tanh,
    nhead=nhead, 
    batch_first=True, 
    no_dense_layers=no_dense_layers,
    pad_token_idx=train_data._tokenidx.get(train_data.pad_token))



In [12]:
model.load_state_dict(torch.load("./checkpoints/2022-11-15_runo11.pt"))

<All keys matched successfully>

In [13]:
model.eval()
counter, train_precision, train_recall, train_f1, train_accu = 0., 0., 0., 0., 0.,
for i, data in enumerate(test_dataloader):
    idx, src, tag_prob, tags, mask = data
    pred = model(src, src, mask)
    for j, (prd, truth, mk) in enumerate(zip(pred, tags, mask)):
        train_precision += precision(prd[~mk], truth.masked_select(~mk).long()).item()
        train_recall += recall(prd[~mk], truth.masked_select(~mk).long()).item()
        train_f1 += f1(prd[~mk], truth.masked_select(~mk).long()).item()
        train_accu += accu(prd[~mk], truth.masked_select(~mk).long()).item()
        counter += 1