In [None]:
from config import *
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models, transforms
import torch.nn.functional as F
import os
from data.data_loader import *
from model import *
import torch.optim as optim
from  RetrievalTest import *
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

batchSize = 1000

print("Loading the dataset")
Source_x, Source_y, Target_x = prepare_Data(data_dir, True)
Gallery_x, Query_x = prepare_Data(data_dir, False)
similarity = csr_matrix(scipy.io.loadmat("data/cifar10/cifar10_Similarity.mat")['label_Similarity']).todense()

print("Data loading finished")

gallery = torch.utils.data.DataLoader(Gallery_x,batch_size=1000)
query = torch.utils.data.DataLoader(Query_x,batch_size=1000)

source = torch.utils.data.DataLoader([(Source_x[i], Source_y[i]) for i in range(len(Source_x))],batch_size=batchSize, shuffle=True)
target = torch.utils.data.DataLoader(Target_x,batch_size=batchSize,shuffle=True)

In [None]:
vggModel=models.vgg16_bn(pretrained=True)
# vggModel.children

net1, net2 = [], []
for i in vggModel.children():
    for r, i in enumerate(i.children()):
        if r <=23: net1.append(i)
        elif r<= 32: net2.append(i)
    break
net1, net2  = nn.Sequential(*net1), nn.Sequential(*net2)

In [None]:
def intranorm(features,n_book):
    x = features.split(n_book,1)
    
    for b in range(n_book):
        if b==0: dummy = F.normalize(x[b],1)
        else:
            dummy = torch.cat((dummy,F.normalize(x[b],1)),1)
    return dummy

def shape_(inp):
     for i in inp:
            print(f"shape is: {i.shape}")
            
def Indexing_(Z,des,numSeg):
        Z = intranorm(Z,numSeg)
        x = torch.split(des,numSeg,1)
        y = torch.split(Z,numSeg,1)
        for i in range(numSeg):
            size_x = x[i].shape[0]
            size_y = y[i].shape[0]
            xx = x[i].unsqueeze(-1)

            dummy = torch.tensor(1)
            xx = xx.tile([1,1,size_y])
            yy = y[i].unsqueeze(-1)
            yy = yy.tile([1,1,size_x])
            yy = yy.permute(2,1,0)
            diff = torch.sum(torch.multiply(xx,yy),1)

            arg = torch.argmax(diff,1)
            max_idx = arg.reshape(-1,1)

            if i == 0: quant_idx = max_idx
            else: quant_idx = torch.cat((quant_idx,max_idx),1)

        return quant_idx

In [None]:
class softassignment_(nn.Module):
    def __init__(self,len_code, n_book, intn_word):
        super(softassignment_,self).__init__()
        self.Z = nn.Linear(len_code * n_book,intn_word, bias=False)

    def forward(self,features,n_book,alpha,):
        z_ = intranorm(self.Z.state_dict()['weight'], n_book).split(n_book,1)
        x_ = features.split(n_book,1)

        for i in range(n_book):
            size_z = z_[i].shape[0] # number of codewords
            size_x = x_[i].shape[0] # batch size
            xx = x_[i].unsqueeze(-1)
            xx = xx.repeat(1,1,size_z)
            zz = z_[i].unsqueeze(-1)
            zz = zz.repeat(1,1,size_x).T

            diff = 1 - torch.sum(torch.mul(xx,zz), 1) # 32,16
            softmax_diff = F.softmax(diff*(-alpha),1) #32,16
            soft_des_temp = torch.matmul(softmax_diff,z_[i]) # 32,12
            if i == 0: descriptor = soft_des_temp
            else: descriptor = torch.cat((descriptor,soft_des_temp),1)

        return intranorm(descriptor,n_book) # 32,144

