In [1]:
from google.colab import drive
import os
drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/Colab Notebooks/final_project')

Mounted at /content/drive


In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet50
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "mps")
device

device(type='cuda')

## method 2

In [3]:
class ResNet50(nn.Module):
    def __init__(self, projection_dim=128):
        super(ResNet50, self).__init__()
        self.resnet50 = models.resnet50(pretrained=False)
        self.resnet50.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.resnet50.maxpool = nn.Identity()
        feature_dim = self.resnet50.fc.in_features
        self.resnet50.fc = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, projection_dim)
        )

    def forward(self, x):
        projection = self.resnet50(x)
        return projection

In [4]:
def color_distortion(s=0.5):
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
    return color_distort

# Set the strength of color distortion
s = 0.5

# train dataset
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    color_distortion(s),
    transforms.ToTensor(),
    # transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])
# test_transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
# ])
test_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    color_distortion(s),
    transforms.ToTensor(),
    # transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

In [5]:
from torchvision.datasets import CIFAR10
from PIL import Image
# from dataset import CIFAR10Pair, test_CIFAR10Pair

class CIFAR10Pair(CIFAR10):
    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)
        # Apply transformations twice to get a pair of different augmentations
        img1 = train_transform(img)
        img2 = train_transform(img)
        return img1, img2, target

class test_CIFAR10Pair(CIFAR10):
    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)
        # Apply transformations twice to get a pair of different augmentations
        img1 = test_transform(img)
        img2 = test_transform(img)
        return img1, img2, target

