# Importing Libraries and Supporting python files

In [3]:
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision
import copy
import os
import numpy as np
from options import Parameters
from collections import OrderedDict


In [4]:
# Calling  parameters class from Arguments file
args = Parameters()


# Embedding Network
#### The embedding sub-network  consists of a deep convolutional network for feature extraction and a  nonparametric one-shot classifier
# CNet
1. Given an input image I, we
use a residual network  to produce its feature representation fθemb (I)
2. a fully-connected
layer on top of the embedding sub-network with a crossentropy loss (CELoss), that outputs |Cbase| scores

In [5]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

In [16]:
class resnet(nn.Module):
    def __init__(self):
        super(resnet,self).__init__()
       
        '''
        Update the resnet model last layer to match the weights
        '''
        resnet18 = torchvision.models.resnet18(pretrained=False)
        num_features=resnet18.fc.in_features
        resnet18.fc=nn.Linear(num_features,64)
        '''
        Updated the resnet model according to the weights(.t7)
        '''

        state_dict = torch.load(r'softRandom.t7',map_location=torch.device('cpu'))

        names=[]
        for k , v in resnet18.state_dict().items():
            names.append(k)
        i=0

        new_state_dict = OrderedDict()

        for k, v in state_dict.items():
            new_state_dict[names[i]] = v
            i=i+1
        resnet18.load_state_dict(new_state_dict)
        
        '''
        Store the last layer weights before removing
        '''
        self.fc_layer_weight=resnet18.state_dict()[names[-2]]
#         resnet_updated.fc=Identity()
        
        '''
        Create a sequence of layers
        '''
        self.conv1=resnet18.conv1
        self.conv1.load_state_dict(resnet18.conv1.state_dict())
        
        self.bn1=resnet18.bn1
        self.bn1.load_state_dict(resnet18.bn1.state_dict())
        
        self.relu=resnet18.relu
        self.maxpool=resnet18.maxpool
        
        
        self.layer1=resnet18.layer1
        self.layer1.load_state_dict(resnet18.layer1.state_dict())
        self.layer2=resnet18.layer2
        self.layer2.load_state_dict(resnet18.layer2.state_dict())
        
        self.layer3=resnet18.layer3
        self.layer3.load_state_dict(resnet18.layer3.state_dict())
        self.layer4=resnet18.layer4
        self.layer4.load_state_dict(resnet18.layer4.state_dict())
        self.avgpool=resnet18.avgpool
        
    def forward(self,x):
        x=self.conv1(x)
        x=self.bn1(x)
        x=self.relu(x)
        x=self.maxpool(x)
        layer1 = self.layer1(x) # (, 64L, 56L, 56L)
        layer2 = self.layer2(layer1) # (, 128L, 28L, 28L)
        layer3 = self.layer3(layer2) # (, 256L, 14L, 14L)
        layer4 = self.layer4(layer3) # (,512,7,7)
        x = self.avgpool(layer4) # (,512,1,1)
        x = x.view(x.size(0), -1)
        return x


# Deformation Network

In [17]:
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)

In [18]:
class DeformationNetwork(nn.Module):
    '''
        Two branch's performance are similar one branch's
        So we use one branch here
        Deeper attention network do not bring in benifits
        So we use small network here
    '''
    def __init__(self):
        super(DeformationNetwork,self).__init__()
        def conv(inp,out):
            return nn.Sequential(nn.Conv2d(inp,out,3,padding=1),
                           nn.BatchNorm2d(out),
                           nn.ReLU(),
                           nn.MaxPool2d(2)
                          )
        self.encoder=nn.Sequential(conv(6,32), #'6*224*224'
                   
                                conv(32,64),#'6*224*224'
                                conv(64,64),#'6*224*224'
                                conv(64,32),#'6*224*224'
                                conv(32,16),
                                Flatten() )
    def forward(self,x):
        """                 
    inputs: Batchsize*3*224*224
    outputs: Batchsize*100
    """
        outputs=self.encoder(x)
        return outputs

