In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader,TensorDataset
import numpy as np
import scipy.io as sio
import math
import argparse
import random
import os
from sklearn.metrics import accuracy_score
torch.cuda.set_device(0)

rd=1234 #torch.randint(1,10000,[1])
torch.manual_seed(rd)
torch.cuda.manual_seed_all(rd)

import warnings
warnings.filterwarnings("ignore")

In [None]:
# step 1: init dataset
print("init dataset")

dataroot = '/data/dataset/zsl/data'
dataset = 'CUB1_data'
image_embedding = 'res101' 
class_embedding = 'att_splits'

matcontent = sio.loadmat(dataroot + "/" + dataset + "/" + image_embedding + ".mat")
feature = matcontent['features'].T
label = matcontent['labels'].astype(int).squeeze() - 1
matcontent = sio.loadmat(dataroot + "/" + dataset + "/" + class_embedding + ".mat")
# numpy array index starts from 0, matlab starts from 1
trainval_loc = matcontent['trainval_loc'].squeeze() - 1
test_seen_loc = matcontent['test_seen_loc'].squeeze() - 1
test_unseen_loc = matcontent['test_unseen_loc'].squeeze() - 1

attribute = matcontent['att'].T 

x = feature[trainval_loc] # train_features
train_label = label[trainval_loc].astype(int)  # train_label
att = attribute[train_label] # train attributes

x_test = feature[test_unseen_loc]  # test_feature
test_label = label[test_unseen_loc].astype(int) # test_label
x_test_seen = feature[test_seen_loc]  #test_seen_feature
test_label_seen = label[test_seen_loc].astype(int) # test_seen_label
test_id = np.unique(test_label)   # test_id
att_pro = attribute[test_id]      # test_attribute

import numpy as np
path=dataroot
file=path+'/cub_attributes_reed.npy'
attribute=15*np.load(file)


# train set
#train_features=torch.from_numpy(x)
train_features=x
print('train_features.shape: ' + str(train_features.shape))

train_label=np.array(torch.from_numpy(train_label).unsqueeze(1))
#train_label=torch.from_numpy(train_label).unsqueeze(1)
print('train_label.shape:  '+str(train_label.shape))

# attributes
all_attributes=np.array(attribute)
print('all_attributes.shape:  '+str(all_attributes.shape))

attributes = torch.from_numpy(attribute)
# test set

# test_features=torch.from_numpy(x_test)
test_features=x_test
print('test_features.shape:  '+ str(test_features.shape))

test_label=np.array(torch.from_numpy(test_label).unsqueeze(1))
print('test_label.shape:  ' +str(test_label.shape))

testclasses_id = np.array(test_id)
print('testclasses_id.shape:  ' +str(testclasses_id.shape))

test_attributes = torch.from_numpy(att_pro).float()
print('test_attributes.shape:  ' +str(test_attributes.shape))


test_seen_features = torch.from_numpy(x_test_seen)
print('test_seen_features.shape:  ' +str(test_seen_features.shape))

test_seen_label = torch.from_numpy(test_label_seen)

train_data = [train_features,train_label]
test_data = [test_features,test_label]

unq_train_labels=np.unique(train_label)
unq_test_labels=np.unique(test_label)
#train_data = TensorDataset(train_features,train_label)

In [None]:
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.normal_(m.weight, mean=0, std=0.002)
        m.bias.data.fill_(0.002)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.modelG = nn.Sequential(
            nn.Linear(args.attri_dim+args.noise_dim, 2048),
            nn.Dropout(0.5),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(2048, 2048),
            nn.BatchNorm1d(2048, 0.8),
            nn.Dropout(0.5),
            nn.Linear(2048, 2048),
            nn.BatchNorm1d(2048, 0.8),
            nn.Dropout(0.5),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(2048, args.input_shape)
        )
        

    def forward(self, noise, attri,weightsM=None):
        # Concatenate label embedding and image to produce input
        gen_input = torch.cat((attri,noise), -1)
        if weightsM is None:
            img = self.modelG(gen_input)
        else:
            i=0
            weights=weightsM[0]
            for m in self.modelG.modules():
                if isinstance(m,nn.Linear):
                    m.weight.data=weights[i]
                    m.bias.data=weights[i+1]
                    i=i+2
                if isinstance(m,nn.BatchNorm1d):
                    m.weight.data=weights[i]
                    m.bias.data=weights[i+1]
                    i=i+2
            img = self.modelG(gen_input)
        
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.modelD = nn.Sequential(
            nn.Linear(args.input_shape+args.attri_dim, 2048),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(2048, 1024),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1024),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1)
        )
