In [1]:
%cd ..

/home/azureuser/notebooks/sketch-to-artwork


In [3]:
# Isntall necessary packages
# !pip install torchmetrics[image]

Collecting torchmetrics
  Downloading torchmetrics-0.6.0-py3-none-any.whl (329 kB)
[K     |████████████████████████████████| 329 kB 26.3 MB/s eta 0:00:01
Installing collected packages: torchmetrics
Successfully installed torchmetrics-0.6.0


In [7]:
import os

import numpy as np

from torchmetrics import FID, LPIPS
import torchvision
import torch
import torch.utils.data
from tqdm import tqdm

from taming.data.wikiart import WikiartEdgesTrain, WikiartEdgesTest
from taming.data.base import ImagePaths

class CustomDataset(ImagePaths):
    def __init__(self, root):
        paths = os.listdir(root)
        paths = [os.path.join(root, fname) for fname in paths if fname[len(fname)-3:] == "png"]
        super().__init__(paths, size=256)

def convert_to_uint8(images_float):
    return (torch.clamp(images_float * 0.5 + 0.5, 0., 1.) * 255.).to(dtype=torch.uint8)

dataset_trn = WikiartEdgesTrain(256, "datasets/wikiart_train.txt")
dataset_gen = CustomDataset(root='datasets/wikiart_generated_256')
dataset_val = WikiartEdgesTest(256, "datasets/wikiart_val.txt")

gen_loader = torch.utils.data.DataLoader(dataset_gen, batch_size=4, num_workers=16)
trn_loader = torch.utils.data.DataLoader(dataset_trn, batch_size=4, num_workers=16)
val_loader = torch.utils.data.DataLoader(dataset_val, batch_size=4, num_workers=16)

In [8]:
# FID

fid_module = FID(feature=2048).to('cuda')

for batch in tqdm(trn_loader):
    imgs = batch['image'].permute(0, 3, 1, 2)
    imgs = convert_to_uint8(imgs).cuda()
    fid_module.update(imgs, real=True)

for batch in tqdm(gen_loader):
    imgs = batch['image'].permute(0, 3, 1, 2)
    imgs = convert_to_uint8(imgs).cuda()
    fid_module.update(imgs, real=False)

fid = fid_module.compute().item()
print(f'FID: {fid:.6f}')

100%|██████████| 587/587 [00:34<00:00, 17.05it/s]
100%|██████████| 35/35 [00:01<00:00, 24.27it/s]


FID: 194.244492


In [4]:
gen_loader = torch.utils.data.DataLoader(dataset_gen, batch_size=1, num_workers=16)
val_loader = torch.utils.data.DataLoader(dataset_val, batch_size=1, num_workers=16)

# LPIPS
lpips = LPIPS(net_type='vgg')
lpips_val = []
for batch1, batch2 in tqdm(zip(gen_loader, val_loader)):
    imgs = batch1['image'].permute(0, 3, 1, 2)
    imgs2 = batch2['image'].permute(0, 3, 1, 2)
    temp = lpips(imgs, imgs2).detach().cpu().numpy()
    lpips_val = np.append(lpips_val, temp)
    
print(f'LPIPS: {np.mean(lpips_val):.6f}')

138it [01:24,  1.63it/s]


LPIPS: 0.802710
