In [20]:
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import torch
import random
import numpy as np
import scipy.io as scp
import torch.optim as optim
import torchvision.models as models
from dataset import train_dataset, test_dataset, val_dataset
from torch.utils.data import Dataset  # Import the Dataset class
import torch.nn.functional as F

In [21]:
train_transform = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                        [0.229, 0.224, 0.225])])

testval_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

train_dataset = torchvision.datasets.Flowers102(root='./data', split='train', download=True, transform=train_transform)
val_dataset = torchvision.datasets.Flowers102(root='./data', split='val', download=True, transform=testval_transform)
test_dataset = torchvision.datasets.Flowers102(root='./data', split='test', download=True, transform=testval_transform)

In [22]:

class TripletDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.labels = np.array([item[1] for item in dataset])
        self.label_to_indices = {label: np.where(self.labels == label)[0] for label in np.unique(self.labels)}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        anchor, anchor_label = self.dataset[idx]
        positive_idx = idx

        while positive_idx == idx:
            positive_idx = random.choice(self.label_to_indices[anchor_label])

        negative_label = random.choice(list(self.label_to_indices.keys()))
        while negative_label == anchor_label:
            negative_label = random.choice(list(self.label_to_indices.keys()))

        negative_idx = random.choice(self.label_to_indices[negative_label])

        positive, _ = self.dataset[positive_idx]
        negative, _ = self.dataset[negative_idx]

        return anchor, positive, negative

In [16]:
batch_size = 256  # Adjust the batch size as needed

train_triplet_dataset = TripletDataset(train_dataset)
train_triplet_loader = DataLoader(train_triplet_dataset, batch_size=batch_size, shuffle=True)

In [23]:
def triplet_loss(anchor, positive, negative, margin=1.0):
    # Calculate Euclidean distances between the anchor, positive, and negative embeddings
    distance_positive = F.pairwise_distance(anchor, positive, p=2)
    distance_negative = F.pairwise_distance(anchor, negative, p=2)

    # Calculate the triplet loss
    loss = torch.clamp(distance_positive - distance_negative + margin, min=0.0)
    
    # Return the average triplet loss over the batch
    return torch.mean(loss)

In [24]:
from model import mobilenet

In [25]:
model,optimizer,criterion = mobilenet()
for anchor, positive, negative in train_triplet_loader:
    # Forward pass to get embeddings for anchor, positive, and negative samples
    anchor_emb = model(anchor)
    positive_emb = model(positive)
    negative_emb = model(negative)

    # Calculating triplet loss using the custom triplet loss function
    loss = triplet_loss(anchor_emb, positive_emb, negative_emb)

    # Backpropagation and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # Implement the evaluation here