In [None]:
import os
import torch
import torchvision
import torchvision.models as models
import torch.nn.functional as F
from torchvision.io import read_image
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
import zipfile
import math
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns

In [None]:
dataset_path = "/kaggle/input/landscape/landscape Images"

In [None]:
MANUAL_SEED = 42
BATCH_SIZE = 32
WIDTH = 150
HEIGHT = 150
SHUFFLE = True
TRAINING_SIZE = 0.8

In [None]:
import numpy as np

class LandscapeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.dataroot = root_dir
        self.images = os.listdir(f'{self.dataroot}/color')
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]

        color_img = Image.open(f'{self.dataroot}/color/{img_path}').convert('RGB')
        gray_img = Image.open(f'{self.dataroot}/gray/{img_path}').convert('L')

        if self.transform:
            color_img = self.transform(color_img)
            gray_img = self.transform(gray_img)

        return color_img, gray_img


In [None]:
transform = transforms.Compose([
    transforms.Resize((WIDTH, HEIGHT)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor()
])

dataset = LandscapeDataset(root_dir=dataset_path, transform=transform)

# Split the data into train and test data
train_set, test_set = random_split(
    dataset, 
    [int(TRAINING_SIZE * len(dataset)), len(dataset) - int(TRAINING_SIZE * len(dataset))], 
    generator=torch.Generator().manual_seed(MANUAL_SEED)
)

In [None]:
trainloader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=SHUFFLE)
testloader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=SHUFFLE)

In [None]:
def show_images(color, gray):
    fig, axs = plt.subplots(5, 2, figsize=(15, 15))
    axs[0, 0].set_title('Grayscale')
    axs[0, 1].set_title('Color')
    for i in range(5):
        axs[i, 0].imshow(gray[i].permute(1, 2, 0), cmap='gray')
        axs[i, 0].axis('off')
        axs[i, 1].imshow(color[i].permute(1, 2, 0))
        axs[i, 1].axis('off')
    plt.show()

In [None]:
color, gray = next(iter(trainloader))
show_images(color, gray)

In [None]:
EPOCHS = 3
LEARNING_RATE = 0.001
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

DEVICE

In [None]:
class EarlyStopping:
    def __init__(self, monitor, patience=3, verbose=False, delta=0, path='checkpoint.pt', max_accuracy=0.95, trace_func=print):
        self.monitor = monitor
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.acc_max = -np.Inf
        self.delta = delta
        self.path = path
        self.max_accuracy = max_accuracy
        self.trace_func = trace_func

    def __call__(self, acc, model):
        score = acc

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(acc, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(acc, model)
            self.counter = 0

        if acc >= self.max_accuracy:
            self.early_stop = True
            if self.verbose:
                self.trace_func(f'Maximum accuracy of {self.max_accuracy} reached. Stopping training.')

    def save_checkpoint(self, acc, model):
        if self.verbose:
            self.trace_func(f'Accuracy increased ({self.acc_max} --> {acc}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.acc_max = acc


In [None]:
class ColorAutoEncoder(nn.Module):
    def __init__(self):
        super(ColorAutoEncoder, self).__init__()
        densenet = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1, bias=False),  # Adjusting input channel to 1
            *list(densenet.features.children())  # Use the features part of DenseNet121
        )
        
        # Define the decoder architecture
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()  # Using sigmoid to map the output between 0 and 1
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = ColorAutoEncoder().to(DEVICE)

In [None]:
model_path = '/kaggle/input/checkpoint/pytorch/v20/1/'
files = os.listdir(model_path)
if files:
    print("Loading model...")
    model.load_state_dict(torch.load(model_path+"checkpoint.pt", map_location=torch.device(DEVICE)))
    print(f"Model successfully loaded on {DEVICE}.")
    
torch.save(model.state_dict(), 'color_autoencoder.pth ')

In [None]:
class PerceptualLoss(nn.Module):
    def __init__(self, feature_layer=9):
        super(PerceptualLoss, self).__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
        self.feature_extractor = nn.Sequential(*list(vgg)[:feature_layer]).eval()
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

    def forward(self, pred, target):
        pred_features = self.feature_extractor(pred)
        target_features = self.feature_extractor(target)
        return F.mse_loss(pred_features, target_features)

In [None]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Number of trainable parameters of this model are: {total_params:,}")

criterion = PerceptualLoss().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [None]:
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

