In [1]:
import cv2
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import STL10
from torchvision.models import resnet18
import torch.nn as nn
import numpy as np
from skimage import color
import matplotlib.pyplot as plt
from collections import OrderedDict
from torch.optim import Adam

# Import custom modules from provided files
from colorization import Colorization, RGB2LabTransform, STL10ColorizationDataset, EarlyStopping
from inpainting import Encoder, Decoder, InpaintingModel, Discriminator, mask_image


In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define a combined model class with both colorization and inpainting models
class CombinedModel(nn.Module):
    def __init__(self, backbone):
        super(CombinedModel, self).__init__()
        self.colorization_model = Colorization(backbone)
        self.inpainting_model = InpaintingModel()

    def forward_colorization(self, L_channel_rgb):
        return self.colorization_model(L_channel_rgb)

    def forward_inpainting(self, masked_images):
        return self.inpainting_model(masked_images)

# Initialize the backbone for both models
backbone = resnet18(pretrained=True)
backbone = nn.Sequential(*list(backbone.children())[:-2])  # Removing the last layers for feature extraction
combined_model = CombinedModel(backbone).to(device)



using new encoder


In [3]:
# Helper function to convert LAB to RGB
def lab_to_rgb(L_channel, AB_channels):
    L_channel = L_channel.squeeze().cpu().numpy() * 255
    AB_channels = AB_channels.squeeze().detach().cpu().numpy().transpose(1, 2, 0)
    AB_channels = (AB_channels * 255) - 128

    lab_image = np.concatenate((L_channel[:, :, np.newaxis], AB_channels), axis=-1)
    rgb_image = color.lab2rgb(lab_image)
    
    return rgb_image

In [4]:
# Load STL10 dataset and apply necessary transformations
transform = transforms.Compose([transforms.ToTensor()])

# Separate datasets for colorization and inpainting
stl10_colorization = STL10ColorizationDataset(root='../data', split='train+unlabeled', download=True, transform=transform)
stl10_inpainting = STL10(root='../data', split='train+unlabeled', download=True, transform=transform)

# DataLoaders for colorization and inpainting tasks
colorization_loader = DataLoader(stl10_colorization, batch_size=64, shuffle=True)
inpainting_loader = DataLoader(stl10_inpainting, batch_size=64, shuffle=True)

# Optimizer and loss functions
optimizer = Adam(combined_model.parameters(), lr=1e-3)
colorization_criterion = nn.MSELoss()
inpainting_criterion = nn.BCELoss()  # Assuming binary cross-entropy loss for inpainting


Files already downloaded and verified
Files already downloaded and verified


In [5]:
# Early stopping configuration
class EarlyStopping:
    def __init__(self, patience=15, min_delta=1e-6):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, loss):
        if self.best_loss is None:
            self.best_loss = loss
        elif loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = loss
            self.counter = 0

early_stop = EarlyStopping(patience=15, min_delta=1e-6)


Both Tasks per Epoch: Each epoch includes both colorization and inpainting tasks back-to-back. The model first processes colorization updates and then inpainting updates within the same epoch, ensuring that both tasks influence each other immediately within each epoch. This setup allows the model to build upon the shared learning of both tasks within a single epoch.

In [6]:

'''
Both Tasks per Epoch: Each epoch includes both colorization and inpainting tasks back-to-back. 
The model first processes colorization updates and then inpainting updates within the same epoch, ensuring that both tasks influence each other immediately within each epoch. 
This setup allows the model to build upon the shared learning of both tasks within a single epoch.
'''
# # Set mask size for inpainting
# mask_size = 8  # Adjust as needed

# # Training loop with combined pretext tasks per epoch
# num_epochs = 100
# for epoch in range(num_epochs):
#     combined_model.train()
#     total_loss_colorization = 0.0
#     total_loss_inpainting = 0.0

#     # Colorization task within each epoch
#     for L_channel, AB_channels in colorization_loader:
#         L_channel, AB_channels = L_channel.to(device), AB_channels.to(device)
#         L_channel_rgb = L_channel.repeat(1, 3, 1, 1)  # Convert grayscale to RGB input format

#         # Forward pass for colorization
#         predicted_AB = combined_model.forward_colorization(L_channel_rgb)
#         loss_colorization = colorization_criterion(predicted_AB, AB_channels)

