In [1]:
import torch
import torch.nn as nn
from torchvision import models
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
import json
from torch.utils.data import Dataset, DataLoader
import random
import numpy as np
from PIL import Image
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Transform for images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    return transform(image)

# Load the dataset JSON file
with open('/workspaces/finetune/AFINAL/resnet/output/train_data.json', 'r') as f:
    dataset = json.load(f)

# Dictionary to map class labels to image paths
class_to_images = {}
for item in dataset:
    cls = item['class']
    if cls not in class_to_images:
        class_to_images[cls] = []
    class_to_images[cls].append(item['image_path'])

In [3]:
class_to_images

{'3000000033': ['/workspaces/finetune/AFINAL/clip/combined_classes_augmented/3000000033/3000000033_original_3000000033(1).jpg_7c870b02-71dd-4821-826c-8b00dbd02661.jpg',
  '/workspaces/finetune/AFINAL/clip/combined_classes_augmented/3000000033/3000000033_original_3000000033(1).jpg_12224dc4-8b60-44e7-8aaa-cba2ae75211e.jpg',
  '/workspaces/finetune/AFINAL/clip/combined_classes_augmented/3000000033/3000000033(1).jpg',
  '/workspaces/finetune/AFINAL/clip/combined_classes_augmented/3000000033/3000000033_original_3000000033(1).jpg_df349ba6-78e4-4f6f-94b5-b015d991657c.jpg',
  '/workspaces/finetune/AFINAL/clip/combined_classes_augmented/3000000033/3000000033_original_3000000033(1).jpg_542e2679-6367-44af-a2b8-e2e88f367b6d.jpg',
  '/workspaces/finetune/AFINAL/clip/combined_classes_augmented/3000000033/3000000033_original_3000000033(1).jpg_3081d009-ed33-44fe-878b-1a6425b107b4.jpg'],
 '2997330284': ['/workspaces/finetune/AFINAL/clip/combined_classes_augmented/2997330284/2997330284_original_29973302

In [4]:
class CustomResNet50(nn.Module):
    def __init__(self, num_classes, embedding_dim=256):
        super(CustomResNet50, self).__init__()
        original_resnet = models.resnet50(pretrained=True)
        num_ftrs = original_resnet.fc.in_features  # Get in_features from the original resnet fc layer
        self.resnet = nn.Sequential(*list(original_resnet.children())[:-2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.embedding = nn.Linear(num_ftrs, embedding_dim)
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def forward(self, x, return_embedding=False):
        x = self.resnet(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        embedding = self.embedding(x)
        if return_embedding:
            return embedding
        output = self.classifier(embedding)
        return output, embedding

In [5]:
# Triplet Loss Definition
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.triplet_loss = nn.TripletMarginLoss(margin=self.margin)

    def forward(self, anchor, positive, negative):
        return self.triplet_loss(anchor, positive, negative)

In [6]:
def generate_triplets(embeddings, targets, model, dataset):
    """
    Generate triplets (anchor, positive, negative) from the embeddings and targets.

    Parameters:
    - embeddings: torch.Tensor of shape (batch_size, embedding_dim)
    - targets: torch.Tensor of shape (batch_size,)
    - model: PyTorch model to compute embeddings for positive and negative samples
    - dataset: Dataset to use for finding positive and negative samples

    Returns:
    - anchors, positives, negatives: torch.Tensors of shape (batch_size, embedding_dim)
    """
    embeddings = embeddings.cpu().detach().numpy()
    targets = targets.cpu().detach().numpy()
    batch_size = len(targets)
    
    anchors = []
    positives = []
    negatives = []
    
    for i in range(batch_size):
        anchor = embeddings[i]
        anchor_label = targets[i]
        anchor_class = dataset.idx_to_class[anchor_label]
        
        # Find positive example (same class as anchor)
        positive_image_path = random.choice(class_to_images[anchor_class])
        positive_image = load_image(positive_image_path).unsqueeze(0).to(device)
        positive_embedding = model(positive_image, return_embedding=True)
        positive_embedding = positive_embedding.cpu().detach().numpy().squeeze()

        # Find negative example (different class than anchor)
        negative_label = random.choice([cls for cls in dataset.class_to_idx.keys() if cls != anchor_class])
        negative_image_path = random.choice(class_to_images[negative_label])
        negative_image = load_image(negative_image_path).unsqueeze(0).to(device)
        negative_embedding = model(negative_image, return_embedding=True)
        negative_embedding = negative_embedding.cpu().detach().numpy().squeeze()
        
        anchors.append(anchor)
        positives.append(positive_embedding)
        negatives.append(negative_embedding)
    
    # Convert lists to tensors
    anchors = torch.tensor(np.array(anchors), dtype=torch.float32).to(device)
    positives = torch.tensor(np.array(positives), dtype=torch.float32).to(device)
    negatives = torch.tensor(np.array(negatives), dtype=torch.float32).to(device)
    
    return anchors, positives, negatives


In [7]:
# Training and Validation Functions
def train(model, device, train_loader, optimizer, epoch, triplet_loss_fn, classifier_loss_fn, dataset):
    model.train()
    total_triplet_loss = 0
    total_classifier_loss = 0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch}")):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        # Forward pass
        output, embedding = model(data)
        classifier_loss = classifier_loss_fn(output, target)

        # Generate triplets
        anchor, positive, negative = generate_triplets(embedding, target, model, dataset)
        triplet_loss = triplet_loss_fn(anchor, positive, negative)

        # Backward pass and optimize
        loss = classifier_loss + triplet_loss
        loss.backward()
        optimizer.step()

        total_triplet_loss += triplet_loss.item()
        total_classifier_loss += classifier_loss.item()

        # Track accuracy
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        total += target.size(0)

    avg_triplet_loss = total_triplet_loss / len(train_loader)
    avg_classifier_loss = total_classifier_loss / len(train_loader)
    accuracy = 100. * correct / total

    print(f'Epoch: {epoch} Triplet Loss: {avg_triplet_loss}, Classifier Loss: {avg_classifier_loss}, Accuracy: {accuracy}%')
    return avg_triplet_loss, avg_classifier_loss, accuracy

In [8]:
def validate(model, device, val_loader, classifier_loss_fn, triplet_loss_fn, dataset):
    model.eval()
    val_triplet_loss = 0
    val_classifier_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)

            output, embedding = model(data)
            classifier_loss = classifier_loss_fn(output, target)

            # Generate triplets
            anchor, positive, negative = generate_triplets(embedding, target, model, dataset)
            triplet_loss = triplet_loss_fn(anchor, positive, negative)

            val_triplet_loss += triplet_loss.item()
            val_classifier_loss += classifier_loss.item()

            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    avg_val_triplet_loss = val_triplet_loss / len(val_loader)
    avg_val_classifier_loss = val_classifier_loss / len(val_loader)
    accuracy = 100. * correct / len(val_loader.dataset)

    print(f'\nValidation set: Average Triplet Loss: {avg_val_triplet_loss:.4f}, Average Classifier Loss: {avg_val_classifier_loss:.4f}, Accuracy: {correct}/{len(val_loader.dataset)} ({accuracy:.2f}%)\n')
    return avg_val_triplet_loss, avg_val_classifier_loss, accuracy


In [9]:
# Hyperparameters
batch_size = 32
epochs = 10
learning_rate = 0.001
momentum = 0.9
log_interval = 10
margin = 1.0

In [10]:
class CustomDataset(Dataset):
    def __init__(self, json_file, transform=None):
        with open(json_file, 'r') as f:
            self.data = json.load(f)
        self.transform = transform
        self.classes = list(set(item['class'] for item in self.data))
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        self.idx_to_class = {idx: cls_name for cls_name, idx in self.class_to_idx.items()}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]['image_path']
        label = self.class_to_idx[self.data[idx]['class']]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

# Paths to your JSON files
train_json = '/workspaces/finetune/AFINAL/resnet/output/train_data.json'
val_json = '/workspaces/finetune/AFINAL/resnet/output/val_data.json'
test_json = '/workspaces/finetune/AFINAL/resnet/output/test_data.json'

# Create datasets and data loaders
train_dataset = CustomDataset(json_file=train_json, transform=transform)
val_dataset = CustomDataset(json_file=val_json, transform=transform)
test_dataset = CustomDataset(json_file=test_json, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [11]:
# Model, Loss, Optimizer, and Scheduler
model = CustomResNet50(num_classes=len(train_dataset.classes)).to(device)
triplet_loss_fn = TripletLoss(margin=1.0)
classifier_loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)



In [12]:
# Training and Validation Loop
best_val_loss = float('inf')
early_stopping_patience = 10
early_stopping_counter = 0

train_triplet_losses = []
train_classifier_losses = []
train_accuracies = []
val_triplet_losses = []
val_classifier_losses = []
val_accuracies = []

In [13]:
for epoch in range(1, 51):
    train_triplet_loss, train_classifier_loss, train_accuracy = train(model, device, train_loader, optimizer, epoch, triplet_loss_fn, classifier_loss_fn, train_dataset)
    val_triplet_loss, val_classifier_loss, val_accuracy = validate(model, device, val_loader, classifier_loss_fn, triplet_loss_fn, val_dataset)

    train_triplet_losses.append(train_triplet_loss)
    train_classifier_losses.append(train_classifier_loss)
    train_accuracies.append(train_accuracy)
    val_triplet_losses.append(val_triplet_loss)
    val_classifier_losses.append(val_classifier_loss)
    val_accuracies.append(val_accuracy)

    scheduler.step(val_classifier_loss)  # Adjust learning rate based on validation classifier loss

    # Early stopping
    if val_classifier_loss < best_val_loss:
        best_val_loss = val_classifier_loss
        early_stopping_counter = 0
        # Save the best model
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= early_stopping_patience:
            print("Early stopping triggered")
            break

Epoch 1:   1%|          | 1/113 [00:11<20:34, 11.02s/it]

In [None]:
# Load the best model for testing
model.load_state_dict(torch.load('best_model.pth'))
test_loss, test_accuracy = validate(model, device, test_loader, classifier_loss_fn, triplet_loss_fn, test_dataset)


In [None]:
# Plotting the results
epochs = range(1, len(train_triplet_losses) + 1)

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs, train_triplet_losses, label='Train Triplet Loss')
plt.plot(epochs, val_triplet_losses, label='Validation Triplet Loss')
plt.plot(epochs, train_classifier_losses, label='Train Classifier Loss')
plt.plot(epochs, val_classifier_losses, label='Validation Classifier Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs, train_accuracies, label='Train Accuracy')
plt.plot(epochs, val_accuracies, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()