In [1]:
import os
import time
import torch
import datetime

import torch.nn as nn
from torch.autograd import Variable
from torchvision.utils import save_image

from sagan_models import Generator, Discriminator
from utils import *

In [2]:
from parameter import *
from trainer import Trainer
from data_loader import Data_Loader
from torch.backends import cudnn
from utils import make_folder
import torch_fidelity

In [3]:
config = get_parameters()

In [4]:
config.batch_size=64
config.imsize=64
config.train=False
config.pretrained_model=996975
config.dataset="celeb"
config.version="sagan_celeb"
config.adv_loss='hinge'

In [None]:
# data_loader = Data_Loader(config.train, config.dataset, config.image_path, config.imsize,
#                          config.batch_size, shuf=config.train)

In [None]:
# data_iter=iter(data_loader.loader())
# for i in data_iter:
#     print(len(i))
#     print(i[0].shape)
#     break

In [None]:
# trainer = Trainer(data_loader.loader(), config)

In [None]:
# trainer.G

In [None]:
# from torchsummary import summary
# summary(trainer.G,(128,1))
# summary(trainer.D,(3,64,64))

In [5]:
class Tester(object):
    def __init__(self, config):
        # exact model and loss
        self.model = config.model
        self.adv_loss = config.adv_loss

        # Model hyper-parameters
        self.imsize = config.imsize
        self.g_num = config.g_num
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.parallel = config.parallel

        self.lambda_gp = config.lambda_gp
        self.total_step = config.total_step
        self.d_iters = config.d_iters
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.lr_decay = config.lr_decay
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.pretrained_model = config.pretrained_model

        self.dataset = config.dataset
        self.use_tensorboard = config.use_tensorboard
        self.image_path = config.image_path
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.version = config.version

        # Path
        self.log_path = os.path.join(config.log_path, self.version)
        self.sample_path = os.path.join(config.sample_path, self.version)
        self.model_save_path = os.path.join(config.model_save_path, self.version)

        self.build_model()

        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()
    
    def test(self):
        make_folder('./','generated_dataset{}'.format(self.pretrained_model))
        # 一共可得到1000*64=64000张照片
        for i in range(1000):
            # 一定要用randn,不可用rand！ 前者为正态分布，后者为均匀分布
            rand_z=tensor2var(torch.randn(self.batch_size, self.z_dim))
            fake_images,_,_=self.G(rand_z)
            for j in range(self.batch_size):
                save_image(denorm(fake_images[j]),'./generated_dataset'+str(self.pretrained_model)+'/{}_fake.png'.format(i*64+j + 1))
            # save_image(denorm(fake_images),'./generated_dataset/tot.png')

    def build_model(self):

        self.G = Generator(self.batch_size,self.imsize, self.z_dim, self.g_conv_dim).cuda()
        self.D = Discriminator(self.batch_size,self.imsize, self.d_conv_dim).cuda()
        if self.parallel:
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)

        # Loss and optimizer
        # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2])

        self.c_loss = torch.nn.CrossEntropyLoss()

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_G.pth'.format(self.pretrained_model))))
        self.D.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}_D.pth'.format(self.pretrained_model))))
        print('loaded trained models (step: {})..!'.format(self.pretrained_model))

In [None]:
# !python -m pytorch_fid ./generated_dataset996975 ./data/CelebA/img_align_celeba/img_align_celeba --device cuda:0

In [None]:
# FID=torch.zeros(10)
# FID[0]=98.36766537081229
# FID[5]=99.15743182237708
# FID[9]=106.43966820133824
# print(FID)

In [6]:
ISCs=[]
FIDs=[]
configs=[]
pretrained_models=[101280,202560,300675,401955,500070,601350,702630,800745,902025,996975]
for i in range(10):
    config.pretrained_model=pretrained_models[i]
    
    if not os.path.exists('./generated_dataset{}'.format(pretrained_models[i])):
        tester=Tester(config)
        tester.test()
    
    metrics_dict = torch_fidelity.calculate_metrics(
        input1='./generated_dataset{}'.format(pretrained_models[i]),
        input2='./data/CelebA/img_align_celeba/img_align_celeba',
        cuda=True,
        fid=True,
        isc=True
    )
    
    print(metrics_dict)
    FIDs.append(metrics_dict['frechet_inception_distance'])
    ISCs.append(metrics_dict['inception_score_mean'])

