In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets, transforms
import torchvision
from torchvision.models import resnet18
import random
import itertools
import numpy as np
import matplotlib.pyplot as plt
import os
from inpainting import Encoder, Decoder, InpaintingModel, Discriminator, mask_image

In [2]:
# Set device (cuda or cpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device:', device)

Using device: cuda


In [3]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

In [4]:
class JigsawResNet(nn.Module):
    def __init__(self, num_patches, num_permutations, backbone=None):
        super(JigsawResNet, self).__init__()
        if backbone is not None:
            self.backbone = backbone  # Use the encoder from the inpainting model
            print("Using pretrained backbone from inpainting model")
        else:
            resnet = models.resnet18(weights=None)
            self.backbone = nn.Sequential(*list(resnet.children())[:-2])
            print("Using new backbone")
        
        self.num_patches = num_patches
        self.num_permutations = num_permutations
        
        self.fc = nn.Sequential(
            nn.Linear(512 * num_patches, 1024),
            nn.ReLU(),
            nn.Linear(1024, num_permutations)
        )
        
    def forward(self, x):
        batch_size, num_patches, channels, height, width = x.shape
        patches = []
        for i in range(num_patches):
            patch = x[:, i]
            features = self.backbone(patch)
            features = torch.flatten(features, start_dim=1)
            patches.append(features)
        concatenated_features = torch.cat(patches, dim=1)
        output = self.fc(concatenated_features)
        return output


In [5]:
# Initialize the inpainting model
inpainting_model = InpaintingModel().to(device)

# Load pre-trained weights
inpainting_checkpoint_path = 'inpainting_model_gen_weights_epoch_100.pth'
inpainting_checkpoint = torch.load(inpainting_checkpoint_path, map_location=device,weights_only=False )
inpainting_model.load_state_dict(inpainting_checkpoint)

# Extract the encoder from the inpainting model
inpainting_encoder = inpainting_model.encoder

using new encoder


In [6]:
# Data transformations
transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
])

# STL-10 Dataset
train_dataset = datasets.STL10(root='./data', split='train+unlabeled', download=True, transform=transform)

# Split into training and validation sets
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# Define Jigsaw Dataset
class JigsawSTL10Dataset(Dataset):
    def __init__(self, dataset, grid_size=3):
        self.dataset = dataset
        self.grid_size = grid_size
        self.permutations = self.create_permutations()
        
    def create_permutations(self, num_patches=9, num_permutations=1000):
        all_permutations = list(itertools.permutations(range(num_patches)))
        random.seed(42)
        selected_permutations = random.sample(all_permutations, num_permutations)
        return selected_permutations
    
    def create_jigsaw_puzzle(self, image):
        _, height, width = image.shape
        grid_size = self.grid_size
        patch_h, patch_w = height // grid_size, width // grid_size
        patches = []
        for i in range(grid_size):
            for j in range(grid_size):
                patch = image[:, i * patch_h: (i + 1) * patch_h, j * patch_w: (j + 1) * patch_w]
                patches.append(patch)
        perm_class = random.choice(range(len(self.permutations)))
        perm = self.permutations[perm_class]
        shuffled_patches = [patches[i] for i in perm]
        return torch.stack(shuffled_patches), torch.tensor(perm_class, dtype=torch.long)
    
    def __getitem__(self, index):
        img, _ = self.dataset[index]
        shuffled_patches, perm_class = self.create_jigsaw_puzzle(img)
        return shuffled_patches, perm_class
    
    def __len__(self):
        return len(self.dataset)

# Create Jigsaw datasets
jigsaw_train_dataset = JigsawSTL10Dataset(train_dataset)
jigsaw_val_dataset = JigsawSTL10Dataset(val_dataset)

# DataLoaders
jigsaw_train_loader = DataLoader(jigsaw_train_dataset, batch_size=64, shuffle=True, num_workers=4)
jigsaw_val_loader = DataLoader(jigsaw_val_dataset, batch_size=64, shuffle=False, num_workers=4)


Files already downloaded and verified


In [7]:
checkpoint_dir = '/users/nladdha/my_env/model/'
os.makedirs(checkpoint_dir, exist_ok=True)


# Initialize the Jigsaw model with the encoder from the inpainting model
num_permutations = 1000
num_patches = 9
jigsaw_model = JigsawResNet(num_patches=num_patches, num_permutations=num_permutations, backbone=inpainting_encoder).to(device)