In [19]:
from torchsummary import summary
deform=DeformationNetwork()
summary(deform,(6,224,224))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 784]                 --
|    └─Sequential: 2-1                   [-1, 32, 112, 112]        --
|    |    └─Conv2d: 3-1                  [-1, 32, 224, 224]        1,760
|    |    └─BatchNorm2d: 3-2             [-1, 32, 224, 224]        64
|    |    └─ReLU: 3-3                    [-1, 32, 224, 224]        --
|    |    └─MaxPool2d: 3-4               [-1, 32, 112, 112]        --
|    └─Sequential: 2-2                   [-1, 64, 56, 56]          --
|    |    └─Conv2d: 3-5                  [-1, 64, 112, 112]        18,496
|    |    └─BatchNorm2d: 3-6             [-1, 64, 112, 112]        128
|    |    └─ReLU: 3-7                    [-1, 64, 112, 112]        --
|    |    └─MaxPool2d: 3-8               [-1, 64, 56, 56]          --
|    └─Sequential: 2-3                   [-1, 64, 28, 28]          --
|    |    └─Conv2d: 3-9                  [-1, 64, 56, 56]          36,928
|  

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 784]                 --
|    └─Sequential: 2-1                   [-1, 32, 112, 112]        --
|    |    └─Conv2d: 3-1                  [-1, 32, 224, 224]        1,760
|    |    └─BatchNorm2d: 3-2             [-1, 32, 224, 224]        64
|    |    └─ReLU: 3-3                    [-1, 32, 224, 224]        --
|    |    └─MaxPool2d: 3-4               [-1, 32, 112, 112]        --
|    └─Sequential: 2-2                   [-1, 64, 56, 56]          --
|    |    └─Conv2d: 3-5                  [-1, 64, 112, 112]        18,496
|    |    └─BatchNorm2d: 3-6             [-1, 64, 112, 112]        128
|    |    └─ReLU: 3-7                    [-1, 64, 112, 112]        --
|    |    └─MaxPool2d: 3-8               [-1, 64, 56, 56]          --
|    └─Sequential: 2-3                   [-1, 64, 28, 28]          --
|    |    └─Conv2d: 3-9                  [-1, 64, 56, 56]          36,928
|  

# IDEMENET

In [20]:
class IDeMeNet(nn.Module):
    def __init__(self):
        super(IDeMeNet,self).__init__()
        
        self.deform=DeformationNetwork()
        
        self.embedding=resnet()
        
        #patch weight for Linear combination of probe and gallery images
        #defualt patch size if 3*3
        self.patch=nn.Sequential(nn.Linear(784,3*3))
        
        '''
        FC layer in Embedding Network
        use the weight we stored in to a separate variable before making in 
        to identity (in other words remove)
        '''
        self.fc=nn.Linear(512,64)
        self.fc.weight=torch.nn.parameter.Parameter(self.embedding.fc_layer_weight)
        
        
    def forward(self,probe,gallery=1,syn_embedding=None,fixSquare=1,oneSquare=1,mode=None):
        if mode=='deform_embedding':
            batch_size=probe.size(0)
            feature=self.deform(torch.cat((probe,gallery),1))
            weight=self.patch(feature)
            '''
            Reshape weights to perform linear operation
            '''
            patch_weight=weight.view(batch_size,3*3,1,1,1)
            patch_weight=patch_weight.expand(batch_size,3*3,3,224,224)
            patch_weight=patch_weight*fixSquare #[batch,9,3,224,224]
            patch_weight=torch.sum(patch_weight,dim=1) #[batch 3,224,224]

            img_syn=patch_weight*probe+(oneSquare-patch_weight)*gallery
            syn_embedding=self.embedding(img_syn)
            return syn_embedding,weight,feature
        elif mode=='fully_connected':
            fc_output=self.fc(syn_embedding)
            return fc_output
        elif mode=='feature_extraction':
            
            feature=self.embedding(probe)
            return feature
        
IDeMeNet=IDeMeNet()      
    

In [None]:
# print("N.o GPU's using ",torch.cuda.device_count())
# IDeMeNet=nn.DataParallel(IDeMeNet)
# # IDeMeNet=IDeMeNet.cuda()


# Set the optimization parameters for training

In [None]:
assert args.train_from_scratch==True
    

optimizer_deform=torch.optim.Adam([
                {'params':IDeMeNet.deform.parameters()},
                {'params':IDeMeNet.patch.parameters(),'lr':args.LR}
                ],lr=args.LR,eps=1e-04)
optimizer_classifer=torch.optim.Adam([
                {'params':IDeMeNet.embedding.parameters(),'lr':args.LR*0.1},
                {'params':IDeMeNet.fc.parameters(),'lr':args.LR}
                ],lr=args.LR*0.2,eps=1e-05)


# The paper suggests to uses learing rate scheduler which led to better performance

In [None]:
from torch.optim import lr_scheduler

deform_lr_scheduler=lr_scheduler.StepLR(optimizer_deform,step_size=40,gamma=0.5)
embedding_lr_scheduler=lr_scheduler.StepLR(optimizer_classifer,step_size=40,gamma=0.5)



# Load the train,test and validation datasets

In [None]:
from data_loading import OneShot_Imagenet

