In [1]:
import os
import sys
import json
import torch
import numpy as np
import pandas as pd
from PIL import Image
from torchmetrics.image.fid import FrechetInceptionDistance as FID

In [2]:
def coco_dataset(data_path, split, num_images=1000):
    with open(os.path.join(data_path, f'annotations/captions_{split}2014.json')) as f:
        data = json.load(f)
    data = data['annotations']
    # select 30k images randomly
    np.random.seed(0)
    np.random.shuffle(data)
    data = data[:num_images]
    imgs = [os.path.join(data_path, f'{split}2014', 'COCO_' + split + '2014_' + str(ann['image_id']).zfill(12) + '.jpg') for ann in data]
    anns = [ann['caption'] for ann in data]
    return imgs, anns

In [3]:
# calculate FID between original images and generated images
imgs, anns = coco_dataset('../../COCO-vqa', 'val', 30000)
print("Evaluating on COCO dataset", len(imgs), len(anns))

Evaluating on COCO dataset 30000 30000


In [8]:
# read the images and generated images
concepts_to_remove = 'Monet'
fine_tuned_unet = 'union-timesteps'
dataset_type = 'coco'
output_path = f'../benchmarking results/{fine_tuned_unet}/{dataset_type}/{concepts_to_remove}'
# read the generated images
gen_imgs = []
orig_imgs = []

for i in range(len(imgs)):
    if i >= 10000:
        break
    print(i)
    img = Image.open(imgs[i]).convert('RGB')
    img = img.resize((512, 512))
    orig_imgs.append(np.array(img))

    img = Image.open(f'{output_path}/removed_{i}.png')
    gen_imgs.append(np.array(img))

# calculate FID
# stack the images
orig_imgs = np.stack(orig_imgs)
gen_imgs = np.stack(gen_imgs)
# convert to tensor
orig_imgs = torch.tensor(orig_imgs).permute(0, 3, 1, 2)
gen_imgs = torch.tensor(gen_imgs).permute(0, 3, 1, 2)
print(orig_imgs.shape, gen_imgs.shape)
print("Calculating FID")
# calculate FID
fid = FID(normalize=True)
fid.update(orig_imgs, real=True)
fid.update(gen_imgs, real=False)
fid = fid.compute()

print("FID:", fid)

torch.Size([10000, 3, 512, 512]) torch.Size([10000, 3, 512, 512])
Calculating FID
FID: tensor(32.2025)
