In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
%cd drive/My\ Drive/DL/end_sem_assgn
!ls

In [0]:
from torchvision import datasets
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import math
import random
from random import shuffle
from torch.autograd import Variable
import numpy as np
import time
import copy
import h5py
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier as KNN
from sklearn.metrics import accuracy_score,average_precision_score
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [0]:
def logsumexp(x,dim=1):
    m=torch.max(x,dim=dim)[0]
    log_sum_exp=torch.log(torch.sum(torch.exp(x-m.unsqueeze(1)),dim=dim))+m
    return log_sum_exp

def supervised_loss(a,p,n):
    dist_pos=torch.sqrt(torch.sum((a-p)**2,dim=1))
    dist_neg=torch.sqrt(torch.sum((a-n)**2,dim=1))
    dist_pos=torch.unsqueeze(dist_pos,-1)
    dist_neg=torch.unsqueeze(dist_neg,-1)
    dist_vec=torch.cat([dist_pos,dist_neg],dim=1)
    loss=torch.mean(logsumexp(dist_vec,dim=1))-torch.mean(dist_neg)
    return loss

def unsupervised_loss(true,fake):
    true_log_sum_exp=logsumexp(true)
    true_softplus=F.softplus(logsumexp(true))
    fake_softplus=F.softplus(logsumexp(fake))
    loss=0.5*(torch.mean(true_softplus)+torch.mean(fake_softplus)-torch.mean(true_log_sum_exp))
    return loss

def discriminator_loss(a,p,n,true,fake):
    sup_loss=supervised_loss(a,p,n)
    unsup_loss=unsupervised_loss(true,fake)
    loss=sup_loss+unsup_loss
    return loss.mean()

In [0]:
#implementation taken from: https://github.com/Sleepychord/ImprovedGAN-pytorch/blob/8aca3021edb771e2fc14bcbd409bd7f1cd453341/functional.py#L13
class WeightNorm(torch.nn.Module):
    def __init__(self,in_features,out_features,bias=True,weight_scale=None,weight_init_stdv=0.1):
        super(WeightNorm,self).__init__()
        self.in_features=in_features
        self.out_features=out_features
        self.weight=Parameter(torch.randn(out_features,in_features)*weight_init_stdv)
        if bias:
            self.bias=Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias',None)
        if weight_scale is not None:
            assert type(weight_scale)==int
            self.weight_scale=Parameter(torch.ones(out_features,1)*weight_scale)
        else:
            self.weight_scale=1 
    def forward(self,x):
        W = self.weight*self.weight_scale/torch.sqrt(torch.sum(self.weight**2,dim=1,keepdim=True))
        return F.linear(x,W,self.bias)
    def __repr__(self):
        return self.__class__.__name__ +'('+'in_features='+str(self.in_features)+', out_features='+str(self.out_features)+')'

In [0]:
class Generator(nn.Module):
    def __init__(self,training,noise_dim,output_dim=784):
        super(Generator,self).__init__()
        self.training=training
        self.noise_dim=noise_dim

        self.fc1=nn.Linear(noise_dim,500)
        self.bn1=nn.BatchNorm1d(500,affine=False,eps=1e-6,momentum=0.5)
        self.fc2=nn.Linear(500,500)
        self.bn2=nn.BatchNorm1d(500,affine=False,eps=1e-6,momentum=0.5)
        self.fc3=WeightNorm(500,output_dim,weight_scale=1)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self,batch_size):
        noise=Variable(torch.rand(batch_size,self.noise_dim),requires_grad=False,volatile=not self.training)
        noise=noise.to(device)

        x=self.fc1(noise)
        x=F.softplus(x)
        x=self.bn1(x)

        x=self.fc2(x)
        x=F.softplus(x)
        x=self.bn2(x)
        
        x=self.fc3(x)
        x=torch.sigmoid(x)
        return x