#         # Backward pass and optimization for colorization
#         optimizer.zero_grad()
#         loss_colorization.backward()
#         optimizer.step()
#         total_loss_colorization += loss_colorization.item()

#     avg_loss_colorization = total_loss_colorization / len(colorization_loader)
#     print(f'Epoch [{epoch+1}/{num_epochs}] Colorization Loss: {avg_loss_colorization:.4f}')
    
#     # Inpainting task within the same epoch
#     for images, _ in inpainting_loader:
#         images = images.to(device)
#         # Apply masking with specified mask_size
#         masked_images, masks = mask_image(images, mask_size=mask_size)
#         masked_images, masks = masked_images.to(device), masks.to(device)

#         # Forward pass for inpainting
#         reconstructed_images = combined_model.forward_inpainting(masked_images)
#         loss_inpainting = inpainting_criterion(reconstructed_images, images)

#         # Backward pass and optimization for inpainting
#         optimizer.zero_grad()
#         loss_inpainting.backward()
#         optimizer.step()
#         total_loss_inpainting += loss_inpainting.item()

#     avg_loss_inpainting = total_loss_inpainting / len(inpainting_loader)
#     print(f'Epoch [{epoch+1}/{num_epochs}] Inpainting Loss: {avg_loss_inpainting:.4f}')

#     # Early stopping check based on combined average loss
#     combined_avg_loss = (avg_loss_colorization + avg_loss_inpainting) / 2
#     early_stop(combined_avg_loss)
#     if early_stop.early_stop:
#         print("Early Stopping Triggered")
#         break

# # Save the final combined model
# torch.save(combined_model.state_dict(), 'models/combined_inpaint_colorization_final.pth')



'''
Help by ChatGPT
Explanation of the Code
Alternating Tasks per Batch: Each epoch now alternates between colorization and inpainting tasks on a batch-by-batch basis. 
This way, the model receives an update from a colorization batch, which immediately influences the inpainting batch that follows within the same epoch.

Iterating Over Both Loaders Simultaneously: We use two iterators (colorization_iter and inpainting_iter) and call next() on each loader to retrieve batches for both tasks in each iteration.

Immediate Influence: With this setup, colorization gradients are applied before running the inpainting task in each iteration, allowing immediate influence from colorization to inpainting within the same epoch.

'''

# Set mask size for inpainting
mask_size = 8  # Adjust as needed

# Training loop with combined pretext tasks per epoch
num_epochs = 100
for epoch in range(num_epochs):
    combined_model.train()
    total_loss_colorization = 0.0
    total_loss_inpainting = 0.0

    colorization_iter = iter(colorization_loader)
    inpainting_iter = iter(inpainting_loader)

    # Alternate between colorization and inpainting batches within the epoch
    for i in range(min(len(colorization_loader), len(inpainting_loader))):
        # Colorization task
        try:
            L_channel, AB_channels = next(colorization_iter)
            L_channel, AB_channels = L_channel.to(device), AB_channels.to(device)
            L_channel_rgb = L_channel.repeat(1, 3, 1, 1)  # Convert grayscale to RGB input format

            # Forward pass for colorization
            predicted_AB = combined_model.forward_colorization(L_channel_rgb)
            loss_colorization = colorization_criterion(predicted_AB, AB_channels)

            # Backward pass and optimization for colorization
            optimizer.zero_grad()
            loss_colorization.backward()
            optimizer.step()
            total_loss_colorization += loss_colorization.item()
        except StopIteration:
            pass

        # Inpainting task
        try:
            images, _ = next(inpainting_iter)
            images = images.to(device)

            # Apply masking with specified mask_size
            masked_images, masks = mask_image(images, mask_size=mask_size)
            masked_images, masks = masked_images.to(device), masks.to(device)

            # Forward pass for inpainting
            reconstructed_images = combined_model.forward_inpainting(masked_images)
            loss_inpainting = inpainting_criterion(reconstructed_images, images)

            # Backward pass and optimization for inpainting
            optimizer.zero_grad()
            loss_inpainting.backward()
            optimizer.step()
            total_loss_inpainting += loss_inpainting.item()
        except StopIteration:
            pass

    # Calculate average losses for the epoch
    avg_loss_colorization = total_loss_colorization / len(colorization_loader)
    avg_loss_inpainting = total_loss_inpainting / len(inpainting_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}] Colorization Loss: {avg_loss_colorization:.4f}')
    print(f'Epoch [{epoch+1}/{num_epochs}] Inpainting Loss: {avg_loss_inpainting:.4f}')

    # Early stopping check based on combined average loss
    combined_avg_loss = (avg_loss_colorization + avg_loss_inpainting) / 2
    early_stop(combined_avg_loss)
    if early_stop.early_stop:
        print("Early Stopping Triggered")
        break

