In [1]:
import cv2
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import STL10
from torchvision.models import resnet18
import torch.nn as nn
import numpy as np
from skimage import color
import matplotlib.pyplot as plt
from collections import OrderedDict
from torch.optim import Adam

# Import custom modules from provided files
from colorization import Colorization, RGB2LabTransform, STL10ColorizationDataset, EarlyStopping
from inpainting import Encoder, Decoder, InpaintingModel, Discriminator, mask_image


In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define a combined model class with both colorization and inpainting models
class CombinedModel(nn.Module):
    def __init__(self, shared_encoder):
        super(CombinedModel, self).__init__()
        self.encoder = shared_encoder  # Shared encoder for both tasks
        self.colorization_head = Colorization(self.encoder)
        self.inpainting_head = InpaintingModel(self.encoder)

    def forward_colorization(self, L_channel_rgb):
        # Pass through colorization head
        return self.colorization_head(L_channel_rgb)

    def forward_inpainting(self, masked_images):
        # Pass through inpainting head
        return self.inpainting_head(masked_images)


# Initialize a shared encoder (ResNet18 backbone as an example)
shared_encoder = resnet18(pretrained=True)
shared_encoder = nn.Sequential(*list(shared_encoder.children())[:-2])  # Remove final layers for feature extraction

# Initialize the combined model with shared encoder
combined_model = CombinedModel(shared_encoder).to(device)




using pre-trained encoder


In [3]:
# Helper function to convert LAB to RGB
def lab_to_rgb(L_channel, AB_channels):
    L_channel = L_channel.squeeze().cpu().numpy() * 255
    AB_channels = AB_channels.squeeze().detach().cpu().numpy().transpose(1, 2, 0)
    AB_channels = (AB_channels * 255) - 128

    lab_image = np.concatenate((L_channel[:, :, np.newaxis], AB_channels), axis=-1)
    rgb_image = color.lab2rgb(lab_image)
    
    return rgb_image

In [4]:
# Load STL10 dataset and apply necessary transformations
transform = transforms.Compose([transforms.ToTensor()])

# Separate datasets for colorization and inpainting
stl10_colorization = STL10ColorizationDataset(root='../data', split='train+unlabeled', download=True, transform=transform)
stl10_inpainting = STL10(root='../data', split='train+unlabeled', download=True, transform=transform)

# DataLoaders for colorization and inpainting tasks
colorization_loader = DataLoader(stl10_colorization, batch_size=64, shuffle=True)
inpainting_loader = DataLoader(stl10_inpainting, batch_size=64, shuffle=True)

# Optimizer and loss functions
optimizer = Adam(combined_model.parameters(), lr=1e-3)
colorization_criterion = nn.MSELoss()
inpainting_criterion = nn.BCELoss()  # Assuming binary cross-entropy loss for inpainting


Files already downloaded and verified
Files already downloaded and verified


In [5]:
# Early stopping configuration
class EarlyStopping:
    def __init__(self, patience=15, min_delta=1e-6):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, loss):
        if self.best_loss is None:
            self.best_loss = loss
        elif loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = loss
            self.counter = 0

early_stop = EarlyStopping(patience=15, min_delta=1e-6)


In [6]:
# Set mask size for inpainting
mask_size = 8  # Adjust as needed

# Training loop with combined pretext tasks per epoch
num_epochs = 100
for epoch in range(num_epochs):
    combined_model.train()
    total_loss_colorization = 0.0
    total_loss_inpainting = 0.0

    colorization_iter = iter(colorization_loader)
    inpainting_iter = iter(inpainting_loader)

    # Alternate between colorization and inpainting batches within the epoch
    for i in range(min(len(colorization_loader), len(inpainting_loader))):
        # Colorization task
        try:
            L_channel, AB_channels = next(colorization_iter)
            L_channel, AB_channels = L_channel.to(device), AB_channels.to(device)
            L_channel_rgb = L_channel.repeat(1, 3, 1, 1)  # Convert grayscale to RGB input format

            # Forward pass for colorization
            predicted_AB = combined_model.forward_colorization(L_channel_rgb)
            loss_colorization = colorization_criterion(predicted_AB, AB_channels)

            # Backward pass and optimization for colorization
            optimizer.zero_grad()
            loss_colorization.backward()
            optimizer.step()
            total_loss_colorization += loss_colorization.item()
        except StopIteration:
            pass

        # Inpainting task
        try:
            images, _ = next(inpainting_iter)
            images = images.to(device)

            # Apply masking with specified mask_size
            masked_images, masks = mask_image(images, mask_size=mask_size)
            masked_images, masks = masked_images.to(device), masks.to(device)

            # Forward pass for inpainting
            reconstructed_images = combined_model.forward_inpainting(masked_images)
            loss_inpainting = inpainting_criterion(reconstructed_images, images)

            # Backward pass and optimization for inpainting
            optimizer.zero_grad()
            loss_inpainting.backward()
            optimizer.step()
            total_loss_inpainting += loss_inpainting.item()
        except StopIteration:
            pass

    # Calculate average losses for the epoch
    avg_loss_colorization = total_loss_colorization / len(colorization_loader)
    avg_loss_inpainting = total_loss_inpainting / len(inpainting_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}] Colorization Loss: {avg_loss_colorization:.4f}')
    print(f'Epoch [{epoch+1}/{num_epochs}] Inpainting Loss: {avg_loss_inpainting:.4f}')

    # Early stopping check based on combined average loss
    combined_avg_loss = (avg_loss_colorization + avg_loss_inpainting) / 2
    early_stop(combined_avg_loss)
    if early_stop.early_stop:
        print("Early Stopping Triggered")
        break

