-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dd72201
commit dbd92ab
Showing
19 changed files
with
3,566 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,209 @@ | ||
| import os | ||
|
|
||
| import numpy as np | ||
| from PIL import Image | ||
| import chainer | ||
| from chainer.dataset import dataset_mixin | ||
| from glob import glob | ||
|
|
||
|
|
||
| class Cifar10Dataset(dataset_mixin.DatasetMixin): | ||
| def __init__(self, test=False): | ||
| d_train, d_test = chainer.datasets.get_cifar10(ndim=3, withlabel=False, scale=1.0) | ||
| if test: | ||
| self.ims = d_test | ||
| else: | ||
| self.ims = d_train | ||
| self.ims = self.ims * 2 - 1.0 # [-1.0, 1.0] | ||
| print("load cifar-10. shape: ", self.ims.shape) | ||
|
|
||
| def __len__(self): | ||
| return self.ims.shape[0] | ||
|
|
||
| def get_example(self, i): | ||
| return self.ims[i] | ||
|
|
||
|
|
||
| def image_to_np(img): | ||
| img = img.convert('RGB') | ||
| img = np.asarray(img, dtype=np.uint8) | ||
| img = img.transpose((2, 0, 1)).astype("f") | ||
| if img.shape[0] == 1: | ||
| img = np.broadcast_to(img, (3, img.shape[1], img.shape[2])) | ||
| img = (img - 127.5)/127.5 | ||
| return img | ||
|
|
||
| ''' | ||
| def preprocess_image(img, crop_width=256, img2np=True): | ||
| wid = min(img.size[0], img.size[1]) | ||
| ratio = crop_width / (wid + 1e-4) | ||
| img = img.resize((int(ratio * img.size[0]), int(ratio * img.size[1])), Image.BILINEAR) | ||
| x_l = (img.size[0]) // 2 - crop_width // 2 | ||
| x_r = x_l + crop_width | ||
| y_u = 0 | ||
| y_d = y_u + crop_width | ||
| img = img.crop((x_l, y_u, x_r, y_d)) | ||
| if img2np: | ||
| img = image_to_np(img) | ||
| return img | ||
| ''' | ||
| def preprocess_image(img, resize=128, img2np=True): | ||
| h,w = img.size[0], img.size[1] | ||
| wid = min(h, w) | ||
| x_l = int((h-wid)/2.0) | ||
| x_r = x_l + wid | ||
| y_u = int((w-wid)/2.0) | ||
| y_d = y_u + wid | ||
| img = img.crop((x_l, y_u, x_r, y_d)) | ||
| img = img.resize((resize, resize), Image.BILINEAR) | ||
|
|
||
| if img2np: | ||
| img = image_to_np(img) | ||
| return img | ||
|
|
||
| def preprocess_image_no_crop(img, resize=64, img2np=True): | ||
| wid = min(img.size[0], img.size[1]) | ||
| ratio = 1.0*resize / wid | ||
| img = img.resize((int(ratio * img.size[0]), int(ratio * img.size[1])), Image.BILINEAR) | ||
|
|
||
| if img2np: | ||
| img = image_to_np(img) | ||
| return img | ||
|
|
||
|
|
||
| def find_all_files(directory): | ||
| """http://qiita.com/suin/items/cdef17e447ceeff6e79d""" | ||
| for root, dirs, files in os.walk(directory): | ||
| yield root | ||
| for file in files: | ||
| yield os.path.join(root, file) | ||
|
|
||
|
|
||
| class CelebADataset(dataset_mixin.DatasetMixin): | ||
| def __init__(self, resize=128): | ||
| self.resize = resize | ||
| self.image_files = glob('/home/yasin/sharedLocal/data/celeba/img_align_celeba/*.jpg') | ||
| print(len(self.image_files)) | ||
|
|
||
| def __len__(self): | ||
| return len(self.image_files) | ||
|
|
||
| def get_example(self, i): | ||
| np.random.seed() | ||
| img = None | ||
|
|
||
| while img is None: | ||
| try: | ||
| fn = "%s" % (self.image_files[i]) | ||
| img = Image.open(fn) | ||
| except Exception as e: | ||
| print(i, fn, str(e)) | ||
| return preprocess_image(img, resize=self.resize) | ||
|
|
||
| class ImagenetDataset(dataset_mixin.DatasetMixin): | ||
| def __init__(self, file_list, crop_width=256): | ||
| self.crop_width = crop_width | ||
| self.image_files = file_list | ||
| print(len(self.image_files)) | ||
|
|
||
| def __len__(self): | ||
| return len(self.image_files) | ||
|
|
||
| def get_example(self, i): | ||
| np.random.seed() | ||
| img = None | ||
|
|
||
| while img is None: | ||
| # print(i,id) | ||
| try: | ||
| fn = "%s" % (self.image_files[i]) | ||
| img = Image.open(fn) | ||
| except Exception as e: | ||
| print(i, fn, str(e)) | ||
| return preprocess_image(img, crop_width=self.crop_width) | ||
|
|
||
| class Imagenet64Dataset(dataset_mixin.DatasetMixin): | ||
| def __init__(self): | ||
| self.image_files = glob('/home/yasin/sharedLocal/train_64x64/*.png') | ||
| print(len(self.image_files)) | ||
|
|
||
| def __len__(self): | ||
| return len(self.image_files) | ||
|
|
||
| def get_example(self, i): | ||
| np.random.seed() | ||
| img = None | ||
|
|
||
| while img is None: | ||
| # print(i,id) | ||
| try: | ||
| fn = "%s" % (self.image_files[i]) | ||
| img = Image.open(fn) | ||
| except Exception as e: | ||
| print(i, fn, str(e)) | ||
| return image_to_np(img) | ||
|
|
||
| class Imagenet32Dataset(dataset_mixin.DatasetMixin): | ||
| def __init__(self): | ||
| self.image_files = glob('/home/users/ntu/yasin001/project/train_32x32/*.png') | ||
| print(len(self.image_files)) | ||
|
|
||
| def __len__(self): | ||
| return len(self.image_files) | ||
|
|
||
| def get_example(self, i): | ||
| np.random.seed() | ||
| img = None | ||
|
|
||
| while img is None: | ||
| # print(i,id) | ||
| try: | ||
| fn = "%s" % (self.image_files[i]) | ||
| img = Image.open(fn) | ||
| except Exception as e: | ||
| print(i, fn, str(e)) | ||
| return image_to_np(img) | ||
|
|
||
| class Stl10Dataset(dataset_mixin.DatasetMixin): | ||
| def __init__(self, resize=64): | ||
| self.resize = resize | ||
| self.image_files = glob('/home/users/ntu/yasin001/project/stl10/*.jpg') | ||
| print(len(self.image_files)) | ||
|
|
||
| def __len__(self): | ||
| return len(self.image_files) | ||
|
|
||
| def get_example(self, i): | ||
| np.random.seed() | ||
| img = None | ||
|
|
||
| while img is None: | ||
| # print(i,id) | ||
| try: | ||
| fn = "%s" % (self.image_files[i]) | ||
| img = Image.open(fn) | ||
| except Exception as e: | ||
| print(i, fn, str(e)) | ||
| return preprocess_image_no_crop(img, resize=self.resize) | ||
|
|
||
| class Stl10_48_Dataset(dataset_mixin.DatasetMixin): | ||
| def __init__(self): | ||
| self.image_files = glob('/home/yasin/sharedLocal/data/stl10-48/*.jpg') | ||
| print(len(self.image_files)) | ||
|
|
||
| def __len__(self): | ||
| return len(self.image_files) | ||
|
|
||
| def get_example(self, i): | ||
| np.random.seed() | ||
| img = None | ||
|
|
||
| while img is None: | ||
| # print(i,id) | ||
| try: | ||
| fn = "%s" % (self.image_files[i]) | ||
| img = Image.open(fn) | ||
| except Exception as e: | ||
| print(i, fn, str(e)) | ||
| return image_to_np(img) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| import os | ||
| import sys | ||
| import math | ||
|
|
||
| import numpy as np | ||
| from PIL import Image | ||
| import scipy.linalg | ||
|
|
||
| import chainer | ||
| import chainer.cuda | ||
| from chainer import Variable | ||
| from chainer import serializers | ||
| import chainer.functions as F | ||
|
|
||
| sys.path.append(os.path.dirname(__file__)) | ||
| from inception.inception_score import inception_score, Inception | ||
|
|
||
|
|
||
| def sample_generate_light(gen, dst, rows=5, cols=5, seed=0): | ||
| @chainer.training.make_extension() | ||
| def make_image(trainer): | ||
| np.random.seed(seed) | ||
| n_images = rows * cols | ||
| xp = gen.xp | ||
| z = Variable(xp.asarray(gen.make_hidden(n_images))) | ||
| with chainer.using_config('train', False), chainer.using_config('enable_backprop', False): | ||
| x = gen(z) | ||
| x = chainer.cuda.to_cpu(x.data) | ||
| np.random.seed() | ||
|
|
||
| x = np.asarray(np.clip(x * 127.5 + 127.5, 0.0, 255.0), dtype=np.uint8) | ||
| _, _, H, W = x.shape | ||
| x = x.reshape((rows, cols, 3, H, W)) | ||
| x = x.transpose(0, 3, 1, 4, 2) | ||
| x = x.reshape((rows * H, cols * W, 3)) | ||
|
|
||
| preview_dir = '{}/preview'.format(dst) | ||
| preview_path = preview_dir + '/image_latest.png' | ||
| if not os.path.exists(preview_dir): | ||
| os.makedirs(preview_dir) | ||
| Image.fromarray(x).save(preview_path) | ||
|
|
||
| return make_image | ||
|
|
||
|
|
||
| def sample_generate(gen, dst, rows=10, cols=10, seed=0): | ||
| """Visualization of rows*cols images randomly generated by the generator.""" | ||
| @chainer.training.make_extension() | ||
| def make_image(trainer): | ||
| np.random.seed(seed) | ||
| n_images = rows * cols | ||
| xp = gen.xp | ||
| z = Variable(xp.asarray(gen.fix_z))#Variable(xp.asarray(gen.make_hidden(n_images))) | ||
| with chainer.using_config('train', False), chainer.using_config('enable_backprop', False): | ||
| x = gen(z) | ||
| x = chainer.cuda.to_cpu(x.data) | ||
| np.random.seed() | ||
|
|
||
| x = np.asarray(np.clip(x * 127.5 + 127.5, 0.0, 255.0), dtype=np.uint8) | ||
| _, _, h, w = x.shape | ||
| x = x.reshape((rows, cols, 3, h, w)) | ||
| x = x.transpose(0, 3, 1, 4, 2) | ||
| x = x.reshape((rows * h, cols * w, 3)) | ||
|
|
||
| preview_dir = '{}/preview'.format(dst) | ||
| preview_path = preview_dir + '/image{:0>8}.png'.format(trainer.updater.iteration) | ||
| if not os.path.exists(preview_dir): | ||
| os.makedirs(preview_dir) | ||
| Image.fromarray(x).save(preview_path) | ||
|
|
||
| return make_image | ||
|
|
||
|
|
||
| def load_inception_model(): | ||
| infile = "%s/inception/inception_score.model"%os.path.dirname(__file__) | ||
| model = Inception() | ||
| serializers.load_hdf5(infile, model) | ||
| model.to_gpu() | ||
| return model | ||
|
|
||
|
|
||
| def calc_inception(gen, batchsize=100): | ||
| @chainer.training.make_extension() | ||
| def evaluation(trainer): | ||
| model = load_inception_model() | ||
|
|
||
| ims = [] | ||
| xp = gen.xp | ||
|
|
||
| n_ims = 50000 | ||
| for i in range(0, n_ims, batchsize): | ||
| z = Variable(xp.asarray(gen.make_hidden(batchsize))) | ||
| with chainer.using_config('train', False), chainer.using_config('enable_backprop', False): | ||
| x = gen(z) | ||
| x = chainer.cuda.to_cpu(x.data) | ||
| x = np.asarray(np.clip(x * 127.5 + 127.5, 0.0, 255.0), dtype=np.uint8) | ||
| ims.append(x) | ||
| ims = np.asarray(ims) | ||
| _, _, _, h, w = ims.shape | ||
| ims = ims.reshape((n_ims, 3, h, w)).astype("f") | ||
|
|
||
| mean, std = inception_score(model, ims) | ||
|
|
||
| chainer.reporter.report({ | ||
| 'inception_mean': mean, | ||
| 'inception_std': std | ||
| }) | ||
|
|
||
| return evaluation | ||
|
|
||
|
|
||
| def get_mean_cov(model, ims, batch_size=100): | ||
| n, c, w, h = ims.shape | ||
| n_batches = int(math.ceil(float(n) / float(batch_size))) | ||
|
|
||
| xp = model.xp | ||
|
|
||
| print('Batch size:', batch_size) | ||
| print('Total number of images:', n) | ||
| print('Total number of batches:', n_batches) | ||
|
|
||
| ys = xp.empty((n, 2048), dtype=xp.float32) | ||
|
|
||
| for i in range(n_batches): | ||
| #print('Running batch', i + 1, '/', n_batches, '...') | ||
| batch_start = (i * batch_size) | ||
| batch_end = min((i + 1) * batch_size, n) | ||
|
|
||
| ims_batch = ims[batch_start:batch_end] | ||
| ims_batch = xp.asarray(ims_batch) # To GPU if using CuPy | ||
| ims_batch = Variable(ims_batch) | ||
|
|
||
| # Resize image to the shape expected by the inception module | ||
| if (w, h) != (299, 299): | ||
| ims_batch = F.resize_images(ims_batch, (299, 299)) # bilinear | ||
|
|
||
| # Feed images to the inception module to get the features | ||
| with chainer.using_config('train', False), chainer.using_config('enable_backprop', False): | ||
| y = model(ims_batch, get_feature=True) | ||
| ys[batch_start:batch_end] = y.data | ||
|
|
||
| mean = xp.mean(ys, axis=0).get() | ||
| # cov = F.cross_covariance(ys, ys, reduce="no").data.get() | ||
| cov = np.cov(ys.get().T) | ||
|
|
||
| return mean, cov | ||
|
|
||
| def FID(m0,c0,m1,c1): | ||
| ret = 0 | ||
| ret += np.sum((m0-m1)**2) | ||
| ret += np.trace(c0 + c1 - 2.0*scipy.linalg.sqrtm(np.dot(c0, c1))) | ||
| return np.real(ret) | ||
|
|
||
| def calc_FID(gen, dataset='cifar10',batchsize=100): | ||
| """Frechet Inception Distance proposed by https://arxiv.org/abs/1706.08500""" | ||
| @chainer.training.make_extension() | ||
| def evaluation(trainer): | ||
| if dataset == 'cifar10': | ||
| stat_file="%s/cifar-10-fid.npz"%os.path.dirname(__file__) | ||
| elif dataset == 'stl10': | ||
| stat_file='/home/yasin/sharedLocal/fid_stats_stl10.npz' | ||
| else: | ||
| NotImplementedError('no such dataset') | ||
| model = load_inception_model() | ||
| stat = np.load(stat_file) | ||
|
|
||
| n_ims = 10000 | ||
| xp = gen.xp | ||
| xs = [] | ||
| for i in range(0, n_ims, batchsize): | ||
| z = Variable(xp.asarray(gen.make_hidden(batchsize))) | ||
| with chainer.using_config('train', False), chainer.using_config('enable_backprop', False): | ||
| x = gen(z) | ||
| x = chainer.cuda.to_cpu(x.data) | ||
| x = np.asarray(np.clip(x * 127.5 + 127.5, 0.0, 255.0), dtype="f") | ||
| xs.append(x) | ||
| xs = np.asarray(xs) | ||
| _, _, _, h, w = xs.shape | ||
|
|
||
| with chainer.using_config('train', False), chainer.using_config('enable_backprop', False): | ||
| mean, cov = get_mean_cov(model, np.asarray(xs).reshape((-1, 3, h, w))) | ||
| fid = FID(stat["mean"], stat["cov"], mean, cov) | ||
|
|
||
| chainer.reporter.report({ | ||
| 'FID': fid, | ||
| }) | ||
|
|
||
| return evaluation |
Oops, something went wrong.