# Initialize the CIFAR-10 Pair dataset
train_dataset = CIFAR10Pair(root='./cifar10', train=True, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

# # Initialize the CIFAR-10 Pair dataset
test_dataset = test_CIFAR10Pair(root='./cifar10', train=False, download=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
def nt_xent_loss(z_i, z_j, temperature):
    """
    Compute the NT-Xent loss.

    Arguments:
    z_i, z_j -- Representations of positive pairs. Each should be of shape (batch_size, feature_size).
    temperature -- A temperature scaling parameter.

    Returns:
    Loss computed from the batch of representations.
    """
    N, Z = z_i.shape  # Batch size and feature dimension

    # Normalize the representations
    z_i = F.normalize(z_i, p=2, dim=1)
    z_j = F.normalize(z_j, p=2, dim=1)

    # Concatenate the representations
    representations = torch.cat([z_i, z_j], dim=0)

    # Compute cosine similarity
    similarity_matrix = torch.matmul(representations, representations.T)

    # Create the mask for positive samples
    l_pos = torch.diag(similarity_matrix, N)
    r_pos = torch.diag(similarity_matrix, -N)
    positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)

    # Mask for removing the similarity of each element with itself
    diag_mask = ~(torch.eye(2 * N).bool())

    # Extract the negatives
    negatives = similarity_matrix[diag_mask].view(2 * N, -1)

    # Combine positives with negatives
    logits = torch.cat([positives, negatives], dim=1)

    # Apply temperature scaling
    logits /= temperature

    # Labels: positives are the first elements
    labels = torch.zeros(2 * N).to(z_i.device).long()

    # Calculate the cross-entropy loss
    loss = F.cross_entropy(logits, labels)

    return loss


In [7]:
def contrastive_accuracy(z_i, z_j, labels):
    with torch.no_grad():
        # Compute the cosine similarity
        similarity_matrix = F.cosine_similarity(z_i.unsqueeze(1), z_j.unsqueeze(0), dim=2)

        # Get the indices of the maximum values along each row
        max_indices = similarity_matrix.max(dim=1)[1]

        # Calculate accuracy
        correct = (labels == labels[max_indices]).float()
        return correct.mean()

In [9]:
model = ResNet50().to(device)

# Load the saved model state
model.load_state_dict(torch.load('simclr_resnet50_64_108to200ep.pt'))

# Optimizer setup
optimizer = optim.Adam(model.parameters(), lr=0.5)

# Number of additional epochs to train
num_epochs = 200
additional_epochs = 16
start_epoch = 185
end_epoch = start_epoch + additional_epochs

best_val_loss = float('inf')

train_loss_list = []
train_accuracy_list = []
test_loss_list = []
test_accuracy_list = []


# Training loop
for epoch in range(start_epoch, end_epoch + 1):
    model.train()
    total_loss = 0
    total_accuracy = 0

    for batch_idx, (img1, img2, labels) in enumerate(train_loader):
        img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)

        # Forward pass
        z_i = model(img1)
        z_j = model(img2)

        # Compute NT-Xent Loss
        loss = nt_xent_loss(z_i, z_j, temperature=0.5)

        # Compute accuracy
        # train_accuracy = contrastive_accuracy(z_i, z_j)
        train_accuracy = contrastive_accuracy(z_i, z_j, labels)
        total_accuracy += train_accuracy.item()


        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Print loss and accuracy every 10 batches
        if (batch_idx + 1) % 200 == 0:
            print(f'Epoch [{epoch}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Train Loss: {loss.item():.4f}, Train Accuracy: {train_accuracy:.4f}')

    # Average loss and accuracy for this epoch
    avg_loss = total_loss / len(train_loader)
    avg_accuracy = total_accuracy / len(train_loader)

    # Evaluate on test set
    model.eval()
    total_test_loss = 0
    total_test_accuracy = 0
    with torch.no_grad():
        for batch_idx, (img1, img2, labels) in enumerate(test_loader):
            img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
            z_i = model(img1)
            z_j = model(img2)
            # Assuming you have a way to calculate test loss, e.g., CrossEntropyLoss for classification
            test_loss = nt_xent_loss(z_i, z_j, temperature=0.5)
            total_test_loss += test_loss.item()
            # Compute accuracy
            # test_accuracy = contrastive_accuracy(z_i, z_j)
            test_accuracy = contrastive_accuracy(z_i, z_j, labels)

            total_test_accuracy += test_accuracy.item()
            avg_test_loss = total_test_loss / len(test_loader)
            avg_test_accuracy = total_test_accuracy / len(test_loader)
    # test_loss, test_accuracy = evaluate_on_test(model, test_loader, device)
    print(f'Epoch [{epoch}/{num_epochs}], Train Loss: {avg_loss:.4f}, Train Accuracy: {avg_accuracy:.4f}, Test Loss: {avg_test_loss:.4f}, Test Accuracy: {avg_test_accuracy:.4f}')
    train_loss_list.append(avg_loss)
    train_accuracy_list.append(avg_accuracy)
    test_loss_list.append(avg_test_loss)
    test_accuracy_list.append(avg_test_accuracy)

    # Save model if it has best test loss yet
    if avg_test_loss < best_val_loss:
        best_val_loss = avg_test_loss
        torch.save(model.state_dict(), f'simclr_resnet50_64_185to200ep.pt')
        print(f"Saved best model at epoch {epoch}, with test loss: {best_val_loss:.4f}, test accuracy: {avg_test_accuracy:.4f}")

print("Training Completed")



Epoch [185/200], Batch [200/781], Train Loss: 3.3530, Train Accuracy: 0.9375
Epoch [185/200], Batch [400/781], Train Loss: 3.3287, Train Accuracy: 0.9219
Epoch [185/200], Batch [600/781], Train Loss: 3.3439, Train Accuracy: 0.9062
Epoch [185/200], Train Loss: 3.3391, Train Accuracy: 0.9021, Test Loss: 3.3754, Test Accuracy: 0.8959
Saved best model at epoch 185, with test loss: 3.3754, test accuracy: 0.8959
Epoch [186/200], Batch [200/781], Train Loss: 3.3715, Train Accuracy: 0.8594
Epoch [186/200], Batch [400/781], Train Loss: 3.3926, Train Accuracy: 0.8750
Epoch [186/200], Batch [600/781], Train Loss: 3.2461, Train Accuracy: 0.9531
Epoch [186/200], Train Loss: 3.3398, Train Accuracy: 0.9004, Test Loss: 3.3741, Test Accuracy: 0.8872
Saved best model at epoch 186, with test loss: 3.3741, test accuracy: 0.8872
Epoch [187/200], Batch [200/781], Train Loss: 3.2846, Train Accuracy: 0.9531
Epoch [187/200], Batch [400/781], Train Loss: 3.3157, Train Accuracy: 0.8906
Epoch [187/200], Batch [60

In [12]:
max(test_accuracy_list)

0.9008413461538461