In [1]:
if False:
    from google.colab import drive
    drive.mount('/content/drive/')
    base_path = '/content/drive/MyDrive/fid-files/'
else:
    base_path = ''

In [1]:
import torch
from torch import nn
from torchvision.models import inception_v3
from torchvision.datasets import CelebA
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def display_gen_pred(imgs, size):
    img_unflat = imgs.detach().cpu().view(-1, *size)
    img_grid = make_grid(img_unflat[:4], nrow=2)
    plt.imshow(img_grid.permute(1, 2, 0).squeeze())
    plt.axis('off')
    plt.show()

# Generator

In [4]:
class Generator(nn.Module):
    '''
    Generator Class
    Values:
        z_dim: the dimension of the noise vector, a scalar
        im_chan: the number of channels in the images, fitted for the dataset used, a scalar
              (CelebA is rgb, so 3 is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, z_dim=10, im_chan=3, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 8),
            self.make_gen_block(hidden_dim * 8, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a generator block of DCGAN;
        a transposed convolution, a batchnorm (except in the final layer), and an activation.
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise):
        '''
        Function for completing a forward pass of the generator: Given a noise tensor, 
        returns generated images.
        Parameters:
            noise: a noise tensor with dimensions (n_samples, z_dim)
        '''
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

def get_noise(n_samples, z_dim, device='cpu'):
    '''
    Function for creating noise vectors: Given the dimensions (n_samples, z_dim)
    creates a tensor of that shape filled with random numbers from the normal distribution.
    Parameters:
        n_samples: the number of samples to generate, a scalar
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    '''
    return torch.randn(n_samples, z_dim, device=device)

In [3]:
IMG_SIZE = 299
Z_DIM = 64
DEVICE = 'cpu'

In [None]:
transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = CelebA('.', download=True, transform=transform)

# dataset = torch.Tensor(np.load(base_path+'fid_images_tensor.npz', allow_pickle=True)['arr_0'])
# dataset = torch.utils.data.TensorDataset(dataset)

# Initializing

In [10]:
gen = Generator(Z_DIM).to(DEVICE)
gen.load_state_dict(torch.load(base_path+'pretrained_celeba.pth', map_location=torch.device(DEVICE))['gen'])

<All keys matched successfully>

In [11]:
BETAS = (0.5, 0.999)
LR = 0.0002
optim = torch.optim.Adam(gen.parameters(), lr=LR, betas=BETAS)

In [12]:
inception = inception_v3(pretrained=False)
inception.load_state_dict(torch.load(base_path+'inception_v3_google-1a9a5a14.pth'))
inception.to(DEVICE)
inception = inception.eval()

inception.fc = torch.nn.Identity()



# FID

In [13]:
from scipy import linalg
def fid(x, y):
    x_mean = torch.mean(x, dim=0)
    y_mean = torch.mean(y, dim=0)
    
    x_sig = torch.cov(x.T)
    y_sig = torch.cov(y.T)

    def mat_sqrt(x):
        y = x.cpu().detach().numpy()
        y = linalg.sqrtm(y)
        return torch.Tensor(y.real).to(DEVICE)
    
    return torch.mean(x_mean - y_mean) + torch.trace(x_sig + y_sig - 2 * mat_sqrt(torch.matmul(x_sig, y_sig)))

In [15]:
def preprocess(imgs):
    return torch.nn.functional.interpolate(imgs, size=IMG_SIZE, mode='bilinear', align_corners=False)

In [16]:
BATCH_SIZE = 10

dataloader = DataLoader(
    dataset,
    BATCH_SIZE,
    shuffle=True
)

In [17]:
def train(no_epochs):
    tqdm_obj = tqdm(range(0, no_epochs))
    no_of_batches = len(dataloader)
    
    for epoch in tqdm_obj:
        real_features_all = []
        fake_features_all = []

        for i, [real_sample] in enumerate(dataloader):
            tqdm_obj.set_postfix({ 'Batch': f'{i}/{no_of_batches}' })
            real_features = inception(real_sample.to(DEVICE))
            # real_features_all.append(real_features)

            # fake_sample = gen(get_noise(len(real_sample), Z_DIM).to(DEVICE))
            # fake_sample = preprocess(fake_sample)
            # fake_features = inception(fake_sample.to(DEVICE))
            # fake_features_all.append(fake_features)

        # fake_features = torch.cat(fake_features_all)
        # real_features = torch.cat(real_features_all)
        # loss = fid(real_features, fake_features)
        # loss.backward()
        # optim.step()
        
        # if epoch % 3 == 0:
        #     img = fake_sample[:4].detach().cpu()
        #     display_gen_pred(img)

In [18]:
train(1)

  0%|                                                     | 0/1 [00:00<?, ?it/s]


RuntimeError: output with shape [1, 64, 64] doesn't match the broadcast shape [3, 64, 64]