# This network implements the Siamese Network for object classification on LV data

In [None]:
import numpy as np, glob, time
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, models, datasets
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# This cell and nesxt one install two libraries that we need to run this notebook.
# They are just required due to the way we implemented the code and are not necessary
# for the Siamese network
!pip install torch_snippets

In [None]:
!pip install jsonlines

In [None]:
# This data set makes triplets that first two ielements are images and third is either 1 or 0 
# depending if the two images belong to fake category or real 
from torch_snippets import *

class SiameseNetworkDataset(Dataset):
    def __init__(self, folder, transform=None, should_invert=True):
        self.folder = folder
        self.items = Glob(f'{self.folder}/*/*') 
        self.transform = transform
    def __getitem__(self, ix):
        itemA = self.items[ix]
        auth = fname(parent(itemA))
        same_auth = randint(2)
        if same_auth:
            itemB = choose(Glob(f'{self.folder}/{auth}/*', silent=True))
        else:
            while True:
                itemB = choose(self.items)
                if auth != fname(parent(itemB)):
                    break
        imgA = read(itemA)
        imgB = read(itemB)
        if self.transform:
            imgA = self.transform(imgA)
            imgB = self.transform(imgB)
        return imgA, imgB, np.array([1-same_auth])
    def __len__(self):
        return len(self.items)

In [None]:
from torchvision import transforms

trn_tfms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(),
    #transforms.RandomHorizontalFlip(),
    #transforms.RandomAffine(5, (0.01,0.2),scale=(0.9,1.1)),
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize((0.5), (0.5))
])

val_tfms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(),
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize((0.5), (0.5))
])

In [None]:
# r directories contain real images and f include fake images
trn_ds = SiameseNetworkDataset(folder="./drive/MyDrive/LV_data/train", transform=trn_tfms)
val_ds = SiameseNetworkDataset(folder="./drive/MyDrive/LV_data/val", transform=val_tfms)

trn_dl = DataLoader(trn_ds, shuffle=True, batch_size=64)
val_dl = DataLoader(val_ds, shuffle=False, batch_size=64)

In [None]:
def convBlock(ni, no):
    return nn.Sequential(
        nn.Dropout(0.2),
        nn.Conv2d(ni, no, kernel_size=3, padding=1, padding_mode='reflect'),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(no),
    )

In [None]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.features = nn.Sequential(
            convBlock(1,4),
            convBlock(4,8),
            convBlock(8,8),
            nn.Flatten(),
            nn.Linear(8*256*256, 500), nn.ReLU(inplace=True),
            nn.Linear(500, 500), nn.ReLU(inplace=True),
            nn.Linear(500, 10)
        )

    def forward(self, input1, input2):
        output1 = self.features(input1)
        output2 = self.features(input2)
        return output1, output2

In [None]:
#We used this ContrastiveLoss for training but another rather newer loss is implemented in below cell
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        acc = ((euclidean_distance > 0.6) == label).float().mean()
        return loss_contrastive, acc

In [None]:
import math
from torch.nn import Parameter

def l2_norm(input,axis=1):
    norm = torch.norm(input,2,axis,True)
    output = torch.div(input, norm)
    return output

class Am_softmax(torch.nn.Module):
    # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599    
    def __init__(self,embedding_size=10,classnum=2):
        super(Am_softmax, self).__init__()
        self.classnum = classnum
        self.kernel = Parameter(torch.Tensor(embedding_size,classnum))
        # initial kernel
        self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
        self.m = 0.35 # additive margin recommended by the paper
        self.s = 30. # see normface https://arxiv.org/abs/1704.06369
    def forward(self,embbedings,label):
        kernel_norm = l2_norm(self.kernel,axis=0)
        cos_theta = torch.mm(embbedings,kernel_norm)
        cos_theta = cos_theta.clamp(-1,1) # for numerical stability
        phi = cos_theta - self.m
        label = label.view(-1,1) #size=(B,1)
        index = cos_theta.data * 0.0 #size=(B,Classnum)
        index.scatter_(1,label.data.view(-1,1),1)
        index = index.byte()
        output = cos_theta * 1.0
        output[index] = phi[index] #only change the correct predicted output
        output *= self.s # scale up in order to make softmax work, first introduced in normface
        acc = ((cos_theta > 0.9) == label).float().mean()
        return output, acc