#         self.modelD.apply(init_weights)

    def forward(self, img, attri,weightsM=None):
        d_in = torch.cat((img.view(img.size(0), -1), attri), -1)
        if weightsM is None:
            real_fake = self.modelD(d_in)
        else:
            i=0
            weights=weightsM[1]
            for m in self.modelD.modules():
                if isinstance(m,nn.Linear):
                    m.weight.data=weights[i]
                    m.bias.data=weights[i+1]
                    i=i+2
                if isinstance(m,nn.BatchNorm1d):
                    m.weight.data=weights[i]
                    m.bias.data=weights[i+1]
                    i=i+2
            real_fake = self.modelD(d_in)
        
        return real_fake
    
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()

        self.modelC = nn.Sequential(
            nn.Linear(args.input_shape, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, args.num_class)
        )

    def forward(self, img_gen, target,weightsM=None):
        # Concatenate label embedding and image to produce input
        if weightsM is None:
            output = self.modelC(img_gen)
        else:
            i=0
            weights=weightsM[2]
            for m in self.modelC.modules():
                if isinstance(m,nn.Linear):
                    m.weight.data=weights[i]
                    m.bias.data=weights[i+1]
                    i=i+2
                if isinstance(m,nn.BatchNorm1d):
                    m.weight.data=weights[i]
                    m.bias.data=weights[i+1]
                    i=i+2
            output = self.modelC(img_gen)
        return output

class Classifier2(nn.Module):
    def __init__(self):
        super(Classifier2, self).__init__()

        self.modelC2 = nn.Sequential(
            nn.Linear(args.input_shape, 1024),
            nn.BatchNorm1d(1024, 0.8),
            nn.Linear(1024, args.num_class)
        )

    def forward(self, img_gen, target,weightsM=None):
        output2 = self.modelC2(img_gen)
        return output2


In [None]:
print(testclasses_id)
class_weight=torch.zeros(attribute.shape[0])

for i in range(attribute.shape[0]):
    if (i in testclasses_id):
        class_weight[i]=1.0
    else:
        class_weight[i]=0.1
# print(class_weight)

In [None]:

class all_arguments():
    n_way=10
    k_spt=5
    k_qry=3
    
    imgsz=2048
    sigma_ts=0.25
    sigma_tr=0.5
    
    meta_lr=1e-5
    meta_lrD=1e-3
    update_lr=1e-3
    update_step=5

    input_shape=2048
    num_class=200
    attri_dim=1024
    noise_dim=512
    clssifier_weight=0.05

    # Get options
    attributes=attributes
    cuda = True if torch.cuda.is_available() else False

    
args=all_arguments()
# Loss functions
adversarial_loss = torch.nn.MSELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss(weight=class_weight)


# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
classifier=Classifier()


if args.cuda:
    generator.cuda()
    discriminator.cuda()
    classifier.cuda()
    adversarial_loss.cuda()
    auxiliary_loss.cuda()
    
para=list(classifier.parameters())
#print(classifier)

In [None]:

import  torch
from    torch import nn
from    torch import optim
from    torch.nn import functional as F
from    torch.utils.data import TensorDataset, DataLoader
from    torch import optim
import  numpy as np

glen=len(list(generator.parameters()))
gen_para=list(generator.parameters())+ list(classifier.parameters())

disc_optim = optim.SGD(discriminator.parameters(), lr=args.meta_lrD,weight_decay=1e-6)
gen_optim = optim.Adam(gen_para, lr=args.meta_lr,betas=(0.9, 0.99),weight_decay=1e-6)

disc_schedular = StepLR(disc_optim,step_size=100,gamma=0.95)
gen_schedular = StepLR(gen_optim,step_size=100,gamma=0.95)

In [None]:

