<a href="https://colab.research.google.com/github/sirandou/meta-learning-few-shot-classification/blob/master/resnet18_miniImagenet_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setting up Google Colab

In [0]:
!pip install -U -q PyDrive

In [0]:
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

In [0]:
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [0]:

from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
cd 'drive/My Drive/ML project'

/content/drive/My Drive/ML project


In [0]:
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch
from torch.utils.data import DataLoader,Dataset
import random
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data.sampler import Sampler

#Getting data folders

In [0]:
def mini_imagenet_folders():
    train_folder = 'datas/miniImagenet/train'
    test_folder = 'datas/miniImagenet/test'                
    val_folder = 'datas/miniImagenet/val'

    metatrain_folders = [os.path.join(train_folder, label) \
                for label in os.listdir(train_folder) \
                if os.path.isdir(os.path.join(train_folder, label)) \
                ]
    metatest_folders = [os.path.join(test_folder, label) \
                for label in os.listdir(test_folder) \
                if os.path.isdir(os.path.join(test_folder, label)) \
                ]
    metaval_folders = [os.path.join(val_folder, label) \
                for label in os.listdir(val_folder) \
                if os.path.isdir(os.path.join(val_folder, label)) \
                ]

    random.seed(1)
    random.shuffle(metatrain_folders)
    random.shuffle(metatest_folders)
    random.shuffle(metaval_folders)

    return metatrain_folders,metatest_folders, metaval_folders

In [0]:
# Step 1: init data folders
print("init data folders")
# init character folders for dataset construction
metatrain_folders,metatest_folders,metaval_folders = mini_imagenet_folders()

init data folders


#Hyperparameters

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
import numpy as np
import os
import math
import argparse
import scipy as sp
import scipy.stats


# Hyper Parameters
FEATURE_DIM = 32
RELATION_DIM = 8
CLASS_NUM = 5
SAMPLE_NUM_PER_CLASS = 5   #5
BATCH_NUM_PER_CLASS = 10   #10
EPISODE = 1000
TEST_EPISODE = 20
LEARNING_RATE = 0.001
GPU = 0
HIDDEN_UNIT = 10

In [0]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm') != -1:
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        n = m.weight.size(1)
        m.weight.data.normal_(0, 0.01)
        m.bias.data = torch.ones(m.bias.data.size())

In [0]:
#if os.path.exists(str("miniimagenet/models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
#    feature_encoder.load_state_dict(torch.load(str("miniimagenet/models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"),map_location='cuda:0'))
#    print("load feature encoder success")
#if os.path.exists(str("miniimagenet/models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
#    relation_network.load_state_dict(torch.load(str("miniimagenet/models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"),map_location='cuda:0'))
#    print("load relation network success")

load feature encoder success
load relation network success


In [0]:
def mean_confidence_interval(data, confidence=0.95):
    a = 1.0*np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * sp.stats.t._ppf((1+confidence)/2., n-1)
    return m,h

#Loading datasets

In [0]:
class Rotate(object):
    def __init__(self, angle):
        self.angle = angle
    def __call__(self, x, mode="reflect"):
        x = x.rotate(self.angle)
        return x

class MiniImagenetTask(object):

    def __init__(self, character_folders, num_classes, train_num,test_num):

        self.character_folders = character_folders
        self.num_classes = num_classes
        self.train_num = train_num
        self.test_num = test_num

        class_folders = random.sample(self.character_folders,self.num_classes)
        labels = np.array(range(len(class_folders)))
        labels = dict(zip(class_folders, labels))
        samples = dict()

        self.train_roots = []
        self.test_roots = []
        for c in class_folders:

            temp = [os.path.join(c, x) for x in os.listdir(c)]
            samples[c] = random.sample(temp, len(temp))
            random.shuffle(samples[c])

            self.train_roots += samples[c][:train_num]
            self.test_roots += samples[c][train_num:train_num+test_num]

        self.train_labels = [labels[self.get_class(x)] for x in self.train_roots]
        self.test_labels = [labels[self.get_class(x)] for x in self.test_roots]

    def get_class(self, sample):
        return os.path.join(*sample.split('/')[:-1])

