# Libraries

In [18]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import torchvision.transforms as T

# Network Architecture

In [19]:
class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.convnet = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 320x320

            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 160x160

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 80x80

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 40x40

            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(512, 128)

    def forward(self, x):
        x = self.convnet(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)  # batch_size, 128)
        return x

In [20]:
class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, anchor, positive, negative):
        anchor_out = self.embedding_net(anchor)
        positive_out = self.embedding_net(positive)
        negative_out = self.embedding_net(negative)
        return anchor_out, positive_out, negative_out

In [21]:
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        d_positive = F.pairwise_distance(anchor, positive, p=2)
        d_negative = F.pairwise_distance(anchor, negative, p=2)
        loss = F.relu(d_positive - d_negative + self.margin)
        return loss.mean()

# Dataset Design

In [22]:
class ImageTripletDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.anchor_dir = os.path.join(root_dir, "anchor")
        self.positive_dir = os.path.join(root_dir, "positive")
        self.negative_dir = os.path.join(root_dir, "negative")

        self.filenames = sorted(os.listdir(self.anchor_dir))
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        anchor_path = os.path.join(self.anchor_dir, self.filenames[idx])
        positive_path = os.path.join(self.positive_dir, self.filenames[idx])
        negative_path = os.path.join(self.negative_dir, self.filenames[idx])

        anchor = Image.open(anchor_path).convert("RGB")
        positive = Image.open(positive_path).convert("RGB")
        negative = Image.open(negative_path).convert("RGB")

        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)

        return anchor, positive, negative

In [23]:
transform = T.Compose([
    T.Resize((640, 640)),
    T.ToTensor(),
])

dataset = ImageTripletDataset(root_dir="triplet_data", transform=transform)
loader = DataLoader(dataset, batch_size=2, shuffle=True)