class Meta(nn.Module):
    """
    Meta Learner
    """
    def __init__(self, args):
        """
        :param args:
        """
        super(Meta, self).__init__()

        self.update_lr = args.update_lr
        self.meta_lr = args.meta_lr
        self.n_way = args.n_way
        self.k_spt = args.k_spt
        self.k_qry = args.k_qry
        self.cls_weight=args.clssifier_weight
        self.update_step = args.update_step
        cuda=args.cuda
        self.noise_dim=args.noise_dim
        #self.all_loss=all_loss()
        
        self.FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
        self.LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

        self.glen=glen
        self.gen_para=gen_para
        self.disc_optim = disc_optim
        self.gen_optim = gen_optim

        self.disc_schedular = disc_schedular
        self.gen_schedular = gen_schedular

    def clip_grad_by_norm_(self, grad, max_norm):
        """
        in-place gradient clipping.
        :param grad: list of gradients
        :param max_norm: maximum norm allowable
        :return:
        """
        total_norm = 0
        counter = 0
        for g in grad:
            param_norm = g.data.norm(2)
            total_norm += param_norm.item() ** 2
            counter += 1
        total_norm = total_norm ** (1. / 2)

        clip_coef = max_norm / (total_norm + 1e-6)
        if clip_coef < 1:
            for g in grad:
                g.data.mul_(clip_coef)

        return total_norm/counter
    
    def all_loss(self, img_feature,img_labels,fast_weight=None):
        batch_size = img_feature.shape[0]
        FloatTensor=self.FloatTensor
        LongTensor=self.LongTensor
        img_labels=img_labels.type(LongTensor)

        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(img_feature.type(FloatTensor))
        gen_attri = Variable(attributes[img_labels].type(FloatTensor))
        
        z = Variable(FloatTensor(np.random.normal(0,args.sigma_tr, (batch_size, self.noise_dim))))

        # Generate a batch of images
        gen_imgs = generator(z, gen_attri,fast_weight)
        validity = discriminator(gen_imgs, gen_attri,fast_weight)
        g_loss = adversarial_loss(validity, valid)
        #print('g_loss: '+str(g_loss))
        
        
        validity_real = discriminator(real_imgs, gen_attri,fast_weight)
        d_real_loss = adversarial_loss(validity_real, valid)
        #print('d_real_loss: '+str(d_real_loss))

        # Loss for fake images
        validity_fake = discriminator(gen_imgs.detach(), gen_attri,fast_weight)
        d_fake_loss = adversarial_loss(validity_fake, fake)
        #print('d_fake_loss: '+str(d_fake_loss))


        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2
        
        cls_out=classifier(gen_imgs,img_labels,fast_weight)
        cls_out=F.log_softmax(cls_out, dim=1)
        c_loss = F.nll_loss(cls_out, img_labels)
        
        
        return g_loss, d_loss, c_loss, cls_out
    

    def forward(self, x_spt, y_spt, x_qry, y_qry):
        """
        :param x_spt:   [b, setsz, d]
        :param y_spt:   [b, setsz]
        :param x_qry:   [b, querysz, d]
        :param y_qry:   [b, querysz]
        :return:
        """
        task_num, setsz,d = x_spt.size()
        querysz = x_qry.size(1)

        losses_gen = [0 for _ in range(self.update_step + 1)]  # losses_q[i], i is tasks idx
        losses_dis = [0 for _ in range(self.update_step + 1)]
        losses_cla = [0 for _ in range(self.update_step + 1)]
        corrects = [0 for _ in range(self.update_step + 1)]


        for i in range(task_num):

            # 1. run the i-th task and compute loss for k=0
            g_loss,d_loss,c_loss,output=self.all_loss(x_spt[i],y_spt[i])
             
            pred = output.max(1, keepdim=True)[1]
            correct = pred.eq(y_spt[i].view_as(pred)).sum().item()
            
            
            joint_gen_loss=g_loss+self.cls_weight*c_loss
                                      
            gen_grad = torch.autograd.grad(joint_gen_loss, self.gen_para)
            fast_weights_G = list(map(lambda g: g[1] - self.update_lr * g[0], zip(gen_grad, self.gen_para)))
            
            disc_grad = torch.autograd.grad(d_loss, discriminator.parameters())
            fast_weights_D = list(map(lambda d: d[1] - self.update_lr * d[0], zip(disc_grad, discriminator.parameters())))