In [None]:
class classifier_(nn.Module):
    def __init__(self,n_CLASSES, len_code, n_book,):
        super(classifier_,self).__init__()
        self.prototypes = nn.Linear(len_code * n_book, n_CLASSES, bias=False)
        self.n_book = n_book
        
    def forward(self, x):
        x_ = x.split(self.n_book,1)
        c_ = (intranorm(self.prototypes.state_dict()['weight'], n_book)*beta).T.split(self.n_book,0)
        for i in range(self.n_book):
            sub_res = torch.matmul(x_[i], c_[i]).unsqueeze(-1)
            if i == 0: res = sub_res
            else: res = torch.cat((res,sub_res),2)
        
        return torch.sum(res, 2)

In [None]:
class features_(nn.Module):
    def __init__(self, net1, net2):
        super(features_,self).__init__()
        
        self.net1 = net1
        self.net2 = net2
        self.gavgp = nn.AdaptiveAvgPool2d(1)
        self.linear = nn.Linear(768, len_code*n_book)
        
    def forward(self,x):
        x = self.net1(x) # shape: torch.Size([32, 3, 32,32])>torch.Size([32, 3, 4, 4])
        x_branch = self.gavgp(x)
        x = self.net2(x) # shape: torch.Size([32, 3, 4, 4])> torch.Size([32, 3, 4, 4])
        x = self.gavgp(x)
        
        x = torch.cat((x,x_branch),1)
        
        return self.linear(x.view(-1,768))
    
# x = torch.randn(32, 3, 32, 32, device='cpu')
# model = features_(net1, net2)
# out = model(x)# model
# out.shape, out.split(12,1)[0].shape, model.Z.shape, model.prototypes.shape

In [None]:
class flipGradient_(nn.Module):
    def forward(self,x,l=1.0):
        positivePath=(x*2).clone().detach().requires_grad_(False)
        negativePath=(-x).requires_grad_(True)
        return positivePath+negativePath

In [None]:
#models
model = features_(net1, net2).to(device)
forget_net = features_(net1, net2).to(device)
classifier = classifier_(n_CLASSES, len_code, n_book).to(device)
discriminator = classifier_(2, len_code, n_book).to(device)
softassignment = softassignment_(len_code, n_book, intn_word).to(device)
# flipGradient = flipGradient_()

# optimizer
class_optim = optim.Adam(classifier.parameters(),lr=0.002,weight_decay=0.00001,amsgrad=True)
model_optim = optim.Adam(model.parameters(),lr=0.0002,weight_decay=0.00001,amsgrad=True)
soft_optim = optim.Adam(softassignment.parameters(),lr=0.002,weight_decay=0.00001,amsgrad=True)
forg_optim = optim.Adam(forget_net.parameters(),lr=0.002,weight_decay=0.00001,amsgrad=True)
disc_optim = optim.Adam(discriminator.parameters(),lr=0.002,weight_decay=0.00001,amsgrad=True)

class_optim.zero_grad()
model_optim.zero_grad()
soft_optim.zero_grad()
forg_optim.zero_grad()
disc_optim.zero_grad()

