In [1]:
import torch
print(torch.cuda.is_available())

True


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
import torch
for i in range(torch.cuda.device_count()):
   print(torch.cuda.get_device_properties(i).name)

Quadro RTX 6000
Quadro RTX 6000


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets, transforms
from torchvision.models import resnet18
import random
import itertools
import numpy as np

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Predefine 1000 fixed permutations from the 9! possible permutations for the jigsaw task
def create_permutations(num_patches=9, num_permutations=1000):
    all_permutations = list(itertools.permutations(range(num_patches)))
    random.seed(42)  # For reproducibility
    selected_permutations = random.sample(all_permutations, num_permutations)
    return selected_permutations

permutations = create_permutations()

def create_jigsaw_puzzle(image, grid_size=3, permutations=permutations):
    _, height, width = image.shape
    patch_h, patch_w = height // grid_size, width // grid_size
    patches = []

    # Extract the patches from the image
    for i in range(grid_size):
        for j in range(grid_size):
            patch = image[:, i * patch_h: (i + 1) * patch_h, j * patch_w: (j + 1) * patch_w]
            patches.append(patch)

    # Select a random permutation index from the predefined set
    perm_class = random.choice(range(len(permutations)))
    perm = permutations[perm_class]

    # Shuffle the patches based on the selected permutation
    shuffled_patches = [patches[i] for i in perm]
    
    return torch.stack(shuffled_patches), torch.tensor(perm_class, dtype=torch.long)

In [5]:
# Custom Dataset for Jigsaw Pretext Task
from torchvision.datasets import STL10
class JigsawSTL10Dataset(Dataset):
    def __init__(self, dataset, grid_size=3):
        self.dataset = dataset
        self.grid_size = grid_size

    def __getitem__(self, index):
        img, _ = self.dataset[index]
        shuffled_patches, perm_class = create_jigsaw_puzzle(img, self.grid_size)
        return shuffled_patches, perm_class

    def __len__(self):
        return len(self.dataset)

# Add Data Augmentation
transform = transforms.Compose([
    # transforms.Resize(255),
    transforms.RandomHorizontalFlip(),     # Random horizontal flip
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),  # Adjust brightness, contrast, and saturation
    transforms.RandomGrayscale(p=0.1),     # Randomly convert to grayscale with a probability of 0.1
    transforms.ToTensor(),                 # Convert to tensor
])

# Dataset and Dataloader
train_dataset = STL10(root='./data',  split='train+unlabeled', download=True, transform=transform)

# Split dataset into training and validation sets
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

jigsaw_train_dataset = JigsawSTL10Dataset(train_dataset)
jigsaw_val_dataset = JigsawSTL10Dataset(val_dataset)

jigsaw_train_loader = DataLoader(jigsaw_train_dataset, batch_size=64, shuffle=True, num_workers=12)
jigsaw_val_loader = DataLoader(jigsaw_val_dataset, batch_size=64, shuffle=False, num_workers=12)

Files already downloaded and verified


In [6]:
# ResNet-based Jigsaw Puzzle Solver
class JigsawResNet(nn.Module):
    def __init__(self, num_patches, num_permutations):
        super(JigsawResNet, self).__init__()
        # ResNet-18 backbone
        self.backbone = resnet18(weights=None)
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])  # Remove last fully connected layer

        # Fully connected layer for jigsaw permutation classification
        self.fc = nn.Sequential(
            nn.Linear(512 * num_patches, 1024),  # 512 is output feature size from ResNet, num_patches = 9
            nn.ReLU(),
            nn.Linear(1024, num_permutations)
        )

    def forward(self, x):
        # x.shape = [batch_size, num_patches, channels, height, width]
        batch_size, num_patches, channels, height, width = x.shape
        patches = []

        # Process each patch independently
        for i in range(num_patches):
            patch_features = self.backbone(x[:, i])  # Extract features for each patch
            patch_features = torch.flatten(patch_features, start_dim=1)  # Flatten each patch's features
            patches.append(patch_features)
            # patch = x[:, i]  # Extract the i-th patch: [batch_size, channels, height, width]
            # patch_features = self.backbone(patch)  # Pass through backbone
            
            # patch_features = patch_features.view(batch_size, -1)  # Flatten to [batch_size, 512]
            # patches.append(patch_features)

        # Concatenate features from all patches
        concatenated_features = torch.cat(patches, dim=1)  # Shape: [batch_size, 512 * num_patches]
        
        # Pass concatenated features through fully connected layers
        output = self.fc(concatenated_features)
        return output

# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_permutations = 1000
num_patches = 9

model = JigsawResNet(num_patches=num_patches, num_permutations=num_permutations).to(device)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

model.apply(weights_init)

JigsawResNet(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats

In [7]:
import os

checkpoint_dir = '/users/nladdha/my_env/model/'
os.makedirs(checkpoint_dir, exist_ok=True)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)


# Training Jigsaw Model with Validation for Early Stopping
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

epochs = 100 # Pretext training for 100 epochs
best_val_loss = float('inf')
early_stopping_patience = 10
no_improve_epochs = 0

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    
    # Training loop
    for batch_idx, (shuffled_patches, perm_class) in enumerate(jigsaw_train_loader):
        shuffled_patches, perm_class = shuffled_patches.to(device), perm_class.to(device)

        optimizer.zero_grad()
        outputs = model(shuffled_patches)
        loss = criterion(outputs, perm_class)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

    avg_train_loss = running_loss / len(jigsaw_train_loader)
    print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {avg_train_loss:.4f}")    

    # Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_batch, (val_patches, val_perm_class) in enumerate(jigsaw_val_loader):
            val_patches, val_perm_class = val_patches.to(device), val_perm_class.to(device)
            val_outputs = model(val_patches)
            val_loss += criterion(val_outputs, val_perm_class).item()

    avg_val_loss = val_loss / len(jigsaw_val_loader)
    print(f"Epoch [{epoch+1}/{epochs}], Validation Loss: {avg_val_loss:.4f}")

    # Model checkpointing based on validation loss
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        model_path = os.path.join(checkpoint_dir, 'stl_jigsaw_model.pth')
        torch.save(model.state_dict(), model_path)
        print(f"Model saved at Epoch {epoch+1}")
        no_improve_epochs = 0
    else:
        no_improve_epochs += 1

    # Early stopping based on validation loss
    if no_improve_epochs >= early_stopping_patience:
        print("Early stopping due to no improvement in validation loss.")
        break

Epoch [1/100], Training Loss: 6.9357
Epoch [1/100], Validation Loss: 6.9078
Model saved at Epoch 1
Epoch [2/100], Training Loss: 6.9082
Epoch [2/100], Validation Loss: 6.9059
Model saved at Epoch 2
Epoch [3/100], Training Loss: 6.3515
Epoch [3/100], Validation Loss: 5.2587
Model saved at Epoch 3
Epoch [4/100], Training Loss: 4.7777
Epoch [4/100], Validation Loss: 4.2391
Model saved at Epoch 4
Epoch [5/100], Training Loss: 4.1442
Epoch [5/100], Validation Loss: 3.8913
Model saved at Epoch 5
Epoch [6/100], Training Loss: 3.8409
Epoch [6/100], Validation Loss: 3.6581
Model saved at Epoch 6
Epoch [7/100], Training Loss: 3.6283
Epoch [7/100], Validation Loss: 3.4874
Model saved at Epoch 7
Epoch [8/100], Training Loss: 3.4455
Epoch [8/100], Validation Loss: 3.2870
Model saved at Epoch 8
Epoch [9/100], Training Loss: 3.2531
Epoch [9/100], Validation Loss: 3.1073
Model saved at Epoch 9
Epoch [10/100], Training Loss: 3.0156
Epoch [10/100], Validation Loss: 2.8327
Model saved at Epoch 10
Epoch [

In [8]:
# ResNet Classifier for STL-10 (Downstream Task)
class ResNetClassifier(nn.Module):
    def __init__(self, pretrained_model, num_classes=10):
        super(ResNetClassifier, self).__init__()
        self.backbone = pretrained_model.module.backbone  # Use the backbone from the pre-trained jigsaw model
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)  # Flatten feature map
        x = self.fc(x)
        return x