class FewShotDataset(Dataset):

    def __init__(self, task, split='train', transform=None, target_transform=None):
        self.transform = transform # Torch operations on the input image
        self.target_transform = target_transform
        self.task = task
        self.split = split
        self.image_roots = self.task.train_roots if self.split == 'train' else self.task.test_roots
        self.labels = self.task.train_labels if self.split == 'train' else self.task.test_labels

    def __len__(self):
        return len(self.image_roots)

    def __getitem__(self, idx):
        raise NotImplementedError("This is an abstract class. Subclass this class for your particular dataset.")

class MiniImagenet(FewShotDataset):

    def __init__(self, *args, **kwargs):
        super(MiniImagenet, self).__init__(*args, **kwargs)

    def __getitem__(self, idx):
        image_root = self.image_roots[idx]
        image = Image.open(image_root)
        image = image.convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        label = self.labels[idx]
        if self.target_transform is not None:
            label = self.target_transform(label)
        return image, label


class ClassBalancedSampler(Sampler):
    ''' Samples 'num_inst' examples each from 'num_cl' pools
        of examples of size 'num_per_class' '''

    def __init__(self, num_cl, num_inst,shuffle=True):

        self.num_cl = num_cl
        self.num_inst = num_inst
        self.shuffle = shuffle

    def __iter__(self):
        # return a single list of indices, assuming that items will be grouped by class
        if self.shuffle:
            batches = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)] for j in range(self.num_cl)]
        else:
            batches = [[i+j*self.num_inst for i in range(self.num_inst)] for j in range(self.num_cl)]
        batches = [[batches[j][i] for j in range(self.num_cl)] for i in range(self.num_inst)]

        if self.shuffle:
            random.shuffle(batches)
            for sublist in batches:
                   random.shuffle(sublist)
        batches = [item for sublist in batches for item in sublist]
        return iter(batches)

    def __len__(self):
        return 1

class ClassBalancedSamplerOld(Sampler):
    ''' Samples 'num_inst' examples each from 'num_cl' pools
        of examples of size 'num_per_class' '''

    def __init__(self, num_per_class, num_cl, num_inst,shuffle=True):
        self.num_per_class = num_per_class
        self.num_cl = num_cl
        self.num_inst = num_inst
        self.shuffle = shuffle

    def __iter__(self):
        # return a single list of indices, assuming that items will be grouped by class
        if self.shuffle:
            batch = [[i+j*self.num_inst for i in torch.randperm(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)]
        else:
            batch = [[i+j*self.num_inst for i in range(self.num_inst)[:self.num_per_class]] for j in range(self.num_cl)]
        batch = [item for sublist in batch for item in sublist]

        if self.shuffle:
            random.shuffle(batch)
        return iter(batch)

    def __len__(self):
        return 1


def get_mini_imagenet_data_loader(task, num_per_class=1, split='train',shuffle = False):
    #normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    resizing =  transforms.Resize(224, interpolation=2)
    dataset = MiniImagenet(task,split=split,transform=transforms.Compose([resizing,transforms.ToTensor(),normalize]))
    if split == 'train':
        sampler = ClassBalancedSamplerOld(num_per_class,task.num_classes, task.train_num,shuffle=shuffle)

    else:
        sampler = ClassBalancedSampler(task.num_classes, task.test_num,shuffle=shuffle)

    loader = DataLoader(dataset, batch_size=num_per_class*task.num_classes, sampler=sampler)
    return loader

#Defining networks

In [0]:
import torchvision.models as models
class PreTrainedResNet(nn.Module):
  def __init__(self):
    super(PreTrainedResNet, self).__init__()
    
    self.resnet18 = models.resnet18(pretrained=True)

    #Set gradients to false
    
    for param in self.resnet18.parameters():
        param.requires_grad = False
    
    num_feats = self.resnet18.fc.in_features
    self.resnet18 = nn.Sequential(*(list(self.resnet18.children())[:-1]))

    
  def forward(self, x):
    x = self.resnet18 (x)

    x = x.view(x.shape[0],32,4,4)
    return x

In [0]:
class RelationNetwork(nn.Module):
    """docstring for RelationNetwork"""
    def __init__(self,input_size,hidden_size):
        super(RelationNetwork, self).__init__()
        self.layer1 = nn.Sequential(
                        nn.Conv2d(64,32,kernel_size=3,padding=1),  #padding
                        nn.BatchNorm2d(32, momentum=1, affine=True),
                        nn.ReLU(),
                        nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
                        nn.Conv2d(32,32,kernel_size=3,padding=1),  #padding
                        nn.BatchNorm2d(32, momentum=1, affine=True),
                        nn.ReLU(),
                        nn.MaxPool2d(2))
        self.fc1 = nn.Linear(input_size,hidden_size)  #3*3: size e ax bade 2 layers
        self.fc2 = nn.Linear(hidden_size,1)

    def forward(self,x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0),-1)
        out = F.relu(self.fc1(out))
        out = torch.sigmoid(self.fc2(out))
        #out = self.fc2(out)
        return out

