In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader

In [2]:
!pip install scikit-learn



In [3]:
!pip install torchsummary



In [4]:
def ResNet18SimCLR():
    resnet = models.resnet18(pretrained=False)
    resnet.fc = nn.Identity()  # Remove the fully connected layer
    return resnet

In [5]:
def augment(x):
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
    ])
    x = torch.stack([transform(img) for img in x])
    return x

In [6]:
def projection_head(x_tuple, hidden_dim=256,device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    flattened_tensors = [tensor.flatten(start_dim=1) for tensor in x_tuple]
    # Concatenate the flattened tensors along dimension 1
    concatenated_tensor = torch.cat(flattened_tensors, dim=1)
    projection = nn.Sequential(
        nn.Linear(concatenated_tensor.shape[1], hidden_dim),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_dim, hidden_dim),
    ).to(device)  # Move projection head to the correct device
    return projection(concatenated_tensor)

In [7]:
def NTXentLoss(z1, z2, temperature=0.5):
    N = z1.size(0)
    z = torch.cat([z1, z2], dim=0)
    sim = torch.mm(z, z.t()) / (torch.norm(z, dim=1, keepdim=True) * torch.norm(z.t(), dim=0, keepdim=True))
    sim = sim / temperature
    sim_exp = torch.exp(sim)
    sim_exp = sim_exp - torch.eye(2 * N, device=z.device)
    sim_1_2 = sim_exp[:N, N:]
    sim_2_1 = sim_exp[N:, :N]
    sim_1_1 = sim_exp[:N, :N]
    sim_2_2 = sim_exp[N:, N:]

    loss = -(torch.log(sim_1_2.sum(dim=1) / (sim_1_1.sum(dim=1) + sim_1_2.sum(dim=1))).mean() +
             torch.log(sim_2_1.sum(dim=1) / (sim_2_1.sum(dim=1) + sim_2_2.sum(dim=1))).mean()) / 2
    return loss

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
transform = transforms.Compose([
    transforms.ToTensor()
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)

Files already downloaded and verified


In [10]:
resnet_model = ResNet18SimCLR().to(device)

# Optimizer
optimizer = optim.Adam(resnet_model.parameters(), lr=3e-4)

In [None]:
num_epochs = 15
temperature = 0.5
for epoch in range(num_epochs):
    total_loss = 0.0
    bs=0
    for i, (images, _) in enumerate(train_loader):
        images = images.to(device)

        # Forward pass
        images_aug1 = augment(images).to(device)
        images_aug2 = augment(images).to(device)

        features1 = projection_head((resnet_model(images_aug1),), device=device)
        features2 = projection_head((resnet_model(images_aug2),), device=device)

        # Calculate loss
        loss = NTXentLoss(features1, features2, temperature)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        print(f"batch [{bs}],Epoch going [{epoch}], Loss: {total_loss}")
        bs+=1
    # Print epoch loss
    print(f"epoch[{epoch+1}/{num_epochs}], Loss: {total_loss / len(train_loader):.4f}")
   
    #model checkpoint
    checkpoint_path = 'resnet_simclr_x-ray_checkpoint.pth'

    # Save the model checkpoint
    torch.save({
    'epoch': epoch,
    'model_state_dict': resnet_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': total_loss / len(train_loader),
    }, checkpoint_path)

In [9]:
checkpoint = torch.load('resnet_simclr_checkpoint.pth')
model_state_dict = checkpoint['model_state_dict']
   

In [10]:

resnet_model = models.resnet18(pretrained=False)  # Adjust model type if different
resnet_model.fc = nn.Identity()

# Load the model state
resnet_model.load_state_dict(model_state_dict)

# Ensure the model is on the correct device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
resnet_model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [11]:
print("Final layer:", resnet_model.fc)

Final layer: Identity()


In [32]:

# Ensure the model is in eval mode if not training (important for BatchNorm and Dropout layers)
resnet_model.eval()

# Forward pass through the model up to the average pooling layer
# This approach utilizes the model's inherent structure, ensuring correct layer application




Output size: torch.Size([1, 512, 1, 1])
Output size: torch.Size([1, 512, 1, 1])
Features shape just before the final layer: torch.Size([1, 512])


In [38]:

# Initialize parameters for the linear layer
num_features = 512  # This is typical for ResNet-18's final layer output
num_classes = 10
linear_layer = nn.Linear(num_features, num_classes).to(device)

criterion = nn.CrossEntropyLoss()
# Optimizer for the linear layer
optimizer = torch.optim.Adam(linear_layer.parameters(), lr=0.001)

In [13]:
transform = transforms.Compose([
    transforms.ToTensor()  # Converts image to Tensor and scales pixels between 0 and 1
])

In [14]:
from torchvision import datasets, transforms

train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)