# Save the final combined model
torch.save(combined_model.state_dict(), 'combined_inpaint_colorization_fina_1.pth')


Epoch [1/100] Colorization Loss: 0.0029
Epoch [1/100] Inpainting Loss: 0.5537
Epoch [2/100] Colorization Loss: 0.0024
Epoch [2/100] Inpainting Loss: 0.5375
Epoch [3/100] Colorization Loss: 0.0023
Epoch [3/100] Inpainting Loss: 0.5330
Epoch [4/100] Colorization Loss: 0.0022
Epoch [4/100] Inpainting Loss: 0.5306
Epoch [5/100] Colorization Loss: 0.0022
Epoch [5/100] Inpainting Loss: 0.5292
Epoch [6/100] Colorization Loss: 0.0021
Epoch [6/100] Inpainting Loss: 0.5279
Epoch [7/100] Colorization Loss: 0.0019
Epoch [7/100] Inpainting Loss: 0.5270
Epoch [8/100] Colorization Loss: 0.0018
Epoch [8/100] Inpainting Loss: 0.5265
Epoch [9/100] Colorization Loss: 0.0017
Epoch [9/100] Inpainting Loss: 0.5257
Epoch [10/100] Colorization Loss: 0.0016
Epoch [10/100] Inpainting Loss: 0.5250
Epoch [11/100] Colorization Loss: 0.0015
Epoch [11/100] Inpainting Loss: 0.5244
Epoch [12/100] Colorization Loss: 0.0015
Epoch [12/100] Inpainting Loss: 0.5240
Epoch [13/100] Colorization Loss: 0.0014
Epoch [13/100] In

In [7]:
# Visualization function to verify colorization predictions
def visualize_colorization(L_channel, predicted_AB, ground_truth_AB):
    batch_size = L_channel.shape[0]

    for i in range(batch_size):
        # Convert model's output (predicted_AB) to RGB for each sample in the batch
        colorized_image = lab_to_rgb(L_channel[i], predicted_AB[i])

        # Convert the ground truth to RGB for each sample
        ground_truth_rgb = lab_to_rgb(L_channel[i], ground_truth_AB[i])

        # Display the colorized image and the ground truth (visualization code)
        plt.subplot(1, 2, 1)
        plt.imshow(colorized_image)
        plt.title('Predicted Colorization')

        plt.subplot(1, 2, 2)
        plt.imshow(ground_truth_rgb)
        plt.title('Ground Truth')

        plt.show()
        print(f"Predicted AB min: {predicted_AB[i].min():.2f}, max: {predicted_AB[i].max():.2f}")
        print(f"Ground Truth AB min: {ground_truth_AB[i].min():.2f}, max: {ground_truth_AB[i].max():.2f}")


# Downstream epoch 50

In [8]:
import torch.nn as nn

class ClassificationNet(nn.Module):
    def __init__(self, backbone, num_classes):
        super(ClassificationNet, self).__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(512, num_classes)  # Adjust input size based on backbone output

    def forward(self, x):
        features = self.backbone(x)
        pooled_features = nn.AdaptiveAvgPool2d((1, 1))(features)
        pooled_features = pooled_features.view(pooled_features.size(0), -1)
        output = self.classifier(pooled_features)
        return output

# Initialize the classification model with the backbone from your trained combined model
classification_model = ClassificationNet(combined_model.colorization_model.backbone, num_classes=10).to(device)


In [9]:
from torchvision.datasets import STL10
from torch.utils.data import DataLoader
from torchvision import transforms

classification_transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),  # RGB for classification
])

# Load STL10 train and test datasets
stl10_train = STL10(root='../data', split='train', download=True, transform=classification_transform)
stl10_test = STL10(root='../data', split='test', download=True, transform=classification_transform)

