# <center style='color:deeppink'>Calculate `Inception Score (IS)` using PyTorch</center>

# 1. Import required libraries

In [1]:
import torch
print('PyTorch version:', torch.__version__)
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import torchvision
print('Torchvision version:', torchvision.__version__)
from torchvision import datasets, transforms
from torchvision import models
from torchvision.models.inception import inception_v3

import os
import pathlib
from PIL import Image 
import numpy as np
from scipy.stats import entropy

PyTorch version: 2.4.1+cu121
Torchvision version: 0.19.1+cu121


# 2. Define `Inception` function

In [2]:
def calculate_inception_score(imgs, batch_size, resize, splits):
    
    N = len(imgs)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dataloader = DataLoader(imgs, batch_size=batch_size, shuffle=False)
    
    inception_model = inception_v3(weights=models.Inception_V3_Weights.DEFAULT, transform_input=False).to(device)
    inception_model.eval()
    up = nn.Upsample(size=(299, 299), mode='bilinear').to(device)
    
    def get_pred(x):
        if resize:
            x = up(x)
        x = inception_model(x)
        return F.softmax(x, dim=1).detach().cpu().numpy()
    
    # get predictions
    preds = np.zeros((N, 1000))
    
    for i, j in enumerate(dataloader):
        batch = j.to(device)
        batch_size_i = batch.size()[0]
        preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batch)
        
    # compute the mean kl-div
    split_scores = []
    
    for k in range(splits):
        part = preds[k * (N // splits): (k+1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))
        split_scores.append(np.exp(np.mean(scores)))
        
    return np.mean(split_scores)

# 3. Define `ImagePathDataset` class

In [3]:
class ImagePathDataset(Dataset):
    def __init__(self, files, transform=None):
        
        self.files = files
        self.transform = transform
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, i):
        path = self.files[i]
        img = Image.open(path).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img

# 4. Calculate `Inception Score`

In [4]:
IMAGE_EXTENSIONS = {'jpg'}

In [5]:
%%time

gen_imgs_dir = os.getcwd() + '/gen'
print('Total images in gen_imgs_dir:', len(next(os.walk(gen_imgs_dir))[2]))

path = pathlib.Path(gen_imgs_dir)
files = sorted([file for ext in IMAGE_EXTENSIONS for file in path.glob('*.{}'.format(ext))])

dataset = ImagePathDataset(files, transform=transforms.ToTensor())
print("Inception Score (IS):", round(calculate_inception_score(dataset, batch_size=50, resize=True, splits=10), 3))

Total images in gen_imgs_dir: 60000
Inception Score (IS): 2.418
CPU times: user 15min 26s, sys: 2.79 s, total: 15min 28s
Wall time: 3min 56s
