In [1]:
import os; os.chdir('../')
from time import time
import dinr.modeling
import torch
from hydra.experimental import compose, initialize
from omegaconf import OmegaConf
from configs import trainer_conf
from collections import OrderedDict

initialize(config_path="../configs", job_name="test_app")

import numpy as np
import torch
import torch.nn as nn
import torch.autograd.profiler as profiler

from dinr.modeling.modules.inr_generator import INRGenerator
from dinr.modeling.modules.noise import mixing_noise

device = 'cuda'

In [7]:
from dinr.modeling.metrics.inception import InceptionV3Wrapper

config = compose(config_name="config", overrides=[
    'model=stylegan2_inrgan',
    'system=stylegan2_noppl_nomix_inr',
    'dataloader.total_batch_size=128',
    'datasets=lsun10',
#     'datasets=mini_imagenet',
#     'datasets.test.0.dataset_class.root=/tmp/skoroki/data/imagenet/mini-imagenet-128'
    'datasets.test.0.dataset_class.root=/tmp/skoroki/data/lsun/lsun10'
], return_hydra_config=True)

In [8]:
from dinr.data.build import build_datasets, build_loaders
from dinr.modeling.metrics.inception_score import calculate_inception_score
from dinr.modeling.metrics.fid import calculate_activation_statistics, calculate_frechet_distance

ds_train = build_datasets(config, 'train')
loader_train = build_loaders(config, ds_train, 'train')[0]
inception = nn.DataParallel(InceptionV3Wrapper()).to(device)
inception.eval();

In [9]:
from tqdm import tqdm

NUM_IMAGES = 50000
all_logits = []
all_feats = []

with torch.no_grad():
    for batch in tqdm(loader_train):
        img = batch['img'].to(device)
        img = (img + 1.0) / 2.0
        feats, logits = inception(img)
        all_logits.extend(logits.cpu())
        all_feats.extend(feats.cpu())
        
        if len(all_logits) >= NUM_IMAGES:
            break
            
logits = torch.stack(all_logits)[:NUM_IMAGES]
feats = torch.stack(all_feats)[:NUM_IMAGES]

390it [01:52,  3.48it/s]


In [10]:
probs = logits.softmax(dim=1).detach().cpu().numpy()
calculate_inception_score(probs[:50000], num_splits=10)

(9.926232, 0.18073939)

In [11]:
real_stats = np.load(config.env.fid_stats_path)
mean_real = real_stats['mean']
cov_real = real_stats['cov']
mean_fake, cov_fake = calculate_activation_statistics(feats.numpy())
fid = calculate_frechet_distance(mean_real, cov_real, mean_fake, cov_fake)
fid

0.421438506559241