In [1]:
import sys
sys.path.append("../src")

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import torchvision.datasets as dataset_
import torchvision.transforms as transforms
from torch.optim import Adam
from torch.utils import data
from torch.utils.tensorboard import SummaryWriter
import random

from dataset import TripletMNISTDataset
from losses import *
from model import *
from utills import *

In [87]:

"""
    Change this depending on the model
"""
out_dimension = 32 # 16 or 32
exp_name = "m_32_2" # i've used m_32_2 for feature 32 model and m_16 for feature 16 model



# this remains pretty constant 
image_size = 28
z_feautures = 100
num_epochs = 50
lr = 0.0002
beta1 = 0.5
batch_size = 128

In [4]:
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")

# setup dataset
normal_mnist_dataset = dataset_.MNIST(root="./data", train=True, download=True,
                                      transform=transforms.Compose([
                                          transforms.Resize(image_size),
                                          transforms.ToTensor(),
                                          transforms.Normalize((.5,), (.5,))
                                      ]))
triplet_mnist_dataset = TripletMNISTDataset(normal_mnist_dataset, num_samples=60000)

# setup loader
normal_mnist_dataloader = data.DataLoader(normal_mnist_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
triplet_mnist_dataloader = data.DataLoader(triplet_mnist_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [6]:
# generator and  discriminator
generator_net = Generator(z_feautures).to(device)
discriminator_net = Discriminator(out_dimension).to(device)

# create optimiser for the same
discriminator_optimizer = Adam(discriminator_net.parameters(), lr=lr, betas=(beta1, 0.999))
generator_optimizer = Adam(generator_net.parameters(), lr=lr, betas=(beta1, 0.999))

# fixed normal noise that is used to generate fake images after sometimes
fixed_noise = torch.randn((32, z_feautures, 1, 1), device=device)

writer = SummaryWriter(f"./data/run/{exp_name}")

In [112]:
generator_net

Generator(
  (generator): Sequential(
    (0): ConvTranspose2d(100, 128, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): Tanh()
  )
)

In [113]:
discriminator_net

Discriminator(
  (discriminator): Sequential(
    (0): Conv2d(1, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Conv2d(16, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Conv2d(128, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): Sigmoid()
  )
)

In [7]:
# for visualising the emebeddings in tensorboard
special_batch = {
    'imgs': [],
    'labels': [],
    'features': [],
}
per_class_sample = 10
for label in triplet_mnist_dataset.classes:
    special_batch['imgs'].extend(random.choices(triplet_mnist_dataset.labels2images[label], k=per_class_sample))
    special_batch['labels'].extend([label for _ in range(per_class_sample)])
special_batch['imgs'] = torch.cat(special_batch['imgs']).reshape(-1, 1, image_size, image_size)

In [None]:
visualisation_imgs = []  # stores visulization of model at the end of each x steps
generator_losses = []
discriminator_losses = []
steps = 1

for epoch in range(num_epochs):
    for i, (anchor_img, positive_img, negative_img) in enumerate(triplet_mnist_dataloader):
        real_imgs, _ = next(iter(normal_mnist_dataloader))
        real_imgs = real_imgs.to(device)
        anchor_img, positive_img, negative_img = anchor_img.to(device), positive_img.to(device), negative_img.to(device)

        """
            Partially Supervised Discriminator Loss:
            Update discriminator net using triplet loss!
        """
        discriminator_net.zero_grad()
        anchor_out, positive_out, negative_out = discriminator_net(anchor_img), discriminator_net(
            positive_img), discriminator_net(negative_img)
        discriminator_loss_triplet = triplet_paper_loss(anchor_out, positive_out, negative_out)
        discriminator_loss_triplet.backward()
        discriminator_optimizer.step()  # <- Update discriminator

        """
            Unsupervised Discriminator Loss:
            Update discriminator net using fake and real data
        """
        discriminator_net.zero_grad()
        fake_imgs = generator_net(torch.randn(real_imgs.size(0), z_feautures, 1, 1, device=device))
        real_output = discriminator_net(real_imgs.detach())
        fake_output = discriminator_net(fake_imgs.detach())
        discriminator_loss_unsupervised = f_discriminator_unsupervised_loss(fake_output, real_output)
        discriminator_loss_unsupervised.backward()
        discriminator_optimizer.step()  # <- Update discriminator

        discriminator_loss_total = discriminator_loss_triplet + discriminator_loss_unsupervised

        """
            Update out generator : Done!!
        """
        generator_net.zero_grad()
        discriminator_net.zero_grad()

        real_output = discriminator_net(real_imgs)
        fake_output = discriminator_net(fake_imgs)

        # feature_matching_loss has been used as generator_loss
        generator_loss = feature_matching_loss(fake_output, real_output)
        generator_loss.backward()

        generator_optimizer.step()  # <- Update  generator

        generator_losses.append(generator_loss.item())
        discriminator_losses.append(discriminator_loss_total.item())

        writer.add_scalars('Loss', {
            'discriminator': discriminator_loss_total.item(),
            'generator': generator_loss.item(),
        }, steps)
        steps = steps + 1

        print(
            f"Epoch: {epoch:3d}; Iteration: {i:4d}; "
            f"Loss D: {discriminator_loss_total.item():0.4f}; "
            f"Loss G: {generator_loss.item():0.4f}; ",
            end="\r"
        )
    
    
    # make images from fixed noise for visualization
    with torch.no_grad():
        fake_images_with_fixed_noise = generator_net(fixed_noise)
    grid = utils.make_grid(fake_images_with_fixed_noise.detach().cpu(), padding=2, normalize=True)
    visualisation_imgs.append(grid)
    
    
    # add image to tb
    writer.add_image("generated image", grid, steps)
    
    # write model graphs to tb
    #writer.add_graph(generator_net, fixed_noise)
    #writer.add_graph(discriminator_net, fake_images_with_fixed_noise)
    
    # add embedding of b to tb
    with torch.no_grad():
        special_batch['features'] = discriminator_net(special_batch['imgs'].to(device)).detach().cpu()
    writer.add_embedding(special_batch['features'].reshape(-1, out_dimension),
                         global_step = steps,
                         metadata=special_batch['labels'],
                         label_img=special_batch['imgs'])
    
    
    checkpoint(generator_net,"generator", exp_name)
    checkpoint(discriminator_net,"discriminator", exp_name)

In [12]:
writer.close()

In [None]:
plot_losses(generator_losses, discriminator_losses)

In [22]:
# plot_animation(visualisation_imgs, exp_name)

## K-NN Classification

In [69]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import average_precision_score
from sklearn.preprocessing import label_binarize

In [70]:
test_dataset = dataset_.MNIST(root="./data", train=False, download=True,
                                      transform=transforms.Compose([
                                          transforms.Resize(image_size),
                                          transforms.ToTensor(),
                                          transforms.Normalize((.5,), (.5,))
                                      ]))
train_dataset = dataset_.MNIST(root="./data", train=True, download=True,
                                      transform=transforms.Compose([
                                          transforms.Resize(image_size),
                                          transforms.ToTensor(),
                                          transforms.Normalize((.5,), (.5,))
                                      ]))

In [122]:
# ./data/discriminator_m_32_2.model for feature vector of 32
# ./data/discriminator_m_16.model for feature vector of 16
discriminator_net = torch.load("./data/discriminator_m_32_2.model")

In [123]:
def get_features(model, dataset, limit=None):
    features = []
    labels = []
    model.eval()
    dataloader = data.DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)
    for i, (x, y) in enumerate(dataloader, 1):
        x = x
        features_batch = list(model(x.to(device)).detach().reshape(-1,model.out_dimension).cpu().numpy())
        features.extend(features_batch)   
        labels.extend(y)
        
        if limit is not None:
            if (i*128) > limit:
                break
                
    return features, labels

In [124]:
x_test, y_test = get_features(discriminator_net, test_dataset, None)
x_train, y_train = get_features(discriminator_net, train_dataset, 200)

In [131]:
# for number of train samples to be used
N = 200

In [132]:
x_train, y_train = x_train[:N], y_train[:N]

In [133]:
knn= KNeighborsClassifier(n_neighbors=9)

In [134]:
knn.fit(x_train, y_train)

KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
                     metric_params=None, n_jobs=None, n_neighbors=9, p=2,
                     weights='uniform')

In [135]:
knn.score(x_test, y_test)

0.9903

In [None]:
y_test_binary = label_binarize(y_test, classes=[i for i in range(10)])
average_precision_score(y_test_binary, knn.predict_proba(x_test), average="samples")

###  Results
- M32: 100: 0.9903, 0.992535
- M32: 200: 0.9911, 0.9908

- M16: 100: 0.989, 0.99151
- M16: 200: 0.9898, 0.99178