In [1]:
import cv2
import os
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import STL10
from torch.utils.data import random_split
import torch.nn as nn
from torchvision.models import resnet18
import itertools
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
from inpainting import Encoder, Decoder, InpaintingModel, Discriminator, mask_image  # Adjust path if needed
from jigsaw import create_jigsaw_puzzle, create_permutations, JigsawResNet, JigsawSTL10Dataset  # Adjust path if needed


In [2]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
# Data augmentation and normalization for training
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
])

# Load STL10 dataset
train_dataset = STL10(root='./data', split='train+unlabeled', download=True, transform=transform)
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])


Files already downloaded and verified


In [4]:
# Set up the Jigsaw dataset and data loaders
jigsaw_train_dataset = JigsawSTL10Dataset(train_dataset)
jigsaw_val_dataset = JigsawSTL10Dataset(val_dataset)

jigsaw_train_loader = DataLoader(jigsaw_train_dataset, batch_size=64, shuffle=True, num_workers=12)
jigsaw_val_loader = DataLoader(jigsaw_val_dataset, batch_size=64, shuffle=False, num_workers=12)


In [9]:
# Load inpainting model as the base, modified to include jigsaw functionality
PATH = 'stl_jigsaw_model.pth' 
checkpoint = torch.load(PATH, map_location=device)
inpainting_model = InpaintingModel().to(device)
inpainting_model.load_state_dict(checkpoint)

  checkpoint = torch.load(PATH, map_location=device)


using new encoder


