# Code for running metrics on the generated images

## Setup

In [None]:
import torch
import numpy as np
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms.functional import normalize
from torchvision.models.inception import inception_v3
from torchmetrics.image.fid import FrechetInceptionDistance
from scipy.stats import entropy
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm
import time
from PIL import Image

In [None]:
data = torch.load('./dataset/image_caption_dataset_with_generated_images.pt', weights_only=False)

data[0]

Loading done in batches to prevent memory flow issues.

In [None]:
for i in tqdm(range(9), desc="Progress"):
    time.sleep(0.1)
    data[i]['image'] = np.array(data[i]['image'][:, 70:290, 115:335].permute(1, 2, 0))
    data[i]['image'] = cv2.resize(data[i]['image'], (4096, 4096), interpolation=cv2.INTER_LANCZOS4)
    tqdm.write(f"Done: {i+1}, Left: {len(data) - (i+1)}")

In [None]:
for i in tqdm(range(9), desc="Progress"):
    time.sleep(0.1)
    data[i+9]['image'] = np.array(data[i+9]['image'][:, 70:290, 115:335].permute(1, 2, 0))
    data[i+9]['image'] = cv2.resize(data[i+9]['image'], (4096, 4096), interpolation=cv2.INTER_LANCZOS4)
    tqdm.write(f"Done: {i+1}, Left: {len(data) - (i+1)}")

In [None]:
for i in tqdm(range(9), desc="Progress"):
    time.sleep(0.1)
    data[i+18]['image'] = np.array(data[i+18]['image'][:, 70:290, 115:335].permute(1, 2, 0))
    data[i+18]['image'] = cv2.resize(data[i+18]['image'], (4096, 4096), interpolation=cv2.INTER_LANCZOS4)
    tqdm.write(f"Done: {i+1}, Left: {len(data) - (i+1)}")

In [None]:
print(len(data))
plt.imshow(data[26]['image']) # 3, 220, 450
data[0]['image'].shape

In [None]:
gen_image = np.array(data[26]['generated_image'])

plt.imshow(gen_image) # 3, 220, 450
gen_image.shape

In [None]:
def load_dataset(file_path, is_gen):
    data = torch.load(file_path)
    if is_gen == False:
        images = data['image']
    else:
        images = data['generated_image']
    return images

In [None]:
def preprocess_images(images, is_tensor = False, image_size=299):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    if is_tensor:
        images = [Image.fromarray(img.numpy()) for img in images]
    print(type(transform(images[0])))
    print(transform(images[0]).shape)
    ans = torch.stack([transform(img) for img in images])
    print(type(ans))
    print(ans.shape)
    return torch.stack([transform(img) for img in images])

## Code for Inception Score

In [None]:
def compute_inception_score(images, batch_size=8, splits=10, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    inception = inception_v3(pretrained=True, transform_input=False).to(device)
    inception.eval()

    def get_pred(x):
        with torch.no_grad():
            x = x.to(torch.float32) / 255.0

            x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)

            x = normalize(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            x = x.to(device)
            preds = inception(x).softmax(dim=1)
        return preds.cpu().numpy()

    loader = DataLoader(TensorDataset(images), batch_size=batch_size, shuffle=False)

    preds = np.concatenate([get_pred(batch[0]) for batch in loader], axis=0)

    split_scores = []
    for chunk in np.array_split(preds, splits):
        p_y = np.mean(chunk, axis=0)
        scores = [entropy(p_y, p) for p in chunk]
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)

## Code for FID score

In [None]:
def compute_fid(real_images, fake_images):
    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = 'cpu'
    fid = FrechetInceptionDistance(feature=2048).to(device)

    fid.update(real_images.to(device), real=True)
    fid.update(fake_images.to(device), real=False)

    return fid.compute().item()

In [None]:
ground_truth_images = torch.tensor(np.array([(data[i]['image'] * 255).clip(0, 255).astype('uint8') for i in range(len(data))]), dtype=torch.uint8).permute(0, 3, 1, 2)
ground_truth_images.dtype, ground_truth_images.shape, ground_truth_images[0].shape, ground_truth_images[0].dtype

In [None]:
generated_images = torch.tensor(np.array([np.array(data[i]['generated_image']) for i in range(len(data))]), dtype=torch.uint8).permute(0, 3, 1, 2)
generated_images.dtype, generated_images.shape, generated_images[0].shape, generated_images[0].dtype

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
fid_value = compute_fid(ground_truth_images, generated_images)

print(f"FID Score: {fid_value}")

In [None]:
inception_mean, inception_std = compute_inception_score(generated_images)

print(f"Inception Score: {inception_mean} ± {inception_std}")