def worker_init_fn(worker_id):                                                         
    np.random.seed(np.random.get_state()[1][0] + worker_id)

image_datasets={x:OneShot_Imagenet(path=r'C:\Users\vsankepa\Desktop\Untitled Folder 1',type_=x,ways=args.trainways if x=='train' else args.ways,
                                  shots=args.shots,test_num=args.test_num,epoch=args.epoch,gallery_img=args.gallery_img) for x in ['train','test','val']}



In [None]:
dataloaders={x:torch.utils.data.DataLoader(image_datasets[x],
                                          batch_size=1,shuffle=True if x=='train' else False
#                                           ,num_workers=n_threads,   commented for GPU
#                                            worker_init_fn=worker_init_fn
                                          ) for x in ['train','test','val']}

In [None]:
#Load Gallery Images

In [None]:
gallery=image_datasets['test'].gallery # train or test does not matter it will give the same data check the function
gallery_feature=image_datasets['test'].get_features(IDeMeNet,args.batch_size) #torch.Size([1920, 512])

# Supporting Functions

In [None]:
def extract_feature(model,probe_images,requires_grad):
    batch=(len(probe_images)+args.batch_size-1)//args.batch_size
    for i in range(batch):
        features=model(Variable(probe_images[i*args.batch_size:(i+1)*args.batch_size],requires_grad=requires_grad),mode='feature_extraction')
        
        if i==0:
#             print(i)
            all_features=features
#             print(all_features.shape)
        else:
            all_features=torch.cat((all_features,features),dim=0)
    return all_features

# Perform Linear operation of Embeddings of support and gallery images

## Creating a weight matrix 

In [None]:
######################################################################
# Weight matrix pre-process
patch_xl = []
patch_xr = []
patch_yl = []
patch_yr = []

if args.patch_size == 3:
    point = [0,74,148,224]
elif args.patch_size == 5:
    point = [0,44,88,132,176,224]
elif args.patch_size == 7:
    point = [0,32,64,96,128,160,192,224]

for i in range(args.patch_size):
    for j in range(args.patch_size):
        patch_xl.append(point[i])
        patch_xr.append(point[i+1])
        patch_yl.append(point[j])
        patch_yr.append(point[j+1])

fixSquare = torch.zeros(1,args.patch_size*args.patch_size,3,224,224).float()
for i in range(args.patch_size*args.patch_size):
    fixSquare[:,i,:,patch_xl[i]:patch_xr[i],patch_yl[i]:patch_yr[i]] = 1.00
fixSquare = fixSquare  #.cuda()

oneSquare = torch.ones(1,3,224,224).float()
oneSquare = oneSquare  #.cuda()

In [None]:
def euclidean_dist(x, y):
    # x: N x D 
    # y: M x D 
    n = x.size(0) #192
    m = y.size(0) # 5
    d = x.size(1) #512
#     assert d == y.size(1)
    x = x.unsqueeze(1).expand(n, m, d) # [192,5,512]
    y = y.unsqueeze(0).expand(n, m, d) #[192,5,512]
    # To accelerate training, but observe little effect

    return (torch.pow(x - y, 2)).sum(2)

# As we have less data we are augmenting dataset based on the euclidean distance between classes for support dataset

In [None]:
'''
 probe Image is same as support Image .Probe Image before augmenting dataset based on distance and Support Image after augmenting
 the dataset
'''