# DataLoaders for training and testing
train_loader = DataLoader(stl10_train, batch_size=64, shuffle=True)
test_loader = DataLoader(stl10_test, batch_size=64, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


In [10]:
import torch.optim as optim

# Define loss and optimizer for classification
criterion = nn.CrossEntropyLoss()  # Suitable for multi-class classification
optimizer = optim.Adam(classification_model.parameters(), lr=1e-3)

# Training loop for downstream classification task
num_epochs = 50  # Adjust as needed
for epoch in range(num_epochs):
    classification_model.train()  # Set model to training mode
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = classification_model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

    # Optionally save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save(classification_model.state_dict(), f'models/classification_model_weights_epoch_{epoch+1}.pth')

# Save the final classification model
torch.save(classification_model.state_dict(), 'models/classification_model_weights_final.pth')


Epoch [1/50], Loss: 1.7707
Epoch [2/50], Loss: 0.7277
Epoch [3/50], Loss: 0.2356
Epoch [4/50], Loss: 0.0495
Epoch [5/50], Loss: 0.0140
Epoch [6/50], Loss: 0.0066
Epoch [7/50], Loss: 0.0089
Epoch [8/50], Loss: 0.0116
Epoch [9/50], Loss: 0.0168
Epoch [10/50], Loss: 0.0162
Epoch [11/50], Loss: 0.0077
Epoch [12/50], Loss: 0.0114
Epoch [13/50], Loss: 0.0125
Epoch [14/50], Loss: 0.0243
Epoch [15/50], Loss: 0.0248
Epoch [16/50], Loss: 0.0215
Epoch [17/50], Loss: 0.0239
Epoch [18/50], Loss: 0.0175
Epoch [19/50], Loss: 0.0199
Epoch [20/50], Loss: 0.0045
Epoch [21/50], Loss: 0.0030
Epoch [22/50], Loss: 0.0047
Epoch [23/50], Loss: 0.0126
Epoch [24/50], Loss: 0.0287
Epoch [25/50], Loss: 0.0338
Epoch [26/50], Loss: 0.0288
Epoch [27/50], Loss: 0.0102
Epoch [28/50], Loss: 0.0124
Epoch [29/50], Loss: 0.0053
Epoch [30/50], Loss: 0.0095
Epoch [31/50], Loss: 0.0068
Epoch [32/50], Loss: 0.0067
Epoch [33/50], Loss: 0.0044
Epoch [34/50], Loss: 0.0168
Epoch [35/50], Loss: 0.0118
Epoch [36/50], Loss: 0.0028
E

In [11]:
# Evaluation on test data
classification_model.eval()  # Set model to evaluation mode
correct = 0
total = 0
top_3_correct = 0
top_5_correct = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = classification_model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Top-3 and Top-5 accuracy calculations
        _, predicted_3 = torch.topk(outputs, k=3, dim=1)
        correct_3 = predicted_3.eq(labels.view(-1, 1).expand_as(predicted_3))
        top_3_correct += correct_3.any(dim=1).sum().item()

        _, predicted_5 = torch.topk(outputs, k=5, dim=1)
        correct_5 = predicted_5.eq(labels.view(-1, 1).expand_as(predicted_5))
        top_5_correct += correct_5.any(dim=1).sum().item()

# Calculate and print accuracies
accuracy = 100 * correct / total
top_3_accuracy = 100 * top_3_correct / total
top_5_accuracy = 100 * top_5_correct / total
print(f'Top-1 Accuracy: {accuracy:.2f}%')
print(f'Top-3 Accuracy: {top_3_accuracy:.2f}%')
print(f'Top-5 Accuracy: {top_5_accuracy:.2f}%')


Top-1 Accuracy: 71.54%
Top-3 Accuracy: 91.64%
Top-5 Accuracy: 96.90%


# Downstream Epoch 100

In [12]:
import torch.nn as nn

class ClassificationNet(nn.Module):
    def __init__(self, backbone, num_classes):
        super(ClassificationNet, self).__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(512, num_classes)  # Adjust input size based on backbone output

    def forward(self, x):
        features = self.backbone(x)
        pooled_features = nn.AdaptiveAvgPool2d((1, 1))(features)
        pooled_features = pooled_features.view(pooled_features.size(0), -1)
        output = self.classifier(pooled_features)
        return output

# Initialize the classification model with the backbone from your trained combined model
classification_model = ClassificationNet(combined_model.colorization_model.backbone, num_classes=10).to(device)


In [13]:
from torchvision.datasets import STL10
from torch.utils.data import DataLoader
from torchvision import transforms

classification_transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),  # RGB for classification
])