In [None]:
model_name = 'v21_checkpoint.pt'
early_stopper = EarlyStopping(monitor='accuracy', max_accuracy=0.8, patience=60, verbose=True, path=model_name)
MAX_EPOCHS = 0

for epoch in range(EPOCHS):
    
#     Start Training
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
#     Get gray image + its colored version
    for idx, (color_img, gray_img) in tqdm(enumerate(trainloader), total=len(trainloader)):
        
#         Send image to GPU
        color_img = color_img.to(DEVICE)
        gray_img = gray_img.float().to(DEVICE)

#         Resize prediction to height x width
        predictions = model(gray_img)
        predictions = F.interpolate(
                predictions, 
                size=(WIDTH, HEIGHT), 
                mode='bilinear', 
                align_corners=True
            )

        optimizer.zero_grad() # Clear past gradient
        loss = criterion(color_img, predictions) # calculate loss
        loss.backward() # calculate loss gradient
        optimizer.step() # optimize parameters

        running_loss += loss.item() # accumulate total loss

#         get prediction class for each pixel
        predicted_classes = torch.argmax(predictions, dim=1) 
        true_classes = torch.argmax(color_img, dim=1)
        
        correct += (predicted_classes == true_classes).sum().item() # move all correctly predicted classes to correct
        total += true_classes.numel()

    train_losses.append(running_loss / len(trainloader)) # append average loss
    train_accuracies.append(correct / total) # append accuracy

    # Start Validation
    model.eval()
    val_running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad(): # disable gradient calculation
        #     Get gray image + its colored version
        for idx, (color_img, gray_img) in tqdm(enumerate(testloader), total=len(testloader)):
            
            #         Send image to GPU
            color_img = color_img.to(DEVICE)
            gray_img = gray_img.to(DEVICE)

            #         Resize prediction to height x width
            predictions = model(gray_img)
            predictions = F.interpolate(
                predictions, 
                size=(WIDTH, HEIGHT), 
                mode='bilinear', 
                align_corners=True
            )

            loss = criterion(predictions, color_img) # calculate loss
            val_running_loss += loss.item() # accumulate validation loss

            #         get prediction class for each pixel
            predicted_classes = torch.argmax(predictions, dim=1)
            true_classes = torch.argmax(color_img, dim=1)
            
            
            correct += (predicted_classes == true_classes).sum().item() # move all correctly predicted classes to correct
            total += true_classes.numel()

    val_losses.append(val_running_loss / len(testloader)) # append average validation loss
    val_accuracies.append(correct / total) # append validation accuracy

    # Print and/or log metrics after each epoch
    print(
        f"Epoch: {epoch + 1} / {EPOCHS}, "
        f"Train Acc: {train_accuracies[-1]}, "
        f"Val Acc: {val_accuracies[-1]}, "
        f"Train Loss: {train_losses[-1]}, "
        f"Val Loss: {val_losses[-1]}"
    )
    
    MAX_EPOCHS = epoch + 1
    early_stopper(train_accuracies[-1], model)

    if early_stopper.early_stop:
        print("Early stopping")
        break

#     Adjust learning rate if using scheduler
    scheduler.step()

model.load_state_dict(torch.load(model_name, map_location=torch.device(DEVICE)))
print('Training Finished!')
torch.save(model.state_dict(), 'color_autoencoder.pth')

In [None]:
class EnhancedColorAutoEncoder(nn.Module):
    def __init__(self):
        super(EnhancedColorAutoEncoder, self).__init__()
        # Load the pre-trained ColorAutoEncoder model
        self.color_autoencoder = ColorAutoEncoder()
        self.color_autoencoder.load_state_dict(torch.load('color_autoencoder.pth', map_location=DEVICE))

        # Freeze the parameters of the ColorAutoEncoder to retain learned features
        for param in self.color_autoencoder.parameters():
            param.requires_grad = False
        
        self.down1 = nn.Conv2d(1, 64, 3, stride=2)
        self.down2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.down3 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
        self.down4 = nn.Conv2d(256, 512, 3, stride=2, padding=1)
        
        self.up0 = nn.ConvTranspose2d(3, 512, kernel_size=3, stride=1, padding=1),
        self.up1 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1)
        self.up2 = nn.ConvTranspose2d(512, 128, 3, stride=2, padding=1)
        self.up3 = nn.ConvTranspose2d(256, 64, 3, stride=2, padding=1, output_padding=1)
        self.up4 = nn.ConvTranspose2d(128, 3, 3, stride=2, output_padding=1)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()      
        

    def forward(self, x):
        d1 = self.relu(self.down1(x))
        d2 = self.relu(self.down2(d1))
        d3 = self.relu(self.down3(d2))
        d4 = self.relu(self.down4(d3))

        u1 = self.relu(self.up1(d4))
        u2 = self.relu(self.up2(torch.cat((u1, d3), dim=1)))
        u3 = self.relu(self.up3(torch.cat((u2, d2), dim=1)))
        u4 = self.sigmoid(self.up4(torch.cat((u3, d1), dim=1)))

        return u4