class Discriminator(nn.Module):
    def __init__(self,output_dim,pretrain,training=True):
        super(Discriminator,self).__init__()
        self.input_dim=784
        self.training=training
        self.pretrain=pretrain

        self.fc1=WeightNorm(self.input_dim,1000)
        self.fc2=WeightNorm(1000,500)
        self.fc3=WeightNorm(500,250)
        self.fc4=WeightNorm(250,250)
        self.fc5=WeightNorm(250,250)
        self.fc6=WeightNorm(250,output_dim,weight_scale=1)
        self.out=WeightNorm(output_dim,1,weight_scale=1)

    def forward(self,x):
        x=x.view(-1,self.input_dim)
        if self.training:
            noise=torch.randn(x.size())*0.3
            noise=noise.to(device)
            x=x+Variable(noise,requires_grad=False)
        
        x=self.fc1(x)
        x=F.relu(x)
        if self.training:
            noise=torch.randn(x.size())*0.5
            noise=noise.to(device)
            x=x+Variable(noise,requires_grad=False)

        x=self.fc2(x)
        x=F.relu(x)
        if self.training:
            noise=torch.randn(x.size())*0.5
            noise=noise.to(device)
            x=x+Variable(noise,requires_grad=False)

        x=self.fc3(x)
        x=F.relu(x)
        if self.training:
            noise=torch.randn(x.size())*0.5
            noise=noise.to(device)
            x=x+Variable(noise,requires_grad=False)

        x=self.fc4(x)
        x=F.relu(x)
        if self.training:
            noise=torch.randn(x.size())*0.5
            noise=noise.to(device)
            x=x+Variable(noise,requires_grad=False)

        x=self.fc5(x)
        x1=F.relu(x)
        if self.training:
            noise=torch.randn(x.size())*0.5
            noise=noise.to(device)
            x=x+Variable(noise,requires_grad=False)

        x=self.fc6(x)
        if self.pretrain:
            x=F.relu(x)
            x=self.out(x)
        return x,x1

In [0]:
train_set=datasets.MNIST(os.getcwd(),download=True,train=True)

In [0]:
x_train=train_set.data[:]
y_train=train_set.targets[:]
x_train=x_train.detach().numpy()
y_train=y_train.detach().numpy()
x_train=x_train/255.0
print(x_train.shape,y_train.shape)

In [0]:
batch_size=64
noise_dim=100
output_dim=32
learning_rate=0.003
n_classes=10
n_samples=10
epochs=100
img_size=28

In [0]:
n_batches_train=len(x_train)//batch_size

Pre-train GAN

In [0]:
generator=Generator(training=True,noise_dim=noise_dim).to(device)
discriminator=Discriminator(output_dim,pretrain=True,training=True).to(device)

print(generator)
total_params=sum(p.numel() for p in generator.parameters() if p.requires_grad)
print('total_params:',total_params)
print()
print(discriminator)
total_params=sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
print('total_params:',total_params)

In [0]:
gen_optim=torch.optim.Adam(generator.parameters(),lr=0.0003)
dis_optim=torch.optim.Adam(discriminator.parameters(),lr=0.0003)
gen_criterion=nn.BCEWithLogitsLoss()
dis_criterion=nn.BCEWithLogitsLoss()

In [0]:
train_idx=list(range(0,x_train.shape[0]))
shuffle(train_idx)
x_train=x_train[train_idx]
x_train=torch.from_numpy(x_train).type(torch.FloatTensor).to(device)

In [0]:
for epoch in range(10):
    start_time=time.time()
    epoch_d_loss=0
    epoch_g_loss=0
    for i in range(n_batches_train):
        x_train_batch=x_train[i*batch_size:(i+1)*batch_size]

        discriminator.zero_grad()
        label=(torch.ones(batch_size)*0.9).to(device)
        output,_=discriminator(x_train_batch)
        output=output.view(-1)
        d_loss_real=dis_criterion(output,label)
        d_loss_real.backward()

        fake_imgs_batch=generator(batch_size)
        label=torch.zeros(batch_size).to(device)
        output,_=discriminator(fake_imgs_batch.detach())
        output=output.view(-1)
        d_loss_fake=dis_criterion(output,label)
        d_loss_fake.backward()
        epoch_d_loss+=d_loss_real.item()+d_loss_fake.item()
        dis_optim.step()

        generator.zero_grad()
        fake_imgs_batch=generator(batch_size)
        label=(torch.ones(batch_size)*0.9).to(device)
        output,_=discriminator(fake_imgs_batch)
        output=output.view(-1)
        g_loss=gen_criterion(output,label)
        g_loss.backward()
        epoch_g_loss+=g_loss.item()
        gen_optim.step()

    epoch_d_loss=epoch_d_loss/n_batches_train
    epoch_g_loss=epoch_g_loss/n_batches_train
    print('Epoch '+str(epoch+1)+'/'+str(10)+' epoch_duration: '+str(time.time()-start_time)+'s'+' g_loss: '+str(epoch_g_loss)+' d_loss: '+str(epoch_d_loss))
    if (epoch+1)%5==0:
        torch.save(generator.state_dict(),os.getcwd()+'/pre_train_generator_'+str(output_dim)+'.pt')
        torch.save(discriminator.state_dict(),os.getcwd()+'/pre_train_discriminator_'+str(output_dim)+'.pt')

Pre-train end