# Load the pre-trained jigsaw model
classification_model = ResNetClassifier(model, num_classes=10).to(device)

# Freeze the backbone for fine-tuning
for param in classification_model.backbone.parameters():
    param.requires_grad = False

In [9]:
# CIFAR-10 DataLoader for classification task
train_dataset = datasets.STL10(root='./data', split='train', download=True, transform=transform)
classification_train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_dataset = datasets.STL10(root='./data', split='test', download=True, transform=transform)
classification_test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [10]:
# Fine-tuning loop
classification_criterion = nn.CrossEntropyLoss()
classification_optimizer = optim.Adam(classification_model.parameters(), lr=0.001)

classification_epochs = 150
best_val_loss = float('inf')
early_stopping_patience = 10
no_improve_epochs = 0

for epoch in range(classification_epochs):
    classification_model.train()
    running_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(classification_train_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        classification_optimizer.zero_grad()
        outputs = classification_model(inputs)
        loss = classification_criterion(outputs, targets)
        loss.backward()
        classification_optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(classification_train_loader)
    print(f"Epoch [{epoch+1}/{classification_epochs}], Loss: {avg_loss:.4f}")

    # Model checkpointing based on validation loss (use validation loader if needed)
    if avg_loss < best_val_loss:
        best_val_loss = avg_loss
        torch.save(classification_model.state_dict(), 'stl_classification_model.pth')
        print(f"Classification model saved at Epoch {epoch+1}")
        no_improve_epochs = 0
    else:
        no_improve_epochs += 1

    # Early stopping
    if no_improve_epochs >= early_stopping_patience:
        print("Early stopping")
        break

# Evaluation with Top-1 and Top-5 Accuracy
classification_model.eval()
correct = 0
top_5_correct = 0
total = 0

with torch.no_grad():
    for images, labels in classification_test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = classification_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Top-5 accuracy
        _, predicted_5 = torch.topk(outputs.data, k=5, dim=1)
        top_5_correct += (predicted_5 == labels.view(-1, 1)).sum().item()

accuracy = 100 * correct / total
top_5_accuracy = 100 * top_5_correct / total
print(f'Top-1 Accuracy: {accuracy:.2f}%')
print(f'Top-5 Accuracy: {top_5_accuracy:.2f}%')

Epoch [1/150], Loss: 2.0103
Classification model saved at Epoch 1
Epoch [2/150], Loss: 1.7831
Classification model saved at Epoch 2
Epoch [3/150], Loss: 1.6614
Classification model saved at Epoch 3
Epoch [4/150], Loss: 1.6058
Classification model saved at Epoch 4
Epoch [5/150], Loss: 1.5534
Classification model saved at Epoch 5
Epoch [6/150], Loss: 1.5446
Classification model saved at Epoch 6
Epoch [7/150], Loss: 1.5001
Classification model saved at Epoch 7
Epoch [8/150], Loss: 1.4928
Classification model saved at Epoch 8
Epoch [9/150], Loss: 1.4673
Classification model saved at Epoch 9
Epoch [10/150], Loss: 1.4194
Classification model saved at Epoch 10
Epoch [11/150], Loss: 1.4043
Classification model saved at Epoch 11
Epoch [12/150], Loss: 1.3988
Classification model saved at Epoch 12
Epoch [13/150], Loss: 1.3796
Classification model saved at Epoch 13
Epoch [14/150], Loss: 1.3632
Classification model saved at Epoch 14
Epoch [15/150], Loss: 1.3611
Classification model saved at Epoch 1