In [1]:
from config import *
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models, transforms
import torch.nn.functional as F
import os
from data.data_loader import *
from model import *
import torch.optim as optim
from  RetrievalTest import *
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

batchSize = 500

print("Loading the dataset")
Source_x, Source_y, Target_x = prepare_Data(data_dir, True)
Gallery_x, Query_x = prepare_Data(data_dir, False)
similarity = csr_matrix(scipy.io.loadmat("data/cifar10/cifar10_Similarity.mat")['label_Similarity']).todense()

print("Data loading finished")

gallery = torch.utils.data.DataLoader(Gallery_x,batch_size=1000)
query = torch.utils.data.DataLoader(Query_x,batch_size=1000)

source = torch.utils.data.DataLoader([(Source_x[i], Source_y[i]) for i in range(len(Source_x))],batch_size=batchSize, shuffle=True)
target = torch.utils.data.DataLoader(Target_x,batch_size=batchSize,shuffle=True)

CUDA available
Loading the dataset
--------------------Loading dataset--------------------
Source:  (5000, 32, 32, 3) (5000,)
Target:  (54000, 32, 32, 3)
--------------------Loading dataset--------------------
Gallery:  (54000, 32, 32, 3) (54000,)
Query:  (1000, 32, 32, 3) (1000,)
Data loading finished


In [2]:
vggModel=models.vgg16_bn(pretrained=True)
# vggModel.children

net1, net2 = [], []
for i in vggModel.children():
    for r, i in enumerate(i.children()):
        if r <=23: net1.append(i)
        elif r<= 32: net2.append(i)
    break
net1, net2  = nn.Sequential(*net1), nn.Sequential(*net2)

In [3]:
def intranorm(features,n_book):
    x = features.split(n_book,1)
    
    for b in range(n_book):
        if b==0: dummy = F.normalize(x[b],1)
        else:
            dummy = torch.cat((dummy,F.normalize(x[b],1)),1)
    return dummy

def shape_(inp):
     for i in inp:
            print(f"shape is: {i.shape}")
            
def Indexing_(Z,des,numSeg):
        Z = intranorm(Z,numSeg)
        x = torch.split(des,numSeg,1)
        y = torch.split(Z,numSeg,1)
        for i in range(numSeg):
            size_x = x[i].shape[0]
            size_y = y[i].shape[0]
            xx = x[i].unsqueeze(-1)

            dummy = torch.tensor(1)
            xx = xx.tile([1,1,size_y])
            yy = y[i].unsqueeze(-1)
            yy = yy.tile([1,1,size_x])
            yy = yy.permute(2,1,0)
            diff = torch.sum(torch.multiply(xx,yy),1)

            arg = torch.argmax(diff,1)
            max_idx = arg.reshape(-1,1)

            if i == 0: quant_idx = max_idx
            else: quant_idx = torch.cat((quant_idx,max_idx),1)

        return quant_idx

In [4]:
class softassignment_(nn.Module):
    def __init__(self,len_code, n_book, intn_word):
        super(softassignment_,self).__init__()
        self.Z = nn.Linear(len_code * n_book,intn_word, bias=False)

    def forward(self,features,n_book,alpha,):
        z_ = intranorm(self.Z.state_dict()['weight'], n_book).split(n_book,1)
        x_ = features.split(n_book,1)

        for i in range(n_book):
            size_z = z_[i].shape[0] # number of codewords
            size_x = x_[i].shape[0] # batch size
            xx = x_[i].unsqueeze(-1)
            xx = xx.repeat(1,1,size_z)
            zz = z_[i].unsqueeze(-1)
            zz = zz.repeat(1,1,size_x).T

            diff = 1 - torch.sum(torch.mul(xx,zz), 1) # 32,16
            softmax_diff = F.softmax(diff*(-alpha),1) #32,16
            soft_des_temp = torch.matmul(softmax_diff,z_[i]) # 32,12
            if i == 0: descriptor = soft_des_temp
            else: descriptor = torch.cat((descriptor,soft_des_temp),1)

        return intranorm(descriptor,n_book) # 32,144

