In [1]:
from google.colab import drive
import os
drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/Colab Notebooks/final_project')

Mounted at /content/drive


In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet50
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "mps")
device

device(type='cuda')

## method 2

In [3]:
class ResNet50(nn.Module):
    def __init__(self, projection_dim=128):
        super(ResNet50, self).__init__()
        self.resnet50 = models.resnet50(pretrained=False)
        self.resnet50.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.resnet50.maxpool = nn.Identity()
        feature_dim = self.resnet50.fc.in_features
        self.resnet50.fc = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Linear(256, projection_dim)
        )

    def forward(self, x):
        projection = self.resnet50(x)
        return projection

In [4]:
def color_distortion(s=0.5):
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
    return color_distort

# Set the strength of color distortion
s = 0.5

# train dataset
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    color_distortion(s),
    transforms.ToTensor(),
    # transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])
# test_transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
# ])
test_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    color_distortion(s),
    transforms.ToTensor(),
    # transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

In [5]:
from torchvision.datasets import CIFAR10
from PIL import Image
# from dataset import CIFAR10Pair, test_CIFAR10Pair

class CIFAR10Pair(CIFAR10):
    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)
        # Apply transformations twice to get a pair of different augmentations
        img1 = train_transform(img)
        img2 = train_transform(img)
        return img1, img2, target

class test_CIFAR10Pair(CIFAR10):
    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)
        # Apply transformations twice to get a pair of different augmentations
        img1 = test_transform(img)
        img2 = test_transform(img)
        return img1, img2, target

# Initialize the CIFAR-10 Pair dataset
train_dataset = CIFAR10Pair(root='./cifar10', train=True, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

# # Initialize the CIFAR-10 Pair dataset
test_dataset = test_CIFAR10Pair(root='./cifar10', train=False, download=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
def nt_xent_loss(z_i, z_j, temperature):
    """
    Compute the NT-Xent loss.

    Arguments:
    z_i, z_j -- Representations of positive pairs. Each should be of shape (batch_size, feature_size).
    temperature -- A temperature scaling parameter.

    Returns:
    Loss computed from the batch of representations.
    """
    N, Z = z_i.shape  # Batch size and feature dimension

    # Normalize the representations
    z_i = F.normalize(z_i, p=2, dim=1)
    z_j = F.normalize(z_j, p=2, dim=1)

    # Concatenate the representations
    representations = torch.cat([z_i, z_j], dim=0)

    # Compute cosine similarity
    similarity_matrix = torch.matmul(representations, representations.T)

    # Create the mask for positive samples
    l_pos = torch.diag(similarity_matrix, N)
    r_pos = torch.diag(similarity_matrix, -N)
    positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)

    # Mask for removing the similarity of each element with itself
    diag_mask = ~(torch.eye(2 * N).bool())

    # Extract the negatives
    negatives = similarity_matrix[diag_mask].view(2 * N, -1)

    # Combine positives with negatives
    logits = torch.cat([positives, negatives], dim=1)

    # Apply temperature scaling
    logits /= temperature

    # Labels: positives are the first elements
    labels = torch.zeros(2 * N).to(z_i.device).long()

    # Calculate the cross-entropy loss
    loss = F.cross_entropy(logits, labels)

    return loss


In [7]:
def contrastive_accuracy(z_i, z_j, labels):
    with torch.no_grad():
        # Compute the cosine similarity
        similarity_matrix = F.cosine_similarity(z_i.unsqueeze(1), z_j.unsqueeze(0), dim=2)

        # Get the indices of the maximum values along each row
        max_indices = similarity_matrix.max(dim=1)[1]

        # Calculate accuracy
        correct = (labels == labels[max_indices]).float()
        return correct.mean()

In [28]:
class ResNet50WithLinear(nn.Module):
    def __init__(self, encoder):
        super(ResNet50WithLinear, self).__init__()
        self.encoder = encoder
        self.encoder.resnet50.fc = nn.Identity()
        self.linear = nn.Linear(2048, 10)
        for param in self.encoder.parameters():
            param.requires_grad = False

    def forward(self, x):
        out = self.encoder(x)
        out = self.linear(out)
        return out

In [9]:
train_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)
test_transform = transforms.Compose(
    [transforms.Resize(size=(32, 32)),
    transforms.ToTensor()]
)
trainset = CIFAR10(root='./cifar10', train=True, download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=0)

