In [1]:
import argparse
import os

import numpy as np
import torch
import torch.nn.functional as F
from module.convnet import Convnet
from moco import MoCo
import utils.config as config
from utils.utils import (Averager, Timer, count_acc, set_env, MiniImageNet, TrainSampler, load_pretrain, ValSampler)
from utils.logger import Logger
from utils.backup import backup_code
from torch.utils.data import DataLoader

In [2]:
set_env(1234 )

parser = argparse.ArgumentParser()

parser.add_argument('--id', type=str, default='2-3')
parser.add_argument('--pretrain', type=bool, default=True)

# 'mean_tasker' ; 'img0_tasker' ; 'vit_tasker' : 'blstm_tasker'
parser.add_argument('--tasker', type=str, default='blstm_tasker')

parser.add_argument('--memory_size', type=int, default=128)
parser.add_argument('--param_momentum', type=float, default=0.99)
parser.add_argument('--temperature', type=float, default=0.07)

parser.add_argument('--max_epoch', type=int, default=200)
# 1 ; 5(to do)
parser.add_argument('--shot', type=int, default=1)
parser.add_argument('--query', type=int, default=15)
parser.add_argument('--train_way', type=int, default=5)
parser.add_argument('--test_way', type=int, default=5)

# 6 ; 8(to do)
parser.add_argument('--num_task', type=int, default=6)

args = parser.parse_args([])

valset = MiniImageNet('val', config.data_path, twice_sample=False)

val_sampler = ValSampler(valset.label, 2000, args.test_way, args.shot + args.query)
val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, num_workers=config.num_workers, pin_memory=True)

trainset = MiniImageNet('train', config.data_path, twice_sample=True)

train_sampler = TrainSampler(trainset.label, 100, args.train_way, args.shot + args.query, args.num_task)
train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler, num_workers=config.num_workers, pin_memory=True)

In [3]:
def val_trainset(args, loader, model):
    model.eval()
    val_acc = Averager()
    accs = []
    t_old = None
    tasks = []
    for i, task in enumerate(train_loader, 1):
        tasks.append(task)
        if i % args.num_task == 0:
            for t in tasks:
                images, labels_all, path, method = [_ for _ in t]
                label_category = np.array(labels_all)[:5]
                images[0] = images[0].cuda()
                images[1] = images[1].cuda()
                p = args.shot * args.train_way
                s1, q1 = images[0][:p], images[0][p:]
                s2 = images[1][:p]

                logits_meta, labels_meta = model(s1, q1, s1, label_category)
                if t_old is None:
                    t_old = t
                    t_c_old = label_category

                sim = len(np.intersect1d(label_category, t_c_old))/5.
                metric = torch.dot(t.squeeze(0), t_old.squeeze(0)).unsqueeze(-1)
                acc = count_acc(logits_meta, labels_meta)
                accs.append(acc)
            tasks = []

    mean = np.mean(accs) * 100
    std = (1.96 * np.std(accs, ddof=1) / np.sqrt(len(val_loader))) * 100

    print('acc={:.4f}±{:.4f}' \
        .format(mean, std))
    
    return mean

In [4]:
model_path = '/root/code/TBFCL/checkpoints/187.pth'
print(model_path)
model = torch.load(model_path)
acc = val_valset(args, val_loader, model)
print(acc)

/root/code/TBFCL/checkpoints/187.pth


RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.