In [0]:

feature_encoder = PreTrainedResNet()
#feature_encoder.apply(weights_init)
feature_encoder.cuda(0)

relation_network = RelationNetwork(FEATURE_DIM,RELATION_DIM)
relation_network.apply(weights_init)
relation_network.cuda(0)


Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/checkpoints/resnet18-5c106cde.pth


HBox(children=(IntProgress(value=0, max=46827520), HTML(value='')))




RelationNetwork(
  (layer1): Sequential(
    (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Linear(in_features=32, out_features=8, bias=True)
  (fc2): Linear(in_features=8, out_features=1, bias=True)
)

#Optimizers

In [0]:

feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=0.001)
feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=100000,gamma=0.5)
relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=0.001)
relation_network_scheduler = StepLR(relation_network_optim,step_size=100000,gamma=0.5)

#5way-5shot

In [0]:
# Step 3: build graph
print("Training...")
CLASS_NUM = 5
x = metatest_folders
last_accuracy = 0.0

for episode in range(1000):


    #print(episode)
    # init dataset
    # sample_dataloader is to obtain previous samples for compare
    # batch_dataloader is to batch samples for training
    task = MiniImagenetTask(x,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS)
    sample_dataloader = get_mini_imagenet_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="train",shuffle=False)
    batch_dataloader = get_mini_imagenet_data_loader(task,num_per_class=BATCH_NUM_PER_CLASS,split="test",shuffle=True)

    # sample datas
    samples,sample_labels = sample_dataloader.__iter__().next() 
    batches,batch_labels = batch_dataloader.__iter__().next()



    # calculate features
    sample_features = feature_encoder(Variable(samples).cuda(GPU)) 
    sample_features = sample_features.view(CLASS_NUM,SAMPLE_NUM_PER_CLASS,FEATURE_DIM,4,4)
    sample_features = torch.sum(sample_features,1).squeeze(1)
    batch_features = feature_encoder(Variable(batches).cuda(GPU)) 
    
    # calculate relations
    # each batch sample link to every samples to calculate relations
    # to form a 100x128 matrix for relation network
    sample_features_ext = sample_features.unsqueeze(0).repeat(BATCH_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
    batch_features_ext = batch_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1)
    batch_features_ext = torch.transpose(batch_features_ext,0,1)
    relation_pairs = torch.cat((sample_features_ext,batch_features_ext),2).view(-1,FEATURE_DIM*2,4,4)
    relations = relation_network(relation_pairs).view(-1,CLASS_NUM)

    mse = nn.MSELoss().cuda(GPU)
    crossen = nn.CrossEntropyLoss().cuda(0)
    one_hot_labels = Variable(torch.zeros(BATCH_NUM_PER_CLASS*CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1,1), 1).cuda(GPU))
    loss = mse(relations,one_hot_labels)
    #print("loss driven")

    # training

    feature_encoder.zero_grad()
    relation_network.zero_grad()
    loss.backward()

    torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(),0.5)
    torch.nn.utils.clip_grad_norm_(relation_network.parameters(),0.5)

    feature_encoder_optim.step()
    relation_network_optim.step()

    feature_encoder_scheduler.step(episode)
    relation_network_scheduler.step(episode)


    if (episode+1)%100 == 0:
      print("episode:",episode+1,"loss",loss.item())

    if (episode+1)%1000 == 0:

        # test
        print("Testing...")
        accuracies = []
        for i in range(100):
            total_rewards = 0
            #counter = 0       #remove
            task = MiniImagenetTask(metaval_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS)
            sample_dataloader = get_mini_imagenet_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="train",shuffle=False)
            num_per_class = 5   #3
            test_dataloader = get_mini_imagenet_data_loader(task,num_per_class=num_per_class,split="test",shuffle=False)

            sample_images,sample_labels = sample_dataloader.__iter__().next()
            for test_images,test_labels in test_dataloader:
                batch_size = test_labels.shape[0]
                # calculate features
                sample_features = feature_encoder(Variable(sample_images).cuda(GPU)) # 5x64
                sample_features = sample_features.view(CLASS_NUM,SAMPLE_NUM_PER_CLASS,FEATURE_DIM,4,4)
                sample_features = torch.sum(sample_features,1).squeeze(1)
                test_features = feature_encoder(Variable(test_images).cuda(GPU)) # 20x64

                # calculate relations
                # each batch sample link to every samples to calculate relations
                # to form a 100x128 matrix for relation network
                sample_features_ext = sample_features.unsqueeze(0).repeat(batch_size,1,1,1,1)

                test_features_ext = test_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1)
                test_features_ext = torch.transpose(test_features_ext,0,1)
                relation_pairs = torch.cat((sample_features_ext,test_features_ext),2).view(-1,FEATURE_DIM*2,4,4)
                relations = relation_network(relation_pairs).view(-1,CLASS_NUM)

                _,predict_labels = torch.max(relations.data,1)

                rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(batch_size)]

                total_rewards += np.sum(rewards)
                #counter +=batch_size   #remove


            accuracy = total_rewards/1.0/CLASS_NUM/BATCH_NUM_PER_CLASS
            #accuracy = total_rewards/1.0/counter
            accuracies.append(accuracy)


        test_accuracy,h = mean_confidence_interval(accuracies)

        print("test accuracy:",test_accuracy,"h:",h)

        if test_accuracy > last_accuracy:

            # save networks
            torch.save(feature_encoder.state_dict(),str("miniimagenet/models/new/miniimagenet_feature_encoder2_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))
            torch.save(relation_network.state_dict(),str("miniimagenet/models/new/miniimagenet_relation_network2_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))

            print("save networks for episode:",episode)

            last_accuracy = test_accuracy


Training...
episode: 100 loss 0.02784169837832451
episode: 200 loss 0.030300322920084
episode: 300 loss 0.046183548867702484
episode: 400 loss 0.03398463502526283
episode: 500 loss 0.04983297362923622
episode: 600 loss 0.018380021676421165
episode: 700 loss 0.01635332591831684
episode: 800 loss 0.04367595165967941
episode: 900 loss 0.026291634887456894
episode: 1000 loss 0.04830363392829895
Testing...


KeyboardInterrupt: ignored

In [0]:
        print("Testing...")
        accuracies = []
        for i in range(20):
            print(i)
            total_rewards = 0
            #counter = 0       #remove
            task = MiniImagenetTask(metaval_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS)
            sample_dataloader = get_mini_imagenet_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="train",shuffle=False)
            num_per_class = 5   #3
            test_dataloader = get_mini_imagenet_data_loader(task,num_per_class=num_per_class,split="test",shuffle=False)

            sample_images,sample_labels = sample_dataloader.__iter__().next()
            for test_images,test_labels in test_dataloader:
                batch_size = test_labels.shape[0]
                # calculate features
                sample_features = feature_encoder(Variable(sample_images).cuda(GPU)) # 5x64
                sample_features = sample_features.view(CLASS_NUM,SAMPLE_NUM_PER_CLASS,FEATURE_DIM,4,4)
                sample_features = torch.sum(sample_features,1).squeeze(1)
                test_features = feature_encoder(Variable(test_images).cuda(GPU)) # 20x64

                # calculate relations
                # each batch sample link to every samples to calculate relations
                # to form a 100x128 matrix for relation network
                sample_features_ext = sample_features.unsqueeze(0).repeat(batch_size,1,1,1,1)

                test_features_ext = test_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1)
                test_features_ext = torch.transpose(test_features_ext,0,1)
                relation_pairs = torch.cat((sample_features_ext,test_features_ext),2).view(-1,FEATURE_DIM*2,4,4)
                relations = relation_network(relation_pairs).view(-1,CLASS_NUM)

                _,predict_labels = torch.max(relations.data,1)

                rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(batch_size)]

                total_rewards += np.sum(rewards)
                #counter +=batch_size   #remove


            accuracy = total_rewards/1.0/CLASS_NUM/BATCH_NUM_PER_CLASS
            #accuracy = total_rewards/1.0/counter
            accuracies.append(accuracy)


        test_accuracy,h = mean_confidence_interval(accuracies)

        print("test accuracy:",test_accuracy,"h:",h)

        if test_accuracy > last_accuracy:

            # save networks
            #torch.save(feature_encoder.state_dict(),str("miniimagenet/models/new/miniimagenet_feature_encoder2_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))
            #torch.save(relation_network.state_dict(),str("miniimagenet/models/new/miniimagenet_relation_network2_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))

            print("save networks for episode:",episode)

            last_accuracy = test_accuracy

Testing...
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
test accuracy: 0.733 h: 0.03735175356152608
save networks for episode: 999


#5way-1shot

In [0]:

feature_encoder = PreTrainedResNet()
#feature_encoder.apply(weights_init)
feature_encoder.cuda(0)

relation_network = RelationNetwork(FEATURE_DIM,RELATION_DIM)
relation_network.apply(weights_init)
relation_network.cuda(0)


In [0]:
# Step 3: build graph
print("Training...")
# Hyper Parameters
CLASS_NUM = 5
SAMPLE_NUM_PER_CLASS = 1   #5
BATCH_NUM_PER_CLASS = 10   #10
EPISODE = 1000
TEST_EPISODE = 20


x = metatest_folders
last_accuracy = 0.0

for episode in range(1000):


    #print(episode)
    # init dataset
    # sample_dataloader is to obtain previous samples for compare
    # batch_dataloader is to batch samples for training
    task = MiniImagenetTask(x,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS)
    sample_dataloader = get_mini_imagenet_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="train",shuffle=False)
    batch_dataloader = get_mini_imagenet_data_loader(task,num_per_class=BATCH_NUM_PER_CLASS,split="test",shuffle=True)

    # sample datas
    samples,sample_labels = sample_dataloader.__iter__().next() 
    batches,batch_labels = batch_dataloader.__iter__().next()



    # calculate features
    sample_features = feature_encoder(Variable(samples).cuda(GPU)) 
    sample_features = sample_features.view(CLASS_NUM,SAMPLE_NUM_PER_CLASS,FEATURE_DIM,4,4)
    sample_features = torch.sum(sample_features,1).squeeze(1)
    batch_features = feature_encoder(Variable(batches).cuda(GPU)) 
    
    # calculate relations
    # each batch sample link to every samples to calculate relations
    # to form a 100x128 matrix for relation network
    sample_features_ext = sample_features.unsqueeze(0).repeat(BATCH_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
    batch_features_ext = batch_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1)
    batch_features_ext = torch.transpose(batch_features_ext,0,1)
    relation_pairs = torch.cat((sample_features_ext,batch_features_ext),2).view(-1,FEATURE_DIM*2,4,4)
    relations = relation_network(relation_pairs).view(-1,CLASS_NUM)

    mse = nn.MSELoss().cuda(GPU)
    crossen = nn.CrossEntropyLoss().cuda(0)
    one_hot_labels = Variable(torch.zeros(BATCH_NUM_PER_CLASS*CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1,1), 1).cuda(GPU))
    loss = mse(relations,one_hot_labels)
    #print("loss driven")

    # training

    feature_encoder.zero_grad()
    relation_network.zero_grad()
    loss.backward()

    torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(),0.5)
    torch.nn.utils.clip_grad_norm_(relation_network.parameters(),0.5)

    feature_encoder_optim.step()
    relation_network_optim.step()

    feature_encoder_scheduler.step(episode)
    relation_network_scheduler.step(episode)


    if (episode+1)%10 == 0:
      print("episode:",episode+1,"loss",loss.item())


