In [1]:
import torch
import torch.nn as nn
from torchvision import models
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
import json
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
# CustomResNet50 Model Definition
class CustomResNet50(nn.Module):
    def __init__(self, num_classes, embedding_dim=256):
        super(CustomResNet50, self).__init__()
        original_resnet = models.resnet50(pretrained=True)
        num_ftrs = original_resnet.fc.in_features  # Get in_features from the original resnet fc layer
        self.resnet = nn.Sequential(*list(original_resnet.children())[:-2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.embedding = nn.Linear(num_ftrs, embedding_dim)
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def forward(self, x, return_embedding=False):
        x = self.resnet(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        embedding = self.embedding(x)
        if return_embedding:
            return embedding
        output = self.classifier(embedding)
        return output, embedding

In [12]:
# Triplet Loss Definition
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.triplet_loss = nn.TripletMarginLoss(margin=self.margin)

    def forward(self, anchor, positive, negative):
        return self.triplet_loss(anchor, positive, negative)

In [13]:
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def generate_triplets(embeddings, targets):
    """
    Generate triplets (anchor, positive, negative) from the embeddings and targets.

    Parameters:
    - embeddings: torch.Tensor of shape (batch_size, embedding_dim)
    - targets: torch.Tensor of shape (batch_size,)

    Returns:
    - anchors, positives, negatives: torch.Tensors of shape (batch_size, embedding_dim)
    """
    embeddings = embeddings.cpu().detach().numpy()
    targets = targets.cpu().detach().numpy()
    batch_size = len(targets)
    
    anchors = []
    positives = []
    negatives = []
    
    for i in range(batch_size):
        anchor = embeddings[i]
        anchor_label = targets[i]
        
        # Find positive example (same class as anchor)
        positive_indices = np.where(targets == anchor_label)[0]
        positive_indices = positive_indices[positive_indices != i]
        if len(positive_indices) == 0:
            continue  # Skip if no positive example found
        positive_index = np.random.choice(positive_indices)
        positive = embeddings[positive_index]
        
        # Find negative example (different class than anchor)
        negative_indices = np.where(targets != anchor_label)[0]
        if len(negative_indices) == 0:
            continue  # Skip if no negative example found
        negative_index = np.random.choice(negative_indices)
        negative = embeddings[negative_index]
        
        anchors.append(anchor)
        positives.append(positive)
        negatives.append(negative)
    
    # Convert lists to tensors
    anchors = torch.tensor(np.array(anchors), dtype=torch.float32).to(device)
    positives = torch.tensor(np.array(positives), dtype=torch.float32).to(device)
    negatives = torch.tensor(np.array(negatives), dtype=torch.float32).to(device)
    
    return anchors, positives, negatives

In [14]:
# Training and Validation Functions
def train(model, device, train_loader, optimizer, epoch, triplet_loss_fn, classifier_loss_fn):
    model.train()
    total_triplet_loss = 0
    total_classifier_loss = 0

    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        # Forward pass
        output, embedding = model(data)
        classifier_loss = classifier_loss_fn(output, target)

        # Generate triplets
        anchor, positive, negative = generate_triplets(embedding, target)
        triplet_loss = triplet_loss_fn(anchor, positive, negative)

        # Backward pass and optimize
        loss = classifier_loss + triplet_loss
        loss.backward()
        optimizer.step()

        total_triplet_loss += triplet_loss.item()
        total_classifier_loss += classifier_loss.item()

    print(f'Epoch: {epoch} Triplet Loss: {total_triplet_loss / len(train_loader)}, Classifier Loss: {total_classifier_loss / len(train_loader)}')


In [15]:
def validate(model, device, val_loader, classifier_loss_fn):
    model.eval()
    val_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)

            output, _ = model(data)
            val_loss += classifier_loss_fn(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    val_loss /= len(val_loader)
    accuracy = 100. * correct / len(val_loader.dataset)

    print(f'\nValidation set: Average loss: {val_loss:.4f}, Accuracy: {correct}/{len(val_loader.dataset)} ({accuracy:.2f}%)\n')
    return val_loss, accuracy


In [16]:
# Hyperparameters
batch_size = 32
epochs = 10
learning_rate = 0.001
momentum = 0.9
log_interval = 10
margin = 1.0

In [17]:
# Custom Dataset class
class CustomDataset(Dataset):
    def __init__(self, json_file, transform=None):
        with open(json_file, 'r') as f:
            self.data = json.load(f)
        self.transform = transform
        self.classes = list(set(item['class'] for item in self.data))
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]['image_path']
        label = self.class_to_idx[self.data[idx]['class']]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

# Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Paths to your JSON files
train_json = '/workspaces/finetune/AFINAL/resnet/output/train_data.json'
val_json = '/workspaces/finetune/AFINAL/resnet/output/val_data.json'
test_json = '/workspaces/finetune/AFINAL/resnet/output/test_data.json'

# Create datasets and data loaders
train_dataset = CustomDataset(json_file=train_json, transform=transform)
val_dataset = CustomDataset(json_file=val_json, transform=transform)
test_dataset = CustomDataset(json_file=test_json, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [18]:
# Model, Loss, Optimizer, and Scheduler
model = CustomResNet50(num_classes=len(train_dataset.classes)).to(device)
triplet_loss_fn = TripletLoss(margin=1.0)
classifier_loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)


In [19]:
model

CustomResNet50(
  (resnet): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2

In [None]:

# Training and Validation Loop
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch, triplet_loss_fn, classifier_loss_fn)
    validate(model, device, val_loader, classifier_loss_fn)
    scheduler.step()