In [9]:
import argparse
import os

import numpy as np
import torch
import torch.nn.functional as F
from module.convnet import Convnet
from moco import MoCo
import utils.config as config
from utils.utils import (Averager, Timer, count_acc, set_env, MiniImageNet, TrainSampler, load_pretrain, ValSampler)
from utils.logger import Logger
from utils.backup import backup_code
from torch.utils.data import DataLoader

In [10]:
set_env(0)

parser = argparse.ArgumentParser()

parser.add_argument('--id', type=str, default='2-3')
parser.add_argument('--pretrain', type=bool, default=True)

# 'mean_tasker' ; 'img0_tasker' ; 'vit_tasker' : 'blstm_tasker'
parser.add_argument('--tasker', type=str, default='blstm_tasker')

parser.add_argument('--memory_size', type=int, default=128)
parser.add_argument('--param_momentum', type=float, default=0.99)
parser.add_argument('--temperature', type=float, default=0.07)

parser.add_argument('--max_epoch', type=int, default=200)
# 1 ; 5(to do)
parser.add_argument('--shot', type=int, default=1)
parser.add_argument('--query', type=int, default=15)
parser.add_argument('--train_way', type=int, default=5)
parser.add_argument('--test_way', type=int, default=5)

# 6 ; 8(to do)
parser.add_argument('--num_task', type=int, default=6)

args = parser.parse_args([])

valset = MiniImageNet('val', config.data_path, twice_sample=False)

val_sampler = ValSampler(valset.label, 2000, args.test_way, args.shot + args.query)
val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, num_workers=config.num_workers, pin_memory=True)

trainset = MiniImageNet('train', config.data_path, twice_sample=True)

train_sampler = TrainSampler(trainset.label, 100, args.train_way, args.shot + args.query, args.num_task)
train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler, num_workers=config.num_workers, pin_memory=True)


In [11]:
def val_trainset(args, loader, model):
    model.eval()
    val_acc = Averager()
    accs = []
    t_old = None
    tasks = []
    for i, task in enumerate(train_loader, 1):
        tasks.append(task)
        if i % args.num_task == 0:
            for t in tasks:
                images, labels_all, path, method = [_ for _ in t]
                label_category = np.array(labels_all)[:5]
                images[0] = images[0].cuda()
                images[1] = images[1].cuda()
                p = args.shot * args.train_way
                s1, q1 = images[0][:p], images[0][p:]
                s2 = images[1][:p]

                logits_meta, labels_meta, t = model(s1, q1, s1, label_category)
                if t_old is None:
                    t_old = t
                    t_c_old = label_category

                sim = len(np.intersect1d(label_category, t_c_old))/5.
                print(sim, t_c_old, label_category)
                metric = torch.dot(t.squeeze(0), t_old.squeeze(0)).unsqueeze(-1)
                print(metric)
                acc = count_acc(logits_meta, labels_meta)
                accs.append(acc)
            tasks = []

    mean = np.mean(accs) * 100
    std = (1.96 * np.std(accs, ddof=1) / np.sqrt(len(val_loader))) * 100

    print('acc={:.4f}±{:.4f}' \
        .format(mean, std))
    
    return mean

In [12]:
def val_valset(args, loader, model):
    model.eval()

    val_loss = Averager()
    val_acc = Averager()

    accs = []
    t_old = None
    for i, task in enumerate(loader, 1):
        images, labels_all, path, method = [_ for _ in task]
        label_category = np.array(labels_all)[:5]
        images = images.cuda()
        p = args.shot * args.test_way
        s1, q1 = images[:p], images[p:]

        logits_meta, labels_meta, t = model(s1, q1, s1, label_category)
        if t_old is None:
            t_old = t
            t_c_old = label_category

        sim = len(np.intersect1d(label_category, t_c_old))/5.
        print(sim, t_c_old, label_category)
        metric = torch.dot(t.squeeze(0), t_old.squeeze(0)).unsqueeze(-1)
        print(metric)
        loss_meta = F.cross_entropy(logits_meta, labels_meta)
        acc = count_acc(logits_meta, labels_meta)

        val_loss.add(loss_meta.item())
        val_acc.add(acc)
        accs.append(acc)

    loss_meta = val_loss.item()
    mean = np.mean(accs) * 100
    std = (1.96 * np.std(accs, ddof=1) / np.sqrt(len(val_loader))) * 100

    print('loss-images={:.4f} acc={:.4f}±{:.4f}' \
        .format(loss_meta, mean, std))

    return mean

In [15]:
model_path = '/root/code/TBFCL/output/2-5-test/ckpt/195.pth'
print(model_path)
model = torch.load(model_path)
acc = val_valset(args, val_loader, model)
print(acc)

/root/code/TBFCL/output/2-5-test/ckpt/195.pth
1.0 [12 10  9  6 11] [12 10  9  6 11]
tensor([1.0000], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.2 [12 10  9  6 11] [ 2 13  4 14 10]
tensor([0.5189], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.6 [12 10  9  6 11] [ 1 10 11  9 13]
tensor([0.5473], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.2 [12 10  9  6 11] [7 1 2 0 6]
tensor([0.6210], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.2 [12 10  9  6 11] [14  4 13 11  5]
tensor([0.5799], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.2 [12 10  9  6 11] [ 7 11  0  8  5]
tensor([0.6792], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.0 [12 10  9  6 11] [1 3 5 8 2]
tensor([0.4271], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.2 [12 10  9  6 11] [ 2 15  8  0 11]
tensor([0.5217], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.0 [12 10  9  6 11] [ 5  7  1  4 15]
tensor([0.4922], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.2 [12 10  9  6 11] [14  6  8  5 13]
tensor([

In [14]:
model_path = '/root/code/TBFCL/output/2-5-test/ckpt/195.pth'
print(model_path)
model = torch.load(model_path)
acc = val_trainset(args, train_loader, model)
print(acc)

/root/code/TBFCL/output/2-5-test/ckpt/195.pth
1.0 [ 6 37 38 55 63] [ 6 37 38 55 63]
tensor([1.0000], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.0 [ 6 37 38 55 63] [10 10 21 41 42]
tensor([0.4980], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.0 [ 6 37 38 55 63] [29 33 46 48 49]
tensor([0.4644], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.0 [ 6 37 38 55 63] [29 32 34 59 61]
tensor([0.5034], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.0 [ 6 37 38 55 63] [ 5  5 15 30 36]
tensor([0.4942], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.0 [ 6 37 38 55 63] [ 2 15 19 40 51]
tensor([0.4838], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.0 [ 6 37 38 55 63] [ 4  4  7 13 31]
tensor([0.5136], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.0 [ 6 37 38 55 63] [ 9 10 14 51 52]
tensor([0.4611], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.0 [ 6 37 38 55 63] [ 5 10 20 45 53]
tensor([0.4702], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
0.2 [ 6 37 38 55 63] [ 6 13 41 44 56