In [1]:
import sys

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os

sys.path.append('./dre/')
from datasets import make_dataset
from mixuploss import MixupLoss

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm import tqdm
import numpy as np

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
batch_size = 16
steps = 5000
warmup_steps = 2000
lr = 5e-5
check_freq = 200
exp_freq = 1
root = '../data/terra_incognita/'
tst_env = 0

exp_weight = 1.0
sparse_weight = 0.5
mix_ce_weight = 0.1

In [4]:
def test_acc(model, device, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
    return 100. * correct / len(test_loader.dataset)

In [5]:
augment_transform = transforms.Compose([
            # transforms.Resize((224,224)),
            transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
            transforms.RandomGrayscale(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
datasets, iters, valloaders, tstloader, val_length = make_dataset(root, tst_env=tst_env, batch_size=batch_size, transform=augment_transform, seed=0)

In [10]:
net = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2)
net= nn.DataParallel(net)
net.module.fc = nn.Linear(net.module.fc.in_features, 10)
net.load_state_dict(torch.load('./ckpts/best_model.pth'))
net = net.to(device)

In [None]:
test_acc(net, device, tstloader)