#             clas_grad = torch.autograd.grad(c_loss, self.discriminator.parameters())
#             fast_weights_C = list(map(lambda c: c[1] - self.update_lr * c[0], zip(clas_grad, self.classifier.parameters())))
            fast_weight=[fast_weights_G[0:self.glen],fast_weights_D,fast_weights_G[self.glen:]]


            # this is the loss and accuracy before first update
            with torch.no_grad():
                # [setsz, nway]
                g_loss,d_loss,c_loss,output=self.all_loss(x_qry[i],y_qry[i])
                pred = output.max(1, keepdim=True)[1]
                correct = pred.eq(y_qry[i].view_as(pred)).sum().item()
                losses_gen[0] += g_loss
                losses_dis[0] += d_loss
                losses_cla[0] += c_loss
                corrects[0] = corrects[0] + correct

            # this is the loss and accuracy after the first update
            with torch.no_grad():
                # [setsz, nway]
                g_loss,d_loss,c_loss,output=self.all_loss(x_qry[i],y_qry[i],fast_weight)
                pred = output.max(1, keepdim=True)[1]
                correct = pred.eq(y_qry[i].view_as(pred)).sum().item()
                # how to initialise with the fast weight
                losses_gen[1] += g_loss
                losses_dis[1] += d_loss
                losses_cla[1] += c_loss
                corrects[1] = corrects[1] + correct

            for k in range(1, self.update_step):
                # 1. run the i-th task and compute loss for k=1~K-1
                g_loss,d_loss,c_loss,output=self.all_loss(x_spt[i],y_spt[i],fast_weight)
                joint_gen_loss=g_loss+self.cls_weight*c_loss

                gen_grad = torch.autograd.grad(joint_gen_loss, self.gen_para)
                fast_weights_G = list(map(lambda g: g[1] - self.update_lr * g[0], zip(gen_grad, self.gen_para)))

                disc_grad = torch.autograd.grad(d_loss, discriminator.parameters())
                fast_weights_D = list(map(lambda d: d[1] - self.update_lr * d[0], zip(disc_grad, discriminator.parameters())))
                
                fast_weight=[fast_weights_G[0:self.glen],fast_weights_D,fast_weights_G[self.glen:]]
                # 2. compute grad on theta_pi
                
                # 3. theta_pi = theta_pi - train_lr * grad
                g_loss,d_loss,c_loss,output=self.all_loss(x_qry[i],y_qry[i],fast_weight)
                pred = output.max(1, keepdim=True)[1]
                correct = pred.eq(y_qry[i].view_as(pred)).sum().item()
                # loss_q will be overwritten and just keep the loss_q on last update step. 
                losses_gen[k+1] += g_loss
                losses_dis[k+1] += d_loss
                losses_cla[k+1] += c_loss
                corrects[k+1] = corrects[k+1] + correct
                
                ################################## Unseen class test #####################################

            # 4. record last step's loss for task i
            losses_gen.append(g_loss)
            losses_dis.append(d_loss)
            losses_cla.append(c_loss)
                                      

        # end of all tasks
        # sum over all losses on query set across all tasks
        loss_gen=0
        for itr in range(task_num):
            loss_gen+=losses_gen[-task_num:][itr]
        loss_gen=loss_gen/task_num
        
        loss_dis=0
        for itr in range(task_num):
            loss_dis+=losses_dis[-task_num:][itr]
        loss_dis=loss_dis/task_num
        
        loss_cla=0
        for itr in range(task_num):
            loss_cla+=losses_cla[-task_num:][itr]
        loss_cla=loss_cla/task_num
            
        joint_Gloss=loss_gen+self.cls_weight*loss_cla
        
        # optimize theta parameters
        self.disc_optim.zero_grad()
        self.gen_optim.zero_grad()
        
        joint_Gloss.backward()
        loss_dis.backward()
        
        self.disc_optim.step()
        self.gen_optim.step()
        
        
        ######################################################################################
        losses_gen = [0 for _ in range(self.update_step + 1)]  # losses_q[i], i is tasks idx
        losses_dis = [0 for _ in range(self.update_step + 1)]
        losses_cla = [0 for _ in range(self.update_step + 1)]
        corrects = [0 for _ in range(self.update_step + 1)]


        for i in range(task_num):

            # 1. run the i-th task and compute loss for k=0
            g_loss,d_loss,c_loss,output=self.all_loss(x_spt[i],y_spt[i])
             
            pred = output.max(1, keepdim=True)[1]
            correct = pred.eq(y_spt[i].view_as(pred)).sum().item()
            
            disc_grad = torch.autograd.grad(d_loss, discriminator.parameters())
            fast_weights_D = list(map(lambda d: d[1] - self.update_lr * d[0], zip(disc_grad, discriminator.parameters())))