In [None]:
dummy_target = torch.cat((torch.tensor(1).tile((batchSize,)), torch.tensor(0).tile((batchSize,))),0).to(device)
target_ = iter(target)
score = 0
save = ""
scoore = ''
for epoch in tqdm(range(50)):
    m_,n,o,p = 0,0,0,0
    model.train()
    classifier.train()
    softassignment.train()
    discriminator.train()
    forget_net.train()
    for df, batch in enumerate(source):
        x, y = batch
        if x.shape[0] < batchSize: break
        x = torch.tensor(data_augmentation(x)).to(device)
        y = y.to(device)
        try: 
            xu = next(target_)
            xu = torch.tensor(data_augmentation(xu)).to(device)
        except:
            target_ = iter(target)
            xu = next(target_)
            xu = torch.tensor(data_augmentation(xu)).to(device)

        # adversarial forgetting
        for g in range(1):
            input_ = torch.cat((x,xu),0)
            z = intranorm(model(input_.permute(0,3,1,2)), n_book)
            m = intranorm(forget_net(input_.permute(0,3,1,2)), n_book)
            z_ = z * m
            
            pred = discriminator(z_)
            dummy_loss = torch.nn.functional.cross_entropy(pred,dummy_target)
        
            disc_optim.zero_grad()
            forg_optim.zero_grad()
            dummy_loss.backward()
            disc_optim.step()

        class_optim.zero_grad()
        model_optim.zero_grad()
        soft_optim.zero_grad()
        forg_optim.zero_grad()
        disc_optim.zero_grad()
        
        input_ = torch.cat((x,xu),0)
        features = intranorm(model(x.permute(0,3,1,2)), n_book)
        features_ = intranorm(model(xu.permute(0,3,1,2)), n_book)
        z = torch.cat((features,features_),0)
        m = intranorm(forget_net(input_.permute(0,3,1,2)), n_book)
        z_ = z * m

        pred = discriminator(z_)
        quanta = softassignment(features,n_book,alpha)
        logits = classifier(features *beta)
        
        entropy_loss = torch.nn.functional.cross_entropy(pred,torch.flip(dummy_target,(0,))) *0.1
        mask_regulariser_loss = (m * (1-m)).mean() 
        
        entropy_loss.backward(retain_graph=True)
        model_optim.zero_grad()
        
        cls_loss = torch.nn.functional.cross_entropy(logits,y)

        y = torch.eye(numClasses)[y].to(device)
        y_ = torch.matmul(y,y.T)
        y_ = y_/torch.sum(y_, axis=1, keepdims=True)
        hash_loss = NPQLoss(y_,features, quanta,n_book)  
        
        final_loss = hash_loss + cls_loss*0.1  + mask_regulariser_loss*0.1
        final_loss.backward()
        
        o += cls_loss.item()
        m_ += final_loss.item()
        n += hash_loss.item()
        p += entropy_loss.item()

        model_optim.step()
        soft_optim.step()
        class_optim.step()
        forg_optim.step()

    print(f"epoch:{epoch+1}\t{m_}\t{n}\t{o}\t{p}\n")
    save += f"{m_/10}\t{n/10}\t{o/10}\t{p/10}\n"
    with open("analysis.txt", "w") as f:
        f.write(save)
#     continue
    
    if epoch % 5 == 0:
        
        with torch.no_grad():
            model.eval()
            forget_net.eval()
            
            for r, i in tqdm(enumerate(gallery)):
                dfg = intranorm(model(i.to(device).permute(0,3,1,2)),n_book) * intranorm(forget_net(i.to(device).permute(0,3,1,2)), n_book)
                if r == 0: 
                    dummy = dfg
                else: 
                    dummy = torch.cat((dummy, dfg), 0)

            for r, i in tqdm(enumerate(query)):
                dfg = intranorm(model(i.to(device).permute(0,3,1,2)),n_book) * intranorm(forget_net(i.to(device).permute(0,3,1,2)), n_book)
                if r == 0: 
                    query_x = dfg
                else: 
                    query_x = torch.cat((query_x, dfg), 0)


        dummy = Indexing_(softassignment.Z.state_dict()['weight'].cpu(), dummy.cpu(), n_book)
        gallery_x = dummy.numpy().astype(int)
        quantizedDist = pqDist(intranorm(softassignment.Z.state_dict()['weight'].cpu(), n_book), n_book,gallery_x, query_x.cpu().numpy()).T
        Rank = np.argsort(quantizedDist, axis=0)
        mean_average_precision=cat_apcal(similarity,Rank,54000)
        scoore += f"{mean_average_precision}\n"
        if mean_average_precision > score:
            score = mean_average_precision
            stateToBeSaved={
                'modelStateDict': [model.state_dict(),classifier.state_dict(), softassignment.state_dict(), forget_net.state_dict()],
                'score': mean_average_precision,}
            torch.save(stateToBeSaved,f"{epoch}_mod2_hemant.pth")
        print(mean_average_precision)
    with open("score.txt", "w") as f:
        f.write(save)

# THE END