# Inpainting -> Jigsaw Multi-Pretext Task

In [16]:
import cv2
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import STL10
from torchvision.models import resnet18
from torchvision import datasets, transforms
from torch.utils.data import random_split
import torch.nn as nn
import itertools
from torchvision.models import resnet18
from jigsaw import JigsawResNet, create_jigsaw_puzzle, create_permutations, JigsawSTL10Dataset
from inpainting import Encoder, Decoder, InpaintingModel, Discriminator, mask_image

In [2]:

# Data augmentation and normalization for training
transform = transforms.Compose([
    # transforms.Resize(255),
    transforms.RandomHorizontalFlip(),     # Random horizontal flip
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),  # Adjust brightness, contrast, and saturation
    transforms.RandomGrayscale(p=0.1),     # Randomly convert to grayscale with a probability of 0.1
    transforms.ToTensor(),                 # Convert to tensor
])

# # Load STL10 dataset
# stl10_pretrain = STL10ColorizationDataset(root='../data', split='train+unlabeled', download=True, transform=transform)
train_dataset = STL10(root='./data',  split='train+unlabeled', download=True, transform=transform)




# # DataLoader to feed batches for training
# pretrain_loader = DataLoader(stl10_pretrain, batch_size=64, shuffle=True)

# Split dataset into training and validation sets
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

jigsaw_train_dataset = JigsawSTL10Dataset(train_dataset)
jigsaw_val_dataset = JigsawSTL10Dataset(val_dataset)

jigsaw_train_loader = DataLoader(jigsaw_train_dataset, batch_size=64, shuffle=True, num_workers=12)
jigsaw_val_loader = DataLoader(jigsaw_val_dataset, batch_size=64, shuffle=False, num_workers=12)




Files already downloaded and verified


In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")  # Use NVIDIA GPU
    print('cuda')
elif torch.backends.mps.is_available():
    device = torch.device("mps")   # Use Apple's Metal (for M1/M2 Macs)
    print('mps')
else:
    device = torch.device("cpu") 
    print('cpu')

cuda


In [4]:
from torchvision.models import resnet18
import torch.nn as nn
import torch

PATH = '../inpaint_jigsaw/models/inpainting_model_gen_weights_epoch_100.pth'
checkpoint= torch.load(PATH, map_location=torch.device('cuda'))

generator = InpaintingModel().to(device)  

# Load the state dictionary
# checkpoint = torch.load(path_to_weights)

# Load the weights into the generator
generator.load_state_dict(checkpoint)

backbone = generator.encoder


  checkpoint= torch.load(PATH, map_location=torch.device('cuda'))


using new encoder


In [5]:
backbone