#             clas_grad = torch.autograd.grad(c_loss, self.discriminator.parameters())
#             fast_weights_C = list(map(lambda c: c[1] - self.update_lr * c[0], zip(clas_grad, self.classifier.parameters())))
            fast_weight=[fast_weights_G[0:self.glen],fast_weights_D,fast_weights_G[self.glen:]]


            # this is the loss and accuracy before first update
            with torch.no_grad():
                # [setsz, nway]
                g_loss,d_loss,c_loss,output=self.all_loss(x_qry[i],y_qry[i])
                pred = output.max(1, keepdim=True)[1]
                correct = pred.eq(y_qry[i].view_as(pred)).sum().item()
                losses_gen[0] += g_loss
                losses_dis[0] += d_loss
                losses_cla[0] += c_loss
                corrects[0] = corrects[0] + correct

            # this is the loss and accuracy after the first update
            with torch.no_grad():
                # [setsz, nway]
                g_loss,d_loss,c_loss,output=self.all_loss(x_qry[i],y_qry[i],fast_weight)
                pred = output.max(1, keepdim=True)[1]
                correct = pred.eq(y_qry[i].view_as(pred)).sum().item()
                # how to initialise with the fast weight
                losses_gen[1] += g_loss
                losses_dis[1] += d_loss
                losses_cla[1] += c_loss
                corrects[1] = corrects[1] + correct

            for k in range(1, self.update_step):
                # 1. run the i-th task and compute loss for k=1~K-1
                g_loss,d_loss,c_loss,output=self.all_loss(x_spt[i],y_spt[i],fast_weight)

                disc_grad = torch.autograd.grad(d_loss, discriminator.parameters())
                fast_weights_D = list(map(lambda d: d[1] - self.update_lr * d[0], zip(disc_grad, discriminator.parameters())))
                
                fast_weight=[fast_weights_G[0:self.glen],fast_weights_D,fast_weights_G[self.glen:]]
                # 2. compute grad on theta_pi
                
                # 3. theta_pi = theta_pi - train_lr * grad
                g_loss,d_loss,c_loss,output=self.all_loss(x_qry[i],y_qry[i],fast_weight)
                pred = output.max(1, keepdim=True)[1]
                correct = pred.eq(y_qry[i].view_as(pred)).sum().item()
                # loss_q will be overwritten and just keep the loss_q on last update step. 
                losses_gen[k+1] += g_loss
                losses_dis[k+1] += d_loss
                losses_cla[k+1] += c_loss
                corrects[k+1] = corrects[k+1] + correct
                
                ################################## Unseen class test #####################################

            # 4. record last step's loss for task i
            losses_gen.append(g_loss)
            losses_dis.append(d_loss)
            losses_cla.append(c_loss)
                                      

        # end of all tasks
        # sum over all losses on query set across all tasks
        
        loss_dis=0
        for itr in range(task_num):
            loss_dis+=losses_dis[-task_num:][itr]
        loss_dis=loss_dis/task_num
        
        # optimize theta parameters
        self.disc_optim.zero_grad()
        loss_dis.backward()
        
        self.disc_optim.step()

        ######################################################################################
        
        accs = np.array(corrects) / (querysz * task_num)
        #all_loss=[losses_gen,losses_dis,losses_cla]

        return accs, loss_gen,loss_dis, loss_cla


In [None]:
import  torchvision.transforms as transforms
from    PIL import Image
import  os.path
import  numpy as np


