# Usage
put this file into https://github.com/clovaai/stargan-v2 to use its evaluation functions

In [1]:
import os

In [2]:
from metrics.fid import calculate_fid_given_paths, InceptionV3, frechet_distance
from metrics.lpips import calculate_lpips_given_images
from core.data_loader import get_eval_loader
from core import utils
from collections import OrderedDict
import torch
from tqdm import tqdm
import numpy as np
import sys 
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import shutil

In [None]:
def load_image(filename):
    transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]),
    ])
    
    img = Image.open(filename)
    img = transform(img)
    return img.unsqueeze(dim=0)   

In [3]:
@torch.no_grad()
def calculate_fid_given_paths(paths, img_size=256, batch_size=50, real_loader=None, real_mu=None, real_cov=None):
    print('Calculating FID given paths %s and %s...' % (paths[0], paths[1]))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    inception = InceptionV3().eval().to(device)
    loaders = [get_eval_loader(path, img_size, batch_size) for path in paths]
    
    if real_loader is None:
        mu, cov = [], []
        for loader in loaders:
            actvs = []
            for x in tqdm(loader, total=len(loader)):
                actv = inception(x.to(device))
                actvs.append(actv)
            actvs = torch.cat(actvs, dim=0).cpu().detach().numpy()
            mu.append(np.mean(actvs, axis=0))
            cov.append(np.cov(actvs, rowvar=False))
        real_loader = loaders[0]
        real_mu = mu[0]
        real_cov = cov[0]
    else:
        mu, cov = [real_mu], [real_cov]
        loader = loaders[1]
        actvs = []
        for x in tqdm(loader, total=len(loader)):
            actv = inception(x.to(device))
            actvs.append(actv)
        actvs = torch.cat(actvs, dim=0).cpu().detach().numpy()
        mu.append(np.mean(actvs, axis=0))
        cov.append(np.cov(actvs, rowvar=False))        
    fid_value = frechet_distance(mu[0], cov[0], mu[1], cov[1])
    return fid_value, real_loader, real_mu, real_cov

In [4]:
def calculate_fid_for_all_tasks(path_real, path_fakes, modes, task, val_batch_size=50, save_path='./'):
    print('Calculating FID for all tasks...')
    fid_values = OrderedDict()
    real_loader, real_mu, real_cov = None, None, None
    for ii, path_fake in enumerate(path_fakes):
        print('Calculating FID for %s...' % (task+modes[ii]))
        fid_value, real_loader, real_mu, real_cov = calculate_fid_given_paths(
            paths=[path_real, path_fake+task],
            img_size=256,
            batch_size=val_batch_size, real_loader = real_loader, real_mu = real_mu, real_cov = real_cov)
        fid_values['FID_%s' % (task+modes[ii])] = fid_value
        print('FID for %s is %.3f' % (task+modes[ii], fid_value))
    # report FID values
    filename = os.path.join(save_path, 'FID_%s.json' % (task))
    utils.save_json(fid_values, filename)

# FID

In [5]:
tasks = ['male2female','female2male','dog2cat','cat2dog','cat2face','face2cat','dog2bird','bird2dog','car2bird','bird2car']
path_reals = ['../../data/celeba_hq/train/female/',
              '../../data/celeba_hq/train/male/',
              '../../data/afhq/images512x512/train/cat/',
              '../../data/afhq/images512x512/train/dog/',
              '../../data/linkdataset_for_starganv2/face2cat/train/1001_face/',
              '../../data/afhq/images512x512/train/cat/',
              '../../data/linkdataset_for_starganv2/birds/train/',
              '../../data/linkdataset_for_starganv2/dogs/train/',
              '../../data/linkdataset_for_starganv2/birds/train/',
              '../../data/linkdataset_for_starganv2/cars/train/']
path_fakes = ['../I2I/comparison/GP-UNIT/','../I2I/comparison/munit/','../I2I/comparison/stargan/',
              '../I2I/comparison/cocofunit/','../I2I/comparison/travelgan/']
modes = ['_GPUNIT','_MUNIT','_StarGAN2','_COCOFUNIT','_TraVeLGAN']

### folder structure:
- GP-UNIT/bird2car/
    - 0000_0.jpg - 0000_9.jpg    # results of test image #1 on 10 random styles
    - 0001_0.jpg - 0001_9.jpg
    - ...
    - 0199_0.jpg - 0199_9.jpg
- GP-UNIT/car2bird/
    - 0000_0.jpg - 0000_9.jpg   
    - 0001_0.jpg - 0001_9.jpg
    - ...
    - 0199_0.jpg - 0199_9.jpg  
