In [3]:
import cv2
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import STL10
from torchvision.models import resnet18
import torch.nn as nn
from colorization import Colorization, RGB2LabTransform, STL10ColorizationDataset, EarlyStopping
from inpainting import Encoder, Decoder, InpaintingModel, Discriminator, mask_image

In [3]:
import torch
from torchvision import datasets, transforms

# Data augmentation and normalization for training
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load STL10 dataset
stl10_pretrain = STL10ColorizationDataset(root='../data', split='train+unlabeled', download=True, transform=transform)

# DataLoader to feed batches for training
pretrain_loader = DataLoader(stl10_pretrain, batch_size=64, shuffle=True)

for L_channel, AB_channels in pretrain_loader:
    print(f"L_channel shape: {L_channel.shape}, AB_channels shape: {AB_channels.shape}")
    break  # Just checking one batch


Files already downloaded and verified
L_channel shape: torch.Size([64, 1, 96, 96]), AB_channels shape: torch.Size([64, 2, 96, 96])


In [10]:
if torch.cuda.is_available():
    device = torch.device("cuda")  # Use NVIDIA GPU
    print('cuda')
elif torch.backends.mps.is_available():
    device = torch.device("mps")   # Use Apple's Metal (for M1/M2 Macs)
    print('mps')
else:
    device = torch.device("cpu") 
    print('cpu')

cuda


In [5]:
from torchvision.models import resnet18
import torch.nn as nn
import torch

PATH = '../inpainting/models/v3/inpainting_model_gen_weights_epoch_100.pth'
checkpoint= torch.load(PATH, map_location=torch.device('cuda'))

generator = InpaintingModel().to(device)  

# Load the state dictionary
# checkpoint = torch.load(path_to_weights)

# Load the weights into the generator
generator.load_state_dict(checkpoint)

backbone = generator.encoder


  checkpoint= torch.load(PATH, map_location=torch.device('cuda'))


using new encoder


In [6]:
backbone