In [None]:
def train_batch(model, data, optimizer, criterion):
    imgsA, imgsB, labels = [t.to(device) for t in data]
    optimizer.zero_grad()
    codesA, codesB = model(imgsA, imgsB)
    loss, acc = criterion(codesA, codesB, labels)
    loss.backward()
    optimizer.step()
    return loss.item(), acc.item()

@torch.no_grad()
def validate_batch(model, data, criterion):
    imgsA, imgsB, labels = [t.to(device) for t in data]
    codesA, codesB = model(imgsA, imgsB)
    loss, acc = criterion(codesA, codesB, labels)
    return loss.item(), acc.item()

In [None]:
model = SiameseNetwork().to(device)
criterion = ContrastiveLoss()
#criterion = Am_softmax()
optimizer = optim.Adam(model.parameters(),lr = 0.001)

In [None]:
from torch.optim import lr_scheduler

n_epochs = 500
log = Report(n_epochs)
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

for epoch in range(n_epochs):
    N = len(trn_dl)
    for i, data in enumerate(trn_dl):
        loss, acc = train_batch(model, data, optimizer, criterion)
        log.record(epoch+(1+i)/N, trn_loss=loss, trn_acc=acc, end='\r')
    N = len(val_dl)
    for i, data in enumerate(val_dl):
        loss, acc = validate_batch(model, data, criterion)
        log.record(epoch+(1+i)/N, val_loss=loss, val_acc=acc, end='\r')
    if (epoch+1)%50==0: log.report_avgs(epoch+1)
    if (epoch+1)%100==0: torch.save(model, '/content/drive/MyDrive/Rebag_Siamese/Rebag_Siamese_Amloss.pth')
    if (epoch+1)%100==0: scheduler.step()
    #if epoch==10: optimizer = optim.Adam(model.parameters(), lr=0.0005)


In [None]:
#TP = 10, FP = 10, TN = 59, FN = 48    sen = 17%  sp = 85%

In [None]:
torch.save(model, '/content/drive/MyDrive/Rebag_Siamese/Rebag_Siamese.pth')

In [None]:
Rebag_model = torch.load('/content/drive/MyDrive/Rebag_Siamese/Rebag_Siamese.pth')

In [None]:
n_epochs = 300
log = Report(n_epochs)

scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

for epoch in range(n_epochs):
    N = len(trn_dl)
    for i, data in enumerate(trn_dl):
        loss, acc = train_batch(Rebag_model, data, optimizer, criterion)
        log.record(epoch+(1+i)/N, trn_loss=loss, trn_acc=acc, end='\r')
    N = len(val_dl)
    for i, data in enumerate(val_dl):
        loss, acc = validate_batch(Rebag_model, data, criterion)
        log.record(epoch+(1+i)/N, val_loss=loss, val_acc=acc, end='\r')
    if (epoch+1)%50==0: log.report_avgs(epoch+1)
    if (epoch+1)%100==0: torch.save(model, '/content/drive/MyDrive/Rebag_Siamese/Rebag_Siamese.pth')
    if (epoch+1)%100==0: scheduler.step()

EPOCH: 50.000	trn_loss: 0.041	trn_acc: 0.973	val_loss: 1.240	val_acc: 0.583	(372.29s - 1861.47s remaining)
EPOCH: 100.000	trn_loss: 0.099	trn_acc: 0.873	val_loss: 1.323	val_acc: 0.437	(743.08s - 1486.16s remaining)
EPOCH: 150.000	trn_loss: 0.125	trn_acc: 0.899	val_loss: 1.257	val_acc: 0.611	(1119.37s - 1119.37s remaining)
EPOCH: 200.000	trn_loss: 0.083	trn_acc: 0.953	val_loss: 1.402	val_acc: 0.516	(1487.80s - 743.90s remaining)
EPOCH: 250.000	trn_loss: 0.074	trn_acc: 0.919	val_loss: 1.244	val_acc: 0.539	(1862.93s - 372.59s remaining)
EPOCH: 300.000	trn_loss: 0.103	trn_acc: 0.907	val_loss: 1.291	val_acc: 0.543	(2232.61s - 0.00s remaining)


In [None]:
#TP = 17, FP = 17, TN = 51, FN = 42  sen = 28%  sp = 75%

In [None]:
Rebag_model800 = torch.load('/content/drive/MyDrive/Rebag_Siamese/Rebag_Siamese.pth')

In [None]:
n_epochs = 700
log = Report(n_epochs)

scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