In [5]:
class classifier_(nn.Module):
    def __init__(self,n_CLASSES, len_code, n_book,):
        super(classifier_,self).__init__()
        self.prototypes = nn.Linear(len_code * n_book, n_CLASSES, bias=False)
        self.n_book = n_book
        
    def forward(self, x):
        x_ = x.split(self.n_book,1)
        c_ = (intranorm(self.prototypes.state_dict()['weight'], n_book)*beta).T.split(self.n_book,0)
        for i in range(self.n_book):
            sub_res = torch.matmul(x_[i], c_[i]).unsqueeze(-1)
            if i == 0: res = sub_res
            else: res = torch.cat((res,sub_res),2)
        
        return torch.sum(res, 2)

In [6]:
class features_(nn.Module):
    def __init__(self, net1, net2):
        super(features_,self).__init__()
        
        self.net1 = net1
        self.net2 = net2
        self.gavgp = nn.AdaptiveAvgPool2d(1)
        self.linear = nn.Linear(768, len_code*n_book)
        
    def forward(self,x):
        x = self.net1(x) # shape: torch.Size([32, 3, 32,32])>torch.Size([32, 3, 4, 4])
        x_branch = self.gavgp(x)
        x = self.net2(x) # shape: torch.Size([32, 3, 4, 4])> torch.Size([32, 3, 4, 4])
        x = self.gavgp(x)
        
        x = torch.cat((x,x_branch),1)
        
        return self.linear(x.view(-1,768))
    
# x = torch.randn(32, 3, 32, 32, device='cpu')
# model = features_(net1, net2)
# out = model(x)# model
# out.shape, out.split(12,1)[0].shape, model.Z.shape, model.prototypes.shape

In [7]:
class flipGradient_(nn.Module):
    def forward(self,x,l=1.0):
        positivePath=(x*2).clone().detach().requires_grad_(False)
        negativePath=(-x).requires_grad_(True)
        return positivePath+negativePath

In [8]:
#models
model = features_(net1, net2).to(device)
forget_net = features_(net1, net2).to(device)
classifier = classifier_(n_CLASSES, len_code, n_book).to(device)
discriminator = classifier_(2, len_code, n_book).to(device)
softassignment = softassignment_(len_code, n_book, intn_word).to(device)
flipGradient = flipGradient_()

# optimizer
class_optim = optim.Adam(classifier.parameters(),lr=0.002,weight_decay=0.00001,amsgrad=True)
model_optim = optim.Adam(model.parameters(),lr=0.0002,weight_decay=0.00001,amsgrad=True)
soft_optim = optim.Adam(softassignment.parameters(),lr=0.002,weight_decay=0.00001,amsgrad=True)
forg_optim = optim.Adam(forget_net.parameters(),lr=0.002,weight_decay=0.00001,amsgrad=True)
disc_optim = optim.Adam(discriminator.parameters(),lr=0.002,weight_decay=0.00001,amsgrad=True)

class_optim.zero_grad()
model_optim.zero_grad()
soft_optim.zero_grad()
forg_optim.zero_grad()
disc_optim.zero_grad()

