In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from skimage import io, color

import torch
from torch.utils.data import random_split
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

from Color import Colorizer, train_model, generate_colored_image
from util import calculate_metrics, collect_image_paths,CustomDataset

In [None]:
batch_size = 128
num_epochs = 30
learning_rate = 1e-3
use_gpu = True

In [None]:
# Define the data transformation for L channel
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.to(torch.float32)),
])

data_dir = "./data/stanford"  # Path to the 'stanford' directory
image_filenames = collect_image_paths(data_dir)

In [None]:
# Create a dataset using the filtered image paths
dataset = CustomDataset(image_filenames, transform=data_transform)
# dataset = Subset(dataset, range(10000))


# Define the sizes of the training and testing datasets
train_size = 0.7  # 70% for training
validation_size = 0.15  # 15% for validation
test_size = 0.15  # 15% for testing

# Split the dataset
train_dataset, test_dataset, validation_dataset = random_split(dataset, [train_size, test_size, validation_size])

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=True)

device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")

In [None]:
cnet = Colorizer()
cnet = cnet.to(device)

optimizer = torch.optim.Adam(params=cnet.parameters(), lr=learning_rate, weight_decay=1e-4)

In [None]:
cnet, psnr, ssim = train_model(cnet, num_epochs, train_dataloader, validation_dataloader, device, optimizer)

In [None]:
psnr_test, ssim_test = calculate_metrics(cnet, test_dataloader, device)
print('Test - PSNR: ',psnr_test, ' - SSIM: ',ssim_test )

In [None]:
name = ...
torch.save(cnet.state_dict(), 'models/'+str(name))

In [None]:
fig = plt.figure(figsize=(15, 5))
plt.plot(ssim)
plt.xlabel('Epochs')
plt.ylabel('SSIM')
plt.show()

In [None]:
fig = plt.figure(figsize=(15, 5))
plt.plot(psnr)
plt.xlabel('Epochs')
plt.ylabel('PSNR')
plt.show()