Encoder(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [7]:
colorization_model = Colorization(backbone)
colorization_model = colorization_model.to(device)

In [9]:
# Training loop
num_epochs = 100
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(colorization_model.parameters(), lr=1e-3)

early_stop = EarlyStopping(paitence=15,min_delta=0.000001)


for epoch in range(num_epochs):
    total_loss = 0.0  # Initialize total loss for this epoch
    num_batches = 0   # Keep track of the number of batches

    # Training loop for the current epoch
    for L_channel, AB_channels in pretrain_loader:
        # Move data to the same device as the model
        L_channel = L_channel.to(device)
        AB_channels = AB_channels.to(device)
        
        L_channel_rgb = L_channel.repeat(1, 3, 1, 1)  # Shape: [batch_size, 3, 96, 96]

        # Forward pass
        predicted_AB = colorization_model(L_channel_rgb)

        # Compute loss
        loss = criterion(predicted_AB, AB_channels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate loss and count batches
        total_loss += loss.item()
        num_batches += 1

    # Calculate average loss for the epoch
    average_loss = total_loss / num_batches
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {average_loss}')

    early_stop(average_loss)
    # print(early_stop.counter)
    if early_stop.early_stop:
        print("Early Stopping Triggered. No improves in Loss for the last 10 epochs")
        break
    
    if (epoch + 1) % 10 == 0:
        torch.save(colorization_model.state_dict(), f'models/inpaint_colorization_model_weights_epoch_{epoch+1}.pth')

    if average_loss < early_stop.best_loss:
        best_model_weights = colorization_model.state_dict()

torch.save(colorization_model.state_dict(), 'models/inpaint_colorization_model_weights_final.pth')
torch.save(best_model_weights, 'models/inpaint_colorization_best_model_weights_final.pth')


Epoch [1/100], Loss: 0.0037259581041369657
Epoch [2/100], Loss: 0.002531736677056797
Epoch [3/100], Loss: 0.0024126278343488444
Epoch [4/100], Loss: 0.002355252119393578
Epoch [5/100], Loss: 0.0023185373907500173
Epoch [6/100], Loss: 0.0022954427377371335
Epoch [7/100], Loss: 0.0022696789941766926
Epoch [8/100], Loss: 0.0022274922458358484
Epoch [9/100], Loss: 0.002180603119992469
Epoch [10/100], Loss: 0.0021143020319398824
Epoch [11/100], Loss: 0.0020320808487569167
Epoch [12/100], Loss: 0.0019357744639819984
Epoch [13/100], Loss: 0.0018537565733781135
Epoch [14/100], Loss: 0.0017719569898083905
Epoch [15/100], Loss: 0.0017110153739434246
Epoch [16/100], Loss: 0.0016657576189245313
Epoch [17/100], Loss: 0.001617744766058835
Epoch [18/100], Loss: 0.001570979261794681
Epoch [19/100], Loss: 0.0015394842611242139
Epoch [20/100], Loss: 0.0015125334373375062
Epoch [21/100], Loss: 0.001482569837944422
Epoch [22/100], Loss: 0.0014625433207410663
Epoch [23/100], Loss: 0.0014393441648648724
Epo

In [1]:
import numpy as np
from skimage import color
import matplotlib.pyplot as plt

def visualize_colorization(L_channel, predicted_AB, ground_truth_AB):
    batch_size = L_channel.shape[0]

    for i in range(batch_size):
        # Convert model's output (predicted_AB) to RGB for each sample in the batch
        colorized_image = lab_to_rgb(L_channel[i], predicted_AB[i])

        # Convert the ground truth to RGB for each sample
        ground_truth_rgb = lab_to_rgb(L_channel[i], ground_truth_AB[i])

        # Display the colorized image and the ground truth (visualization code)
        plt.subplot(1, 2, 1)
        plt.imshow(colorized_image)
        plt.title('Predicted Colorization')

        plt.subplot(1, 2, 2)
        plt.imshow(ground_truth_rgb)
        plt.title('Ground Truth')

        plt.show()
        print(f"Predicted AB min: {predicted_AB[i].min():.2f}, max: {predicted_AB[i].max():.2f}")
        print(f"Ground Truth AB min: {ground_truth_AB[i].min():.2f}, max: {ground_truth_AB[i].max():.2f}")

# Convert LAB to RGB
def lab_to_rgb(L_channel, AB_channels):
    # Ensure L_channel has shape [96, 96] and scale it appropriately
    L_channel = L_channel.squeeze().cpu().numpy() * 255
    L_channel = L_channel 

    # Ensure AB_channels has shape [2, 96, 96] and transpose it to [96, 96, 2]
    AB_channels = AB_channels.squeeze().detach().cpu().numpy().transpose(1, 2, 0)
    AB_channels = (AB_channels * 255) - 128

    # Concatenate L and AB channels to form LAB image
    lab_image = np.concatenate((L_channel[:, :, np.newaxis], AB_channels), axis=-1)

    # Convert LAB to RGB using a library like skimage
    rgb_image = color.lab2rgb(lab_image)
    
    return rgb_image


In [2]:
visualize_colorization(L_channel, predicted_AB, AB_channels)


NameError: name 'L_channel' is not defined

In [8]:
from torchvision.models import resnet18
import torch.nn as nn
import torch
from collections import OrderedDict

PATH = 'models/inpaint_colorization_model_weights_epoch_100.pth'

checkpoint = torch.load(PATH, map_location=torch.device('cuda'))

new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    new_key = k.replace("backbone.encoder.","backbone." )  # Modify based on your structure
    new_state_dict[new_key] = v



backbone = resnet18(weights=None)
backbone = nn.Sequential(*list(backbone.children())[:-2])

colorization_model = Colorization(backbone)
# colorization_model = colorization_model.to(device)
colorization_model.load_state_dict(new_state_dict)



  checkpoint = torch.load(PATH, map_location=torch.device('cuda'))


<All keys matched successfully>

In [11]:
class ClassificationNet(nn.Module):
    def __init__(self, backbone, num_classes):
        super(ClassificationNet, self).__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        features = self.backbone(x)
        pooled_features = nn.AdaptiveAvgPool2d((1, 1))(features)
        pooled_features = pooled_features.view(pooled_features.size(0), -1)
        output = self.classifier(pooled_features)
        return output

classification_model = ClassificationNet(colorization_model.backbone, num_classes=10).to(device)

In [12]:
classification_transform = transforms.Compose([
    # transforms.RandomResizedCrop(96),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),  # RGB for classification
])

In [13]:
stl10_train = STL10(root='../data', split='train', download=True, transform=classification_transform)
stl10_test = STL10(root='../data', split='test', download=True, transform=classification_transform)

# Fine-tuning: Load training data for classification task
train_loader = DataLoader(stl10_train, batch_size=64, shuffle=True)

# Testing: Load test data for final evaluation
test_loader = DataLoader(stl10_test, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [14]:
criterion = nn.CrossEntropyLoss()  # Suitable for multi-class classification
optimizer = torch.optim.Adam(classification_model.parameters(), lr=1e-3)

# Training Loop
num_epochs = 150
for epoch in range(num_epochs):
    classification_model.train()  # Set model to training mode
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = classification_model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

    if (epoch + 1) % 10 == 0:
        torch.save(classification_model.state_dict(), f'models/downstream/classification_model_weights_epoch_{epoch+1}.pth')



PATH = 'models/downstream/classification_model_weights_final.pth'
torch.save(classification_model.state_dict(), PATH)

Epoch [1/150], Loss: 1.8439
Epoch [2/150], Loss: 0.9399
Epoch [3/150], Loss: 0.4614
Epoch [4/150], Loss: 0.1248
Epoch [5/150], Loss: 0.0304
Epoch [6/150], Loss: 0.0102
Epoch [7/150], Loss: 0.0116
Epoch [8/150], Loss: 0.0204
Epoch [9/150], Loss: 0.0215
Epoch [10/150], Loss: 0.0067
Epoch [11/150], Loss: 0.0164
Epoch [12/150], Loss: 0.0400
Epoch [13/150], Loss: 0.0211
Epoch [14/150], Loss: 0.0203
Epoch [15/150], Loss: 0.0151
Epoch [16/150], Loss: 0.0440
Epoch [17/150], Loss: 0.0414
Epoch [18/150], Loss: 0.0090
Epoch [19/150], Loss: 0.0233
Epoch [20/150], Loss: 0.0267
Epoch [21/150], Loss: 0.0709
Epoch [22/150], Loss: 0.0388
Epoch [23/150], Loss: 0.0106
Epoch [24/150], Loss: 0.0183
Epoch [25/150], Loss: 0.0407
Epoch [26/150], Loss: 0.0158
Epoch [27/150], Loss: 0.0067
Epoch [28/150], Loss: 0.0250
Epoch [29/150], Loss: 0.0273
Epoch [30/150], Loss: 0.0080
Epoch [31/150], Loss: 0.0065
Epoch [32/150], Loss: 0.0177
Epoch [33/150], Loss: 0.0142
Epoch [34/150], Loss: 0.0250
Epoch [35/150], Loss: 0

In [15]:
# Evaluation
classification_model.eval()  # Set model to evaluation mode
correct = 0
top_5_correct = 0
top_3_correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = classification_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        _, predicted_3 = torch.topk(outputs.data, k=3, dim=1)
        correct_3 = predicted_3.eq(labels.unsqueeze(1).expand_as(predicted_3))
        top_3_correct += correct_3.any(dim=1).sum().item()

        _, predicted_5 = torch.topk(outputs.data, k=5, dim=1)
        correct_5 = predicted_5.eq(labels.unsqueeze(1).expand_as(predicted_5))
        top_5_correct += correct_5.any(dim=1).sum().item()



accuracy = 100 * correct / total
top_5 = 100 * top_5_correct / total
top_3 = 100 * top_3_correct / total
print(f'Top-1 Accuracy of the model on the test set: {accuracy:.2f}%')
print(f'Top-5 Accuracy of the model on the test set: {top_5:.2f}%')
print(f'Top-3 Accuracy of the model on the test set: {top_3:.2f}%')

Top-1 Accuracy of the model on the test set: 64.50%
Top-5 Accuracy of the model on the test set: 94.97%
Top-3 Accuracy of the model on the test set: 88.01%
