In [1]:
from src.Metrics import batch_metrics
from src.Models import NetworkColor, load_model
from src.Datasets import CocoDataset
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from skimage import color
from tqdm import tqdm
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device used : {device}")

device used : cuda


In [2]:
model = NetworkColor().to(device)
load_model(model, "models/NetworkColor.pt")

NetworkColor(
  (conv1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv3): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv4): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv5): Sequential(
    (0): Conv2d(256, 256, kernel_size=(4, 4), stride=(1, 1), padding=(3, 3), dilation=(2, 2))
    (1): BatchNorm2d(256, eps=1e-05, 

In [3]:
root_dir = "./ressources"

transform = v2.Compose([
	color.rgb2lab,
	v2.ToImage(),
	v2.ToDtype(torch.float32),
	v2.Resize((128,128), antialias=True),
	v2.Normalize(((-126.,-126.,0.)), ((256.,256.,100.)))
])

UnNormalize = v2.Compose([
	v2.Normalize((0.,0.,0.), (1/256.,1/256.,1/100.)),
	v2.Normalize((126.,126.,0.), (1.,1.,1.))
])

train_dataset = CocoDataset(root=root_dir, split="train", transform=transform)
test_dataset = CocoDataset(root=root_dir, split="test", transform=transform)

batch_size = 256

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [4]:
from src.Metrics import compute_psnr

img1, img2 = train_dataset[1]
img1, img2 = img1.numpy(), img2.numpy()
print(img1.shape, img2.shape)
print(compute_psnr(img1, img2))

(1, 128, 128) (3, 128, 128)
58.44602971004501


In [5]:
def metrics_loop(dataloader, model, metric = "PSNR"):
	size = len(dataloader.dataset)
	model = model.to(device)
	model.eval()
	total = 0
	with torch.no_grad():
		for batch in tqdm(dataloader, desc="metrics"):
			total += batch_metrics(model, batch, device, metric)
	return total / size

In [6]:
# Compute metrics
PSNR_train = metrics_loop(train_loader, model, metric="PSNR")
PSNR_test = metrics_loop(test_loader, model, metric="PSNR")
SSIM_train = metrics_loop(train_loader, model, metric="SSIM")
SSIM_test = metrics_loop(test_loader, model, metric="SSIM")
MSE_train = metrics_loop(train_loader, model, metric="MSE")
MSE_test = metrics_loop(test_loader, model, metric="MSE")

print(f"PSNR train : {PSNR_train}")
print(f"PSNR test : {PSNR_test}")
print(f"SSIM train : {SSIM_train}")
print(f"SSIM test : {SSIM_test}")
print(f"MSE train : {MSE_train}")
print(f"MSE test : {MSE_test}")


metrics:   0%|          | 0/159 [00:11<?, ?it/s]


AttributeError: 'list' object has no attribute 'shape'