loaded trained models (step: 101280)..!


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to C:\Users\Owner/.cache\torch\hub\checkpoints\weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:02<00:00, 37.9MB/s]
Extracting features from input1
Looking for samples non-recursivelty in "./generated_dataset101280" with extensions png,jpg,jpeg
Found 64000 samples
  img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())).view(height, width, 3)
Processing samples                                                              
Extracting features from input2
Looking for samples non-recursivelty in "./data/CelebA/img_align_celeba/img_align_celeba" with extensions png,jpg,jpeg
Found 202599 samples, some are lossy-compressed - this may affect metrics
Processing samples                                                                
Inception Sc

{'inception_score_mean': 1.9817197514882445, 'inception_score_std': 0.012518461887355354, 'frechet_inception_distance': 93.19936273210743}
loaded trained models (step: 202560)..!


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "./generated_dataset202560" with extensions png,jpg,jpeg
Found 64000 samples
Processing samples                                                              
Extracting features from input2
Looking for samples non-recursivelty in "./data/CelebA/img_align_celeba/img_align_celeba" with extensions png,jpg,jpeg
Found 202599 samples, some are lossy-compressed - this may affect metrics
Processing samples                                                                
Inception Score: 2.001442248183252 ± 0.014230098780839487
Frechet Inception Distance: 96.31385320822477


{'inception_score_mean': 2.001442248183252, 'inception_score_std': 0.014230098780839487, 'frechet_inception_distance': 96.31385320822477}
loaded trained models (step: 300675)..!


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "./generated_dataset300675" with extensions png,jpg,jpeg
Found 64000 samples
Processing samples                                                              
Extracting features from input2
Looking for samples non-recursivelty in "./data/CelebA/img_align_celeba/img_align_celeba" with extensions png,jpg,jpeg
Found 202599 samples, some are lossy-compressed - this may affect metrics
Processing samples                                                                
Inception Score: 1.983259409584378 ± 0.011687999878698471
Frechet Inception Distance: 100.26028524625696


{'inception_score_mean': 1.983259409584378, 'inception_score_std': 0.011687999878698471, 'frechet_inception_distance': 100.26028524625696}
loaded trained models (step: 401955)..!


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "./generated_dataset401955" with extensions png,jpg,jpeg
Found 64000 samples
Processing samples                                                              
Extracting features from input2
Looking for samples non-recursivelty in "./data/CelebA/img_align_celeba/img_align_celeba" with extensions png,jpg,jpeg
Found 202599 samples, some are lossy-compressed - this may affect metrics
Processing samples                                                                
Inception Score: 1.980664936185562 ± 0.007506401083250908
Frechet Inception Distance: 98.23548190791382


{'inception_score_mean': 1.980664936185562, 'inception_score_std': 0.007506401083250908, 'frechet_inception_distance': 98.23548190791382}
loaded trained models (step: 500070)..!


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "./generated_dataset500070" with extensions png,jpg,jpeg
Found 64000 samples
Processing samples                                                              
Extracting features from input2
Looking for samples non-recursivelty in "./data/CelebA/img_align_celeba/img_align_celeba" with extensions png,jpg,jpeg
Found 202599 samples, some are lossy-compressed - this may affect metrics
Processing samples                                                                
Inception Score: 1.8681145732343416 ± 0.016126393387134886
Frechet Inception Distance: 107.43194621176099