# Save the final combined model
torch.save(combined_model.state_dict(), '/users/soh62/SSL/image_colorization_simultaneous/models2/combined_inpaint_colorization_final.pth')


Epoch [1/100] Colorization Loss: 0.0031
Epoch [1/100] Inpainting Loss: 0.5606
Epoch [2/100] Colorization Loss: 0.0025
Epoch [2/100] Inpainting Loss: 0.5426
Epoch [3/100] Colorization Loss: 0.0024
Epoch [3/100] Inpainting Loss: 0.5370
Epoch [4/100] Colorization Loss: 0.0024
Epoch [4/100] Inpainting Loss: 0.5329
Epoch [5/100] Colorization Loss: 0.0024
Epoch [5/100] Inpainting Loss: 0.5307
Epoch [6/100] Colorization Loss: 0.0023
Epoch [6/100] Inpainting Loss: 0.5291
Epoch [7/100] Colorization Loss: 0.0023
Epoch [7/100] Inpainting Loss: 0.5277
Epoch [8/100] Colorization Loss: 0.0023
Epoch [8/100] Inpainting Loss: 0.5267
Epoch [9/100] Colorization Loss: 0.0023
Epoch [9/100] Inpainting Loss: 0.5258
Epoch [10/100] Colorization Loss: 0.0023
Epoch [10/100] Inpainting Loss: 0.5250
Epoch [11/100] Colorization Loss: 0.0023
Epoch [11/100] Inpainting Loss: 0.5242
Epoch [12/100] Colorization Loss: 0.0022
Epoch [12/100] Inpainting Loss: 0.5236
Epoch [13/100] Colorization Loss: 0.0022
Epoch [13/100] In

In [7]:
# Visualization function to verify colorization predictions
def visualize_colorization(L_channel, predicted_AB, ground_truth_AB):
    batch_size = L_channel.shape[0]

    for i in range(batch_size):
        # Convert model's output (predicted_AB) to RGB for each sample in the batch
        colorized_image = lab_to_rgb(L_channel[i], predicted_AB[i])

        # Convert the ground truth to RGB for each sample
        ground_truth_rgb = lab_to_rgb(L_channel[i], ground_truth_AB[i])

        # Display the colorized image and the ground truth (visualization code)
        plt.subplot(1, 2, 1)
        plt.imshow(colorized_image)
        plt.title('Predicted Colorization')

        plt.subplot(1, 2, 2)
        plt.imshow(ground_truth_rgb)
        plt.title('Ground Truth')

        plt.show()
        print(f"Predicted AB min: {predicted_AB[i].min():.2f}, max: {predicted_AB[i].max():.2f}")
        print(f"Ground Truth AB min: {ground_truth_AB[i].min():.2f}, max: {ground_truth_AB[i].max():.2f}")


# Downstream Epoch 50

In [8]:
import torch.nn as nn

class ClassificationNet(nn.Module):
    def __init__(self, backbone, num_classes):
        super(ClassificationNet, self).__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(512, num_classes)  # Adjust input size based on backbone output

    def forward(self, x):
        features = self.backbone(x)
        pooled_features = nn.AdaptiveAvgPool2d((1, 1))(features)
        pooled_features = pooled_features.view(pooled_features.size(0), -1)
        output = self.classifier(pooled_features)
        return output

# Initialize the classification model with the backbone from your trained combined model
classification_model = ClassificationNet(combined_model.encoder, num_classes=10).to(device)