Encoder(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [6]:
import torchvision
print(torchvision.__version__)  # This should print the torchvision version without errors
from torchvision.models import resnet18  # Confirm import works


0.19.1+cu121


In [7]:


# Initialize the JigsawResNet model
jigsaw_model = JigsawResNet(num_patches=9, num_permutations=1000)
jigsaw_model = jigsaw_model.to(device)


In [8]:
# # Training loop
# num_epochs = 100
# criterion = nn.MSELoss()
# optimizer = torch.optim.Adam(colorization_model.parameters(), lr=1e-3)

# early_stop = EarlyStopping(paitence=15,min_delta=0.000001)


# for epoch in range(num_epochs):
#     total_loss = 0.0  # Initialize total loss for this epoch
#     num_batches = 0   # Keep track of the number of batches

#     # Training loop for the current epoch
#     for L_channel, AB_channels in pretrain_loader:
#         # Move data to the same device as the model
#         L_channel = L_channel.to(device)
#         AB_channels = AB_channels.to(device)
        
#         L_channel_rgb = L_channel.repeat(1, 3, 1, 1)  # Shape: [batch_size, 3, 96, 96]

#         # Forward pass
#         predicted_AB = colorization_model(L_channel_rgb)

#         # Compute loss
#         loss = criterion(predicted_AB, AB_channels)
        
#         # Backward pass and optimization
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         # Accumulate loss and count batches
#         total_loss += loss.item()
#         num_batches += 1

#     # Calculate average loss for the epoch
#     average_loss = total_loss / num_batches
#     print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {average_loss}')

#     early_stop(average_loss)
#     # print(early_stop.counter)
#     if early_stop.early_stop:
#         print("Early Stopping Triggered. No improves in Loss for the last 10 epochs")
#         break
    
#     if (epoch + 1) % 10 == 0:
#         torch.save(colorization_model.state_dict(), f'models/inpaint_jigsaw_model_weights_epoch_{epoch+1}.pth')

#     if average_loss < early_stop.best_loss:
#         best_model_weights = colorization_model.state_dict()

# torch.save(colorization_model.state_dict(), 'models/inpaint_jigsaw_model_weights_final.pth')
# torch.save(best_model_weights, 'models/inpaint_jigsaw_best_model_weights_final.pth')
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.0001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        # Check if there's an improvement in validation loss
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0  # Reset counter if there's an improvement
        else:
            self.counter += 1  # Increment counter if no improvement
            if self.counter >= self.patience:
                self.early_stop = True


# Training loop for Jigsaw Model
num_epochs = 100
criterion = nn.CrossEntropyLoss()  # For classification tasks
optimizer = torch.optim.Adam(jigsaw_model.parameters(), lr=1e-3)
early_stop = EarlyStopping(patience=15, min_delta=0.000001)  # Set patience based on task requirements

# Initialize best_model_weights with the initial model state
best_model_weights = jigsaw_model.state_dict()

for epoch in range(num_epochs):
    jigsaw_model.train()
    running_loss = 0.0  # Initialize total loss for this epoch
    num_batches = 0     # Keep track of the number of batches

    # Training loop for the current epoch
    for shuffled_patches, perm_class in jigsaw_train_loader:
        # Move data to the same device as the model
        shuffled_patches = shuffled_patches.to(device)
        perm_class = perm_class.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = jigsaw_model(shuffled_patches)
        loss = criterion(outputs, perm_class)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Accumulate loss and count batches
        running_loss += loss.item()
        num_batches += 1

    # Calculate average loss for the epoch
    avg_train_loss = running_loss / num_batches
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}')

    # Early stopping and model checkpointing based on validation loss
    jigsaw_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_patches, val_perm_class in jigsaw_val_loader:
            val_patches = val_patches.to(device)
            val_perm_class = val_perm_class.to(device)
            val_outputs = jigsaw_model(val_patches)
            val_loss += criterion(val_outputs, val_perm_class).item()

    avg_val_loss = val_loss / len(jigsaw_val_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_val_loss:.4f}")

    early_stop(avg_val_loss)
    if early_stop.early_stop:
        print("Early stopping due to no improvement in validation loss.")
        break

    # Save best model based on validation loss
    if avg_val_loss < early_stop.best_loss:
        best_model_weights = jigsaw_model.state_dict()
        model_path = os.path.join(checkpoint_dir, f'inpaint_jigsaw_model_epoch_{epoch+1}.pth')
        torch.save(jigsaw_model.state_dict(), model_path)
        print(f"Model saved at Epoch {epoch+1}")

# Save final model weights
torch.save(jigsaw_model.state_dict(), 'models/inpaint_jigsaw_best_model_weights_final.pth')
torch.save(best_model_weights, 'models/inpaint_jigsaw_best_model_weights_final.pth')



Epoch [1/100], Training Loss: 6.9134
Epoch [1/100], Validation Loss: 6.9085
Epoch [2/100], Training Loss: 6.9086
Epoch [2/100], Validation Loss: 6.9085
Epoch [3/100], Training Loss: 6.9085
Epoch [3/100], Validation Loss: 6.9092
Epoch [4/100], Training Loss: 6.9088
Epoch [4/100], Validation Loss: 6.9092
Epoch [5/100], Training Loss: 6.9086
Epoch [5/100], Validation Loss: 6.9086
Epoch [6/100], Training Loss: 6.9088
Epoch [6/100], Validation Loss: 6.9088
Epoch [7/100], Training Loss: 6.9087
Epoch [7/100], Validation Loss: 6.9088
Epoch [8/100], Training Loss: 6.9089
Epoch [8/100], Validation Loss: 6.9091
Epoch [9/100], Training Loss: 6.9087
Epoch [9/100], Validation Loss: 6.9090
Epoch [10/100], Training Loss: 6.9090
Epoch [10/100], Validation Loss: 6.9087
Epoch [11/100], Training Loss: 6.9087
Epoch [11/100], Validation Loss: 6.9088
Epoch [12/100], Training Loss: 6.9087
Epoch [12/100], Validation Loss: 6.9084
Epoch [13/100], Training Loss: 6.9086
Epoch [13/100], Validation Loss: 6.9092
Epoc

In [9]:
# import numpy as np
# from skimage import color
# import matplotlib.pyplot as plt

# def visualize_colorization(L_channel, predicted_AB, ground_truth_AB):
#     batch_size = L_channel.shape[0]

#     for i in range(batch_size):
#         # Convert model's output (predicted_AB) to RGB for each sample in the batch
#         colorized_image = lab_to_rgb(L_channel[i], predicted_AB[i])

#         # Convert the ground truth to RGB for each sample
#         ground_truth_rgb = lab_to_rgb(L_channel[i], ground_truth_AB[i])

#         # Display the colorized image and the ground truth (visualization code)
#         plt.subplot(1, 2, 1)
#         plt.imshow(colorized_image)
#         plt.title('Predicted Colorization')