- ...
- GP-UNIT/male2female/
    - 0000_0.jpg - 0000_9.jpg    
    - 0001_0.jpg - 0001_9.jpg
    - ...
    - 0999_0.jpg - 0999_9.jpg  

In [6]:
for task in tasks:
    print('='*50)
    print(task)
    for i, path in enumerate(path_fakes):
        files = os.listdir(path+task)
        files.sort()
        print('%s: %04d, %s'%(modes[i], len(files), files[0]))

male2female
_GPUNIT: 10000, 0000_0.jpg
_MUNIT: 10000, 0000_0.jpg
_StarGAN2: 10000, 0000_0.jpg
_COCOFUNIT: 10000, 0000_0.jpg
_TraVeLGAN: 1000, 0000_0.jpg
female2male
_GPUNIT: 10000, 0000_0.jpg
_MUNIT: 10000, 0000_0.jpg
_StarGAN2: 10000, 0000_0.jpg
_COCOFUNIT: 10000, 0000_0.jpg
_TraVeLGAN: 1000, 0000_0.jpg
dog2cat
_GPUNIT: 5000, 0000_0.jpg
_MUNIT: 5000, 0000_0.jpg
_StarGAN2: 5000, 0000_0.jpg
_COCOFUNIT: 5000, 0000_0.jpg
_TraVeLGAN: 0500, 0000_0.jpg
cat2dog
_GPUNIT: 5000, 0000_0.jpg
_MUNIT: 5000, 0000_0.jpg
_StarGAN2: 5000, 0000_0.jpg
_COCOFUNIT: 5000, 0000_0.jpg
_TraVeLGAN: 0500, 0000_0.jpg
cat2face
_GPUNIT: 5000, 0000_0.jpg
_MUNIT: 5000, 0000_0.jpg
_StarGAN2: 5000, 0000_0.jpg
_COCOFUNIT: 5000, 0000_0.jpg
_TraVeLGAN: 0500, 0000_0.jpg
face2cat
_GPUNIT: 10000, 0000_0.jpg
_MUNIT: 10000, 0000_0.jpg
_StarGAN2: 10000, 0000_0.jpg
_COCOFUNIT: 10000, 0000_0.jpg
_TraVeLGAN: 1000, 0000_0.jpg
dog2bird
_GPUNIT: 2000, 0000_0.jpg
_MUNIT: 2000, 0000_0.jpg
_StarGAN2: 2000, 0000_0.jpg
_COCOFUNIT: 2000, 00

In [46]:
tasks = ['male2female_Waug','female2male_Waug','dog2cat_Waug','cat2dog_Waug','cat2face_Waug','face2cat_Waug',
         'dog2bird','bird2dog','car2bird','bird2car']
path_reals = ['../../data/celeba_hq/train/female/',
              '../../data/celeba_hq/train/male/',
              '../../data/afhq/images512x512/train/cat/',
              '../../data/afhq/images512x512/train/dog/',
              '../../data/linkdataset_for_starganv2/face2cat/train/1001_face/',
              '../../data/afhq/images512x512/train/cat/',
              '../../data/linkdataset_for_starganv2/birds/train/',
              '../../data/linkdataset_for_starganv2/dogs/train/',
              '../../data/linkdataset_for_starganv2/birds/train/',
              '../../data/linkdataset_for_starganv2/cars/train/']
path_fakes = ['../I2I/comparison/GP-UNIT/']
modes = ['_GPUNIT']

In [47]:
for ii, task in enumerate(tasks):
    calculate_fid_for_all_tasks(path_reals[ii], path_fakes, modes, task, save_path='../I2I/result/FID/')

Calculating FID for all tasks...
Calculating FID for male2female_Waug_GPUNIT...
Calculating FID given paths ../../data/celeba_hq/train/female/ and ../I2I/comparison/GP-UNIT/male2female_Waug...
Preparing DataLoader for the evaluation phase...
Preparing DataLoader for the evaluation phase...


100%|██████████| 359/359 [02:38<00:00,  2.26it/s]
100%|██████████| 200/200 [00:29<00:00,  6.70it/s]


FID for male2female_Waug_GPUNIT is 12.602
Calculating FID for all tasks...
Calculating FID for female2male_Waug_GPUNIT...
Calculating FID given paths ../../data/celeba_hq/train/male/ and ../I2I/comparison/GP-UNIT/female2male_Waug...
Preparing DataLoader for the evaluation phase...
Preparing DataLoader for the evaluation phase...


100%|██████████| 202/202 [01:28<00:00,  2.28it/s]
100%|██████████| 200/200 [00:29<00:00,  6.78it/s]