Training...
episode: 10 loss 0.32025179266929626
episode: 20 loss 0.2504648268222809
episode: 30 loss 0.18935303390026093
episode: 40 loss 0.1612299680709839
episode: 50 loss 0.16015499830245972
episode: 60 loss 0.1596599817276001
episode: 70 loss 0.15839271247386932
episode: 80 loss 0.151975616812706
episode: 90 loss 0.15307706594467163
episode: 100 loss 0.1280609369277954
episode: 110 loss 0.12190590798854828
episode: 120 loss 0.13942624628543854
episode: 130 loss 0.10935577750205994
episode: 140 loss 0.1136331856250763
episode: 150 loss 0.12622761726379395
episode: 160 loss 0.10424743592739105
episode: 170 loss 0.1000719666481018
episode: 180 loss 0.10642261058092117
episode: 190 loss 0.11744077503681183
episode: 200 loss 0.10847144573926926
episode: 210 loss 0.09370134770870209
episode: 220 loss 0.13775424659252167
episode: 230 loss 0.13025178015232086
episode: 240 loss 0.08624213188886642
episode: 250 loss 0.09775585681200027
episode: 260 loss 0.09387928992509842
episode: 270 loss

In [0]:
        print("Testing...")
        accuracies = []
        for i in range(20):
            print(i)
            total_rewards = 0
            #counter = 0       #remove
            task = MiniImagenetTask(metaval_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS)
            sample_dataloader = get_mini_imagenet_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="train",shuffle=False)
            num_per_class = 5   #3
            test_dataloader = get_mini_imagenet_data_loader(task,num_per_class=num_per_class,split="test",shuffle=False)

            sample_images,sample_labels = sample_dataloader.__iter__().next()
            for test_images,test_labels in test_dataloader:
                batch_size = test_labels.shape[0]
                # calculate features
                sample_features = feature_encoder(Variable(sample_images).cuda(GPU)) # 5x64
                sample_features = sample_features.view(CLASS_NUM,SAMPLE_NUM_PER_CLASS,FEATURE_DIM,4,4)
                sample_features = torch.sum(sample_features,1).squeeze(1)
                test_features = feature_encoder(Variable(test_images).cuda(GPU)) # 20x64

                # calculate relations
                # each batch sample link to every samples to calculate relations
                # to form a 100x128 matrix for relation network
                sample_features_ext = sample_features.unsqueeze(0).repeat(batch_size,1,1,1,1)

                test_features_ext = test_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1)
                test_features_ext = torch.transpose(test_features_ext,0,1)
                relation_pairs = torch.cat((sample_features_ext,test_features_ext),2).view(-1,FEATURE_DIM*2,4,4)
                relations = relation_network(relation_pairs).view(-1,CLASS_NUM)

                _,predict_labels = torch.max(relations.data,1)

                rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(batch_size)]

                total_rewards += np.sum(rewards)
                #counter +=batch_size   #remove


            accuracy = total_rewards/1.0/CLASS_NUM/BATCH_NUM_PER_CLASS
            #accuracy = total_rewards/1.0/counter
            accuracies.append(accuracy)


        test_accuracy,h = mean_confidence_interval(accuracies)

        print("test accuracy:",test_accuracy,"h:",h)

        if test_accuracy > last_accuracy:

            # save networks
            torch.save(feature_encoder.state_dict(),str("miniimagenet/models/new/miniimagenet_feature_encoder2_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))
            torch.save(relation_network.state_dict(),str("miniimagenet/models/new/miniimagenet_relation_network2_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))

            print("save networks for episode:",episode)

            last_accuracy = test_accuracy

Testing...
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
test accuracy: 0.5810000000000001 h: 0.06705074207791162
save networks for episode: 999
