# Libraries

In [36]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import torchvision.transforms as T

# Network Architecture

In [37]:
class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.convnet = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 320x320

            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 160x160

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 80x80

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 40x40

            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(512, 128)

    def forward(self, x):
        x = self.convnet(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)  # batch_size, 128)
        return x

In [38]:
class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, anchor, positive, negative):
        anchor_out = self.embedding_net(anchor)
        positive_out = self.embedding_net(positive)
        negative_out = self.embedding_net(negative)
        return anchor_out, positive_out, negative_out

In [39]:
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        d_positive = F.pairwise_distance(anchor, positive, p=2)
        d_negative = F.pairwise_distance(anchor, negative, p=2)
        loss = F.relu(d_positive - d_negative + self.margin)
        return loss.mean()

# Dataset Design

In [40]:
class ImageTripletDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.anchor_dir = os.path.join(root_dir, "anchor")
        self.positive_dir = os.path.join(root_dir, "positive")
        self.negative_dir = os.path.join(root_dir, "negative")

        self.filenames = sorted(os.listdir(self.anchor_dir))
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        anchor_path = os.path.join(self.anchor_dir, self.filenames[idx])
        positive_path = os.path.join(self.positive_dir, self.filenames[idx])
        negative_path = os.path.join(self.negative_dir, self.filenames[idx])

        anchor = Image.open(anchor_path).convert("RGB")
        positive = Image.open(positive_path).convert("RGB")
        negative = Image.open(negative_path).convert("RGB")

        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)

        return anchor, positive, negative

In [41]:
transform = T.Compose([
    T.Resize((640, 640)),
    T.ToTensor(),
])

dataset = ImageTripletDataset(root_dir="triplet_data", transform=transform)
loader = DataLoader(dataset, batch_size=2, shuffle=True)

# Training Loop

In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

embedding_net = EmbeddingNet()
model = TripletNet(embedding_net).to(device)

triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [43]:
num_epochs = 32

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for anchor, positive, negative in loader:
        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)

        optimizer.zero_grad()
        anchor_out, positive_out, negative_out = model(anchor, positive, negative)

        loss = triplet_loss(anchor_out, positive_out, negative_out)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {avg_loss:.4f}")

Epoch [1/32] - Loss: 0.9991
Epoch [2/32] - Loss: 0.9938
Epoch [3/32] - Loss: 0.9852
Epoch [4/32] - Loss: 0.9788
Epoch [5/32] - Loss: 0.9634
Epoch [6/32] - Loss: 0.9483
Epoch [7/32] - Loss: 0.9304
Epoch [8/32] - Loss: 0.8912
Epoch [9/32] - Loss: 0.8507
Epoch [10/32] - Loss: 0.8047
Epoch [11/32] - Loss: 0.7052
Epoch [12/32] - Loss: 0.6207
Epoch [13/32] - Loss: 0.4879
Epoch [14/32] - Loss: 0.3568
Epoch [15/32] - Loss: 0.2534
Epoch [16/32] - Loss: 0.2347
Epoch [17/32] - Loss: 0.1953
Epoch [18/32] - Loss: 0.1233
Epoch [19/32] - Loss: 0.0027
Epoch [20/32] - Loss: 0.0000
Epoch [21/32] - Loss: 0.0000
Epoch [22/32] - Loss: 0.0243
Epoch [23/32] - Loss: 0.0000
Epoch [24/32] - Loss: 0.0000
Epoch [25/32] - Loss: 0.0453
Epoch [26/32] - Loss: 0.0128
Epoch [27/32] - Loss: 0.0000
Epoch [28/32] - Loss: 0.0000
Epoch [29/32] - Loss: 0.0000
Epoch [30/32] - Loss: 0.0000
Epoch [31/32] - Loss: 0.0000
Epoch [32/32] - Loss: 0.0000


# Evaluation

In [68]:
def classify_by_distance(anchor_path, test_path, model, threshold=0.8, device="cuda"):
    model = model.to(device)

    anchor_img = Image.open(anchor_path).convert("RGB")
    test_img = Image.open(test_path).convert("RGB")

    anchor_tensor = transform(anchor_img).unsqueeze(0).to(device)
    test_tensor = transform(test_img).unsqueeze(0).to(device)

    with torch.no_grad():
        anchor_embed = model(anchor_tensor)
        test_embed = model(test_tensor)
        dist = F.pairwise_distance(anchor_embed, test_embed).item()

    print(f"Distance: {dist:.4f}")
    if dist < threshold:
        return "Positive (Same class)"
    else:
        return "Negative (Different class)"

result = classify_by_distance("triplet_data/anchor/4.jpg", "triplet_data/negative/4.jpg", embedding_net, threshold=1.0)
print("Result:", result)

Distance: 5.2442
Result: Negative (Different class)