In [15]:
def setup_model(original_model, num_features, num_classes, device):
    # Set the original fully connected layer to Identity, disabling it
    
    # Create a new classifier layer
    classifier = nn.Linear(num_features, num_classes).to(device)
    
    return original_model, classifier

num_features = 512  # Output size from the last pooling layer of ResNet
num_classes = 10    # Define the number of classes for your task
resnet_model, classifier = setup_model(resnet_model, num_features, num_classes, device)

In [16]:
def forward_pass(images, feature_extractor, classifier):
    # Forward through feature extractor
    features = feature_extractor(images)
    
    # Forward through classifier
    outputs = classifier(features)
    
    return outputs

In [20]:
import torch.optim as optim

# Setup optimizer for the classifier only
optimizer = optim.Adam(classifier.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

num_epochs = 15
for epoch in range(num_epochs):
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Assuming resizing and channel replication if necessary
        images = torch.nn.functional.interpolate(images, size=(32, 32))
        images = images.repeat(1, 3, 1, 1)

        # Compute outputs
        outputs = forward_pass(images, resnet_model, classifier)
        
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        print(f'Epoch Going {epoch+1}: Loss = {running_loss:.4f}, Accuracy = {100 * (correct / total)}%')
        
    epoch_loss = running_loss / len(train_loader)
    epoch_accuracy = 100 * correct / total
    print(f'Epoch {epoch+1}: Loss = {epoch_loss:.4f}, Accuracy = {epoch_accuracy:.2f}%')

Epoch Going 1: Loss = 1.2339, Accuracy = 59.375%
Epoch Going 1: Loss = 2.1694, Accuracy = 64.84375%
Epoch Going 1: Loss = 3.1997, Accuracy = 64.0625%
Epoch Going 1: Loss = 4.3987, Accuracy = 63.671875%
Epoch Going 1: Loss = 5.6626, Accuracy = 61.25000000000001%
Epoch Going 1: Loss = 7.0170, Accuracy = 60.416666666666664%
Epoch Going 1: Loss = 7.9288, Accuracy = 61.16071428571429%
Epoch Going 1: Loss = 8.9152, Accuracy = 61.9140625%
Epoch Going 1: Loss = 9.9543, Accuracy = 61.458333333333336%
Epoch Going 1: Loss = 11.1449, Accuracy = 61.71875%
Epoch Going 1: Loss = 12.1161, Accuracy = 61.93181818181818%
Epoch Going 1: Loss = 13.1333, Accuracy = 61.71875%
Epoch Going 1: Loss = 14.1060, Accuracy = 62.25961538461539%
Epoch Going 1: Loss = 15.0588, Accuracy = 62.38839285714286%
Epoch Going 1: Loss = 16.2242, Accuracy = 61.979166666666664%
Epoch Going 1: Loss = 17.2858, Accuracy = 62.3046875%
Epoch Going 1: Loss = 18.2051, Accuracy = 62.5%
Epoch Going 1: Loss = 18.9797, Accuracy = 63.1944444

In [24]:
checkpoint_path = 'resnet_fashion_checkpoint.pth'

    # Save the model checkpoint
torch.save({
    'epoch': epoch,
    'model_state_dict': resnet_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': running_loss / len(train_loader),
    }, checkpoint_path)