In [21]:
import network as net
from config import config
import torch
import numpy as np
import os
from PIL import Image
from torchvision import transforms
from torchvision import utils as vutils
import random

seed = 7412

np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

def save_tile_images(images, save_path, nrow):
    images = images.detach().cpu()
    vutils.save_image(images, save_path, nrow=nrow)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# root = '../storage/PGGAN/results/gucci/2020_05_11_16_54_33' # 256 longer
root = '../storage/PGGAN/results/gucci/2020_05_06_23_44_31' # 256 shorter
# root = '../storage/PGGAN/results/gucci/2020_05_05_01_18_33' # 128
# root = '../storage/PGGAN/results/gucci/2020_05_02_20_55_49' # 64

model_root = os.path.join(root, 'models')
save_root = os.path.join(root, 'tile_images')

model_list = os.listdir(model_root)
model_list = sorted([m for m in model_list  if 'gen' in m])
print(model_list)

z = torch.FloatTensor(16, config.nz).normal_(0.0, 1.0).to(device)

G = net.Generator(config).to(device)
G.eval()
with torch.no_grad():
    old_resl = 2
    for m in model_list:
        resl = int(m[5])
        if resl > old_resl:
            G.grow_network(resl)
            G.flush_network()
            G.to(device)
            G.eval()
            old_resl = resl
        checkpoint = torch.load(os.path.join(model_root, m))
        G.load_state_dict(checkpoint['state_dict'])

        images = G(z)
        os.makedirs(save_root, exist_ok=True)
        save_tile_images(images, os.path.join(save_root, m+'.png'), 4)

['gen_R3_T305.pth.tar', 'gen_R3_T310.pth.tar', 'gen_R3_T315.pth.tar', 'gen_R3_T320.pth.tar', 'gen_R3_T325.pth.tar', 'gen_R3_T330.pth.tar', 'gen_R3_T335.pth.tar', 'gen_R3_T340.pth.tar', 'gen_R3_T345.pth.tar', 'gen_R3_T350.pth.tar', 'gen_R3_T355.pth.tar', 'gen_R3_T360.pth.tar', 'gen_R3_T365.pth.tar', 'gen_R3_T370.pth.tar', 'gen_R3_T375.pth.tar', 'gen_R3_T380.pth.tar', 'gen_R3_T385.pth.tar', 'gen_R3_T390.pth.tar', 'gen_R3_T395.pth.tar', 'gen_R3_T400.pth.tar', 'gen_R4_T505.pth.tar', 'gen_R4_T510.pth.tar', 'gen_R4_T515.pth.tar', 'gen_R4_T520.pth.tar', 'gen_R4_T525.pth.tar', 'gen_R4_T530.pth.tar', 'gen_R4_T535.pth.tar', 'gen_R4_T540.pth.tar', 'gen_R4_T545.pth.tar', 'gen_R4_T550.pth.tar', 'gen_R4_T555.pth.tar', 'gen_R4_T560.pth.tar', 'gen_R4_T565.pth.tar', 'gen_R4_T570.pth.tar', 'gen_R4_T575.pth.tar', 'gen_R4_T580.pth.tar', 'gen_R4_T585.pth.tar', 'gen_R4_T590.pth.tar', 'gen_R4_T595.pth.tar', 'gen_R4_T600.pth.tar', 'gen_R5_T705.pth.tar', 'gen_R5_T710.pth.tar', 'gen_R5_T715.pth.tar', 'gen_R5_T7

  if initializer == 'kaiming':    kaiming_normal(self.conv.weight, a=calculate_gain('conv2d'))


growing network[4x4 to 8x8]. It may take few seconds...
flushing network... It may take few seconds...




growing network[8x8 to 16x16]. It may take few seconds...
flushing network... It may take few seconds...
growing network[16x16 to 32x32]. It may take few seconds...
flushing network... It may take few seconds...
growing network[32x32 to 64x64]. It may take few seconds...
flushing network... It may take few seconds...
growing network[64x64 to 128x128]. It may take few seconds...
flushing network... It may take few seconds...
growing network[128x128 to 256x256]. It may take few seconds...
flushing network... It may take few seconds...


In [18]:
print(model_root)
checkpoint_path = os.path.join(model_root, 'gen_R6_T950.pth.tar')
n_intp = 20

np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

test_model = net.Generator(config).to(device)

for resl in range(3, 6+1):
    test_model.grow_network(resl)
    test_model.flush_network()
test_model.to(device)

print('load checkpoint form ... {}'.format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
test_model.load_state_dict(checkpoint['state_dict'])

save_root = os.path.join(root, 'morph_images')
os.makedirs(save_root, exist_ok=True)

for k in range(10):
    # interpolate between twe noise(z1, z2).
    z1 = torch.FloatTensor(1, config.nz).normal_(0.0, 1.0).to(device)
    z2 = torch.FloatTensor(1, config.nz).normal_(0.0, 1.0).to(device)
    z_intp = torch.FloatTensor(n_intp, config.nz).to(device)

    for i in range(n_intp):
        alpha = 1.0/float(n_intp+1)
        z_intp[i] = z1.mul_(alpha) + z2.mul_(1.0-alpha)
    fake_im = test_model(z_intp)
    save_tile_images(fake_im, os.path.join(save_root, str(k)+'.png'), 1)

../storage/PGGAN/results/gucci/2020_05_02_20_55_49/models


  if initializer == 'kaiming':    kaiming_normal(self.conv.weight, a=calculate_gain('conv2d'))


growing network[4x4 to 8x8]. It may take few seconds...
flushing network... It may take few seconds...
growing network[8x8 to 16x16]. It may take few seconds...
flushing network... It may take few seconds...
growing network[16x16 to 32x32]. It may take few seconds...
flushing network... It may take few seconds...
growing network[32x32 to 64x64]. It may take few seconds...
flushing network... It may take few seconds...
load checkpoint form ... ../storage/PGGAN/results/gucci/2020_05_02_20_55_49/models/gen_R6_T950.pth.tar