# Initialize the new model
enhanced_model = EnhancedColorAutoEncoder().to(DEVICE)

In [None]:
# Define a loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(enhanced_model.parameters(), lr=0.001)

# Training loop
for epoch in range(1):
    enhanced_model.train()
    running_loss = 0.0
    
    for idx, (color_img, gray_img) in tqdm(enumerate(trainloader), total=len(trainloader)):
        color_img = color_img.to(DEVICE)
        gray_img = gray_img.float().to(DEVICE)
        
        optimizer.zero_grad()
        
        outputs = enhanced_model(gray_img)
        outputs = F.interpolate(
                outputs, 
                size=(WIDTH, HEIGHT), 
                mode='bilinear', 
                align_corners=True
            )
        
        loss = criterion(outputs, color_img)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {running_loss / len(trainloader)}")

# Save the enhanced model
torch.save(enhanced_model.state_dict(), 'enhanced_color_autoencoder.pth')


In [None]:
def calculate_psnr(img1, img2):
    return peak_signal_noise_ratio(img1, img2, data_range=img2.max() - img2.min())

In [None]:
total_loss = 0.0
total_psnr = 0.0
total_ssim = 0.0
all_true_classes = []
all_predicted_classes = []

In [None]:
with torch.no_grad():
    for idx, (color_img, gray_img) in tqdm(enumerate(testloader), total=len(testloader)):
        color_img = color_img.to(DEVICE)
        gray_img = gray_img.to(DEVICE)

        prediction = enhanced_model(gray_img)
        prediction = F.interpolate(
                prediction, 
                size=(WIDTH, HEIGHT), 
                mode='bilinear', 
                align_corners=True
            )

        loss = criterion(prediction, color_img)
        total_loss += loss.item()

        psnr = calculate_psnr(color_img.cpu().numpy(), prediction.cpu().numpy())
        total_psnr += psnr

        predicted_classes = torch.argmax(prediction, dim=1)
        true_classes = torch.argmax(color_img, dim=1)
        all_true_classes.extend(true_classes.cpu().numpy().flatten())
        all_predicted_classes.extend(predicted_classes.cpu().numpy().flatten())


In [None]:
avg_psnr = total_psnr / len(testloader)

print(f"Total Testing loss is: {total_loss / len(testloader)}")
print(f"Average PSNR: {avg_psnr}")

In [None]:
# Compute confusion matrix
conf_matrix = confusion_matrix(all_true_classes, all_predicted_classes)

# Compute accuracy
accuracy = accuracy_score(all_true_classes, all_predicted_classes)

print(f"Confusion Matrix:\n{conf_matrix}")
print(f"Accuracy: {accuracy}")

# Plot confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, cmap="Blues", fmt="d", xticklabels=True, yticklabels=True)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix')
plt.show()

In [None]:
def show_predictions(color, gray, pred):
    fig, axs = plt.subplots(5, 3, figsize=(15, 15))
    axs[0, 0].set_title('Grayscale')
    axs[0, 1].set_title('Color')
    axs[0, 2].set_title('Predicted')
    for i in range(5):
        axs[i, 0].axis('off')
        axs[i, 0].imshow(gray[i].permute(1, 2, 0), cmap='gray')
        axs[i, 1].axis('off')
        axs[i, 1].imshow(color[i].permute(1, 2, 0))
        axs[i, 2].axis('off')
        axs[i, 2].imshow(pred[i].permute(1, 2, 0))
    plt.show()

show_predictions(color_img.detach().cpu(), gray_img.detach().cpu(), prediction.detach().cpu())

In [None]:
epochs = range(MAX_EPOCHS)

plt.figure(figsize=(15, 5))

plt.subplot(1, 2, 1)
plt.plot(epochs, train_accuracies, label='Train Set')
plt.plot(epochs, val_accuracies, label='Val Set')
plt.title('Model Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs, train_losses, label='Train Set')
plt.plot(epochs, val_losses, label='Val Set')
plt.title('Model Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()