# FID with GAN generation

In [1]:
from __future__ import print_function

import argparse
import math
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
from torchfusion.gan.applications import DCGANDiscriminator

from data_loader import MarioDataset
from models.custom import Generator

import csv

from image_gen.asset_map import get_asset_map
from image_gen.fixer import PipeFixer
from image_gen.image_gen import GameImageGenerator
from tqdm import tqdm

from get_level import GetLevel as getLevel
from scipy.linalg import sqrtm

  from .autonotebook import tqdm as notebook_tqdm


### Functions for FID

In [2]:
torch.cuda.is_available()

True

In [3]:
# Function to preprocess the matrices
def preprocess_matrices(matrices):
    # Normalize values to the range [0, 255]
    normalized_matrices = (matrices - np.min(matrices)) / (np.max(matrices) - np.min(matrices))
    normalized_matrices = normalized_matrices * 255
    return normalized_matrices.astype(np.uint8)

# Function to compute mean and covariance of features
def compute_statistics(matrices):
    # Flatten matrices into vectors
    flattened_matrices = matrices.reshape((matrices.shape[0], -1))
    # Compute mean and covariance
    mean = np.mean(flattened_matrices, axis=0)
    covariance = np.cov(flattened_matrices, rowvar=False)

    return mean, covariance

# Function to compute Fréchet distance
def compute_frechet_distance(real_mean, real_cov, generated_mean, generated_cov):
    epsilon = 1e-6  # Small constant to avoid numerical instability
    sqrt_cov_product = sqrtm(real_cov.dot(generated_cov))
    fid_score = np.linalg.norm(real_mean - generated_mean) + np.trace(real_cov + generated_cov - 2 * sqrt_cov_product)

    return fid_score

## Batches for Real samples

In [4]:
org_data = MarioDataset()
ref_idx = torch.randperm(len(org_data))
prev_frame, curr_frame = (org_data[:].prev_frame, org_data[:].curr_frame)
complete_frame = torch.cat((prev_frame,curr_frame),dim=3)
#convert one-hot encoding back to 2D matrices
complete_frame = torch.argmax(complete_frame, dim = 1) 

In [5]:
# Preprocess real data
complete_frame_np = complete_frame.detach().numpy()
real_matrices = preprocess_matrices(complete_frame_np)
# Compute statistics for real matrices
real_mean, real_cov = compute_statistics(real_matrices)


### Batches for generated data

In [40]:
conditional_channels = [0,1,6,7]
def Gen_sample(conditional_channels,ini_data=120):
    dataset = MarioDataset()
    netG = Generator(
            latent_size=(len(conditional_channels) + 1, 14, 14), out_size=(13, 32, 32)
        )
    netG.load_state_dict(torch.load("./trained_models/netG_epoch_300000_0_32.pth"))
        # 300000
    mario_map = get_asset_map(game="mario")
    gen = GameImageGenerator(asset_map=mario_map)
    prev_frame, curr_frame = dataset[[ini_data]]
    fixer = PipeFixer()

    level_gen = getLevel(netG, gen, fixer, prev_frame, curr_frame, conditional_channels)
    var = 1
    #noise = np.rand((1, 1, 14, 14)).normal_(0, var)
    noise = np.random.normal(0,var,size=(14,14))
    level = level_gen.generate_frames(noise, var=var, frame_count=1) # generated matrix without padded
    # convert to onehot encoding
    np.set_printoptions(threshold=np.inf)
    # onehot = np.eye(13, dtype="uint8")[level]  # create a one hot mapping for the features
    # onehot = np.rollaxis(onehot, 2, 0)  # (num_samples, chann.=13, h=14, w=28)
    # padded = np.full((1, onehot.shape[0], 32, 32), 0.0)
    # padded[:, :, 9:-9, 2:-2] = onehot
    # padded = torch.from_numpy(padded)
    # return padded

    padded = torch.zeros(32,32)
    padded[9:-9,2:-2] = torch.from_numpy(level)
    level = padded

    return level
    # this is just for visualization
    #level_gen.gen.save_gen_level(img_name="test_fuc_gen")

In [41]:
# number of generated samples 
n_gen = 100

# generate samples and stack together
#level_gen = torch.zeros(n_gen,13,32,32)
level_gen = torch.zeros(n_gen,32,32)
for i in range(n_gen):
    level = Gen_sample(conditional_channels,ini_data=120)
    level_gen[i,:,:] = level
    #level_gen[i,:,:,:] = level
    

In [42]:
# Preprocess generated data
generated_frame_np = level_gen.detach().numpy()
generated_matrices = preprocess_matrices(generated_frame_np)
# Compute statistics for real matrices
gen_mean, gen_cov = compute_statistics(generated_matrices)

In [43]:
# Compute Fréchet distance
fid_score = compute_frechet_distance(real_mean, real_cov, gen_mean, gen_cov)
print("FID score:", fid_score)

FID score: (151142.54590961675-0.0003873447056625816j)