RuntimeError: Error(s) in loading state_dict for InpaintingModel:
	Missing key(s) in state_dict: "encoder.encoder.0.weight", "encoder.encoder.1.weight", "encoder.encoder.1.bias", "encoder.encoder.1.running_mean", "encoder.encoder.1.running_var", "encoder.encoder.4.0.conv1.weight", "encoder.encoder.4.0.bn1.weight", "encoder.encoder.4.0.bn1.bias", "encoder.encoder.4.0.bn1.running_mean", "encoder.encoder.4.0.bn1.running_var", "encoder.encoder.4.0.conv2.weight", "encoder.encoder.4.0.bn2.weight", "encoder.encoder.4.0.bn2.bias", "encoder.encoder.4.0.bn2.running_mean", "encoder.encoder.4.0.bn2.running_var", "encoder.encoder.4.1.conv1.weight", "encoder.encoder.4.1.bn1.weight", "encoder.encoder.4.1.bn1.bias", "encoder.encoder.4.1.bn1.running_mean", "encoder.encoder.4.1.bn1.running_var", "encoder.encoder.4.1.conv2.weight", "encoder.encoder.4.1.bn2.weight", "encoder.encoder.4.1.bn2.bias", "encoder.encoder.4.1.bn2.running_mean", "encoder.encoder.4.1.bn2.running_var", "encoder.encoder.5.0.conv1.weight", "encoder.encoder.5.0.bn1.weight", "encoder.encoder.5.0.bn1.bias", "encoder.encoder.5.0.bn1.running_mean", "encoder.encoder.5.0.bn1.running_var", "encoder.encoder.5.0.conv2.weight", "encoder.encoder.5.0.bn2.weight", "encoder.encoder.5.0.bn2.bias", "encoder.encoder.5.0.bn2.running_mean", "encoder.encoder.5.0.bn2.running_var", "encoder.encoder.5.0.downsample.0.weight", "encoder.encoder.5.0.downsample.1.weight", "encoder.encoder.5.0.downsample.1.bias", "encoder.encoder.5.0.downsample.1.running_mean", "encoder.encoder.5.0.downsample.1.running_var", "encoder.encoder.5.1.conv1.weight", "encoder.encoder.5.1.bn1.weight", "encoder.encoder.5.1.bn1.bias", "encoder.encoder.5.1.bn1.running_mean", "encoder.encoder.5.1.bn1.running_var", "encoder.encoder.5.1.conv2.weight", "encoder.encoder.5.1.bn2.weight", "encoder.encoder.5.1.bn2.bias", "encoder.encoder.5.1.bn2.running_mean", "encoder.encoder.5.1.bn2.running_var", "encoder.encoder.6.0.conv1.weight", "encoder.encoder.6.0.bn1.weight", "encoder.encoder.6.0.bn1.bias", "encoder.encoder.6.0.bn1.running_mean", "encoder.encoder.6.0.bn1.running_var", "encoder.encoder.6.0.conv2.weight", "encoder.encoder.6.0.bn2.weight", "encoder.encoder.6.0.bn2.bias", "encoder.encoder.6.0.bn2.running_mean", "encoder.encoder.6.0.bn2.running_var", "encoder.encoder.6.0.downsample.0.weight", "encoder.encoder.6.0.downsample.1.weight", "encoder.encoder.6.0.downsample.1.bias", "encoder.encoder.6.0.downsample.1.running_mean", "encoder.encoder.6.0.downsample.1.running_var", "encoder.encoder.6.1.conv1.weight", "encoder.encoder.6.1.bn1.weight", "encoder.encoder.6.1.bn1.bias", "encoder.encoder.6.1.bn1.running_mean", "encoder.encoder.6.1.bn1.running_var", "encoder.encoder.6.1.conv2.weight", "encoder.encoder.6.1.bn2.weight", "encoder.encoder.6.1.bn2.bias", "encoder.encoder.6.1.bn2.running_mean", "encoder.encoder.6.1.bn2.running_var", "encoder.encoder.7.0.conv1.weight", "encoder.encoder.7.0.bn1.weight", "encoder.encoder.7.0.bn1.bias", "encoder.encoder.7.0.bn1.running_mean", "encoder.encoder.7.0.bn1.running_var", "encoder.encoder.7.0.conv2.weight", "encoder.encoder.7.0.bn2.weight", "encoder.encoder.7.0.bn2.bias", "encoder.encoder.7.0.bn2.running_mean", "encoder.encoder.7.0.bn2.running_var", "encoder.encoder.7.0.downsample.0.weight", "encoder.encoder.7.0.downsample.1.weight", "encoder.encoder.7.0.downsample.1.bias", "encoder.encoder.7.0.downsample.1.running_mean", "encoder.encoder.7.0.downsample.1.running_var", "encoder.encoder.7.1.conv1.weight", "encoder.encoder.7.1.bn1.weight", "encoder.encoder.7.1.bn1.bias", "encoder.encoder.7.1.bn1.running_mean", "encoder.encoder.7.1.bn1.running_var", "encoder.encoder.7.1.conv2.weight", "encoder.encoder.7.1.bn2.weight", "encoder.encoder.7.1.bn2.bias", "encoder.encoder.7.1.bn2.running_mean", "encoder.encoder.7.1.bn2.running_var", "decoder.decoder.0.weight", "decoder.decoder.0.bias", "decoder.decoder.2.weight", "decoder.decoder.2.bias", "decoder.decoder.4.weight", "decoder.decoder.4.bias", "decoder.decoder.6.weight", "decoder.decoder.6.bias", "decoder.decoder.8.weight", "decoder.decoder.8.bias". 
	Unexpected key(s) in state_dict: "module.backbone.0.weight", "module.backbone.1.weight", "module.backbone.1.bias", "module.backbone.1.running_mean", "module.backbone.1.running_var", "module.backbone.1.num_batches_tracked", "module.backbone.4.0.conv1.weight", "module.backbone.4.0.bn1.weight", "module.backbone.4.0.bn1.bias", "module.backbone.4.0.bn1.running_mean", "module.backbone.4.0.bn1.running_var", "module.backbone.4.0.bn1.num_batches_tracked", "module.backbone.4.0.conv2.weight", "module.backbone.4.0.bn2.weight", "module.backbone.4.0.bn2.bias", "module.backbone.4.0.bn2.running_mean", "module.backbone.4.0.bn2.running_var", "module.backbone.4.0.bn2.num_batches_tracked", "module.backbone.4.1.conv1.weight", "module.backbone.4.1.bn1.weight", "module.backbone.4.1.bn1.bias", "module.backbone.4.1.bn1.running_mean", "module.backbone.4.1.bn1.running_var", "module.backbone.4.1.bn1.num_batches_tracked", "module.backbone.4.1.conv2.weight", "module.backbone.4.1.bn2.weight", "module.backbone.4.1.bn2.bias", "module.backbone.4.1.bn2.running_mean", "module.backbone.4.1.bn2.running_var", "module.backbone.4.1.bn2.num_batches_tracked", "module.backbone.5.0.conv1.weight", "module.backbone.5.0.bn1.weight", "module.backbone.5.0.bn1.bias", "module.backbone.5.0.bn1.running_mean", "module.backbone.5.0.bn1.running_var", "module.backbone.5.0.bn1.num_batches_tracked", "module.backbone.5.0.conv2.weight", "module.backbone.5.0.bn2.weight", "module.backbone.5.0.bn2.bias", "module.backbone.5.0.bn2.running_mean", "module.backbone.5.0.bn2.running_var", "module.backbone.5.0.bn2.num_batches_tracked", "module.backbone.5.0.downsample.0.weight", "module.backbone.5.0.downsample.1.weight", "module.backbone.5.0.downsample.1.bias", "module.backbone.5.0.downsample.1.running_mean", "module.backbone.5.0.downsample.1.running_var", "module.backbone.5.0.downsample.1.num_batches_tracked", "module.backbone.5.1.conv1.weight", "module.backbone.5.1.bn1.weight", "module.backbone.5.1.bn1.bias", "module.backbone.5.1.bn1.running_mean", "module.backbone.5.1.bn1.running_var", "module.backbone.5.1.bn1.num_batches_tracked", "module.backbone.5.1.conv2.weight", "module.backbone.5.1.bn2.weight", "module.backbone.5.1.bn2.bias", "module.backbone.5.1.bn2.running_mean", "module.backbone.5.1.bn2.running_var", "module.backbone.5.1.bn2.num_batches_tracked", "module.backbone.6.0.conv1.weight", "module.backbone.6.0.bn1.weight", "module.backbone.6.0.bn1.bias", "module.backbone.6.0.bn1.running_mean", "module.backbone.6.0.bn1.running_var", "module.backbone.6.0.bn1.num_batches_tracked", "module.backbone.6.0.conv2.weight", "module.backbone.6.0.bn2.weight", "module.backbone.6.0.bn2.bias", "module.backbone.6.0.bn2.running_mean", "module.backbone.6.0.bn2.running_var", "module.backbone.6.0.bn2.num_batches_tracked", "module.backbone.6.0.downsample.0.weight", "module.backbone.6.0.downsample.1.weight", "module.backbone.6.0.downsample.1.bias", "module.backbone.6.0.downsample.1.running_mean", "module.backbone.6.0.downsample.1.running_var", "module.backbone.6.0.downsample.1.num_batches_tracked", "module.backbone.6.1.conv1.weight", "module.backbone.6.1.bn1.weight", "module.backbone.6.1.bn1.bias", "module.backbone.6.1.bn1.running_mean", "module.backbone.6.1.bn1.running_var", "module.backbone.6.1.bn1.num_batches_tracked", "module.backbone.6.1.conv2.weight", "module.backbone.6.1.bn2.weight", "module.backbone.6.1.bn2.bias", "module.backbone.6.1.bn2.running_mean", "module.backbone.6.1.bn2.running_var", "module.backbone.6.1.bn2.num_batches_tracked", "module.backbone.7.0.conv1.weight", "module.backbone.7.0.bn1.weight", "module.backbone.7.0.bn1.bias", "module.backbone.7.0.bn1.running_mean", "module.backbone.7.0.bn1.running_var", "module.backbone.7.0.bn1.num_batches_tracked", "module.backbone.7.0.conv2.weight", "module.backbone.7.0.bn2.weight", "module.backbone.7.0.bn2.bias", "module.backbone.7.0.bn2.running_mean", "module.backbone.7.0.bn2.running_var", "module.backbone.7.0.bn2.num_batches_tracked", "module.backbone.7.0.downsample.0.weight", "module.backbone.7.0.downsample.1.weight", "module.backbone.7.0.downsample.1.bias", "module.backbone.7.0.downsample.1.running_mean", "module.backbone.7.0.downsample.1.running_var", "module.backbone.7.0.downsample.1.num_batches_tracked", "module.backbone.7.1.conv1.weight", "module.backbone.7.1.bn1.weight", "module.backbone.7.1.bn1.bias", "module.backbone.7.1.bn1.running_mean", "module.backbone.7.1.bn1.running_var", "module.backbone.7.1.bn1.num_batches_tracked", "module.backbone.7.1.conv2.weight", "module.backbone.7.1.bn2.weight", "module.backbone.7.1.bn2.bias", "module.backbone.7.1.bn2.running_mean", "module.backbone.7.1.bn2.running_var", "module.backbone.7.1.bn2.num_batches_tracked", "module.fc.0.weight", "module.fc.0.bias", "module.fc.2.weight", "module.fc.2.bias". 

