In [1]:
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import torch
import random
import numpy as np
import scipy.io as scp
import torch.optim as optim
import torchvision.models as models
from dataset import train_dataset, test_dataset, val_dataset
from torch.utils.data import Dataset  # Import the Dataset class
import torch.nn.functional as F

In [2]:

class TripletDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.labels = np.array([item[1] for item in dataset])
        self.label_to_indices = {label: np.where(self.labels == label)[0] for label in np.unique(self.labels)}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        anchor, anchor_label = self.dataset[idx]
        positive_idx = idx

        while positive_idx == idx:
            positive_idx = random.choice(self.label_to_indices[anchor_label])

        negative_label = random.choice(list(self.label_to_indices.keys()))
        while negative_label == anchor_label:
            negative_label = random.choice(list(self.label_to_indices.keys()))

        negative_idx = random.choice(self.label_to_indices[negative_label])

        positive, _ = self.dataset[positive_idx]
        negative, _ = self.dataset[negative_idx]

        return anchor, positive, negative

In [3]:
batch_size = 256  # Adjust the batch size as needed

train_triplet_dataset = TripletDataset(train_dataset)
train_triplet_loader = DataLoader(train_triplet_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [4]:
def triplet_loss(anchor, positive, negative, margin=1.0):
    # Calculate Euclidean distances between the anchor, positive, and negative embeddings
    distance_positive = F.pairwise_distance(anchor, positive, p=2)
    distance_negative = F.pairwise_distance(anchor, negative, p=2)

    # Calculate the triplet loss
    loss = torch.clamp(distance_positive - distance_negative + margin, min=0.0)
    
    # Return the average triplet loss over the batch
    return torch.mean(loss)

In [5]:
from model import mobilenet

In [6]:
# def eval(dataloader, model, criterion, device):
#     model.eval()
#     correct = 0
#     total_loss = 0
#     with torch.no_grad():
#         for idx, (data, target) in enumerate(dataloader):
#             data, target = data.to(device), target.to(device)

#             output = model(data)
#             loss = criterion(output, target)
#             total_loss += loss.item()
#             pred = output.argmax(dim=1)

#             correct += pred.eq(target.view_as(pred)).sum().item() # compare predicted label to actual label
#     return correct / len(dataloader.dataset), total_loss / len(dataloader)

In [10]:
from common_utils import EarlyStopper
from train import train, eval
from triplet_loss import TripletLoss
from collections import OrderedDict

NUM_EPOCHS = 100
EARLY_STOP_THRESHOLD = 3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
early_stopper = EarlyStopper(patience=3)

model,optimizer,criterion = mobilenet()
model.classifier = nn.Sequential(OrderedDict([ #mobilenet
                        ('dropout1', nn.Dropout(0.5)),
                        ('fc2', nn.Linear(1280, 102)),
                        ('output', nn.Linear(102, 2))
                        ]))
criterion = TripletLoss(DEVICE)
model.train()
for epoch in range(NUM_EPOCHS):
    for anchor, positive, negative in train_triplet_loader:
        optimizer.zero_grad()
        anchor_emb = model(anchor)
        positive_emb = model(positive)
        negative_emb = model(negative)
        loss = triplet_loss(anchor_emb, positive_emb, negative_emb)
        loss.backward()
        optimizer.step()
    accuracy, val_loss = eval(val_loader, model, criterion, DEVICE)
    print(f"Epoch {epoch + 1}: Triplet Loss: {loss.item()}, Val Accuracy: {accuracy}")
    if early_stopper.early_stop(accuracy):
        print("Early Stopping...")
        break

# early_stopper = EarlyStopper(patience=EARLY_STOP_THRESHOLD)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
# best_acc = 0
# early_stop_count = 0
# for epoch in range(1, NUM_EPOCHS+1):
#     train_loss = train(train_loader, model, criterion, optimizer, DEVICE)
#     accuracy, val_loss = eval(val_loader, model, criterion, DEVICE)
#     print(f'Epoch {epoch}, Train Loss: {train_loss}, Val Accuracy: {accuracy}, Val Loss: {val_loss}')
#     if early_stopper.early_stop(val_loss):
#         print("Early Stopping...")
#         break
#     scheduler.step()
# test_accuracy, _ = eval(test_loader, model, criterion, DEVICE)
# print(f'Test Accuracy: {test_accuracy}')

Epoch 1: Triplet Loss: 0.9912272691726685, Val Accuracy: 0.00980392156862745
Epoch 2: Triplet Loss: 0.9657003283500671, Val Accuracy: 0.00980392156862745


KeyboardInterrupt: 

In [None]:
test_accuracy, _ = eval(test_loader, model, criterion, DEVICE)
print(f'Test Accuracy: {test_accuracy}')