for epoch in range(n_epochs):
    N = len(trn_dl)
    for i, data in enumerate(trn_dl):
        loss, acc = train_batch(Rebag_model800, data, optimizer, criterion)
        log.record(epoch+(1+i)/N, trn_loss=loss, trn_acc=acc, end='\r')
    N = len(val_dl)
    for i, data in enumerate(val_dl):
        loss, acc = validate_batch(Rebag_model800, data, criterion)
        log.record(epoch+(1+i)/N, val_loss=loss, val_acc=acc, end='\r')
    if (epoch+1)%50==0: log.report_avgs(epoch+1)
    if (epoch+1)%100==0: torch.save(model, '/content/drive/MyDrive/Rebag_Siamese/Rebag_Siamese.pth')
    if (epoch+1)%100==0: scheduler.step()

EPOCH: 50.000	trn_loss: 0.063	trn_acc: 0.949	val_loss: 1.289	val_acc: 0.484	(370.66s - 4818.57s remaining)
EPOCH: 100.000	trn_loss: 0.054	trn_acc: 0.965	val_loss: 1.209	val_acc: 0.500	(738.30s - 4429.83s remaining)
EPOCH: 150.000	trn_loss: 0.072	trn_acc: 0.941	val_loss: 1.135	val_acc: 0.575	(1115.36s - 4089.64s remaining)
EPOCH: 200.000	trn_loss: 0.129	trn_acc: 0.849	val_loss: 1.194	val_acc: 0.497	(1484.26s - 3710.65s remaining)
EPOCH: 250.000	trn_loss: 0.071	trn_acc: 0.945	val_loss: 1.208	val_acc: 0.500	(1868.93s - 3364.07s remaining)
EPOCH: 300.000	trn_loss: 0.175	trn_acc: 0.934	val_loss: 1.144	val_acc: 0.484	(2240.48s - 2987.30s remaining)
EPOCH: 350.000	trn_loss: 0.059	trn_acc: 0.957	val_loss: 1.186	val_acc: 0.528	(2615.06s - 2615.06s remaining)
EPOCH: 400.000	trn_loss: 0.169	trn_acc: 0.899	val_loss: 1.340	val_acc: 0.473	(2987.25s - 2240.44s remaining)
EPOCH: 450.000	trn_loss: 0.154	trn_acc: 0.857	val_loss: 1.232	val_acc: 0.516	(3370.10s - 1872.28s remaining)
EPOCH: 500.000	trn_los

In [None]:
test_ds = SiameseNetworkDataset(folder="./drive/MyDrive/LV_data/test", transform=val_tfms)

test_dl = DataLoader(test_ds, shuffle=False, batch_size=1)

In [None]:
Rebag_model1500 = torch.load('/content/drive/MyDrive/Rebag_Siamese/Rebag_Siamese.pth')

In [None]:
dataiter = iter(test_dl)
labels = []
preds = []

for i in range(len(test_dl)):
    imgA, imgB, label = next(dataiter)
    imgA, imgB, label = imgA.to(device), imgB.to(device), label.to(device)
    outA, outB = Rebag_model1500(imgA, imgB)
    euclidean_distance = F.pairwise_distance(outA, outB, keepdim = True)
    pred = (euclidean_distance > 0.6).float()
    preds.append(pred.item()), labels.append(label.item())


In [None]:
labels

tensor([[1]], device='cuda:0')

In [None]:
def confusion(prediction, truth):
    """ Returns the confusion matrix for the values in the `prediction` and `truth`
    tensors, i.e. the amount of positions where the values of `prediction`
    and `truth` are
    - 1 and 1 (True Positive)
    - 1 and 0 (False Positive)
    - 0 and 0 (True Negative)
    - 0 and 1 (False Negative)
    """

    confusion_vector = prediction / truth
    # Element-wise division of the 2 tensors returns a new tensor which holds a
    # unique value for each case:
    #   1     where prediction and truth are 1 (True Positive)
    #   inf   where prediction is 1 and truth is 0 (False Positive)
    #   nan   where prediction and truth are 0 (True Negative)
    #   0     where prediction is 0 and truth is 1 (False Negative)

    true_positives = torch.sum(confusion_vector == 1).item()
    false_positives = torch.sum(confusion_vector == float('inf')).item()
    true_negatives = torch.sum(torch.isnan(confusion_vector)).item()
    false_negatives = torch.sum(confusion_vector == 0).item()

    return true_positives, false_positives, true_negatives, false_negatives

In [None]:
confusion(torch.Tensor(preds), torch.Tensor(labels))

(10, 8, 57, 52)