class zsl_NShot:

    def __init__(self, trainData,testData, batchsz, n_way, k_shot, k_query):
        """
        Different from mnistNShot, the
        :param root:
        :param batchsz: task num
        :param n_way:
        :param k_shot:
        :param k_qry:
        :param imgsz:
        """
        self.x_train, self.x_test = trainData, testData  
        # self.normalization()
        self.batchsz = batchsz # number of task
        self.n_way = n_way  # n way
        self.k_shot = k_shot  # k shot
        self.k_query = k_query  # k query

        # save pointer of current read batch in total cache
        self.indexes = {"train": 0, "test": 0}
        self.datasets = {"train": self.x_train, "test": self.x_test}  # original data cached
        #print("DB: train", self.x_train[0].shape, "test", self.x_test[0].shape)

        self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"])} # current epoch data cached
                              # "test": self.load_data_cache(self.datasets["test"])}

    def normalization(self):
        """
        Normalizes our data, to have a mean of 0 and sdt of 1
        """
        self.mean = np.mean(self.x_train)
        self.std = np.std(self.x_train)
        self.max = np.max(self.x_train)
        self.min = np.min(self.x_train)
        # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
        self.x_train = (self.x_train - self.mean) / self.std
        self.x_test = (self.x_test - self.mean) / self.std

        self.mean = np.mean(self.x_train)
        self.std = np.std(self.x_train)
        self.max = np.max(self.x_train)
        self.min = np.min(self.x_train)

    # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)

    def load_data_cache(self, data_pack):
        """
        Collects several batches data for N-shot learning
        :param data_pack: [N,2048]
        :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks
        """
        #  take 5 way 1 shot as example: 5 * 1
        setsz = self.k_shot * self.n_way
        querysz = self.k_query * self.n_way
        data_cache = []

        unq_labels=np.unique(data_pack[1])
        labels=np.array(data_pack[1])
        
        Data=data_pack[0]
       
        x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
        for i in range(self.batchsz):  # one batch means one set

            x_spt, y_spt, x_qry, y_qry = [], [], [], []
#             print('unq labels shape: '+str(unq_labels.shape)+'  ' +str(i))
            selected_cls = np.random.choice(unq_labels, 2*self.n_way, False)
#             print(unq_labels)
#                 print('selected: '+str(selected_cls))

            for j, cur_class in enumerate(selected_cls[:self.n_way]):
                indx=np.where(labels==cur_class)[0]
                selected_img = np.random.choice(indx, self.k_shot, False)
                x_spt.append(Data[selected_img])
                y_spt.append(labels[selected_img])
                #print('shape: '+str(np.array(x_spt).shape))

            for j, cur_class in enumerate(selected_cls[self.n_way:]):
                indx=np.where(labels==cur_class)[0]
                selected_img = np.random.choice(indx, self.k_query, False)
                x_qry.append(Data[selected_img])
                y_qry.append(labels[selected_img])
                #print('shape: '+str(np.array(x_spt).shape))

            # shuffle inside a batch
            perm = np.random.permutation(self.n_way * self.k_shot)
            x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 2048)[perm]
            y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
            perm = np.random.permutation(self.n_way * self.k_query)
            x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 2048)[perm]
            y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]

            # append [N,2048] => [b, N,2048]
            x_spts.append(x_spt)
            y_spts.append(y_spt)
            x_qrys.append(x_qry)
            y_qrys.append(y_qry)


        # [b, N,2048]
        x_spts = np.array(x_spts).astype(np.float32).reshape(self.batchsz, setsz, 2048)
        y_spts = np.array(y_spts).astype(np.int).reshape(self.batchsz, setsz)
        # [b, N,2048]
        x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 2048)
        y_qrys = np.array(y_qrys).astype(np.int).reshape(self.batchsz, querysz)

        data_cache.append([x_spts, y_spts, x_qrys, y_qrys])

        return data_cache

    def next(self, mode='train'):
        """
        Gets next batch from the dataset with name.
        :param mode: The name of the splitting (one of "train", "val", "test")
        :return:
        """
        # update cache if indexes is larger cached num
        if self.indexes[mode] >= len(self.datasets_cache[mode]):
            self.indexes[mode] = 0
            self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode])

        next_batch = self.datasets_cache[mode][self.indexes[mode]]
        self.indexes[mode] += 1

        return next_batch

In [None]:
from sklearn.preprocessing import normalize
from sklearn.metrics import accuracy_score
from sklearn import svm
from torch.utils.data import DataLoader,TensorDataset