In [9]:
dummy_target = torch.cat((torch.tensor(1).tile((batchSize,)), torch.tensor(0).tile((batchSize,))),0).to(device)
target_ = iter(target)
score = 0
save = ""
scoore = ''
for epoch in tqdm(range(50)):
    m_,n,o,p = 0,0,0,0
    model.train()
    classifier.train()
    softassignment.train()
    discriminator.train()
    forget_net.train()
    for df, batch in enumerate(source):
        x, y = batch
        if x.shape[0] < batchSize: break
        x = torch.tensor(data_augmentation(x)).to(device)
        y = y.to(device)
        try: 
            xu = next(target_)
            xu = torch.tensor(data_augmentation(xu)).to(device)
        except:
            target_ = iter(target)
            xu = next(target_)
            xu = torch.tensor(data_augmentation(xu)).to(device)

        # adversarial forgetting
        for g in range(1):
            input_ = torch.cat((x,xu),0)
            z = intranorm(model(input_.permute(0,3,1,2)), n_book)
            m = intranorm(forget_net(input_.permute(0,3,1,2)), n_book)
            z_ = z * m
            
            pred = discriminator(z_)
            dummy_loss = torch.nn.functional.cross_entropy(pred,dummy_target)
        
            disc_optim.zero_grad()
            forg_optim.zero_grad()
            dummy_loss.backward()
            disc_optim.step()

        class_optim.zero_grad()
        model_optim.zero_grad()
        soft_optim.zero_grad()
        forg_optim.zero_grad()
        disc_optim.zero_grad()
        
        input_ = torch.cat((x,xu),0)
        features = intranorm(model(x.permute(0,3,1,2)), n_book)
        features_ = intranorm(model(xu.permute(0,3,1,2)), n_book)
        z = torch.cat((features,features_),0)
        m = intranorm(forget_net(input_.permute(0,3,1,2)), n_book)
        z_ = z * m

        pred = discriminator(z_)
        quanta = softassignment(z_[:batchSize],n_book,alpha)
        logits = classifier(z_[:batchSize] *beta)
        
        entropy_loss = torch.nn.functional.cross_entropy(pred,torch.flip(dummy_target,(0,))) *0.1
        mask_regulariser_loss = (m * (1-m)).mean() 
        
        entropy_loss.backward(retain_graph=True)
        model_optim.zero_grad()
        
        cls_loss = torch.nn.functional.cross_entropy(logits,y)

        y = torch.eye(numClasses)[y].to(device)
        y_ = torch.matmul(y,y.T)
        y_ = y_/torch.sum(y_, axis=1, keepdims=True)
        hash_loss = NPQLoss(y_,z_[:batchSize], quanta,n_book)  
        
        final_loss = hash_loss + cls_loss*0.1  + mask_regulariser_loss*0.1
        final_loss.backward()
        
        o += cls_loss.item()
        m_ += final_loss.item()
        n += hash_loss.item()
        p += entropy_loss.item()
#         print(final_loss.item())
        
        model_optim.step()
        soft_optim.step()
        class_optim.step()
        forg_optim.step()

    print(f"epoch:{epoch+1}\t{m_}\t{n}\t{o}\t{p}\n")
    save += f"{m_/10}\t{n/10}\t{o/10}\t{p/10}\n"
    with open("analysis.txt", "w") as f:
        f.write(save)
#     continue
    
    if epoch % 5 == 0:
        
        with torch.no_grad():
            model.eval()
            forget_net.eval()
            
            for r, i in tqdm(enumerate(gallery)):
                dfg = intranorm(model(i.to(device).permute(0,3,1,2)),n_book) * intranorm(forget_net(i.to(device).permute(0,3,1,2)), n_book)
                if r == 0: 
                    dummy = dfg
                else: 
                    dummy = torch.cat((dummy, dfg), 0)

            for r, i in tqdm(enumerate(query)):
                dfg = intranorm(model(i.to(device).permute(0,3,1,2)),n_book) * intranorm(forget_net(i.to(device).permute(0,3,1,2)), n_book)
                if r == 0: 
                    query_x = dfg
                else: 
                    query_x = torch.cat((query_x, dfg), 0)


        dummy = Indexing_(softassignment.Z.state_dict()['weight'].cpu(), dummy.cpu(), n_book)
        gallery_x = dummy.numpy().astype(int)
        quantizedDist = pqDist(intranorm(softassignment.Z.state_dict()['weight'].cpu(), n_book), n_book,gallery_x, query_x.cpu().numpy()).T
        Rank = np.argsort(quantizedDist, axis=0)
        mean_average_precision=cat_apcal(similarity,Rank,54000)
        scoore += f"{mean_average_precision}\n"
        if mean_average_precision > score:
            score = mean_average_precision
            stateToBeSaved={
                'modelStateDict': [model.state_dict(),classifier.state_dict(), softassignment.state_dict(), forget_net.state_dict()],
                'score': mean_average_precision,}
            torch.save(stateToBeSaved,f"/home/hemant/mod2_hemant.pth")
        print(mean_average_precision)
        with open("score.txt", "w") as f:
            f.write(save)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))

