In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd drive/My \Drive/ML/

In [None]:
!pip install pytorch_fid
!pip install POT

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision.models import inception_v3
from pytorch_fid import fid_score
from scipy.stats import entropy
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms.functional import resize
import ot
from sklearn.metrics import precision_recall_fscore_support
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# VAE model
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.fc_mu = nn.Linear(4 * 4 * 64 * 4, latent_dim)
        self.fc_logvar = nn.Linear(4 * 4 * 64 * 4, latent_dim)

    def forward(self, x):
        x = self.main(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

# VAE Decoder
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 4 * 4 * 64 * 4)
        self.main = nn.Sequential(
            nn.ConvTranspose2d(64 * 4, 64 * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 64 * 4, 4, 4)
        x = self.main(x)
        return x

# VAE
latent_dim = 100
channels = 1
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_reconst = self.decoder(z)
        return x_reconst, mu, logvar

vae = VAE().to(device)
state_dict = torch.load("./Weights/MNIST_VAE.pth")
vae.load_state_dict(state_dict)
generator = vae.decoder

# Fréchet Inception Distance (FID)
print('Calculating FID')
def generate_fake_images(num_images):
    noise = torch.randn(num_images, latent_dim, device=device)
    fake_images = generator(noise)
    # Convert grayscale images to RGB by duplicating the single channel three times
    fake_images_rgb = fake_images.repeat(1, 3, 1, 1)
    return fake_images_rgb

def extract_real_images(data_loader, num_images, save_dir="./data/MNIST/real_images"):
    os.makedirs(save_dir, exist_ok=True)

    real_images = []
    count = 0
    for images, _ in data_loader:
        for image in images:
            if count < num_images:
                torchvision.utils.save_image(image, f"{save_dir}/image_{count}.png")
                real_images.append(image)
                count += 1
            else:
                break
        if count >= num_images:
            break

    return torch.stack(real_images)

# MNIST dataset
batch_size = 128
image_size = 28
channels = 1
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = DataLoader(mnist_data, batch_size=batch_size, shuffle=True)

print('Generating Data')
num_images = 1000
real_images = extract_real_images(data_loader, num_images)
fake_images = generate_fake_images(num_images)

# Save generated images
os.makedirs("./data/MNIST/fake_images", exist_ok=True)
for idx, img in enumerate(fake_images):
    torchvision.utils.save_image(img, f"./data/MNIST/fake_images/image_{idx}.png")

def calculate_fid(real_images_path, fake_images_path, batch_size=128):
    dims = 2048  # Set to 2048 for Inception v3
    fid = fid_score.calculate_fid_given_paths([real_images_path, fake_images_path], batch_size, device, dims)
    return fid

# paths for real and fake images
real_images_path = "./data/MNIST/real_images"
fake_images_path = "./data/MNIST/fake_images"

# Calculate FID
fid = calculate_fid(real_images_path, fake_images_path)
print("FID:", fid)

# EMD
print('Calculating EMD')
print('Generating Data')
num_images = 1000
real_images = extract_real_images(data_loader, num_images)
fake_images = generate_fake_images(num_images)
# Convert fake images back to grayscale
fake_images = fake_images[:, 0, :, :] * 0.299 + fake_images[:, 1, :, :] * 0.587 + fake_images[:, 2, :, :] * 0.114

def calculate_emd(real_images, fake_images):
    real_images = real_images.view(real_images.size(0), -1).detach().cpu().numpy()
    fake_images = fake_images.view(fake_images.size(0), -1).detach().cpu().numpy()
    cost_matrix = ot.dist(real_images, fake_images)
    emd = ot.emd2([], [], cost_matrix)
    return emd
emd = calculate_emd(real_images, fake_images)
print("EMD:", emd)

# PRF Metrics
print('Calculating PRF')
def extract_features(images, model, batch_size=32):
    model.eval()  # Set the model to evaluation mode
    images = images.to(device)  # Move images to the same device as the model
    
    # Convert grayscale images to RGB by duplicating the single channel three times
    images_rgb = images.repeat(1, 3, 1, 1)
    
    # Resize images to match the Inception model's input size
    resize_transform = transforms.Resize((299, 299))
    images_rgb_resized = torch.stack([resize_transform(img) for img in images_rgb])

    num_images = len(images_rgb_resized)
    num_batches = (num_images + batch_size - 1) // batch_size
    features_list = []

    with torch.no_grad():  # Disable gradient computation to save memory
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = min(start_idx + batch_size, num_images)
            batch_images = images_rgb_resized[start_idx:end_idx]
            batch_features = model(batch_images).squeeze()
            features_numpy = batch_features.cpu().numpy()
            features_list.append(features_numpy)

    return np.concatenate(features_list, axis=0)

# Load pre-trained Inception v3 model
inception = inception_v3(pretrained=True, transform_input=True).to(device)
inception.eval()

print('Generating Data')
fake_images = fake_images.reshape(-1,1,28,28)
real_features = extract_features(real_images, inception)
fake_features = extract_features(fake_images, inception)

# Train a classifier
X = np.concatenate((real_features, fake_features))
y = np.concatenate((np.ones(len(real_features)), np.zeros(len(fake_features))))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
clf = RandomForestClassifier(random_state=0, n_jobs=-1)  # n_jobs=-1 uses all available CPU cores
clf.fit(X_train, y_train)

# Predict probabilities
y_pred_proba = clf.predict_proba(X_test)

# Choose the probability of the positive class as the prediction score
y_pred_scores = y_pred_proba[:, 1]

# Calculate precision, recall, and F1-score
precision, recall, f1_score, _ = precision_recall_fscore_support(y_test, y_pred_scores > 0.5, average='binary')
print("Precision:", precision)
print("Recall:", recall)
print("F1-score:", f1_score)

# Inception Score (IS)
print('Calculating IS')
def generate_fake_images(num_images):
    noise = torch.randn(num_images, latent_dim, device=device)
    fake_images = generator(noise)
    # Convert grayscale images to RGB by duplicating the single channel three times
    fake_images_rgb = fake_images.repeat(1, 3, 1, 1)
    # Resize images to 299x299
    fake_images_rgb_resized = torch.zeros(num_images, 3, 299, 299, device=device)
    for idx, img in enumerate(fake_images_rgb):
        fake_images_rgb_resized[idx] = resize(img, (299, 299))
    return fake_images_rgb_resized

def inception_score(images, n_splits=10):
    # Load pre-trained Inception model
    inception_model = inception_v3(pretrained=True, transform_input=True).to(device)
    inception_model.eval()

    # Calculate inception score
    scores = []
    n_total = images.shape[0]
    chunk_size = n_total // n_splits
    for k in range(n_splits):
        images_chunk = images[k * chunk_size: (k + 1) * chunk_size]
        with torch.no_grad():
            logits = inception_model(images_chunk)
        p_yx = torch.softmax(logits, dim=1).cpu().numpy()
        p_y = np.mean(p_yx, axis=0)
        scores.append(entropy(p_yx).mean() - entropy(p_y))
    return np.exp(np.mean(scores))

print('Generating Data')
fake_images = generate_fake_images(1000)
inception_score = inception_score(fake_images)
print("Inception Score:", inception_score)