In [None]:
# Using Encoder from InpaintingModel in JigsawResNet
backbone = inpainting_model.encoder

In [None]:
# Initialize the JigsawResNet model using the inpainting model's encoder
num_patches = 9
num_permutations = 1000
jigsaw_model = JigsawResNet(num_patches=num_patches, num_permutations=num_permutations)
jigsaw_model.backbone = backbone  # Set inpainting backbone to jigsaw model
jigsaw_model = jigsaw_model.to(device)

# Early stopping class
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.0001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

# Classification model using JigsawResNet backbone
class ClassificationNet(nn.Module):
    def __init__(self, backbone, num_classes):
        super(ClassificationNet, self).__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        features = self.backbone(x)
        pooled_features = nn.AdaptiveAvgPool2d((1, 1))(features)
        pooled_features = pooled_features.view(pooled_features.size(0), -1)
        output = self.classifier(pooled_features)
        return output

classification_model = ClassificationNet(jigsaw_model.backbone, num_classes=10).to(device)


In [None]:
# Training and evaluation loop setup
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classification_model.parameters(), lr=1e-3)
early_stop = EarlyStopping(patience=15, min_delta=0.000001)
num_epochs = 150

best_model_weights = classification_model.state_dict()

for epoch in range(num_epochs):
    classification_model.train()
    running_loss = 0.0
    for images, labels in jigsaw_train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = classification_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(jigsaw_train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

    # Validation
    classification_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_images, val_labels in jigsaw_val_loader:
            val_images, val_labels = val_images.to(device), val_labels.to(device)
            val_outputs = classification_model(val_images)
            val_loss += criterion(val_outputs, val_labels).item()
    
    avg_val_loss = val_loss / len(jigsaw_val_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Validation Loss: {avg_val_loss:.4f}")

    early_stop(avg_val_loss)
    if early_stop.early_stop:
        print("Early stopping due to no improvement in validation loss.")
        break

    # Save best model based on validation loss
    if avg_val_loss < early_stop.best_loss:
        best_model_weights = classification_model.state_dict()

        # Ensure directory structure exists
        os.makedirs('models/downstream', exist_ok=True)
        
        # Save model
        model_path = f'models/downstream/classification_model_weights_epoch_{epoch+1}.pth'
        torch.save(classification_model.state_dict(), model_path)
        print(f"Model saved at Epoch {epoch + 1}")

# Final model saving
torch.save(classification_model.state_dict(), 'models/downstream/classification_model_weights_final.pth')
torch.save(best_model_weights, 'models/downstream/classification_best_model_weights_final.pth')