def aug_images_basedOnDistance(support_images,support_features,support_group,support_class,ways):
    '''
    Calculate the distance between the support/probe features
    and gallery features
    IMP: this step will separate gallery and probe/support images based on the distance
    with out this step gallery and probe/support images are same
    
    '''
    
    support_center=support_features.view(ways,args.shots,-1).mean(dim=1) # [5,5,512] --> [5,512]
    batch_size=len(gallery_feature)//10 #1920//10=192
    dists=euclidean_dist(gallery_feature[:batch_size],support_center)
    with torch.no_grad():
        for i in range(1,10): #adding one will include
            
            distances=euclidean_dist(gallery_feature[i*batch_size:(i+1)*batch_size],support_center)
            
            dists=torch.cat((dists,distances),dim=0)
    dists=dists.transpose(0,1) ## [ways,ways*Gallery_size] check self.train_data in loadData
    probe_images=torch.FloatTensor(ways*args.shots*(1+args.augnum),3,224,224) # [5,5,6,3,224,224]
    gallery_images=torch.FloatTensor(ways*args.shots*(1+args.augnum),3,224,224) # [5,5,6,3,224,224]
    probe_group=torch.FloatTensor(ways*args.shots*(1+args.augnum),1)#way number  # [5,5,6,1]
    probe_class=torch.FloatTensor(ways*args.shots*(1+args.augnum),1) # class # [5,5,6,1]
    _,distance_ind=torch.topk(dists,args.chooseNum,dim=1,largest=False)  #returns top chooseNum distances in ascending order.
                                                #I think he are we are using gallery images which are close to original images in euclidean distance
    for i in range(ways):
        for j in range(args.shots):
            probe_images[i*args.shots*(1+args.augnum)+j*(1+args.augnum)+0]=support_images[i*args.shots+j]
            probe_group[i*args.shots*(1+args.augnum)+j*(1+args.augnum)+0]=support_group[i*args.shots+j]
            probe_class[i*args.shots*(1+args.augnum)+j*(1+args.augnum)+0]=support_class[i*args.shots+j]
            
            gallery_images[i*args.shots*(1+args.augnum)+j*(1+args.augnum)+0]=support_images[i*args.shots+j]
            
            for k in range(args.augnum):
                p=np.random.randint(0,2)
                if p==0:                                                   #1+k is because k starts from 0 and oth position has an image already
                    probe_images[i*args.shots*(1+args.augnum)+j*(1+args.augnum)+1+k]=torch.flip(probe_images[i*args.shots+j],[2])
                else:
                    probe_images[i*args.shots*(1+args.augnum)+j*(1+args.augnum)+1+k]=probe_images[i*args.shots+j]
                    
                probe_group[i*args.shots*(1+args.augnum)+j*(1+args.augnum)+1+k]=support_group[i*args.shots+j]
                probe_class[i*args.shots*(1+args.augnum)+j*(1+args.augnum)+1+k]=support_class[i*args.shots+j]
                
                
                
                choose=np.random.randint(0,args.chooseNum)                                       # train or test does not matter it will select all train images for gallery
                gallery_images[i*args.shots*(1+args.augnum)+j*(1+args.augnum)+1+k]=image_datasets['test'].get_gallery_images(gallery[distance_ind[i][choose]])
                
    
    return probe_images,gallery_images,probe_group,probe_class     
            
    

# Training Function

