In [1]:
import argparse
import random
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import RocCurveDisplay

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from data_loader import DataGenerator
from tqdm import tqdm

from hw1_copy import MANN

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


In [3]:
k = 5
meta_batch_size = 128
num_classes = 2

In [4]:
test_iterable = DataGenerator(
    data_json_path=f'data/test.json',
    k=k,
    repr="smiles_only",
)
test_loader = iter(
    torch.utils.data.DataLoader(
        test_iterable,
        batch_size=meta_batch_size,
        num_workers=4,
        pin_memory=True,
    )
)

model = torch.load("model/bert_model.pt")
model.to(device)

SmilesBertModel(
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=769, out_features=769, bias=True)
            (key): Linear(in_features=769, out_features=769, bias=True)
            (value): Linear(in_features=769, out_features=769, bias=True)
            (dropout): Dropout(p=0.3, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=769, out_features=769, bias=True)
            (LayerNorm): LayerNorm((769,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.3, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=769, out_features=64, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=64, out_features=769, bias=True)
          (Lay

In [6]:
### LSTM
num_correct = 0
N = 1000
for _ in tqdm(range(N)):
    i, l = next(test_loader)
    i, l = i.to(device), l.to(device)
    pred = model(i, l).detach()
    # pred ~ (B, K+1, N, N)
    # model sees K support examples for each of N classes and predicts on 1 query example for each of N classes (all shuffled ofc) -> (_, K+1, N, _)
    # here, LSTM outputs logits for each of N classes for each of the (K+1) * N support & query examples -> (_, K+1, N, N)
    # batch to leverage parallelism -> (B, K+1, N, N)

    pred = torch.reshape(
        pred,
        [
            -1,
            k + 1,
            num_classes,
            num_classes,
        ],
    )  # no change, already in correct shape
    pred_class = torch.argmax(pred[:, -1, :, :], axis=2)
    # pred[:, -1, :, :] selects logits for query example (not K support examples) over entire batch, shape ~ (B, N, N)
    # torch.argmax(..., axis=2) selects predicted class (with largest logit) for the N query examples (1 for each class), shape ~ (B, N)
    true_class = torch.argmax(l[:, -1, :, :], axis=2)  # selects ground-truth class for the N query examples
    num_correct += pred_class.eq(true_class).sum().item()  # sums the number of matches between predicted and ground-truth class

  0%|          | 0/1000 [00:00<?, ?it/s]


TypeError: forward() missing 1 required positional argument: 'attention_mask'

In [18]:
### BERT
num_correct = 0
N = 1000
for _ in tqdm(range(N)):
    i, l = next(test_loader)
    i, l = i.to(device), l.to(device)
    attention_mask = torch.ones((meta_batch_size, (k+1)*num_classes))
    attention_mask = attention_mask.to(device)
    _, pred = model(i.float(), l.float(), attention_mask)
    # pred ~ (B, K+1, N, N)
    # model sees K support examples for each of N classes and predicts on 1 query example for each of N classes (all shuffled ofc) -> (_, K+1, N, _)
    # here, LSTM outputs logits for each of N classes for each of the (K+1) * N support & query examples -> (_, K+1, N, N)
    # batch to leverage parallelism -> (B, K+1, N, N)

    pred_class = torch.argmax(pred, axis=1)
    # pred[:, -1, :, :] selects logits for query example (not K support examples) over entire batch, shape ~ (B, N, N)
    # torch.argmax(..., axis=2) selects predicted class (with largest logit) for the N query examples (1 for each class), shape ~ (B, N)
    true_class = torch.argmax(l[:, -1, :, :], axis=2)  # selects ground-truth class for the N query examples
    num_correct += pred_class.eq(true_class).sum().item()  # sums the number of matches between predicted and ground-truth class

100%|██████████| 1000/1000 [00:20<00:00, 49.68it/s]


In [19]:
print("Test accuracy", num_correct / (meta_batch_size * num_classes * N))

Test accuracy 0.653078125


In [None]:
# Plot ROC curve
fpr, tpr, thresholds = torchmetrics.functional.classification.binary_roc(preds=pred[:, -1, :, :], target=l[:, -1, :, :])
auc = torchmetrics.functional.classification.binary_auroc(preds=pred[:, -1, :, :], target=l[:, -1, :, :]).item()
roc_curve = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=auc)
roc_curve.plot()
plt.show()