In [5]:
import argparse
import random
import json

import numpy as np
import pandas as pd

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from data_loader import DataGenerator
from tqdm import tqdm

from hw1_copy import MANN

In [7]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


In [14]:
k = 5
meta_batch_size = 128
num_classes = 2

In [15]:
test_iterable = DataGenerator(
    data_json_path=f'data/test.json',
    k=k,
    repr="smiles_only",
)
test_loader = iter(
    torch.utils.data.DataLoader(
        test_iterable,
        batch_size=meta_batch_size,
        num_workers=4,
        pin_memory=True,
    )
)

model = torch.load("model/model.pt")
model.to(device)

MANN(
  (layer1): LSTM(769, 128, batch_first=True)
  (dropout): Dropout(p=0.35, inplace=False)
  (layer2): LSTM(128, 2, batch_first=True)
)

In [33]:
num_correct = 0
N = 1000
for _ in tqdm(range(N)):
    i, l = next(test_loader)
    i, l = i.to(device), l.to(device)
    pred = model(i, l).detach()

    pred = torch.reshape(
        pred,
        [
            -1,
            k + 1,
            num_classes,
            num_classes,
        ],
    )
    pred = torch.argmax(pred[:, -1, :, :], axis=2)
    l = torch.argmax(l[:, -1, :, :], axis=2)
    num_correct += pred.eq(l).sum().item()

100%|██████████| 1000/1000 [00:26<00:00, 37.94it/s]


In [36]:
print("Test accuracy", num_correct / (meta_batch_size * num_classes * N))

Test accuracy 0.64705078125