In [9]:
from torchvision.datasets import STL10
from torch.utils.data import DataLoader
from torchvision import transforms

classification_transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),  # RGB for classification
])

# Load STL10 train and test datasets
stl10_train = STL10(root='../data', split='train', download=True, transform=classification_transform)
stl10_test = STL10(root='../data', split='test', download=True, transform=classification_transform)

# DataLoaders for training and testing
train_loader = DataLoader(stl10_train, batch_size=64, shuffle=True)
test_loader = DataLoader(stl10_test, batch_size=64, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


In [10]:
import torch.optim as optim

# Define loss and optimizer for classification
criterion = nn.CrossEntropyLoss()  # Suitable for multi-class classification
optimizer = optim.Adam(classification_model.parameters(), lr=1e-3)

# Training loop for downstream classification task
num_epochs = 50  # Adjust as needed
for epoch in range(num_epochs):
    classification_model.train()  # Set model to training mode
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = classification_model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

    # Optionally save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save(classification_model.state_dict(), f'models2/classification_model_weights2_epoch_{epoch+1}.pth')

# Save the final classification model
torch.save(classification_model.state_dict(), 'models2/classification_model_weights_final2.pth')


Epoch [1/50], Loss: 1.8470
Epoch [2/50], Loss: 1.3864
Epoch [3/50], Loss: 1.1259
Epoch [4/50], Loss: 0.8858
Epoch [5/50], Loss: 0.6475
Epoch [6/50], Loss: 0.4143
Epoch [7/50], Loss: 0.2225
Epoch [8/50], Loss: 0.0978
Epoch [9/50], Loss: 0.0557
Epoch [10/50], Loss: 0.0200
Epoch [11/50], Loss: 0.0192
Epoch [12/50], Loss: 0.0393
Epoch [13/50], Loss: 0.0212
Epoch [14/50], Loss: 0.0151
Epoch [15/50], Loss: 0.0359
Epoch [16/50], Loss: 0.0529
Epoch [17/50], Loss: 0.0124
Epoch [18/50], Loss: 0.0047
Epoch [19/50], Loss: 0.0034
Epoch [20/50], Loss: 0.0028
Epoch [21/50], Loss: 0.0019
Epoch [22/50], Loss: 0.0015
Epoch [23/50], Loss: 0.0050
Epoch [24/50], Loss: 0.0550
Epoch [25/50], Loss: 0.0282
Epoch [26/50], Loss: 0.0388
Epoch [27/50], Loss: 0.0109
Epoch [28/50], Loss: 0.0083
Epoch [29/50], Loss: 0.0315
Epoch [30/50], Loss: 0.0674
Epoch [31/50], Loss: 0.0411
Epoch [32/50], Loss: 0.0065
Epoch [33/50], Loss: 0.0034
Epoch [34/50], Loss: 0.0139
Epoch [35/50], Loss: 0.0274
Epoch [36/50], Loss: 0.0238
E

In [11]:
# Evaluation on test data
classification_model.eval()  # Set model to evaluation mode
correct = 0
total = 0
top_3_correct = 0
top_5_correct = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = classification_model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Top-3 and Top-5 accuracy calculations
        _, predicted_3 = torch.topk(outputs, k=3, dim=1)
        correct_3 = predicted_3.eq(labels.view(-1, 1).expand_as(predicted_3))
        top_3_correct += correct_3.any(dim=1).sum().item()

        _, predicted_5 = torch.topk(outputs, k=5, dim=1)
        correct_5 = predicted_5.eq(labels.view(-1, 1).expand_as(predicted_5))
        top_5_correct += correct_5.any(dim=1).sum().item()

# Calculate and print accuracies
accuracy = 100 * correct / total
top_3_accuracy = 100 * top_3_correct / total
top_5_accuracy = 100 * top_5_correct / total
print(f'Top-1 Accuracy: {accuracy:.2f}%')
print(f'Top-3 Accuracy: {top_3_accuracy:.2f}%')
print(f'Top-5 Accuracy: {top_5_accuracy:.2f}%')


Top-1 Accuracy: 62.16%
Top-3 Accuracy: 87.97%
Top-5 Accuracy: 95.24%


# Downstream Epoch 100

In [12]:
import torch.nn as nn

class ClassificationNet(nn.Module):
    def __init__(self, backbone, num_classes):
        super(ClassificationNet, self).__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(512, num_classes)  # Adjust input size based on backbone output

    def forward(self, x):
        features = self.backbone(x)
        pooled_features = nn.AdaptiveAvgPool2d((1, 1))(features)
        pooled_features = pooled_features.view(pooled_features.size(0), -1)
        output = self.classifier(pooled_features)
        return output

# Initialize the classification model with the backbone from your trained combined model
classification_model = ClassificationNet(combined_model.encoder, num_classes=10).to(device)


In [13]:
from torchvision.datasets import STL10
from torch.utils.data import DataLoader
from torchvision import transforms

classification_transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),  # RGB for classification
])