#         plt.subplot(1, 2, 2)
#         plt.imshow(ground_truth_rgb)
#         plt.title('Ground Truth')

#         plt.show()
#         print(f"Predicted AB min: {predicted_AB[i].min():.2f}, max: {predicted_AB[i].max():.2f}")
#         print(f"Ground Truth AB min: {ground_truth_AB[i].min():.2f}, max: {ground_truth_AB[i].max():.2f}")

# # Convert LAB to RGB
# def lab_to_rgb(L_channel, AB_channels):
#     # Ensure L_channel has shape [96, 96] and scale it appropriately
#     L_channel = L_channel.squeeze().cpu().numpy() * 255
#     L_channel = L_channel 

#     # Ensure AB_channels has shape [2, 96, 96] and transpose it to [96, 96, 2]
#     AB_channels = AB_channels.squeeze().detach().cpu().numpy().transpose(1, 2, 0)
#     AB_channels = (AB_channels * 255) - 128

#     # Concatenate L and AB channels to form LAB image
#     lab_image = np.concatenate((L_channel[:, :, np.newaxis], AB_channels), axis=-1)

#     # Convert LAB to RGB using a library like skimage
#     rgb_image = color.lab2rgb(lab_image)
    
#     return rgb_image
import numpy as np
import matplotlib.pyplot as plt

def visualize_jigsaw(shuffled_patches, predicted_class, ground_truth_class):
    batch_size = shuffled_patches.shape[0]

    for i in range(batch_size):
        # Display the shuffled patches for each sample in the batch
        plt.figure(figsize=(10, 5))
        
        # Display the shuffled patches in a 3x3 grid (assuming 9 patches per image)
        num_patches = shuffled_patches.shape[1]
        grid_size = int(np.sqrt(num_patches))
        
        # Display each patch in its shuffled position
        for idx, patch in enumerate(shuffled_patches[i]):
            plt.subplot(grid_size, grid_size, idx + 1)
            plt.imshow(patch.permute(1, 2, 0).cpu().numpy())  # Assuming patch shape [C, H, W]
            plt.axis('off')
        
        # Titles for prediction and ground truth
        plt.suptitle(f"Predicted Class: {predicted_class[i]}, Ground Truth Class: {ground_truth_class[i]}")
        plt.show()
        
        print(f"Sample {i+1} - Predicted Class: {predicted_class[i]}, Ground Truth Class: {ground_truth_class[i]}")



In [10]:
# visualize_jigsaw(shuffled_patches, predicted_class, ground_truth_class)


NameError: name 'predicted_class' is not defined

In [4]:
PATH = '/users/soh62/SSL/inpaint_jigsaw/models/inpaint_jigsaw_best_model_weights_final.pth'
import os
print(os.path.exists(PATH))  # This should print True if the file exists


True


In [10]:
# from torchvision.models import resnet18
# import torch.nn as nn
# import torch
# from collections import OrderedDict

# PATH = 'models/inpaint_jigsaw_model_weights_epoch_100.pth'

# checkpoint = torch.load(PATH, map_location=torch.device('cuda'))

# new_state_dict = OrderedDict()
# for k, v in checkpoint.items():
#     new_key = k.replace("backbone.encoder.","backbone." )  # Modify based on your structure
#     new_state_dict[new_key] = v



# backbone = resnet18(weights=None)
# backbone = nn.Sequential(*list(backbone.children())[:-2])

# colorization_model = Colorization(backbone)
# # colorization_model = colorization_model.to(device)
# colorization_model.load_state_dict(new_state_dict)

from torchvision.models import resnet18
import torch.nn as nn
import torch
from collections import OrderedDict
from jigsaw import JigsawResNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PATH = 'SSL/inpaint_jigsaw/models/inpaint_jigsaw_best_model_weights_final.pth'
checkpoint = torch.load('/users/soh62/SSL/inpaint_jigsaw/models/inpaint_jigsaw_best_model_weights_final.pth', map_location=torch.device('cuda'))

# Load the saved checkpoint
# checkpoint = torch.load(PATH, map_location=torch.device('cuda'), weights_only=True)


# Prepare the state dictionary for loading
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    new_key = k.replace("backbone.encoder.", "backbone.")  # Adjust key names as needed
    new_state_dict[new_key] = v

# Set up the ResNet18 backbone
backbone = resnet18(weights=None)
backbone = nn.Sequential(*list(backbone.children())[:-2])  # Removing the last two layers

# Create the jigsaw model using the ResNet backbone
num_patches = 9
num_permutations = 1000  # Or however many permutations you've defined
jigsaw_model = JigsawResNet(num_patches=num_patches, num_permutations=num_permutations)
jigsaw_model = jigsaw_model.to(device)

