forked from bilocq/Trust-the-Critics
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ttc_eval.py
227 lines (170 loc) · 8.89 KB
/
ttc_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""
Takes a list of critics and step sizes, and produces a sequence of pictures using TTC.
Optional evaluation of FID and MMD (the latter not used in paper).
This code is intended to be run to use critics trained with ttc.py to generate samples. The arguments with names repeated from
arguments ttc.py should be set to the same values used in ttc.py when training the critics.
"""
import os, sys
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(),'TTC_utils'))
import argparse
import time
import shutil
import random
import numpy as np
import torch
import pandas as pd
import pickle
from tqdm import tqdm
import matplotlib
matplotlib.use('Agg')
import dataloader
import networks
from generate_samples import generate_image
from generate_samples import save_individuals
from mmd import mmd
from steptaker import steptaker
from pytorch_fid import fid_score
from pytorch_fid import inception
#get command line args~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
parser = argparse.ArgumentParser('TTC Evaluation Code')
parser.add_argument('--target', type=str, default='cifar10', choices=['cifar10','mnist','fashion', 'celeba', 'monet', 'celebaHQ'])
parser.add_argument('--source', type=str, default='cifar10', choices=['noise', 'untrained_gen', 'photo'])
parser.add_argument('--temp_dir', type=str, required=True, help = 'directory where model state dicts are located')
parser.add_argument('--data', type=str, required=True, help = 'directory where data is located')
parser.add_argument('--model', type=str, default='dcgan', choices=['dcgan', 'infogan', 'arConvNet', 'sndcgan','bsndcgan'])
parser.add_argument('--dim', type=int, default=64, help = 'int determining network dimensions')
parser.add_argument('--seed', type=int, default=-1, help = 'Set random seed for reproducibility')
parser.add_argument('--bs', type=int, default=128, help = 'batch size')
parser.add_argument('--num_workers', type=int, default = 0, help = 'number of data loader processes')
parser.add_argument('--MMD', action= 'store_true', help = 'compute the MMD between args.bs samples of updated source and target')
parser.add_argument('--FID', action= 'store_true', help = 'compute the FID between generated examples and test set for each generator.')
parser.add_argument('--numsample', type=int, default = 0, help = 'how many pics to generate for FID')
parser.add_argument('--eval_freq', type=int, default = 5, help = 'frequency of MMD/FID evaluation')
parser.add_argument('--num_step', type=int, default=1, help = 'how many steps to use in gradient descent')
parser.add_argument('--commonfake', action= 'store_true', help = 'Use if you want a common source element to compare two models')
parser.add_argument('--translate_eval', type=int, default = -1, help = 'index for particular image you want translated')
args = parser.parse_args()
temp_dir = args.temp_dir#directory for temp saving
num_crit = len(os.listdir(os.path.join(temp_dir,'model_dicts')))#number of critics
#code to get deterministic behaviour
if args.seed != -1: #if non-default seed
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark=False #If true, optimizes convolution for hardware, but gives non-deterministic behaviour
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
#get dataloader
target_loader = getattr(dataloader, args.target)(args, train=False)
args.num_chan = target_loader.in_channels
args.hpix = target_loader.hpix
args.wpix = target_loader.wpix
source_loader = getattr(dataloader, args.source)(args, train=False)
if args.commonfake:
gen = iter(source_loader)
commonfake = next(gen)[0]
#begin definitions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
critic_list = [None]*num_crit
for i in range(num_crit):#initialize pre-trained critics
critic_list[i] = getattr(networks, args.model)(args.dim, args.num_chan, args.hpix, args.wpix)
critic_list[i].load_state_dict(torch.load(os.path.join(temp_dir,'model_dicts','critic{}.pth'.format(i))))
#Extract list of steps from log file
log = pd.read_pickle(os.path.join(temp_dir,'log.pkl'))
steps_d = log['steps']
steps = []
for key in steps_d.keys():
steps.append(steps_d[key])
print('Arguments:')
for p in vars(args).items():
print(' ',p[0]+': ',p[1])
print('\n')
use_cuda = torch.cuda.is_available()
if use_cuda:
for i in range(num_crit):
critic_list[i] = critic_list[i].cuda()
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#Sample source images and update them according to the critics and step sizes
gen = iter(source_loader)#make dataloaders into iterables
tgen = iter(target_loader)
if args.numsample>0:
num_batch = args.numsample//args.bs
else:
num_batch = 1
starttime = time.time()
if args.MMD:
mmdvals = torch.zeros(1+(num_crit//args.eval_freq), num_batch).cuda()#where MMD measurements will be stored
num_samp = min(args.bs, 128)
print('Using max num_samp = 128')
#repeating seed selection again here to get same noise sequence
#code to get deterministic behaviour
if args.seed != -1: #if non-default seed
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark=False #If true, optimizes convolution for hardware, but gives non-deterministic behaviour
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
for b_idx in tqdm(range(num_batch)):
fake = next(gen)[0]
if args.commonfake:
fake = commonfake
fake = fake.cuda()
if args.translate_eval >= 0:
indices = [0, args.translate_eval]#code requires minibatch size of > 1
fake = fake[indices,:,:,:]
orig = fake.detach().clone()
if args.MMD:
tbatch = next(tgen)[0]#target data only necessary here if computing MMD
tbatch = tbatch.cuda()
if b_idx ==0:
generate_image('00',fake[0:num_samp,:,:,:].detach().cpu(), 'jpg', temp_dir)#visualize initial images
if args.MMD:#record MMD for training phase 0 and minibatch b_idx
mmdvals[0, b_idx] = mmd(fake.view(args.bs, -1), tbatch.view(args.bs,-1), alph = [10**n for n in range(3,-11,-1)])
#save minibatch if doing FID computation
if args.numsample>0:
save_individuals(b_idx, 0, fake, 'jpg', temp_dir, to_rgb = True if args.num_chan == 1 else False)
for i in range(num_crit):#apply the steps of TTC
eps = torch.tensor(steps[i]).cuda()
fake = steptaker(fake, critic_list[i], eps, num_step = args.num_step)
if ((i+1) % args.eval_freq == 0):
if b_idx == 0:#only visualize if on the first batch
if args.translate_eval >= 0:
img_bank = torch.stack((orig[1,:,:,:], fake[1,:,:,:]))#puts orig and fake next to each other
generate_image(i, img_bank.detach().cpu(), 'pdf', temp_dir)
else:
generate_image(i, fake[0:num_samp, :,:,:].detach().cpu(), 'pdf', temp_dir)
if args.MMD:#record MMD for training phase i and minibatch b_idx
mmdvals[(i+1)//args.eval_freq, b_idx] = mmd(fake.view(args.bs, -1), tbatch.view(args.bs,-1), alph = [10**n for n in range(3,-11,-1)])
#save minibatch if doing big sample
if args.numsample>0:
save_individuals(b_idx, i+1, fake, 'jpg', temp_dir, to_rgb = True if args.num_chan ==1 else False)
if args.numsample>0:#save zip file of pics
shutil.make_archive(os.path.join(temp_dir,'pics'), 'zip', os.path.join(temp_dir,'pics'))
print('time required: {}'.format(time.time() - starttime))
if args.MMD or args.FID:
metrics = {}
#save mmd vals
if args.MMD:
metrics['mmd'] = np.array(mmdvals.detach().cpu())
print(mmdvals)
#compute FID and save
if args.FID:
fidvals = torch.zeros(1+(num_crit//args.eval_freq)).cuda()
test_data_path = os.path.join(temp_dir, args.target + 'test')
device = torch.device('cuda') if use_cuda else torch.device('cpu')
block_idx = inception.InceptionV3.BLOCK_INDEX_BY_DIM[2048]
inception_net = inception.InceptionV3([block_idx]).to(device)
for c_idx in range(1+num_crit//args.eval_freq):
crit_num = c_idx*args.eval_freq
real_mean, real_cov = fid_score.compute_statistics_of_path(test_data_path, inception_net, 100, 2048, device)
if not test_data_path.split('.')[-1] == 'npz':
test_data_path = os.path.join(temp_dir, 'test_data_stats.npz')
np.savez(test_data_path, mu=real_mean, sigma=real_cov)
fake_mean, fake_cov = fid_score.compute_statistics_of_path(os.path.join(temp_dir, 'pics/timestamp{}'.format(crit_num)), inception_net, 100, 2048, device)
fidvals[c_idx] = fid_score.calculate_frechet_distance(fake_mean, fake_cov, real_mean, real_cov)
metrics['fid'] = np.array(fidvals.detach().cpu())
print(fidvals)
if args.MMD or args.FID:
with open(temp_dir + '/metrics.pkl', 'wb') as f:
pickle.dump(metrics, f, pickle.HIGHEST_PROTOCOL)#save metrics