In [1]:
from __future__ import print_function

import argparse
import csv
import os
import collections
import pickle
import random

import numpy as np
import torch
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from data.datamgr import SimpleDataManager , SetDataManager
import configs

import wrn_mixup_model
import res_mixup_model

import torch.nn.functional as F


from io_utils import parse_args, get_resume_file ,get_assigned_file
from os import path

use_gpu = torch.cuda.is_available()

In [2]:
import easydict
np.random.seed(10)
args = easydict.EasyDict({
    'dataset': 'miniImagenet',
    'model': 'WideResNet28_10',
    'method': 'S2M2_R',
    'num_classes': 200,
    'split': 'novel'
})

In [3]:
split = args.split
loadfile = configs.data_dir[args.dataset] + split + '.json'

In [4]:
datamgr         = SimpleDataManager(84, batch_size = 256)
novel_loader      = datamgr.get_data_loader(loadfile, aug = False)

  "please use transforms.Resize instead.")


In [5]:
checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, args.dataset, args.model, args.method)
modelfile   = get_resume_file(checkpoint_dir)

In [6]:
model = wrn_mixup_model.wrn28_10(num_classes = args.num_classes)
model = model.cuda()
cudnn.benchmark = True
print(model)

WideResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): NetworkBlock(
    (layer): Sequential(
      (0): BasicBlock(
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(16, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (convShortcut): Conv2d(16, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (1): BasicBlock(
        (bn1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNor

In [7]:
checkpoint = torch.load(modelfile)

In [8]:
class WrappedModel(nn.Module):
    def __init__(self, module):
        super(WrappedModel, self).__init__()
        self.module = module 
    def forward(self, x):
        return self.module(x)
    
state = checkpoint['state']
state_keys = list(state.keys())

callwrap = False
if 'module' in state_keys[0]:
    callwrap = True
if callwrap:
    model = WrappedModel(model)
        
model_dict_load = model.state_dict()

model_dict_load.update(state)
model.load_state_dict(model_dict_load)
model.eval()

WideResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): NetworkBlock(
    (layer): Sequential(
      (0): BasicBlock(
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(16, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (convShortcut): Conv2d(16, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (1): BasicBlock(
        (bn1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNor

In [9]:
def save_pickle(file, data):
    with open(file, 'wb') as f:
        pickle.dump(data, f)

def load_pickle(file):
    with open(file, 'rb') as f:
        return pickle.load(f)

In [10]:
def extract_feature(novel_loader, model, tag='last'):
    save_dir = '{}/{}'.format(checkpoint_dir, tag)
    if os.path.isfile(save_dir + '/output.plk'):
        data = load_pickle(save_dir + '/output.plk')
        return data
    else:
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)

    with torch.no_grad():
        
        output_dict = collections.defaultdict(list)
        output_dict_aug = collections.defaultdict(list)

        for i, (inputs, labels) in enumerate(novel_loader):
            # compute output
            inputs = inputs.cuda()
            labels = labels.cuda()
            outputs,_ = model(inputs)
            outputs = outputs.cpu().data.numpy()
            
            inputs_ = []
            labels_ = []
            
            for j in range(inputs.size(0)):
                x90 = inputs[j].transpose(2,1).flip(1)
                x180 = x90.transpose(2,1).flip(1)
                x270 =  x180.transpose(2,1).flip(1)
                list_rotation = [x90, x180, x270]
                inputs_ += [list_rotation[random.randint(0, 2)]]
                labels_ += [labels[j] for _ in range(1)]
            
            inputs_aug = Variable(torch.stack(inputs_,0))
            labels_aug = Variable(torch.stack(labels_,0))
            outputs_aug,_ = model(inputs_aug)
            outputs_aug = outputs_aug.cpu().data.numpy()
            
            for out_aug, label_aug in zip(outputs_aug, labels_aug):
                output_dict_aug[label_aug.item()].append(out_aug)
            
            for out, label in zip(outputs, labels):
                output_dict[label.item()].append(out)
    
        save_pickle(save_dir + '/output.plk', output_dict)
        save_pickle(save_dir + '/output_aug.plk', output_dict_aug)
        return output_dict, output_dict_aug

save_dir = '{}/{}'.format(checkpoint_dir, 'last')
if os.path.isfile(save_dir + '/output.plk'):
    out_dict = load_pickle(save_dir + '/output.plk')
else:
    out_dict = extract_feature(novel_loader, model)

In [11]:
save_dir = '{}/{}'.format(checkpoint_dir, 'last')
output_dict, output_dict_aug = extract_feature(novel_loader, model)