In [21]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import random
from PIL import Image
import PIL.ImageOps

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torchvision.utils
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import classification_report
from torchsummary import summary
from torchvision import models
import os

In [22]:
class SiameseNetworkDataset(Dataset):
    def __init__(self,imageFolderDataset,false_dataset, transform=None):
        self.imageFolderDataset = imageFolderDataset
        self.false_dataset = false_dataset
        self.transform = transform

    def __getitem__(self,index):
        img0_tuple = random.choice(self.imageFolderDataset)
        #We need to approximately 50% of images to be positive
        should_get_same_class = random.randint(0,1)
        if should_get_same_class:
            target = 1
            while True:
                #Look untill the same stage image is found
                img1_tuple = random.choice(self.imageFolderDataset)
                if img0_tuple[1] == img1_tuple[1]:
                    break
        else:
            target = 0
            include_false =random.randint(0,4)

            # in 3 out of 4 cases, choose a false image from the same stage
            # in the 1 out of 4 cases choose a false image from another stage
            if include_false < 3:
                while True:
                    img1_tuple = random.choice(self.false_dataset)
                    if img0_tuple[1] == img1_tuple[1]:
                        break
            else:
                img1_tuple = random.choice(self.false_dataset)


        img0 = img0_tuple[0].convert('RGB')
        width, height = img0.size
        if width > height:
            img0 = img0.transpose(Image.TRANSPOSE)
        img1 = img1_tuple[0].convert('RGB')
        width, height = img1.size
        if width > height:
            img1 = img1.transpose(Image.TRANSPOSE)
        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)

        return img0, img1, torch.from_numpy(np.array([1-target], dtype=np.float32))
    def __len__(self):
        return len(self.imageFolderDataset) + len(self.false_dataset)

In [23]:
# Load the training dataset
folder_dataset = datasets.ImageFolder(root=r"D:\PycharmProjects\AISS\siamese\dataset_true")
false_dataset = datasets.ImageFolder(root=r"D:\PycharmProjects\AISS\siamese\dataset_false")

folder_dataset_test = datasets.ImageFolder(root=r"D:\PycharmProjects\AISS\siamese\dataset_true_val")
false_dataset_test = datasets.ImageFolder(root=r"D:\PycharmProjects\AISS\siamese\dataset_false_val")

# Augment + Resize the images and transform to tensors
transformation = transforms.Compose([
                                    transforms.Resize((100, 100)),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomVerticalFlip(),
                                     transforms.ToTensor()

                                    ])

transformation_test = transforms.Compose([
                                    transforms.Resize((100, 100)),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomVerticalFlip(),
                                     transforms.ToTensor()
                                    ])

# Initialize the dataset
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset,false_dataset=false_dataset,
                                        transform=transformation)

siamese_dataset_test = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test,false_dataset=false_dataset_test,
                                        transform=transformation_test)

In [24]:
# create the dataloader
trainloader = DataLoader(siamese_dataset,
                        batch_size=16, drop_last=True)

testloader = DataLoader(siamese_dataset_test,
                        batch_size=16, drop_last=True)

In [25]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()

        self.backbone = models.__dict__["resnet18"](progress=True, weights='DEFAULT')
        out_features = list(self.backbone.modules())[-1].out_features

        self.fc1 = nn.Sequential(
        nn.Linear(out_features , 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),)

    def forward_once(self, x):
        output = self.backbone(x)
        output = self.fc1(output)
        return output

    def forward(self, input1,input2):
        # output the representations of both images
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1,output2