epoch:1	62.78375577926636	61.16508674621582	16.630850315093994	0.6969852223992348



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


0.37777540640559476
epoch:2	60.96410942077637	59.8322811126709	11.892013430595398	0.7045582234859467

epoch:3	60.20009422302246	59.242629528045654	10.249484360218048	0.7138472124934196

epoch:4	59.56916093826294	58.777169704437256	8.660336136817932	0.7193053513765335

epoch:5	59.215954303741455	58.51450252532959	7.796039164066315	0.7262647077441216

epoch:6	58.93836975097656	58.29346418380737	7.2573288679122925	0.7300125285983086



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


0.5448718099297396
epoch:7	58.68287229537964	58.10357332229614	6.632027983665466	0.729028731584549

epoch:8	58.53906774520874	57.99400568008423	6.316668689250946	0.7326974272727966

epoch:9	58.400973320007324	57.87692642211914	6.125210285186768	0.7360685467720032

epoch:10	58.2462944984436	57.77209949493408	5.644714444875717	0.7397869974374771

epoch:11	58.109615325927734	57.66983366012573	5.31342950463295	0.7400111109018326



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


0.6035012485833238
epoch:12	57.99783277511597	57.5770525932312	5.149858981370926	0.7414542213082314

epoch:13	57.94028043746948	57.53266382217407	5.029498755931854	0.7446908578276634

epoch:14	57.880213260650635	57.49641036987305	4.796201109886169	0.7436771541833878

epoch:15	57.76840162277222	57.41624975204468	4.485266000032425	0.7496164813637733

epoch:16	57.692097663879395	57.35496807098389	4.3489925265312195	0.7427384406328201



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


0.5796998266601024
epoch:17	57.63182020187378	57.31801080703735	4.113497674465179	0.7483472302556038

epoch:18	57.54409456253052	57.24330186843872	3.991314470767975	0.7432256415486336

epoch:19	57.50168228149414	57.234424114227295	3.6486834287643433	0.7415945529937744

epoch:20	57.40179395675659	57.15614032745361	3.44854599237442	0.7462454661726952

epoch:21	57.284188747406006	57.07548666000366	3.0882182717323303	0.7461762726306915



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


0.6480535067431707
epoch:22	57.362011432647705	57.121007442474365	3.418461889028549	0.7439207583665848

epoch:23	57.244842529296875	57.03714990615845	3.081064522266388	0.7391294836997986

epoch:24	57.27329444885254	57.05602979660034	3.17857825756073	0.7443123683333397

epoch:25	57.2189154624939	57.01193952560425	3.0780717730522156	0.742102600634098

epoch:26	57.09269905090332	56.93277883529663	2.613275423645973	0.7451115325093269



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


0.676259841238009
epoch:27	57.103707790374756	56.92757749557495	2.780221089720726	0.7428247407078743

epoch:28	57.13498258590698	56.949249267578125	2.872847616672516	0.7456517294049263

epoch:29	57.04120111465454	56.88070201873779	2.631280407309532	0.7441138848662376

epoch:30	56.99705219268799	56.848122119903564	2.518518701195717	0.7445415928959846

epoch:31	56.9514946937561	56.81932497024536	2.3514041751623154	0.7439640313386917



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


0.6385932859326546
epoch:32	56.99540376663208	56.84275817871094	2.5561110377311707	0.7431143671274185

epoch:33	56.91618204116821	56.7791314125061	2.400209069252014	0.7380931824445724

