In [1]:
import cv2
import numpy as np
import colour
from visualize_model import visualize
import model_predict as mdl
import torch
import glob
from dataset import make_lab_dataloaders, make_oklab_dataloaders
from color_spaces import lab_to_rgb
import tqdm.notebook as tqdm

In [2]:
torch.manual_seed(42)
np.random.seed(42)

In [3]:
VAL_SIZE = 0.2
TEST_SIZE = 0.2
BATCH_SIZE = 16

In [4]:
from sklearn.model_selection import train_test_split

def get_test():
    paths = glob.glob("dataset/*.jpg") # Your path for your dataset
    paths_subset = np.random.choice(paths, len(paths), replace=False) # choosing 1000 images randomly

    train_paths, test_paths = train_test_split(paths_subset,test_size=TEST_SIZE,shuffle=True)
    train_paths, val_paths = train_test_split(train_paths,test_size=VAL_SIZE / (1- TEST_SIZE),shuffle=True)

    print(len(train_paths), len(val_paths), len(test_paths))
    
    test_dl = make_oklab_dataloaders(BATCH_SIZE,paths=test_paths, split='test')
    return test_dl

test_dl = get_test()

5904 1968 1968


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
def delta_e(pair):
    predicted, target = pair
    predicted_lab = cv2.cvtColor(predicted, cv2.COLOR_RGB2Lab)
    target_lab = cv2.cvtColor(target, cv2.COLOR_RGB2Lab)
    delta_E = colour.delta_E(predicted_lab, target_lab)
    return np.mean(delta_E)

In [7]:
def psnr(pair):
    predicted, target = pair
    predicted_lab = cv2.cvtColor(predicted, cv2.COLOR_RGB2Lab)
    target_lab = cv2.cvtColor(target, cv2.COLOR_RGB2Lab)
    return cv2.PSNR(predicted_lab, target_lab)

In [8]:
def visual_eval(model, test_dl):
    data = next(iter(test_dl))
    visualize(model, data, save=True)

In [9]:
def get_model(path):
    model = mdl.MainModel(net_G=mdl.build_res_unet(n_input=1, n_output=2, size=256))
    model.load_state_dict(torch.load(path, map_location=device))
    return model
model = get_model('models/oklab/res_net_unet_gan.pt')

model initialized with norm initialization


In [10]:
def calculate_metrics(path):
    test_dl = get_test()
    model = get_model(path)

    true_batches, pred_batches = [], []
    for data in tqdm.tqdm(test_dl):
        preds = mdl.colorize(model, data['L'])
        
        true_batches.append(data)
        pred_batches.append(preds)

    psnr_results, delta_e_results = [], []
    for i, true_batch in enumerate(true_batches):
        pred_batch = pred_batches[i]
        true_batch = lab_to_rgb(true_batch["L"], true_batch["ab"], space="Lab")

        pairs = list(zip(pred_batch, true_batch))

        psnr_results.extend(list(map(psnr, pairs.copy())))
        delta_e_results.extend(list(map(delta_e, pairs.copy())))

    return np.mean(psnr_results), np.mean(delta_e_results)

In [11]:
psnr_results_oklab, delta_e_results_oklab = calculate_metrics('models/oklab/res_net_unet_gan.pt')
psnr_results_oklab, delta_e_results_oklab

5904 1968 1968
model initialized with norm initialization


  0%|          | 0/123 [00:00<?, ?it/s]

(38.698163104637246, 4.5642122413832045)

In [12]:
psnr_results_lab, delta_e_results_lab = calculate_metrics('models/lab/res_net_unet_gan.pt')
psnr_results_lab, delta_e_results_lab

5904 1968 1968
model initialized with norm initialization


  0%|          | 0/123 [00:00<?, ?it/s]

(32.429419575599375, 7.8789199569279749)