def accuracy_zsl(gen_imgs,img_labels,testData,testLabels):
    FloatTensor = torch.cuda.FloatTensor
    LongTensor = torch.cuda.LongTensor
    gen_imgs=(torch.from_numpy(gen_imgs)).type(FloatTensor)
    img_labels=(torch.from_numpy(img_labels)).type(LongTensor)
    testData=(torch.from_numpy(testData)).type(FloatTensor)
    testLabels=(torch.from_numpy(testLabels)).type(LongTensor)
    
    ########################################################
    # For the training of the classifier, we are using synthesize sample for both seen and unseen class,
    # to overcome the bias towards the seen classes.
    # But at the same time we can use seen class training data also with the synthesize seen and unseen
    # class data. If you want to use this uncomment the below 4 lines. It gives the GZSL H-mean ~60. 
    # While in the paper we use former case and have H-mean ~56.
    
#     train_features=(torch.from_numpy(x)).type(FloatTensor) 
#     train_label1=(torch.from_numpy(train_label).squeeze(1)).type(LongTensor)    
#     gen_imgs=torch.cat([gen_imgs,train_features])
#     img_labels=torch.cat([img_labels,train_label1])

    train_data = TensorDataset(gen_imgs,img_labels)
    train_loader = DataLoader(train_data,batch_size=32,shuffle=True)
    
    
    test_seen_features1=test_seen_features.type(FloatTensor)
    test_seen_label1=test_seen_label.type(LongTensor)
        
    classifier2=Classifier2()
    classifier2.cuda()
    
    cls2_para=list(p for p in classifier2.parameters() if p.requires_grad)
    cls2_optim = optim.Adam(cls2_para, lr=args.meta_lrD,betas=(0.9, 0.999),weight_decay=1e-3)
    cls2_schedular=StepLR(cls2_optim,step_size=100,gamma=0.9)
    
    besthmean=0
    epsln=0.001
    US=[]
    for it in range(8001):
        cls2_schedular.step()
        # optimize theta parameters
        cls2_optim.zero_grad()
        gen_imgs,img_labels = train_loader.__iter__().next()
        
        cls_out=classifier2(gen_imgs,img_labels)
        c_loss2=auxiliary_loss(cls_out, img_labels)

        c_loss2.backward()
        cls2_optim.step()
        
        if it%10==0:
            with torch.no_grad():
                cls_out=classifier2(testData,testLabels)
                pred = cls_out.max(1, keepdim=True)[1]
                correct = pred.eq(testLabels.view_as(pred)).sum().item()
                ACC_unseen=correct/testLabels.shape[0] +epsln
                
                cls_out=classifier2(test_seen_features1,test_seen_label1)
                pred = cls_out.max(1, keepdim=True)[1]
                correct = pred.eq(test_seen_label1.view_as(pred)).sum().item()
                ACC_seen=correct/test_seen_label1.shape[0]+epsln
                H_mean=(2*ACC_unseen*ACC_seen)/(ACC_unseen+ACC_seen)
        
            if H_mean>besthmean:
                besthmean=H_mean
                US=[ACC_unseen,ACC_seen]

            
        if (it%2000)==0:
            print('curr_Hmean: '+str(round(H_mean,3))+ '  Best-Hmean: '+str(round(besthmean,3)) + '  US  '+str(US))
        
    del classifier2
    return H_mean,besthmean
            


def test_zsl(genunseen_input,test_labels_repeat,test_features,testLabels):
        """
        """
        with torch.no_grad():
            gen_imgs = generator(genunseen_input[:,args.attri_dim:], genunseen_input[:,:args.attri_dim])
            
            pseudoTrainData = np.array(gen_imgs.cpu())
            testData = test_features
            pseudoTrainLabels=test_labels_repeat
            testLabels=testLabels
            
        hmean,besthmean=accuracy_zsl(pseudoTrainData,pseudoTrainLabels,testData,testLabels)
            

        return hmean,besthmean