# Load STL10 train and test datasets
stl10_train = STL10(root='../data', split='train', download=True, transform=classification_transform)
stl10_test = STL10(root='../data', split='test', download=True, transform=classification_transform)

# DataLoaders for training and testing
train_loader = DataLoader(stl10_train, batch_size=64, shuffle=True)
test_loader = DataLoader(stl10_test, batch_size=64, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


In [14]:
import torch.optim as optim

# Define loss and optimizer for classification
criterion = nn.CrossEntropyLoss()  # Suitable for multi-class classification
optimizer = optim.Adam(classification_model.parameters(), lr=1e-3)

# Training loop for downstream classification task
num_epochs = 50  # Adjust as needed
for epoch in range(num_epochs):
    classification_model.train()  # Set model to training mode
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = classification_model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

    # Optionally save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save(classification_model.state_dict(), f'models2/classification_model_weights_final2_150_{epoch+1}.pth')

# Save the final classification model
torch.save(classification_model.state_dict(), 'models2/classification_model_weights_final2_150.pth')


Epoch [1/50], Loss: 0.4989
Epoch [2/50], Loss: 0.0306
Epoch [3/50], Loss: 0.0147
Epoch [4/50], Loss: 0.0061
Epoch [5/50], Loss: 0.0039
Epoch [6/50], Loss: 0.0035
Epoch [7/50], Loss: 0.0027
Epoch [8/50], Loss: 0.0065
Epoch [9/50], Loss: 0.0235
Epoch [10/50], Loss: 0.0358
Epoch [11/50], Loss: 0.0058
Epoch [12/50], Loss: 0.0036
Epoch [13/50], Loss: 0.0060
Epoch [14/50], Loss: 0.0219
Epoch [15/50], Loss: 0.0320
Epoch [16/50], Loss: 0.0039
Epoch [17/50], Loss: 0.0017
Epoch [18/50], Loss: 0.0009
Epoch [19/50], Loss: 0.0025
Epoch [20/50], Loss: 0.0152
Epoch [21/50], Loss: 0.0091
Epoch [22/50], Loss: 0.0405
Epoch [23/50], Loss: 0.0196
Epoch [24/50], Loss: 0.0040
Epoch [25/50], Loss: 0.0028
Epoch [26/50], Loss: 0.0296
Epoch [27/50], Loss: 0.0120
Epoch [28/50], Loss: 0.0289
Epoch [29/50], Loss: 0.0081
Epoch [30/50], Loss: 0.0047
Epoch [31/50], Loss: 0.0017
Epoch [32/50], Loss: 0.0014
Epoch [33/50], Loss: 0.0008
Epoch [34/50], Loss: 0.0009
Epoch [35/50], Loss: 0.0008
Epoch [36/50], Loss: 0.0003
E

In [15]:
# Evaluation on test data
classification_model.eval()  # Set model to evaluation mode
correct = 0
total = 0
top_3_correct = 0
top_5_correct = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = classification_model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Top-3 and Top-5 accuracy calculations
        _, predicted_3 = torch.topk(outputs, k=3, dim=1)
        correct_3 = predicted_3.eq(labels.view(-1, 1).expand_as(predicted_3))
        top_3_correct += correct_3.any(dim=1).sum().item()

        _, predicted_5 = torch.topk(outputs, k=5, dim=1)
        correct_5 = predicted_5.eq(labels.view(-1, 1).expand_as(predicted_5))
        top_5_correct += correct_5.any(dim=1).sum().item()

# Calculate and print accuracies
accuracy = 100 * correct / total
top_3_accuracy = 100 * top_3_correct / total
top_5_accuracy = 100 * top_5_correct / total
print(f'Top-1 Accuracy: {accuracy:.2f}%')
print(f'Top-3 Accuracy: {top_3_accuracy:.2f}%')
print(f'Top-5 Accuracy: {top_5_accuracy:.2f}%')


Top-1 Accuracy: 61.00%
Top-3 Accuracy: 86.15%
Top-5 Accuracy: 93.91%
