In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader

In [2]:
!pip install scikit-learn



In [3]:
!pip install torchsummary



In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
checkpoint = torch.load('resnet_simclr_checkpoint.pth')
model_state_dict = checkpoint['model_state_dict']

In [6]:

resnet_model = models.resnet18(pretrained=False)  # Adjust model type if different
resnet_model.fc = nn.Identity()

# Load the model state
resnet_model.load_state_dict(model_state_dict)

# Ensure the model is on the correct device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
resnet_model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [7]:
print("Final layer:", resnet_model.fc)

Final layer: Identity()


In [8]:

# Initialize parameters for the linear layer
num_features = 512  # This is typical for ResNet-18's final layer output
num_classes = 10
linear_layer = nn.Linear(num_features, num_classes).to(device)

criterion = nn.CrossEntropyLoss()
# Optimizer for the linear layer
optimizer = torch.optim.Adam(linear_layer.parameters(), lr=0.001)

In [13]:

# Initialize parameters for the linear layer
num_features = 512  # This is typical for ResNet-18's final layer output
num_classes = 10
linear_layer = nn.Linear(num_features, num_classes).to(device)

# Optimizer for the linear layer
optimizer = torch.optim.Adam(linear_layer.parameters(), lr=0.001)

In [14]:
def apply_linear_classifier(features, linear_layer):
    return linear_layer(features)

In [11]:
transform = transforms.Compose([
    transforms.ToTensor()
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)

Files already downloaded and verified


In [15]:
num_epochs = 15
for epoch in range(num_epochs):
    total_loss = 0
    correct = 0
    total = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        # Assuming 'resnet_model' is your pre-trained ResNet model and it's already loaded
        with torch.no_grad():
            features = resnet_model(images)
        
        # Reset gradient
        optimizer.zero_grad()
        
        # Apply linear classifier
        outputs = apply_linear_classifier(features, linear_layer)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Update weights
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    # Print statistics
    print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {total_loss / len(train_loader):.4f} Acc: {100 * correct / total:.2f}%")

Epoch [1/15] Loss: 2.0794 Acc: 22.93%
Epoch [2/15] Loss: 2.0143 Acc: 25.55%
Epoch [3/15] Loss: 2.0024 Acc: 26.17%
Epoch [4/15] Loss: 1.9929 Acc: 26.69%
Epoch [5/15] Loss: 1.9843 Acc: 26.88%
Epoch [6/15] Loss: 1.9828 Acc: 27.45%
Epoch [7/15] Loss: 1.9782 Acc: 27.36%
Epoch [8/15] Loss: 1.9795 Acc: 27.12%
Epoch [9/15] Loss: 1.9750 Acc: 27.24%
Epoch [10/15] Loss: 1.9737 Acc: 27.47%
Epoch [11/15] Loss: 1.9720 Acc: 27.92%
Epoch [12/15] Loss: 1.9713 Acc: 27.69%
Epoch [13/15] Loss: 1.9677 Acc: 28.14%
Epoch [14/15] Loss: 1.9703 Acc: 27.78%
Epoch [15/15] Loss: 1.9699 Acc: 27.79%


In [17]:
checkpoint_resnet_linear = {
   'resnet_state_dict': resnet_model.state_dict(),
   'linear_state_dict': linear_layer.state_dict(),
}
torch.save(checkpoint_resnet_linear, 'linear_resnet_classifier.pth')

In [18]:
checkpoint = torch.load('linear_resnet_classifier.pth', map_location=device)
resnet_model.load_state_dict(checkpoint['resnet_state_dict'])
linear_layer.load_state_dict(checkpoint['linear_state_dict'])

<All keys matched successfully>

In [20]:
def forward_pass(images):
    resnet_model.eval()
    linear_layer.eval()
    with torch.no_grad():
        features = resnet_model(images)
        outputs = linear_layer(features)
    return outputs

In [21]:
test_transform = transforms.Compose([
    transforms.ToTensor()
])

In [22]:
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)
test_size = len(test_dataset)
print(f"Size of the test set: {test_size} samples")

Files already downloaded and verified
Size of the test set: 10000 samples


In [None]:
criterion = nn.CrossEntropyLoss()

running_loss = 0.0
total_correct = 0
total_images = 0
for images, labels in test_loader:
    images = images.to(device)
    labels = labels.to(device)
    
    outputs = forward_pass(images)

    loss = criterion(outputs, labels)
    running_loss += loss.item()
    
    _, predicted = torch.max(outputs, 1)
    total_correct += (predicted == labels).sum().item()
    total_images += labels.size(0)

average_loss = running_loss / len(test_loader)
accuracy = total_correct / total_images * 100
print(f'Test Accuracy: {accuracy:.2f}%')