In [0]:
generator=Generator(training=True,noise_dim=noise_dim).to(device)
generator.load_state_dict(torch.load(os.getcwd()+'/pre_train_generator_'+str(output_dim)+'.pt'))
discriminator=Discriminator(output_dim,pretrain=False,training=True).to(device)
discriminator.load_state_dict(torch.load(os.getcwd()+'/pre_train_discriminator_'+str(output_dim)+'.pt'))

print(generator)
total_params=sum(p.numel() for p in generator.parameters() if p.requires_grad)
print('total_params:',total_params)
print()
print(discriminator)
total_params=sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
print('total_params:',total_params)

In [0]:
gen_optim=torch.optim.Adam(generator.parameters(),lr=learning_rate,betas=(0.5,0.999))
dis_optim=torch.optim.Adam(discriminator.parameters(),lr=learning_rate,betas=(0.5,0.999))
gen_criterion=nn.MSELoss()
dis_criterion=discriminator_loss

In [0]:
x_train_1=copy.deepcopy(x_train)
train_idx=list(range(0,x_train.shape[0]))
shuffle(train_idx)
x_train=x_train[train_idx]
y_train=y_train[train_idx]
shuffle(train_idx)
x_train_1=x_train_1[train_idx]
samples_x={}
for i in range(n_classes):
    if i not in samples_x.keys():
        samples_x[i]=[]

    samples_i=x_train[y_train==i]
    shuffle(samples_i)
    samples_x[i]=samples_i[:n_samples]

x_train=torch.from_numpy(x_train).type(torch.FloatTensor).to(device)
y_train=torch.from_numpy(y_train).to(device)
x_train_1=torch.from_numpy(x_train_1).type(torch.FloatTensor).to(device)

In [0]:
g_loss_list=[]
d_loss_list=[]
for epoch in range(epochs):
    start_time=time.time()
    anchor=[]
    pos=[]
    neg=[]
    labels=[]
    epoch_g_loss=0
    epoch_d_loss=0
    for i in range(n_classes):
        a=samples_x[i]
        p=samples_x[i]
        n=[]
        for j in range(n_classes):
            if j!=i:
                samples=samples_x[j]
                for k in range(len(samples)):
                    n.append(samples[k])
        n=np.array(n)
        for j in range(a.shape[0]):
            for k in range(n_classes*10-10):
                n_idx=list(range(0,n.shape[0]))
                shuffle(n_idx)
                n_idx=n_idx[:n_samples]
                n=n[n_idx]
                anchor.append(a)
                idx=list(range(0,n_samples))
                shuffle(idx)
                pos.append(p[idx])
                neg.append(n[idx])

    anchor=np.concatenate(anchor,axis=0)
    pos=np.concatenate(pos,axis=0)
    neg=np.concatenate(neg,axis=0)
    train_idx=list(range(0,anchor.shape[0]))
    shuffle(train_idx)
    train_idx=train_idx[0:x_train.shape[0]]
    anchor=anchor[train_idx]
    pos=pos[train_idx]
    neg=neg[train_idx]

    anchor=torch.from_numpy(anchor).type(torch.FloatTensor).to(device)
    pos=torch.from_numpy(pos).type(torch.FloatTensor).to(device)
    neg=torch.from_numpy(neg).type(torch.FloatTensor).to(device)
    
    for i in range(n_batches_train):
        anchor_batch=anchor[i*batch_size:(i+1)*batch_size]
        pos_batch=pos[i*batch_size:(i+1)*batch_size]
        neg_batch=neg[i*batch_size:(i+1)*batch_size]
        x_train_batch=x_train[i*batch_size:(i+1)*batch_size]
        x_train_1_batch=x_train_1[i*batch_size:(i+1)*batch_size]

        fake_imgs_batch=generator(batch_size)

        anchor_batch_net,_=discriminator(anchor_batch)
        pos_batch_net,_=discriminator(pos_batch)
        neg_batch_net,_=discriminator(neg_batch)
        real_batch_net,_=discriminator(x_train_batch)
        fake_batch_net,_=discriminator(fake_imgs_batch)
        d_loss=dis_criterion(anchor_batch_net,pos_batch_net,neg_batch_net,real_batch_net,fake_batch_net)
        dis_optim.zero_grad()
        d_loss.backward()
        dis_optim.step()
        epoch_d_loss+=d_loss.item()

        fake_imgs_batch=generator(batch_size)
        _,fake_imgs_batch_net=discriminator(fake_imgs_batch)
        _,x_train_1_batch_net=discriminator(x_train_1_batch)
        g_loss=gen_criterion(fake_imgs_batch_net,x_train_1_batch_net)
        gen_optim.zero_grad()
        dis_optim.zero_grad()
        g_loss.backward()
        gen_optim.step()
        epoch_g_loss+=g_loss.item()
    
    epoch_d_loss=epoch_d_loss/n_batches_train
    epoch_g_loss=epoch_g_loss/n_batches_train
    print('Epoch '+str(epoch+1)+'/'+str(epochs)+' epoch_duration: '+str(time.time()-start_time)+'s'+' g_loss: '+str(epoch_g_loss)+' d_loss: '+str(epoch_d_loss))
    g_loss_list.append(epoch_g_loss)
    d_loss_list.append(epoch_d_loss)
    if (epoch+1)%5==0:
        torch.save(generator.state_dict(),os.getcwd()+'/generator_'+str(n_samples*n_classes)+'_'+str(output_dim)+'.pt')
        torch.save(discriminator.state_dict(),os.getcwd()+'/discriminator_'+str(n_samples*n_classes)+'_'+str(output_dim)+'.pt')
        with h5py.File('metrics_list_'+str(n_samples*n_classes)+'_'+str(output_dim)+'.h5','w') as out:
            out.create_dataset("generator",data=np.array(g_loss_list))
            out.create_dataset("discriminator",data=np.array(d_loss_list))