In [None]:
def train_model(model,num_epochs=25):
    summary=dict()
    num_epochs=20
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')
    emb_loss=nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        print('Running Epoch -->',epoch)


        for phase in ['train','test']:
            if phase=='train':
                deform_lr_scheduler.step()
                embedding_lr_scheduler.step()

            model.train(False)
            loss,acc=0.0,0
            classifier_loss,classifier_acc=0,0

            weights={}
            for k in range(args.patch_size*args.patch_size):
                weights[str(k)]=[]

            count=0
            for i , (support_images,support_group,support_class,test_images,test_group,test_class) in enumerate(dataloaders['train']):

                '''
                x = torch.tensor([1, 2, 3, 4]) # (4)
                x.unsqueeze(1).shape #(4,1)
                x=x.unsqueeze(0).shape #(1,4)

                x.squueze(0).shape # (4)

                '''
                count=count+1
                support_images=support_images.squeeze(0) #torch.Size([25, 3, 224, 224])
                support_group=support_group.squeeze(0) #torch.Size([25, 1])
                support_class=support_class.squeeze(0) #torch.Size([25, 1])

                test_images=test_images.squeeze(0) #torch.Size([25, 3, 224, 224])
                test_class=test_class.squeeze(0) #torch.Size([25, 3, 224, 224])
                ways=int(support_images.size(0)//args.shots)             
                support_features=extract_feature(model,support_images,requires_grad=True) #torch.Size([25, 512])
                test_features=extract_feature(model,test_images,requires_grad=True) #torch.Size([25, 512])
                probe_images,gallery_images,probe_group,probe_class=aug_images_basedOnDistance(support_images,support_features,support_group,support_class,ways=ways)
                batch=len(probe_images+args.batch_size-1)//args.batch_size
                first=True
                for b in range(batch):
                    if b==batch-1:
                        remaining=probe_images.size(0)-b*args.batch_size
                        syn_embedding,patch_weight,features=model(Variable(probe_images[b*args.batch_size:],requires_grad=True),
                            Variable(gallery_images[b*args.batch_size:],requires_grad=True),
                            fixSquare=Variable(fixSquare.expand(remaining,args.patch_size*args.patch_size,3,224,224),requires_grad=False),
                            oneSquare=Variable(oneSquare.expand(remaining,3,224,224),requires_grad=False),
                            mode='deform_embedding'
                            )
                        _cls = model(None,syn_embedding=syn_embedding,gallery=None,fixSquare=1,oneSquare=1,mode='fully_connected')

                    else:
                        syn_embedding,patch_weight,features=model(Variable(probe_images[b*args.batch_size:(b+1)*args.batch_size],requires_grad=True),
                            Variable(gallery_images[b*args.batch_size:(b+1)*args.batch_size],requires_grad=True),
                            fixSquare=Variable(fixSquare.expand(args.batch_size,args.patch_size*args.patch_size,3,224,224),requires_grad=False),
                            oneSquare=Variable(oneSquare.expand(args.batch_size,3,224,224),requires_grad=False),
                            mode='deform_embedding'
                            )
                        _cls = model(None,syn_embedding=syn_embedding,gallery=None,fixSquare=1,oneSquare=1,mode='fully_connected')

                    if b==0:
                        all_syn_embedding=syn_embedding
                        all_patch_weight=patch_weight
                        all_features=features
                        all_classes=_cls
                    else:
                        all_syn_embedding=torch.cat((all_syn_embedding,syn_embedding),dim=0)
                        all_patch_weight=torch.cat((all_patch_weight,patch_weight),dim=0)
                        all_features=torch.cat((all_features,features),dim=0)
                        all_classes=torch.cat((all_classes,_cls),dim=0)
                all_patch_weight=all_patch_weight.transpose(1,0)  
                for k in range(args.patch_size*args.patch_size):
                    weights[str(k)]=weights[str(k)]+all_patch_weight[k].reshape(-1).tolist()
                syn_embedding_mean=all_syn_embedding.view(ways,args.shots*(1+args.augnum),-1).mean(1) #[ways,512]
                dists=euclidean_dist(test_features,syn_embedding_mean) #[ways*test_num,ways]

                log_prob=F.log_softmax(-dists,dim=1).view(ways,args.test_num,-1)  # [ways,test_num,ways]]

                loss_val=-log_prob.gather(2,test_group.view(ways,args.test_num,1)).view(-1).mean()

                val,ind=log_prob.max(2) # 0- columns , 1-rows ,2- shape dim-0*dim-1
                acc_val=torch.eq(ind,test_group.view(ways,args.test_num)).float().mean()


                loss+=loss_val.item()
                acc+=acc_val.item()

            #back propagation in training phase
                if phase=='train':
                    if args.fix_deform==0:
                        optimizer_deform.zero_grad()
                        loss_val.backward(retain_graph=True)
                        optimizer_deform.step()
                    ind,pred=torch.max(all_classes,1)
                    probe_class=probe_class.view(probe_class.size(0))

                    entropy_loss=emb_loss(all_classes,probe_class.long())
                    if epoch!=0 and args.fix_emb==True:
                        optimizer_classifer.zero_grad()
                        entropy_loss.backward()
                        optimizer_classifer.step()

                    classifier_loss+=entropy_loss.item()
                    classifier_acc+=torch.eq(pred,probe_class.view(-1)).float().mean()

                






            epoch_loss=classifier_loss/float(count)
            epoch_acc=classifier_acc/float(count)

            epoch_classifier_loss=classifier_loss/float(count)
            epoch_classifier_acc=classifier_acc/float(count)


            summary[str(epoch)+'-'+phase]={
                phase+'loss': epoch_loss,
                phase+'accuracy': epoch_acc,
                phase+'_classifier_loss': epoch_classifier_loss,
                phase+'_classifier_acc': epoch_classifier_acc,
            }


            print('{} Loss: {:.4f} Accuracy: {:.4f}'.format(
                phase, epoch_loss,epoch_acc))



            # deep copy the model
            if phase == 'test' and epoch_loss < best_loss:
                best_loss = epoch_loss

                if torch.cuda.device_count() > 1:
                    best_model_wts = copy.deepcopy(model.module.state_dict())
                else:
                    best_model_wts = copy.deepcopy(model.state_dict())


        if epoch%2 == 0 :

            torch.save(best_model_wts,os.path.join(os.getcwd(),'saved_models/'+str(args.tensorname)+'.t7'))


    
    # load best model weights
    model.load_state_dict(best_model_wts)
    return model,summary






In [None]:
IDeMeNet,summary = train_model(IDeMeNet, num_epochs=120)
with open('summary.txt', 'w') as f:
    print(summary, file=f)


if torch.cuda.device_count() > 1:
    torch.save(IDeMeNet.module.state_dict(),os.path.join(os.getcwd(),'saved_models/'+str(args.tensorname)+'.t7'))
else:
    torch.save(IDeMeNet.state_dict(),os.path.join(os.getcwd(),'saved_models/'+str(args.tensorname)+'.t7'))