epoch:34	56.97185277938843	56.8327693939209	2.4324663430452347	0.7402466386556625

epoch:35	57.01722955703735	56.85948038101196	2.618986129760742	0.7351617440581322

epoch:36	56.965492248535156	56.823612689971924	2.45974263548851	0.7402656525373459



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


0.6516968103772098
epoch:37	56.90149736404419	56.76744842529297	2.3821196854114532	0.734370730817318

epoch:38	56.78257894515991	56.690874099731445	1.9734761863946915	0.7341937869787216

epoch:39	56.86160469055176	56.74326753616333	2.241103634238243	0.733466275036335

epoch:40	56.84748840332031	56.731242179870605	2.217312216758728	0.7338187992572784

epoch:41	56.73908185958862	56.65461301803589	1.907302439212799	0.7360842898488045



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


0.6919166817030709
epoch:42	56.83247137069702	56.72854042053223	2.1036475896835327	0.7307396605610847

epoch:43	56.889057636260986	56.757164001464844	2.38005830347538	0.7338637858629227

epoch:44	56.866347789764404	56.74502515792847	2.274666577577591	0.7340978607535362

epoch:45	56.777050495147705	56.68388891220093	1.9999576061964035	0.7367803975939751

epoch:46	56.652095317840576	56.59731483459473	1.621795006096363	0.7340838611125946



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


0.6765871418680022
epoch:47	56.61803579330444	56.57230567932129	1.5367155075073242	0.7333464249968529

epoch:48	56.63983774185181	56.585062980651855	1.6261011883616447	0.7336803451180458

epoch:49	56.632543087005615	56.58513021469116	1.5546146109700203	0.7323544397950172

epoch:50	56.709335803985596	56.63968515396118	1.7696180120110512	0.7305288016796112




# THE END

In [10]:
#tsne plots

from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
def tsne_plot(labelled_feature,unlabelled_features,labelled_classes):
  
  target_classes = np.argmax(labelled_classes, axis = 1)

  target_ids=range(len(labelled_classes))

  tsne = TSNE(n_components=2)
  colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k', 'w', 'orange', 'purple']
  plt.figure(figsize=(6, 5))
  labelled_tsne=tsne.fit_transform(labelled_feature)
  unlabelled_tsne=tsne.fit_transform(unlabelled_features)
  for i, c, label in zip(target_ids, colors,target_classes ):
      plt.scatter(labelled_tsne[target_classes == i, 0], labelled_tsne[target_classes == i, 1], c=c, label=i)

  unlabelled_tsne=tsne.fit_transform(unlabelled_features)
  plt.scatter(x=unlabelled_tsne[:,0],y=unlabelled_tsne[:,1],c='gray',label='Unlabelled Data')


  plt.legend()
  plt.show()



In [11]:
tsne_plot(source_x.cpu().numpy(),gallery_x.cpu().numpy(),torch.eye(numClasses)[y_].cpu().numpy())

NameError: name 'source_x' is not defined

In [None]:
unlabelled_tsne=tsne.fit_transform(unlabelled_features)
  plt.scatter(x=unlabelled_tsne[:,0],y=unlabelled_tsne[:,1],c='gray',label='Unlabelled Data')


  plt.legend()
  plt.show()


In [None]:
with torch.no_grad():
    for r, i in tqdm(enumerate(gallery)):
        dfg = intranorm(model(i.to(device).permute(0,3,1,2)),n_book).cpu() 
        if r == 0: 
            gallery_x = dfg
        else: 
            gallery_x = torch.cat((gallery_x, dfg), 0)
    
    for r, batch in enumerate(source):
        i, y = batch
        dfg = intranorm(model(i.to(device).permute(0,3,1,2)),n_book) 
        if r == 0: 
            source_x = dfg
            y_ = y
        else: 
            source_x = torch.cat((source_x, dfg), 0)
            y_ = torch.cat((y_,y),0)