In [0]:
metrics=h5py.File('metrics_list_'+str(n_samples*n_classes)+'_'+str(output_dim)+'.h5','r')
g_loss_list=metrics['generator']
d_loss_list=metrics['discriminator']
g_loss_list=np.array(g_loss_list).tolist()
d_loss_list=np.array(d_loss_list).tolist()
metrics.close()

In [0]:
plt.figure()
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('gen_loss vs epochs')
plt.plot(g_loss_list)

plt.figure()
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('dis_loss vs epochs')
plt.plot(d_loss_list)

Evaluation

In [0]:
discriminator=Discriminator(output_dim,pretrain=False,training=False)
discriminator.load_state_dict(torch.load(os.getcwd()+'/discriminator_'+str(n_samples*n_classes)+'_'+str(output_dim)+'.pt',map_location=device))

In [0]:
generator=Generator(training=False,noise_dim=noise_dim)
generator=generator.to(device)
generator.load_state_dict(torch.load(os.getcwd()+'/generator_'+str(n_samples*n_classes)+'_'+str(output_dim)+'.pt',map_location=device))
with torch.no_grad():
    images=generator(64)
images=images.detach().cpu().numpy()
fig,ax=plt.subplots(8,8)
count=0
for i in range(8):
    for j in range(8):
        img=images[count]
        img=img.reshape((28,28))
        ax[i,j].imshow(img,cmap='gray')
        count+=1

In [0]:
x_train=train_set.data[:]
y_train=train_set.targets[:]
x_train=x_train/255.0
print(x_train.shape,y_train.shape)

In [0]:
discriminator.eval()
batch_size=50
features_train=np.empty((batch_size,output_dim))
for i in range(len(x_train)//batch_size):
    input=x_train[i*batch_size:(i+1)*batch_size]
    features_pred,_=discriminator(input)
    if i==0:
        features_train=features_pred[:,:].detach().cpu().numpy()
    else:
        features_train=np.concatenate((features_train,features_pred.detach().cpu().numpy()),axis=0)

print(features_train.shape)

In [0]:
knn=KNN(n_neighbors=9)
knn.fit(features_train,y_train.detach().cpu().numpy())

In [0]:
test_set=datasets.MNIST(os.getcwd(),download=True,train=False)
x_test=test_set.data[:]
y_test=test_set.targets[:]
x_test=x_test/255.0
print(x_test.shape,y_test.shape)

In [0]:
features_test=np.empty((batch_size,output_dim))
for i in range(len(x_test)//batch_size):
    input=x_test[i*batch_size:(i+1)*batch_size]
    features_pred,_=discriminator(input)
    if i==0:
        features_test=features_pred[:,:].detach().cpu().numpy()
    else:
        features_test=np.concatenate((features_test,features_pred.detach().cpu().numpy()),axis=0)

print(features_test.shape)

In [0]:
y_test=y_test.detach().cpu().numpy()

In [0]:
mAP=0
pred=knn.predict(features_test)
acc=accuracy_score(y_test,pred)
for i in range(n_classes):
    y1=copy.deepcopy(pred)
    y2=copy.deepcopy(y_test)
    for j in range(y_test.shape[0]):
        if y2[j]==i:
            y2[j]=1
        else:
            y2[j]=0
        
        if y1[j]==i:
            y1[j]=1
        else:
            y1[j]=0

    ap=average_precision_score(y2,y1)
    mAP+=ap
    
print('Accuracy: '+str(acc*100)+'%')
print('mAP: '+str(mAP/n_classes))