# Ensure the backbone is unfrozen for training
for param in jigsaw_model.backbone.parameters():
    param.requires_grad = True

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(jigsaw_model.parameters(), lr=0.0001)

# Training loop
epochs = 10
best_val_loss = float('inf')
early_stopping_patience = 10
no_improve_epochs = 0

for epoch in range(epochs):
    jigsaw_model.train()
    running_loss = 0.0

    for shuffled_patches, perm_class in jigsaw_train_loader:
        shuffled_patches, perm_class = shuffled_patches.to(device), perm_class.to(device)
        optimizer.zero_grad()
        outputs = jigsaw_model(shuffled_patches)
        loss = criterion(outputs, perm_class)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_train_loss = running_loss / len(jigsaw_train_loader)
    print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {avg_train_loss:.4f}")

    # Validation
    jigsaw_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_patches, val_perm_class in jigsaw_val_loader:
            val_patches, val_perm_class = val_patches.to(device), val_perm_class.to(device)
            val_outputs = jigsaw_model(val_patches)
            val_loss += criterion(val_outputs, val_perm_class).item()
    avg_val_loss = val_loss / len(jigsaw_val_loader)
    print(f"Epoch [{epoch+1}/{epochs}], Validation Loss: {avg_val_loss:.4f}")

    # Checkpointing
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        model_path = os.path.join(checkpoint_dir, 'inpaint_jigsaw_model_testing.pth')
        torch.save(jigsaw_model.state_dict(), model_path)
        print(f"Model saved at Epoch {epoch+1}")
        no_improve_epochs = 0
    else:
        no_improve_epochs += 1

    if no_improve_epochs >= early_stopping_patience:
        print("Early stopping due to no improvement in validation loss.")
        break

Using pretrained backbone from inpainting model
Epoch [1/10], Training Loss: 6.7007
Epoch [1/10], Validation Loss: 7.2744
Model saved at Epoch 1
Epoch [2/10], Training Loss: 5.5953
Epoch [2/10], Validation Loss: 5.4044
Model saved at Epoch 2
Epoch [3/10], Training Loss: 4.7806
Epoch [3/10], Validation Loss: 5.0467
Model saved at Epoch 3
Epoch [4/10], Training Loss: 4.3170
Epoch [4/10], Validation Loss: 4.0588
Model saved at Epoch 4
Epoch [5/10], Training Loss: 4.0063
Epoch [5/10], Validation Loss: 3.7886
Model saved at Epoch 5
Epoch [6/10], Training Loss: 3.8144
Epoch [6/10], Validation Loss: 3.6389
Model saved at Epoch 6
Epoch [7/10], Training Loss: 3.6655
Epoch [7/10], Validation Loss: 3.5072
Model saved at Epoch 7
Epoch [8/10], Training Loss: 3.5444
Epoch [8/10], Validation Loss: 3.4443
Model saved at Epoch 8
Epoch [9/10], Training Loss: 3.4310
Epoch [9/10], Validation Loss: 3.6864
Epoch [10/10], Training Loss: 3.3321
Epoch [10/10], Validation Loss: 3.7105


In [23]:
import torchvision.models as models

# Initialize the Jigsaw model (same architecture as during training)
combined_model = JigsawResNet(num_patches=num_patches, num_permutations=num_permutations, backbone=None).to(device)
# Wrap the model
combined_model = nn.DataParallel(combined_model)

# Load the saved state_dict
checkpoint = torch.load('/users/nladdha/my_env/model/inpaint_jigsaw_model.pth', map_location=device)
state_dict = checkpoint

# Adjust the keys in the state_dict
adjusted_state_dict = {}
for key, value in state_dict.items():
    # Remove 'encoder.' from keys that start with 'module.backbone.encoder.'
    if key.startswith('module.backbone.encoder.'):
        new_key = key.replace('encoder.', '')  # Remove 'encoder.' from the key
        adjusted_state_dict[new_key] = value
    else:
        adjusted_state_dict[key] = value  # Keep other keys unchanged

# Load the adjusted state_dict into the model
combined_model.load_state_dict(adjusted_state_dict)

# Since the model is wrapped with DataParallel, access the module to get the backbone
combined_backbone = combined_model.module.backbone

Using new backbone


  checkpoint = torch.load('/users/nladdha/my_env/model/inpaint_jigsaw_model.pth', map_location=device)