FID for female2male_Waug_GPUNIT is 16.664
Calculating FID for all tasks...
Calculating FID for dog2cat_Waug_GPUNIT...
Calculating FID given paths ../../data/afhq/images512x512/train/cat/ and ../I2I/comparison/GP-UNIT/dog2cat_Waug...
Preparing DataLoader for the evaluation phase...
Preparing DataLoader for the evaluation phase...


100%|██████████| 104/104 [00:23<00:00,  4.47it/s]
100%|██████████| 100/100 [00:15<00:00,  6.63it/s]


FID for dog2cat_Waug_GPUNIT is 8.504
Calculating FID for all tasks...
Calculating FID for cat2dog_Waug_GPUNIT...
Calculating FID given paths ../../data/afhq/images512x512/train/dog/ and ../I2I/comparison/GP-UNIT/cat2dog_Waug...
Preparing DataLoader for the evaluation phase...
Preparing DataLoader for the evaluation phase...


100%|██████████| 95/95 [00:21<00:00,  4.48it/s]
100%|██████████| 100/100 [00:15<00:00,  6.52it/s]


FID for cat2dog_Waug_GPUNIT is 22.077
Calculating FID for all tasks...
Calculating FID for cat2face_Waug_GPUNIT...
Calculating FID given paths ../../data/linkdataset_for_starganv2/face2cat/train/1001_face/ and ../I2I/comparison/GP-UNIT/cat2face_Waug...
Preparing DataLoader for the evaluation phase...


  0%|          | 0/580 [00:00<?, ?it/s]

Preparing DataLoader for the evaluation phase...


100%|██████████| 580/580 [01:30<00:00,  6.39it/s]
100%|██████████| 100/100 [00:14<00:00,  6.67it/s]


FID for cat2face_Waug_GPUNIT is 16.878
Calculating FID for all tasks...
Calculating FID for face2cat_Waug_GPUNIT...
Calculating FID given paths ../../data/afhq/images512x512/train/cat/ and ../I2I/comparison/GP-UNIT/face2cat_Waug...
Preparing DataLoader for the evaluation phase...
Preparing DataLoader for the evaluation phase...


100%|██████████| 104/104 [00:23<00:00,  4.38it/s]
100%|██████████| 200/200 [00:30<00:00,  6.50it/s]


FID for face2cat_Waug_GPUNIT is 9.200
Calculating FID for all tasks...
Calculating FID for dog2bird_GPUNIT...
Calculating FID given paths ../../data/linkdataset_for_starganv2/birds/train/ and ../I2I/comparison/GP-UNIT/dog2bird...


  0%|          | 0/48 [00:00<?, ?it/s]

Preparing DataLoader for the evaluation phase...
Preparing DataLoader for the evaluation phase...


100%|██████████| 48/48 [00:06<00:00,  7.31it/s]
100%|██████████| 40/40 [00:06<00:00,  6.32it/s]


FID for dog2bird_GPUNIT is 6.223
Calculating FID for all tasks...
Calculating FID for bird2dog_GPUNIT...
Calculating FID given paths ../../data/linkdataset_for_starganv2/dogs/train/ and ../I2I/comparison/GP-UNIT/bird2dog...


  0%|          | 0/48 [00:00<?, ?it/s]

Preparing DataLoader for the evaluation phase...
Preparing DataLoader for the evaluation phase...


100%|██████████| 48/48 [00:06<00:00,  7.30it/s]
100%|██████████| 40/40 [00:06<00:00,  6.42it/s]


FID for bird2dog_GPUNIT is 16.347
Calculating FID for all tasks...
Calculating FID for car2bird_GPUNIT...
Calculating FID given paths ../../data/linkdataset_for_starganv2/birds/train/ and ../I2I/comparison/GP-UNIT/car2bird...


  0%|          | 0/48 [00:00<?, ?it/s]

Preparing DataLoader for the evaluation phase...
Preparing DataLoader for the evaluation phase...


100%|██████████| 48/48 [00:05<00:00,  8.25it/s]
100%|██████████| 40/40 [00:06<00:00,  6.53it/s]


FID for car2bird_GPUNIT is 6.518
Calculating FID for all tasks...
Calculating FID for bird2car_GPUNIT...
Calculating FID given paths ../../data/linkdataset_for_starganv2/cars/train/ and ../I2I/comparison/GP-UNIT/bird2car...


  0%|          | 0/48 [00:00<?, ?it/s]

Preparing DataLoader for the evaluation phase...
Preparing DataLoader for the evaluation phase...


100%|██████████| 48/48 [00:06<00:00,  7.37it/s]
100%|██████████| 40/40 [00:06<00:00,  6.56it/s]


