# <center>Calculate FID distance using samples generated with `Linear` scheduler</center>

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.model_zoo import load_url as load_state_dict_from_url

import torchvision
from torchvision import models, transforms

import os
from PIL import Image
import numpy as np
from scipy import linalg
import pathlib
from tqdm import tqdm

In [2]:
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'

In [3]:
class InceptionV3(nn.Module):
    '''Pre-trained InceptionV3 network returning feature maps'''
    
    DEFAULT_BLOCK_INDEX = 3
    BLOCK_INDEX_BY_DIM = {64: 0, 192: 1, 768: 2, 2048: 3}
    
    def __init__(self, output_blocks=(DEFAULT_BLOCK_INDEX,), resize_input=True, normalize_input=True, requires_grad=False,
                 use_fid_inception=True):
        super(InceptionV3, self).__init__()
        
        self.resize_input = resize_input
        self.normalize_input = normalize_input
        self.output_blocks = sorted(output_blocks)
        self.last_needed_block = max(output_blocks)
        
        assert self.last_needed_block <= 3, 'Last possible output block index is 3'
        self.blocks = nn.ModuleList()
        
        if use_fid_inception:
            inception = fid_inception_v3()
        else:
            inception = _inception_v3(weights='DEFAULT')
            
        block0 = [inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3,
                  nn.MaxPool2d(kernel_size=3, stride=2)]
        self.blocks.append(nn.Sequential(*block0))
        
        if self.last_needed_block >= 1:
            block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)]
            self.blocks.append(nn.Sequential(*block1))
        
        if self.last_needed_block >= 2:
            block2 = [inception.Mixed_5b, inception.Mixed_5c, inception.Mixed_5d, inception.Mixed_6a, inception.Mixed_6b,
                      inception.Mixed_6c, inception.Mixed_6d, inception.Mixed_6e]
            self.blocks.append(nn.Sequential(*block2))
            
        if self.last_needed_block >= 3:
            block3 = [inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, nn.AdaptiveAvgPool2d(output_size=(1, 1))]
            self.blocks.append(nn.Sequential(*block3))
            
        for param in self.parameters():
            param.requires_grad = requires_grad
            
    def forward(self, input_):
        output = []
        x = input_
        
        if self.resize_input:
            x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
        if self.normalize_input:
            x = 2 * x - 1
            
        for idx, block in enumerate(self.blocks):
            x = block(x)
            if idx in self.output_blocks:
                output.append(x)
            if idx == self.last_needed_block:
                break
        return output

In [4]:
def _inception_v3(*args, **kwargs):
    '''Wraps `torchvision.models.inception_v3`'''
    
    try:
        version = tuple(map(int, torchvision.__version__.split('.')[:2]))
    except ValueError:
        version = (0,)
        
    if version >= (0, 6):
        kwargs['init_weights'] = False
        
    if version < (0, 13) and 'weights' in kwargs:
        if kwargs['weights'] == 'DEFAULT':
            kwargs['pretrained'] = True
        elif kwargs['weights'] is None:
            kwargs['pretrained'] = False
        else:
            raise ValueError('weights=={} not supported in torchvision {}'.format(kwargs['weights'], torchvision.__version__))
        del kwargs['weights']
        
    return torchvision.models.inception_v3(*args, **kwargs)

In [5]:
def fid_inception_v3():
    
    inception = _inception_v3(num_classes=1008, aux_logits=False, weights=None)
    
    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
    inception.Mixed_7b = FIDInceptionE_1(1280)
    inception.Mixed_7c = FIDInceptionE_2(2048)
    
    state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
    inception.load_state_dict(state_dict)
    
    return inception

In [6]:
class FIDInceptionA(models.inception.InceptionA):
    def __init__(self, in_channels, pool_features):
        super(FIDInceptionA, self).__init__(in_channels, pool_features)
        
    def forward(self, x):
        
        branch1x1 = self.branch1x1(x)
        
        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)
        
        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
        
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)
        
        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return torch.cat(outputs, dim=1)

In [7]:
class FIDInceptionC(models.inception.InceptionC):
    def __init__(self, in_channels, channels_7x7):
        super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
        
    def forward(self, x):
        
        branch1x1 = self.branch1x1(x)
        
        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)
        
        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
        
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)
        
        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return torch.cat(outputs, dim=1)