In [27]:
class ClassificationModel(nn.Module):
    def __init__(self, encoder, num_classes=10):
        super(ClassificationModel, self).__init__()
        self.encoder = encoder  # Combined encoder
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

classification_model = ClassificationModel(encoder=combined_backbone, num_classes=10).to(device)


if torch.cuda.device_count() > 1:
    classification_model = nn.DataParallel(classification_model)
    
# Optionally freeze the encoder
for param in classification_model.encoder.parameters():
    param.requires_grad = True  

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classification_model.parameters(), lr=0.001)

In [28]:
# Define the transform
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    transforms.ToTensor(),
])

# Load the STL-10 training dataset
full_train_dataset = torchvision.datasets.STL10(root='./data', split='train', download=True, transform=transform)

# Split the training dataset into train and validation sets (80% train, 20% validation)
train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

# Test dataset
test_dataset = datasets.STL10(root='./data', split='test', download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [29]:
# Training settings
epochs = 150
early_stopping_patience = 10
best_val_loss = float('inf')
no_improve_epochs = 0

# Training Loop for the downstream task
for epoch in range(epochs):
    classification_model.train()
    running_loss = 0.0

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = classification_model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}")

    # Validation Loop for early stopping
    classification_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_inputs, val_targets in val_loader:
            val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
            val_outputs = classification_model(val_inputs)
            val_loss += criterion(val_outputs, val_targets).item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch [{epoch+1}/{epochs}], Validation Loss: {avg_val_loss:.4f}")

    # Model checkpointing based on validation loss
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        model_save_path = '/users/nladdha/my_env/model/stl_inpaint_jigsaw_test.pth'
        # Save the model, adjusting for DataParallel
        if torch.cuda.device_count() > 1:
            torch.save(classification_model.module.state_dict(), model_save_path)
        else:
            torch.save(classification_model.state_dict(), model_save_path)
        print(f"Model saved at Epoch {epoch+1}")
        no_improve_epochs = 0
    else:
        no_improve_epochs += 1

    # Early stopping
    if no_improve_epochs >= early_stopping_patience:
        print("Early stopping")
        break

# Evaluation with Top-1, Top-3, and Top-5 Accuracy
classification_model.eval()
correct = 0
top_3_correct = 0
top_5_correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = classification_model(images)
        _, predicted = outputs.topk(5, dim=1)

        total += labels.size(0)
        correct += (predicted[:, 0] == labels).sum().item()

        # Top-3 accuracy
        top_3_correct += (predicted[:, :3] == labels.unsqueeze(1)).any(dim=1).sum().item()

        # Top-5 accuracy
        top_5_correct += (predicted == labels.unsqueeze(1)).any(dim=1).sum().item()

accuracy = 100 * correct / total
top_3_accuracy = 100 * top_3_correct / total
top_5_accuracy = 100 * top_5_correct / total

print(f'Top-1 Accuracy: {accuracy:.2f}%')
print(f'Top-3 Accuracy: {top_3_accuracy:.2f}%')
print(f'Top-5 Accuracy: {top_5_accuracy:.2f}%')

Epoch [1/150], Train Loss: 1.8735
Epoch [1/150], Validation Loss: 1.6125
Model saved at Epoch 1
Epoch [2/150], Train Loss: 1.5183
Epoch [2/150], Validation Loss: 1.4248
Model saved at Epoch 2
Epoch [3/150], Train Loss: 1.3550
Epoch [3/150], Validation Loss: 1.2626
Model saved at Epoch 3
Epoch [4/150], Train Loss: 1.1815
Epoch [4/150], Validation Loss: 1.2608
Model saved at Epoch 4
Epoch [5/150], Train Loss: 1.0828
Epoch [5/150], Validation Loss: 1.1562
Model saved at Epoch 5
Epoch [6/150], Train Loss: 0.9924
Epoch [6/150], Validation Loss: 1.1294
Model saved at Epoch 6
Epoch [7/150], Train Loss: 0.9322
Epoch [7/150], Validation Loss: 1.0382
Model saved at Epoch 7
Epoch [8/150], Train Loss: 0.8469
Epoch [8/150], Validation Loss: 1.1035
Epoch [9/150], Train Loss: 0.7955
Epoch [9/150], Validation Loss: 1.1346
Epoch [10/150], Train Loss: 0.7258
Epoch [10/150], Validation Loss: 1.0857
Epoch [11/150], Train Loss: 0.6733
Epoch [11/150], Validation Loss: 1.1291
Epoch [12/150], Train Loss: 0.65