# Load the modified state dictionary into the jigsaw model
jigsaw_model.load_state_dict(new_state_dict)


  checkpoint = torch.load('/users/soh62/SSL/inpaint_jigsaw/models/inpaint_jigsaw_best_model_weights_final.pth', map_location=torch.device('cuda'))


<All keys matched successfully>

In [12]:
class ClassificationNet(nn.Module):
    def __init__(self, backbone, num_classes):
        super(ClassificationNet, self).__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        features = self.backbone(x)
        pooled_features = nn.AdaptiveAvgPool2d((1, 1))(features)
        pooled_features = pooled_features.view(pooled_features.size(0), -1)
        output = self.classifier(pooled_features)
        return output

classification_model = ClassificationNet(jigsaw_model.backbone, num_classes=10).to(device)

In [14]:
from torchvision import transforms

classification_transform = transforms.Compose([
    # transforms.RandomResizedCrop(96),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),  # RGB for classification
])

In [17]:
stl10_train = STL10(root='../data', split='train', download=True, transform=classification_transform)
stl10_test = STL10(root='../data', split='test', download=True, transform=classification_transform)

# Fine-tuning: Load training data for classification task
train_loader = DataLoader(stl10_train, batch_size=64, shuffle=True)

# Testing: Load test data for final evaluation
test_loader = DataLoader(stl10_test, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [20]:
criterion = nn.CrossEntropyLoss()  # Suitable for multi-class classification
optimizer = torch.optim.Adam(classification_model.parameters(), lr=1e-3)

# Training Loop
num_epochs = 150
for epoch in range(num_epochs):
    classification_model.train()  # Set model to training mode
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = classification_model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

    if (epoch + 1) % 10 == 0:
        torch.save(classification_model.state_dict(), f'models/downstream/classification_model_weights_epoch_{epoch+1}.pth')



PATH = 'models/downstream/classification_model_weights_final.pth'
torch.save(classification_model.state_dict(), PATH)

Epoch [1/150], Loss: 0.0170
Epoch [2/150], Loss: 0.0210
Epoch [3/150], Loss: 0.0238
Epoch [4/150], Loss: 0.0318
Epoch [5/150], Loss: 0.0432
Epoch [6/150], Loss: 0.0444
Epoch [7/150], Loss: 0.0460
Epoch [8/150], Loss: 0.0537
Epoch [9/150], Loss: 0.0359
Epoch [10/150], Loss: 0.0043
Epoch [11/150], Loss: 0.0018
Epoch [12/150], Loss: 0.0007
Epoch [13/150], Loss: 0.0005
Epoch [14/150], Loss: 0.0005
Epoch [15/150], Loss: 0.0007
Epoch [16/150], Loss: 0.0086
Epoch [17/150], Loss: 0.0363
Epoch [18/150], Loss: 0.0099
Epoch [19/150], Loss: 0.0215
Epoch [20/150], Loss: 0.0089
Epoch [21/150], Loss: 0.0016
Epoch [22/150], Loss: 0.0010
Epoch [23/150], Loss: 0.0009
Epoch [24/150], Loss: 0.0024
Epoch [25/150], Loss: 0.0013
Epoch [26/150], Loss: 0.0005
Epoch [27/150], Loss: 0.0005
Epoch [28/150], Loss: 0.0003
Epoch [29/150], Loss: 0.0006
Epoch [30/150], Loss: 0.0016
Epoch [31/150], Loss: 0.1210
Epoch [32/150], Loss: 0.1111
Epoch [33/150], Loss: 0.0848
Epoch [34/150], Loss: 0.0139
Epoch [35/150], Loss: 0

In [21]:
# Evaluation
classification_model.eval()  # Set model to evaluation mode
correct = 0
top_5_correct = 0
top_3_correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = classification_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        _, predicted_3 = torch.topk(outputs.data, k=3, dim=1)
        correct_3 = predicted_3.eq(labels.unsqueeze(1).expand_as(predicted_3))
        top_3_correct += correct_3.any(dim=1).sum().item()

        _, predicted_5 = torch.topk(outputs.data, k=5, dim=1)
        correct_5 = predicted_5.eq(labels.unsqueeze(1).expand_as(predicted_5))
        top_5_correct += correct_5.any(dim=1).sum().item()



accuracy = 100 * correct / total
top_5 = 100 * top_5_correct / total
top_3 = 100 * top_3_correct / total
print(f'Top-1 Accuracy of the model on the test set: {accuracy:.2f}%')
print(f'Top-5 Accuracy of the model on the test set: {top_5:.2f}%')
print(f'Top-3 Accuracy of the model on the test set: {top_3:.2f}%')

Top-1 Accuracy of the model on the test set: 54.89%
Top-5 Accuracy of the model on the test set: 91.84%
Top-3 Accuracy of the model on the test set: 81.54%