In [27]:
class ContrastiveLoss(torch.nn.Module):

    def __init__(self, margin=1.5):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        pos = (1-label) * torch.pow(euclidean_distance, 2)
        neg = (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        loss_contrastive = torch.mean( pos + neg)
        return loss_contrastive

In [28]:
def evaluate_pair(output1,output2,target,threshold):
    euclidean_distance = F.pairwise_distance(output1, output2)
    cond = euclidean_distance<threshold
    pos_sum = 0
    neg_sum = 0
    pos_acc = 0
    neg_acc = 0
    # count accuracy of positive and negative pairs
    for i in range(len(cond)):
        if target[i]:
            neg_sum+=1
            if not cond[i]:
                neg_acc+=1
        if not target[i]:
            pos_sum+=1
            if cond[i]:
                pos_acc+=1

    return pos_acc,pos_sum,neg_acc,neg_sum

In [29]:


## Initialize network
model = SiameseNetwork()
model = model.cuda()

## Initialize optimizer
optim = torch.optim.Adam(model.parameters(),lr=0.005)

## Initialize loss
criterion = ContrastiveLoss()


In [None]:
train_loss = []
valid_loss = []
for epoch in range(1000):
    train_epoch_loss = 0
    model.train()
    c = 0
    for i,(input1,input2,target) in enumerate(trainloader):
        c+=1
        optim.zero_grad()
        output1,output2 = model(input1.cuda(),input2.cuda())
        out = model(input1.cuda(),input2.cuda())

        loss = criterion(output1,output2,target.cuda())
        train_epoch_loss += loss.item()
        loss.backward()
        optim.step()

    train_epoch_loss /= c
    train_loss.append(train_epoch_loss)

    print("Epoch [{}/{}] ----> Training loss :{} \n".format(epoch+1,1000,train_epoch_loss))
    # train loop ended
    # start evaluation

    valid_epoch_loss = 0
    val_pos_accuracy = 0
    val_neg_accuracy = 0
    num_pos = 0
    num_neg = 0

    val_pos_accuracy2 = 0
    val_neg_accuracy2= 0
    num_pos2 = 0
    num_neg2 = 0
    model.eval()
    c  =0
    for i,(input1,input2,target) in enumerate(testloader):
        c +=1
        output1,output2 = model(input1.cuda(),input2.cuda())
        loss = criterion(output1,output2,target.cuda())
        valid_epoch_loss += loss.item()
        pos_acc,pos_sum,neg_acc,neg_sum = evaluate_pair(output1,output2,target.cuda(),0.5)
        pos_acc2,pos_sum2,neg_acc2,neg_sum2 = evaluate_pair(output1,output2,target.cuda(),0.35)
        val_pos_accuracy+=pos_acc
        val_neg_accuracy+=neg_acc
        num_pos+=pos_sum
        num_neg+=neg_sum
        val_pos_accuracy2+=pos_acc2
        val_neg_accuracy2+=neg_acc2
        num_pos2+=pos_sum2
        num_neg2+=neg_sum2

    valid_epoch_loss /= c
    val_pos_accuracy /= num_pos
    val_neg_accuracy /= num_neg
    val_pos_accuracy2 /= num_pos2
    val_neg_accuracy2 /= num_neg2
    valid_loss.append(valid_epoch_loss)



    print("Validation loss :{} \t\t\t P Acc : {}, N Acc: {} P Acc : {}, N Acc : {}\n".format(valid_epoch_loss,val_pos_accuracy,val_neg_accuracy, val_pos_accuracy2, val_neg_accuracy2))
    # Save model
    if (epoch + 1) % 10 == 0:
        torch.save(
            {
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "backbone": "resnet-18",
                "optimizer_state_dict": optim.state_dict()
            },
            os.path.join("siamese", "epoch_{}.pth".format(epoch + 1))
        )

In [None]:
# predict
img1 = r"D:\PycharmProjects\AISS_Seminar\yolov5\pred_folder\true\vlcsnap-2023-06-27-23h11m53s160.png"
img2 = r"D:\PycharmProjects\AISS_Seminar\yolov5\pred_folder\generated\5\vlcsnap-2023-06-27-23h06m59s656_2.jpg"
transform = transforms.Compose([transforms.Resize((100, 100)),transforms.ToTensor()])
img1 = transform(PIL.Image.open(img1)).cuda()
img2 = transform(PIL.Image.open(img2)).cuda()
img1 = img1[None, :]
img2 = img2[None, :]
output1, output2 = model(img1, img2)
print(F.pairwise_distance(output1, output2))