FID for bird2car_GPUNIT is 21.350


# LPIPS

In [5]:
tasks = ['male2female_Waug','female2male_Waug','dog2cat_Waug','cat2dog_Waug','cat2face_Waug','face2cat_Waug',
         'dog2bird','bird2dog','car2bird','bird2car']
path_fakes = ['../I2I/comparison/GP-UNIT/']
modes = ['_GPUNIT']

In [6]:
lpips_dict = OrderedDict()
save_path='../I2I/result/LPIPS/'
for task in tasks:
    print('='*50)
    print(task)
    for ii, path_fake in enumerate(path_fakes):
        print('Calculating LPIPS for %s...' % (task+modes[ii]))
        # this loader will load 10 images sequentially, i.e. 10 results from the same input image
        loader = get_eval_loader(path_fake+task, img_size=256, batch_size=10,
                            imagenet_normalize=False, shuffle=False)        
        lpips_values = []
        for x in tqdm(loader, total=len(loader)):
            lpips_value = calculate_lpips_given_images(x.cuda())
            lpips_values.append(lpips_value)
        lpips_mean = np.array(lpips_values).mean()
        lpips_dict['LPIPS_%s' % (task+modes[ii])] = lpips_mean
        print('LPIPS for %s is %.3f' % (task+modes[ii], lpips_mean))
        
    # report LPIPS values
    filename = os.path.join(save_path, 'LPIPS_%s.json' % (task))
    utils.save_json(lpips_dict, filename)

male2female_Waug
Calculating LPIPS for male2female_Waug_GPUNIT...
Preparing DataLoader for the evaluation phase...


100%|██████████| 1000/1000 [18:39<00:00,  1.12s/it]


LPIPS for male2female_Waug_GPUNIT is 0.355
female2male_Waug
Calculating LPIPS for female2male_Waug_GPUNIT...
Preparing DataLoader for the evaluation phase...


100%|██████████| 1000/1000 [18:46<00:00,  1.13s/it]
  0%|          | 0/500 [00:00<?, ?it/s]

LPIPS for female2male_Waug_GPUNIT is 0.393
dog2cat_Waug
Calculating LPIPS for dog2cat_Waug_GPUNIT...
Preparing DataLoader for the evaluation phase...


100%|██████████| 500/500 [09:26<00:00,  1.13s/it]


LPIPS for dog2cat_Waug_GPUNIT is 0.489
cat2dog_Waug
Calculating LPIPS for cat2dog_Waug_GPUNIT...
Preparing DataLoader for the evaluation phase...


100%|██████████| 500/500 [09:26<00:00,  1.13s/it]
  0%|          | 0/500 [00:00<?, ?it/s]

LPIPS for cat2dog_Waug_GPUNIT is 0.533
cat2face_Waug
Calculating LPIPS for cat2face_Waug_GPUNIT...
Preparing DataLoader for the evaluation phase...


100%|██████████| 500/500 [09:31<00:00,  1.14s/it]


LPIPS for cat2face_Waug_GPUNIT is 0.455
face2cat_Waug
Calculating LPIPS for face2cat_Waug_GPUNIT...
Preparing DataLoader for the evaluation phase...


100%|██████████| 1000/1000 [18:54<00:00,  1.13s/it]
  0%|          | 0/200 [00:00<?, ?it/s]

LPIPS for face2cat_Waug_GPUNIT is 0.521
dog2bird
Calculating LPIPS for dog2bird_GPUNIT...
Preparing DataLoader for the evaluation phase...


100%|██████████| 200/200 [03:46<00:00,  1.13s/it]
  0%|          | 0/200 [00:00<?, ?it/s]

LPIPS for dog2bird_GPUNIT is 0.625
bird2dog
Calculating LPIPS for bird2dog_GPUNIT...
Preparing DataLoader for the evaluation phase...


100%|██████████| 200/200 [03:45<00:00,  1.13s/it]
  0%|          | 0/200 [00:00<?, ?it/s]

LPIPS for bird2dog_GPUNIT is 0.580
car2bird
Calculating LPIPS for car2bird_GPUNIT...
Preparing DataLoader for the evaluation phase...


100%|██████████| 200/200 [03:48<00:00,  1.14s/it]
  0%|          | 0/200 [00:00<?, ?it/s]

LPIPS for car2bird_GPUNIT is 0.627
bird2car
Calculating LPIPS for bird2car_GPUNIT...
Preparing DataLoader for the evaluation phase...


100%|██████████| 200/200 [03:46<00:00,  1.13s/it]

LPIPS for bird2car_GPUNIT is 0.587



