In [14]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import random
from PIL import Image
import PIL.ImageOps

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torchvision.utils
import torch
import torch.nn as nn
from torchvision import models
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report
import os
from sklearn.metrics import f1_score

In [15]:
class SiameseNetworkDataset(Dataset):
    def __init__(self,imageFolderDataset,false_dataset, transform=None):
        self.imageFolderDataset = imageFolderDataset
        self.false_dataset = false_dataset
        self.transform = transform

        folder = r"D:\PycharmProjects\AISS\siamese\dataset_true"
        self.reference_imgs0 = self.prepare_imgs(os.path.join(folder,"0", "82_14491.jpg"), anchor=True)


        self.reference_imgs1 = self.prepare_imgs(os.path.join(folder,"1", "42_14314.jpg"), anchor=True)


        self.reference_imgs2 = self.prepare_imgs(os.path.join(folder,"2", "02_3176.jpg"), anchor=True)


        self.reference_imgs3 = self.prepare_imgs(os.path.join(folder,"3", "310_18578.jpg"), anchor=True)

    def prepare_imgs(self, p, anchor=False):
        if anchor:
            img = PIL.Image.open(p)
        else:
            img = p
        width, height = img.size
        if width > height:
            img = img.transpose(Image.TRANSPOSE)
        transform = transforms.Compose([transforms.Resize((150, 50)),transforms.ToTensor()])
        img = transform(img)
        return img

    def __getitem__(self,index):
        # select a random stage
        class_ = random.randint(0,3)

        if class_ == 0:
            img0 = self.reference_imgs0
        if class_ == 1:
            img0 = self.reference_imgs1

        if class_ == 2:
            img0 = self.reference_imgs2

        if class_ == 3:
            img0 = self.reference_imgs3



        while True:
            #Look untill the same stage image is found
            img1_tuple = random.choice(self.imageFolderDataset)
            if class_ == img1_tuple[1]:
                break

        while True:
            #Look untill the same stage image is found
            img2_tuple = random.choice(self.false_dataset)
            if class_ == img2_tuple[1]:
                break


        img1 = self.prepare_imgs(img1_tuple[0])
        img2 = self.prepare_imgs(img2_tuple[0])

        return img0,img1 ,img2
    def __len__(self):
        return len(self.imageFolderDataset)

In [None]:
# Load the training dataset
folder_dataset = datasets.ImageFolder(root=r"D:\PycharmProjects\AISS\siamese\dataset_true")
false_dataset = datasets.ImageFolder(root=r"D:\PycharmProjects\AISS\siamese\dataset_false")

folder_dataset_test = datasets.ImageFolder(root=r"D:\PycharmProjects\AISS\siamese\dataset_true_val")
false_dataset_test = datasets.ImageFolder(root=r"D:\PycharmProjects\AISS\siamese\dataset_false_val")

# Resize the images and transform to tensors
transformation = transforms.Compose([
                                    transforms.Resize((150, 50)),
                                     transforms.ToTensor()

                                    ])

transformation_test = transforms.Compose([
                                    transforms.Resize((150, 50)),
                                     transforms.ToTensor()
                                    ])

# Initialize the dataset
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset,false_dataset=false_dataset,
                                        transform=transformation)

siamese_dataset_test = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test,false_dataset=false_dataset_test,
                                        transform=transformation_test)

In [17]:
# create the dataloader
trainloader = DataLoader(siamese_dataset,
                        batch_size=32, shuffle=True, drop_last=True)

testloader = DataLoader(siamese_dataset_test,
                        batch_size=32, drop_last=True)

In [18]:
from torchvision.models import resnet18

class SiameseNetwork(nn.Module):
    def __init__(self, backbone="resnet18"):
        '''
        Creates a siamese network with a network from torchvision.models as backbone.
            Parameters:
                    backbone (str): Options of the backbone networks can be found at https://pytorch.org/vision/stable/models.html
        '''

        super().__init__()
        # Create a backbone network from the pretrained models provided in torchvision.models
        self.Feature_Extractor = resnet18(pretrained=True)
        # freeze layers
        for j, child in enumerate(self.Feature_Extractor.children()):
            if j < 4:
                for param in child.parameters():
                    param.requires_grad = False
        num_filters = self.Feature_Extractor.fc.in_features

        # feature representation head
        self.Feature_Extractor.fc = nn.Sequential(
                  nn.Linear(num_filters,512),
                  nn.LeakyReLU(),
                  nn.Linear(512,10))
        self.Triplet_Loss = nn.Sequential(
                  nn.Linear(10,2))
    def forward(self,x):
        features = self.Feature_Extractor(x)
        triplets = self.Triplet_Loss(features)
        return triplets

    def forward(self, img1):
        # Pass one image through the network and get the representation
        feat1 = self.Feature_Extractor(img1)
        output = self.Triplet_Loss(feat1)
        return output

In [19]:
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
    def calc_euclidean(self, x1, x2):
        return (x1 - x2).pow(2).sum(1)
    def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
        distance_positive = self.calc_euclidean(anchor, positive)
        distance_negative = self.calc_euclidean(anchor, negative)
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SiameseNetwork()
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = TripletLoss()



for epoch in range(300):
    print("[{} / {}]".format(epoch, 300))
    model.train()

    losses = []
    # Training Loop Start
    model.train()
    for img1, img2, img3 in trainloader:
        optimizer.zero_grad()
        img1, img2, img3 = map(lambda x: x.to(device), [img1, img2, img3])

        anchor_out = model(img1)
        positive_out = model(img2)
        negative_out = model(img3)


        loss = criterion(anchor_out, positive_out, negative_out)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())


    print("\tTraining: Loss={:.2f}\t ".format(np.mean(losses)))
    # Training Loop End
    # start evaluation
    model.eval()

    val_loss = []
    for img1, img2, img3 in testloader:

        img1, img2, img3 = map(lambda x: x.to(device), [img1, img2, img3])

        anchor_out = model(img1)
        positive_out = model(img2)
        negative_out = model(img3)


        loss = criterion(anchor_out, positive_out, negative_out)
        val_loss.append(loss.item())

    l = np.mean(val_loss)
    print("Validation loss :{}".format(l))



    # Save model
    if (epoch + 1) % 1 == 0:
        torch.save(
            {
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "backbone": "resnet-18",
                "optimizer_state_dict": optimizer.state_dict()
            },
            os.path.join("siamese","models_new", "epoch_{}_{}.pth".format(epoch + 1, 2))
        )

In [None]:
# load in model
model = SiameseNetwork()
model.load_state_dict(torch.load("siamese/models/epoch_35_0.8938460690362686.pth")["model_state_dict"])
model.cuda()

In [None]:
# predict with anchor image
model.eval()

img2 = PIL.Image.open(r"D:\PycharmProjects\AISS\siamese\dataset_true\1\10_2830.jpg")
img1 = PIL.Image.open(r"D:\PycharmProjects\AISS\testbild.jpg")
#img1 = PIL.Image.open(r"D:\PycharmProjects\AISS_Seminar\yolov5\pred_folder\generated\5\vlcsnap-2023-06-28-20h53m36s380_1.jpg")

width, height = img1.size
if width > height:
    img1 = img1.transpose(Image.TRANSPOSE)
width, height = img2.size
if width > height:
    img2 = img2.transpose(Image.TRANSPOSE)
transform = transforms.Compose([transforms.Resize((150, 50)),transforms.ToTensor()])
img1 = transform(img1).cuda()
img2 = transform(img2).cuda()
img1 = img1[None, :]
img2 = img2[None, :]
outputs = model(img1)
print(outputs.item())