def test():
    # Generate seen class samples
    testunq_labels=np.unique(test_label)
    Nsample=100
    test_attri_repeat=[]
    test_labels_repeat=[]
    for i in testunq_labels:
        lab_attribute=np.reshape(attribute[i],[1,args.attri_dim])
        test_attri_repeat.append(np.repeat(lab_attribute,Nsample,axis=0))
        test_labels_repeat.append(np.repeat(i,Nsample,axis=0))
    test_attri_repeat=np.concatenate(test_attri_repeat,0)

    
    test_labels_repeat=np.concatenate(test_labels_repeat,0)
    z = np.random.normal(0, args.sigma_ts, (Nsample*testunq_labels.shape[0], args.noise_dim))
    genunseen_input = np.concatenate((test_attri_repeat,z), 1)

    FloatTensor = torch.cuda.FloatTensor
    genunseen_input=(torch.from_numpy(genunseen_input)).type(FloatTensor)
#     print(genunseen_input.shape)
#     print(test_labels_repeat.shape)
    
    # Generate Unseen class samples
    # unq_train_labels
    Nsample=100
    train_attri_repeat=[]
    train_labels_repeat=[]
    for i in unq_train_labels:
        lab_attribute=np.reshape(attribute[i],[1,args.attri_dim])
        train_attri_repeat.append(np.repeat(lab_attribute,Nsample,axis=0))
        train_labels_repeat.append(np.repeat(i,Nsample,axis=0))
    train_attri_repeat=np.concatenate(train_attri_repeat,0)

    train_labels_repeat=np.concatenate(train_labels_repeat,0)
    z = np.random.normal(0, args.sigma_ts, (Nsample*unq_train_labels.shape[0], args.noise_dim))
    genseen_input = np.concatenate((train_attri_repeat,z), 1)

    FloatTensor = torch.cuda.FloatTensor
    genseen_input=(torch.from_numpy(genseen_input)).type(FloatTensor)
    
#     print(genseen_input.shape)
#     print(train_labels_repeat.shape)
    
    seen_unseen_input=torch.cat([genunseen_input,genseen_input])
    seen_unseen_labels=np.concatenate([test_labels_repeat,train_labels_repeat])

#     print(seen_unseen_input.shape)
#     print(seen_unseen_labels.shape)
    
    hmean,besthmean=test_zsl(seen_unseen_input,seen_unseen_labels,test_features,test_label)
    return hmean,besthmean

In [None]:
db = zsl_NShot(train_data,test_data, batchsz=10, n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry)
x_spt, y_spt, x_qry, y_qry = db.next('train')
x_spt = torch.from_numpy(x_spt).cuda()
x_qry = torch.from_numpy(x_qry).cuda()
y_spt = torch.from_numpy(y_spt).cuda()
y_qry = torch.from_numpy(y_qry).cuda()
print('train :  ' +str(x_spt.shape))
print('train_labels :  ' +str(y_spt.shape))
print('test :  ' +str(x_qry.shape))
print('test_labels :  ' +str(y_qry.shape))

best=0
bestsoft=0
maml = Meta(args).cuda()
for itr in range(10000):
    x_spt, y_spt, x_qry, y_qry = db.next('train')
    x_spt = torch.from_numpy(x_spt).cuda()
    x_qry = torch.from_numpy(x_qry).cuda()
    y_spt = torch.from_numpy(y_spt).cuda()
    y_qry = torch.from_numpy(y_qry).cuda()
    
    accs,loss_gen,loss_dis, loss_cla = maml(x_spt, y_spt, x_qry, y_qry)
    accs=np.mean(accs)
    if itr%100==0:
        hmean,besthmean=test()
#         if soft_accuracy>=bestsoft:
#             bestsoft=soft_accuracy
#         print('accuracy: ' +str(soft_accuracy)+'  Best: '+str(bestsoft))
            
        if besthmean>=best:
            best=besthmean
            
#             torch.save(generator,'./generator.pth')
#             torch.save(discriminator,'./discriminator.pth')
#             torch.save(classifier,'./classifier.pth')
#             torch.save(classifier2,'./classifier2.pth')
            
        print(str(itr)+'-Accuracy: '+str(round(accs,3))+'::::Loss::::'+str(round(loss_gen.item(),3))+ ' : ' +str(round(loss_dis.item(),3))+ ' : '+str(round(loss_cla.item(),3))+' Accuracy: '+str(round(hmean,3))+' Best : '+str(round(best,3)))