In [8]:
class FIDInceptionE_1(models.inception.InceptionE):
    def __init__(self, in_channels):
        super(FIDInceptionE_1, self).__init__(in_channels)
        
    def forward(self, x):
        
        branch1x1 = self.branch1x1(x)
        
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)]
        branch3x3 = torch.cat(branch3x3, dim=1)
        
        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl)]
        branch3x3dbl = torch.cat(branch3x3dbl, dim=1)
        
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)
        
        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, dim=1)

In [9]:
class FIDInceptionE_2(models.inception.InceptionE):
    def __init__(self, in_channels):
        super(FIDInceptionE_2, self).__init__(in_channels)
        
    def forward(self, x):
        
        branch1x1 = self.branch1x1(x)
        
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)]
        branch3x3 = torch.cat(branch3x3, dim=1)
        
        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl)]
        branch3x3dbl = torch.cat(branch3x3dbl, dim=1)
        
        branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)
        
        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, dim=1)

In [10]:
IMAGE_EXTENSIONS = {'jpg'}

In [11]:
class ImagePathDataset(Dataset):
    def __init__(self, files, transform=None):
        
        self.files = files
        self.transform = transform
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, i):
        path = self.files[i]
        img = Image.open(path).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img

In [12]:
def get_activations(files, model, batch_size, dims, device='cpu'):
    
    model.eval()
    
    if batch_size > len(files):
        batch_size = len(files)
        
    dataset = ImagePathDataset(files, transform=transforms.ToTensor())
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    pred_arr = np.empty((len(files), dims))
    start_idx = 0
    
    for batch in tqdm(data_loader):
        batch = batch.to(device)
        
        with torch.inference_mode():
            pred = model(batch)[0]
            
        if pred.size(2) != 1 or pred.size(3) != 1:
            pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))
        
        pred = pred.squeeze(3).squeeze(2).cpu().numpy()
        pred_arr[start_idx:start_idx+pred.shape[0]] = pred
        start_idx = start_idx + pred.shape[0]
        
    return pred_arr

In [13]:
def calculate_frechet_distance(mu1, mu2, sigma1, sigma2, eps=1e-6):
    
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)
    
    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)
    
    assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'
    
    diff = mu1 - mu2
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces sigular product; adding %s to diagonal cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
        
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real
        
    tr_covmean = np.trace(covmean)
    
    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean

In [14]:
def calculate_activation_statistics(files, model, batch_size, dims, device='cpu'):
    
    act = get_activations(files, model, batch_size, dims, device)
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    
    return mu, sigma

In [15]:
def compute_statistics_of_path(path, model, batch_size, dims, device='cpu'):
    
    if path.endswith('.npz'):
        with np.load(path) as f:
            mu, sigma = f['mu'][:], f['sigma'][:]
            
    else:
        path = pathlib.Path(path)
        files = sorted([file for ext in IMAGE_EXTENSIONS for file in path.glob('*.{}'.format(ext))])
        mu, sigma = calculate_activation_statistics(files, model, batch_size, dims, device)
        
    return mu, sigma

In [16]:
def calculate_fid_given_paths(path1, path2, batch_size, dims, device='cpu'):
    
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
    model = InceptionV3([block_idx]).to(device)
    
    mu1, sigma1 = compute_statistics_of_path(path1, model, batch_size, dims, device)
    mu2, sigma2 = compute_statistics_of_path(path2, model, batch_size, dims, device)
    
    fid_value = calculate_frechet_distance(mu1, mu2, sigma1, sigma2)
    return print('FID distance:', round(fid_value, 3))

In [17]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 50
dims = 2048

src_path = '../SRC'
gen_path = '../GEN_Linear'

print('Total images in src_path:', len(next(os.walk(src_path))[2]))
print('Total images in gen_path:', len(next(os.walk(gen_path))[2]))

calculate_fid_given_paths(path1=src_path, path2=gen_path, batch_size=batch_size, dims=dims, device=device)

Total images in src_path: 11000
Total images in gen_path: 11000


100%|█████████████████████████████████████████████████████████████████████████████████| 220/220 [00:30<00:00,  7.13it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 220/220 [00:30<00:00,  7.22it/s]


FID distance: 36.833