testset = CIFAR10(root='./cifar10', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [None]:
criterion = nn.CrossEntropyLoss()
from torch.optim import optimizer
from tqdm import tqdm_notebook as tqdm
rot_model = ResNet50()
rot_model.load_state_dict(torch.load("simclr_resnet50_256_200ep.pt"))
rot_model.cuda()
rot_linear_eval_model = ResNet50WithLinear(rot_model)
rot_linear_eval_model.cuda()
#rot_linear_eval_model.load_state_dict(torch.load("models/rot_model_semi_sup_30_128_0.001_0.1.pth"))

optimizer = torch.optim.Adam(rot_linear_eval_model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2)

best_loss = 9999999
num_epoch = 30
for epoch_idx in range(num_epoch):
    epoch_losses = 0
    epoch_corrects = 0
    rot_linear_eval_model.train()
    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()

        rot_linear_eval_model.zero_grad()
        out = rot_linear_eval_model(image)

        loss = criterion(out, label)
        loss.backward()
        optimizer.step()

        epoch_losses += loss

        pred = torch.argmax(out, dim=1)
        epoch_corrects += torch.sum(pred == label).item()

    epoch_losses /= len(trainloader)
    epoch_corrects /= len(trainset)


    with torch.no_grad():
        rot_linear_eval_model.eval()

        test_epoch_losses = 0
        test_epoch_corrects = 0

        for batch_idx, data in enumerate(tqdm(testloader)):
            image, label = data
            image = image.cuda()
            label = label.cuda()

            out = rot_linear_eval_model(image)

            loss = criterion(out, label)

            test_epoch_losses += loss

            pred = torch.argmax(out, dim=1)
            test_epoch_corrects += torch.sum(pred == label).item()

        test_epoch_losses /= len(testloader)
        test_epoch_corrects /= len(testset)

        if test_epoch_losses < best_loss:
            best_loss = test_epoch_losses
            torch.save(rot_linear_eval_model.state_dict(), f'simclr_linear_eval_256_30ep.pth')

    scheduler.step()
    print(f'Train Loss {epoch_losses} Acc {epoch_corrects} ; Val Loss {test_epoch_losses} Acc {test_epoch_corrects}')

In [20]:

rot_model = ResNet50()
rot_model.load_state_dict(torch.load("simclr_resnet50_256_200ep.pt"))
rot_model.cuda()
rot_linear_eval_model = ResNet50WithLinear(rot_model)
rot_linear_eval_model.load_state_dict(torch.load('simclr_linear_eval_256_30ep.pth'))
rot_linear_eval_model.cuda()

with torch.no_grad():
    rot_linear_eval_model.eval()

    test_epoch_losses = 0
    test_epoch_corrects = 0

    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()

        out = rot_linear_eval_model(image)

        loss = criterion(out, label)

        test_epoch_losses += loss

        pred = torch.argmax(out, dim=1)
        test_epoch_corrects += torch.sum(pred == label).item()

    test_epoch_losses /= len(trainloader)
    test_epoch_corrects /= len(trainset)

print(f'Test Loss {test_epoch_losses} Acc {test_epoch_corrects}')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_idx, data in enumerate(tqdm(trainloader)):


  0%|          | 0/391 [00:00<?, ?it/s]

Test Loss 1.1886874437332153 Acc 0.9032


In [None]:
trainset = CIFAR10(root='./cifar10', train=True, download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)

testset = CIFAR10(root='./cifar10', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
criterion = nn.CrossEntropyLoss()
from torch.optim import optimizer
from tqdm import tqdm_notebook as tqdm
rot_model = ResNet50()
rot_model.load_state_dict(torch.load("simclr_resnet50_128_200ep.pt"))
rot_model.cuda()
rot_linear_eval_model = ResNet50WithLinear(rot_model)
rot_linear_eval_model.cuda()
#rot_linear_eval_model.load_state_dict(torch.load("models/rot_model_semi_sup_30_128_0.001_0.1.pth"))

optimizer = torch.optim.Adam(rot_linear_eval_model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2)

best_loss = 9999999
num_epoch = 30
for epoch_idx in range(num_epoch):
    epoch_losses = 0
    epoch_corrects = 0
    rot_linear_eval_model.train()
    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()

        rot_linear_eval_model.zero_grad()
        out = rot_linear_eval_model(image)

        loss = criterion(out, label)
        loss.backward()
        optimizer.step()

        epoch_losses += loss

        pred = torch.argmax(out, dim=1)
        epoch_corrects += torch.sum(pred == label).item()

    epoch_losses /= len(trainloader)
    epoch_corrects /= len(trainset)


    with torch.no_grad():
        rot_linear_eval_model.eval()

        test_epoch_losses = 0
        test_epoch_corrects = 0

        for batch_idx, data in enumerate(tqdm(testloader)):
            image, label = data
            image = image.cuda()
            label = label.cuda()

            out = rot_linear_eval_model(image)

            loss = criterion(out, label)

            test_epoch_losses += loss

            pred = torch.argmax(out, dim=1)
            test_epoch_corrects += torch.sum(pred == label).item()

        test_epoch_losses /= len(testloader)
        test_epoch_corrects /= len(testset)

        if test_epoch_losses < best_loss:
            best_loss = test_epoch_losses
            torch.save(rot_linear_eval_model.state_dict(), f'simclr_linear_eval_128_30ep.pth')

    scheduler.step()
    print(f'Train Loss {epoch_losses} Acc {epoch_corrects} ; Val Loss {test_epoch_losses} Acc {test_epoch_corrects}')

In [22]:
rot_model = ResNet50()
rot_model.load_state_dict(torch.load("simclr_resnet50_128_200ep.pt"))
rot_model.cuda()
rot_linear_eval_model = ResNet50WithLinear(rot_model)
rot_linear_eval_model.load_state_dict(torch.load('simclr_linear_eval_128_30ep.pth'))
rot_linear_eval_model.cuda()

with torch.no_grad():
    rot_linear_eval_model.eval()

    test_epoch_losses = 0
    test_epoch_corrects = 0

    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()

        out = rot_linear_eval_model(image)

        loss = criterion(out, label)

        test_epoch_losses += loss

        pred = torch.argmax(out, dim=1)
        test_epoch_corrects += torch.sum(pred == label).item()

    test_epoch_losses /= len(trainloader)
    test_epoch_corrects /= len(trainset)

print(f'Test Loss {test_epoch_losses} Acc {test_epoch_corrects}')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_idx, data in enumerate(tqdm(trainloader)):


  0%|          | 0/391 [00:00<?, ?it/s]

Test Loss 3.3800535202026367 Acc 0.8962


In [None]:
trainset = CIFAR10(root='./cifar10', train=True, download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=0)

testset = CIFAR10(root='./cifar10', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
criterion = nn.CrossEntropyLoss()
from torch.optim import optimizer
from tqdm import tqdm_notebook as tqdm
rot_model = ResNet50()
rot_model.load_state_dict(torch.load("simclr_resnet50_64_200ep.pt"))
rot_model.cuda()
rot_linear_eval_model = ResNet50WithLinear(rot_model)
rot_linear_eval_model.cuda()
#rot_linear_eval_model.load_state_dict(torch.load("models/rot_model_semi_sup_30_128_0.001_0.1.pth"))

optimizer = torch.optim.Adam(rot_linear_eval_model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2)

best_loss = 9999999
num_epoch = 30
for epoch_idx in range(num_epoch):
    epoch_losses = 0
    epoch_corrects = 0
    rot_linear_eval_model.train()
    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()

        rot_linear_eval_model.zero_grad()
        out = rot_linear_eval_model(image)

        loss = criterion(out, label)
        loss.backward()
        optimizer.step()

        epoch_losses += loss

        pred = torch.argmax(out, dim=1)
        epoch_corrects += torch.sum(pred == label).item()

    epoch_losses /= len(trainloader)
    epoch_corrects /= len(trainset)


    with torch.no_grad():
        rot_linear_eval_model.eval()

        test_epoch_losses = 0
        test_epoch_corrects = 0

        for batch_idx, data in enumerate(tqdm(testloader)):
            image, label = data
            image = image.cuda()
            label = label.cuda()

            out = rot_linear_eval_model(image)

            loss = criterion(out, label)

            test_epoch_losses += loss

            pred = torch.argmax(out, dim=1)
            test_epoch_corrects += torch.sum(pred == label).item()

        test_epoch_losses /= len(testloader)
        test_epoch_corrects /= len(testset)

        if test_epoch_losses < best_loss:
            best_loss = test_epoch_losses
            torch.save(rot_linear_eval_model.state_dict(), f'simclr_linear_eval_64_30ep.pth')

    scheduler.step()
    print(f'Train Loss {epoch_losses} Acc {epoch_corrects} ; Val Loss {test_epoch_losses} Acc {test_epoch_corrects}')

In [30]:
rot_model = ResNet50()
rot_model.load_state_dict(torch.load("simclr_resnet50_64_200ep.pt"))
rot_model.cuda()
rot_linear_eval_model = ResNet50WithLinear(rot_model)
rot_linear_eval_model.load_state_dict(torch.load('simclr_linear_eval_64_30ep.pth'))
rot_linear_eval_model.cuda()

with torch.no_grad():
    rot_linear_eval_model.eval()

    test_epoch_losses = 0
    test_epoch_corrects = 0

    for batch_idx, data in enumerate(tqdm(trainloader)):
        image, label = data
        image = image.cuda()
        label = label.cuda()

        out = rot_linear_eval_model(image)

        loss = criterion(out, label)

        test_epoch_losses += loss

        pred = torch.argmax(out, dim=1)
        test_epoch_corrects += torch.sum(pred == label).item()

    test_epoch_losses /= len(trainloader)
    test_epoch_corrects /= len(trainset)

print(f'Test Loss {test_epoch_losses} Acc {test_epoch_corrects}')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_idx, data in enumerate(tqdm(trainloader)):


  0%|          | 0/782 [00:00<?, ?it/s]

Test Loss 8.981148719787598 Acc 0.87276