{'inception_score_mean': 1.8681145732343416, 'inception_score_std': 0.016126393387134886, 'frechet_inception_distance': 107.43194621176099}
loaded trained models (step: 601350)..!


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "./generated_dataset601350" with extensions png,jpg,jpeg
Found 64000 samples
Processing samples                                                              
Extracting features from input2
Looking for samples non-recursivelty in "./data/CelebA/img_align_celeba/img_align_celeba" with extensions png,jpg,jpeg
Found 202599 samples, some are lossy-compressed - this may affect metrics
Processing samples                                                                
Inception Score: 1.9287016672082309 ± 0.012127940698830808
Frechet Inception Distance: 101.87529504188609


{'inception_score_mean': 1.9287016672082309, 'inception_score_std': 0.012127940698830808, 'frechet_inception_distance': 101.87529504188609}
loaded trained models (step: 702630)..!


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "./generated_dataset702630" with extensions png,jpg,jpeg
Found 64000 samples
Processing samples                                                              
Extracting features from input2
Looking for samples non-recursivelty in "./data/CelebA/img_align_celeba/img_align_celeba" with extensions png,jpg,jpeg
Found 202599 samples, some are lossy-compressed - this may affect metrics
Processing samples                                                                
Inception Score: 1.8775383581699998 ± 0.010941680038397314
Frechet Inception Distance: 101.23040864875799


{'inception_score_mean': 1.8775383581699998, 'inception_score_std': 0.010941680038397314, 'frechet_inception_distance': 101.23040864875799}
loaded trained models (step: 800745)..!


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "./generated_dataset800745" with extensions png,jpg,jpeg
Found 64000 samples
Processing samples                                                              
Extracting features from input2
Looking for samples non-recursivelty in "./data/CelebA/img_align_celeba/img_align_celeba" with extensions png,jpg,jpeg
Found 202599 samples, some are lossy-compressed - this may affect metrics
Processing samples                                                                
Inception Score: 1.952939847577786 ± 0.00837152962010387
Frechet Inception Distance: 97.94280041983791


{'inception_score_mean': 1.952939847577786, 'inception_score_std': 0.00837152962010387, 'frechet_inception_distance': 97.94280041983791}
loaded trained models (step: 902025)..!


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "./generated_dataset902025" with extensions png,jpg,jpeg
Found 64000 samples
Processing samples                                                              
Extracting features from input2
Looking for samples non-recursivelty in "./data/CelebA/img_align_celeba/img_align_celeba" with extensions png,jpg,jpeg
Found 202599 samples, some are lossy-compressed - this may affect metrics
Processing samples                                                                
Inception Score: 2.0940204584227744 ± 0.009608713965648719
Frechet Inception Distance: 98.12563405271786


{'inception_score_mean': 2.0940204584227744, 'inception_score_std': 0.009608713965648719, 'frechet_inception_distance': 98.12563405271786}
loaded trained models (step: 996975)..!


Creating feature extractor "inception-v3-compat" with features ['2048', 'logits_unbiased']
Extracting features from input1
Looking for samples non-recursivelty in "./generated_dataset996975" with extensions png,jpg,jpeg
Found 64000 samples
Processing samples                                                              
Extracting features from input2
Looking for samples non-recursivelty in "./data/CelebA/img_align_celeba/img_align_celeba" with extensions png,jpg,jpeg
Found 202599 samples, some are lossy-compressed - this may affect metrics
Processing samples                                                                
Inception Score: 2.0738501509172895 ± 0.01245419441592559
Frechet Inception Distance: 94.7430723149806


{'inception_score_mean': 2.0738501509172895, 'inception_score_std': 0.01245419441592559, 'frechet_inception_distance': 94.7430723149806}


In [7]:
FIDs

[93.19936273210743,
 96.31385320822477,
 100.26028524625696,
 98.23548190791382,
 107.43194621176099,
 101.87529504188609,
 101.23040864875799,
 97.94280041983791,
 98.12563405271786,
 94.7430723149806]

In [8]:
ISCs

[1.9817197514882445,
 2.001442248183252,
 1.983259409584378,
 1.980664936185562,
 1.8681145732343416,
 1.9287016672082309,
 1.8775383581699998,
 1.952939847577786,
 2.0940204584227744,
 2.0738501509172895]