# Load STL10 train and test datasets
stl10_train = STL10(root='../data', split='train', download=True, transform=classification_transform)
stl10_test = STL10(root='../data', split='test', download=True, transform=classification_transform)

# DataLoaders for training and testing
train_loader = DataLoader(stl10_train, batch_size=64, shuffle=True)
test_loader = DataLoader(stl10_test, batch_size=64, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


In [14]:
import torch.optim as optim

# Define loss and optimizer for classification
criterion = nn.CrossEntropyLoss()  # Suitable for multi-class classification
optimizer = optim.Adam(classification_model.parameters(), lr=1e-3)

# Training loop for downstream classification task
num_epochs = 100  # Adjust as needed
for epoch in range(num_epochs):
    classification_model.train()  # Set model to training mode
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = classification_model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

    # Optionally save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save(classification_model.state_dict(), f'models/classification_model_weights_epoch_150_{epoch+1}.pth')

# Save the final classification model
torch.save(classification_model.state_dict(), 'models/classification_model_weights_final_epoch150.pth')


Epoch [1/100], Loss: 0.3744
Epoch [2/100], Loss: 0.0170
Epoch [3/100], Loss: 0.0082
Epoch [4/100], Loss: 0.0080
Epoch [5/100], Loss: 0.0077
Epoch [6/100], Loss: 0.0040
Epoch [7/100], Loss: 0.0036
Epoch [8/100], Loss: 0.0105
Epoch [9/100], Loss: 0.0164
Epoch [10/100], Loss: 0.0043
Epoch [11/100], Loss: 0.0102
Epoch [12/100], Loss: 0.0521
Epoch [13/100], Loss: 0.0160
Epoch [14/100], Loss: 0.0236
Epoch [15/100], Loss: 0.0127
Epoch [16/100], Loss: 0.0029
Epoch [17/100], Loss: 0.0015
Epoch [18/100], Loss: 0.0017
Epoch [19/100], Loss: 0.0009
Epoch [20/100], Loss: 0.0077
Epoch [21/100], Loss: 0.0217
Epoch [22/100], Loss: 0.0128
Epoch [23/100], Loss: 0.0075
Epoch [24/100], Loss: 0.0035
Epoch [25/100], Loss: 0.0015
Epoch [26/100], Loss: 0.0006
Epoch [27/100], Loss: 0.0011
Epoch [28/100], Loss: 0.0153
Epoch [29/100], Loss: 0.0252
Epoch [30/100], Loss: 0.0263
Epoch [31/100], Loss: 0.0060
Epoch [32/100], Loss: 0.0019
Epoch [33/100], Loss: 0.0088
Epoch [34/100], Loss: 0.0016
Epoch [35/100], Loss: 0

In [15]:
# Evaluation on test data
classification_model.eval()  # Set model to evaluation mode
correct = 0
total = 0
top_3_correct = 0
top_5_correct = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = classification_model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Top-3 and Top-5 accuracy calculations
        _, predicted_3 = torch.topk(outputs, k=3, dim=1)
        correct_3 = predicted_3.eq(labels.view(-1, 1).expand_as(predicted_3))
        top_3_correct += correct_3.any(dim=1).sum().item()

        _, predicted_5 = torch.topk(outputs, k=5, dim=1)
        correct_5 = predicted_5.eq(labels.view(-1, 1).expand_as(predicted_5))
        top_5_correct += correct_5.any(dim=1).sum().item()

# Calculate and print accuracies
accuracy = 100 * correct / total
top_3_accuracy = 100 * top_3_correct / total
top_5_accuracy = 100 * top_5_correct / total
print(f'Top-1 Accuracy: {accuracy:.2f}%')
print(f'Top-3 Accuracy: {top_3_accuracy:.2f}%')
print(f'Top-5 Accuracy: {top_5_accuracy:.2f}%')


Top-1 Accuracy: 70.92%
Top-3 Accuracy: 90.30%
Top-5 Accuracy: 95.60%
