In [None]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import utils as vutils

import os
import random
import argparse
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

from models import Generator
from helper_function import denormalized


def load_params(model, new_param):
    for p, new_p in zip(model.parameters(), new_param):
        p.data.copy_(new_p)

def resize(img,size=256):
    return F.interpolate(img, size=size)

def batch_generate(zs, netG, batch=8):
    g_images = []
    with torch.no_grad():
        for i in range(len(zs)//batch):
            g_images.append( netG(zs[i*batch:(i+1)*batch]).cpu() )
        if len(zs)%batch>0:
            g_images.append( netG(zs[-(len(zs)%batch):]).cpu() )
    return torch.cat(g_images)

def batch_save(images, folder_name):
    if not os.path.exists(folder_name):
        os.mkdir(folder_name)
    for i, image in enumerate(images):
        vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)

In [None]:
# if __name__ == "__main__":
#     parser = argparse.ArgumentParser(
#         description='generate images'
#     )
#     parser.add_argument('--ckpt', type=str)
#     parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
#     parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
#     parser.add_argument('--start_iter', type=int, default=6)
#     parser.add_argument('--end_iter', type=int, default=10)

#     parser.add_argument('--dist', type=str, default='.')
#     parser.add_argument('--size', type=int, default=256)
#     parser.add_argument('--batch', default=16, type=int, help='batch size')
#     parser.add_argument('--n_sample', type=int, default=2000)
#     parser.add_argument('--big', action='store_true')
#     parser.add_argument('--im_size', type=int, default=1024)
#     parser.add_argument('--multiplier', type=int, default=10000, help='multiplier for model number')
#     parser.set_defaults(big=False)
#     args = parser.parse_args()
class Args:
    def __init__(self, ckpt=None, artifacts=".", cuda=0, start_iter=0, end_iter=10, dist='.', size= 512, batch=16, n_sample=2000, big=False, im_size= 512, multiplier=10000):
        self.ckpt = ckpt
        self.artifacts = artifacts
        self.cuda = cuda
        self.start_iter = start_iter
        self.end_iter = end_iter
        self.dist = dist
        self.size = size
        self.batch = batch
        self.n_sample = n_sample
        self.big = big
        self.im_size = im_size
        self.multiplier = multiplier

In [None]:
args = Args()
args.multiplier = 10
args.artifacts =  "train_results/result"
# args.artifacts =  "../train_results/test1"
args.n_sample = 10
args.batch = 1

In [None]:
noise_dim = 256
device = torch.device('cuda:%d'%(args.cuda))

net_ig = Generator( ngf=64, nz=noise_dim, nc=3, im_size=args.im_size)#, big=args.big )
net_ig.to(device)
print("success")

success


In [None]:
# epoch = 1000
# ckpt = f"{args.artifacts}/models/{epoch}.pth"
# checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
# # Remove prefix `module`.
# checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
# net_ig.load_state_dict(checkpoint['g'])
# #load_params(net_ig, checkpoint['g_ema'])

# #net_ig.eval()
# print('load checkpoint success, epoch %d'%epoch)

# net_ig.to(device)

# # del checkpoint

# dest = 'eval_%d'%(epoch)
# dest = os.path.join(args.artifacts,dest, 'img')
# os.makedirs(dest, exist_ok=True)

# with torch.no_grad():
#     dpi = 100
#     for i in tqdm(range(args.n_sample//args.batch)):
#         noise = torch.randn(args.batch, noise_dim).to(device)
#         g_imgs = net_ig(noise)[0]
#         # g_imgs = resize(g_imgs,args.im_size) # resize the image using given dimension
#         # print(g_imgs.shape)
#         fig = plt.figure(figsize=(512/dpi, 512/dpi), dpi=dpi)
#         ax = plt.Axes(fig, [0., 0., 1., 1.])
#         ax.set_axis_off()
#         fig.add_axes(ax)
#         img = np.transpose(denormalized(g_imgs).squeeze().cpu(),(1,2,0))
#         plt.imshow(img)
#         plt.savefig(f'{dest}/generated_{i}.jpg',pad_inches = 0)
#         # plt.show()
# #         # for j, g_img in enumerate( g_imgs ):
# #         #     vutils.save_image(g_img.add(1).ml(0.5),
# #         #         os.path.join(dist, '%d.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))

In [None]:
args.artifacts

'train_results/result'

# For generating image

In [None]:
args = Args()
args.multiplier = 10
args.artifacts =  "CR Scale RIS"
args.n_sample = 5000
args.batch = 1
noise_dim = 256
device = torch.device('cuda:%d'%(args.cuda))

net_ig = Generator( ngf=64, nz=noise_dim, nc=3, im_size=args.im_size)#, big=args.big )
net_ig.to(device)
print("success")

epoch = 60000
ckpt = f"{args.artifacts}/models/{epoch}.pth"
checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
# Remove prefix `module`.
checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
net_ig.load_state_dict(checkpoint['g'])
#load_params(net_ig, checkpoint['g_ema'])
print('load checkpoint success, epoch %d'%epoch)
net_ig.to(device)

# del checkpoint

dest = 'eval_%d'%(epoch)

dest = os.path.join(args.artifacts,dest, 'img')
os.makedirs(dest, exist_ok=True)

# net_ig.eval()
# with torch.no_grad():
#     dpi = 100
#     for i in tqdm(range(args.n_sample//args.batch)):
#         noise = torch.randn(args.batch, noise_dim).to(device)
#         g_imgs = net_ig(noise)[0]
#         # g_imgs = resize(g_imgs,args.im_size) # resize the image using given dimension
#         # print(g_imgs.shape)
#         fig = plt.figure(figsize=(512/dpi, 512/dpi), dpi=dpi)
#         ax = plt.Axes(fig, [0., 0., 1., 1.])
#         ax.set_axis_off()
#         fig.add_axes(ax)
#         img = np.transpose(denormalized(g_imgs).squeeze().cpu(),(1,2,0))
#         plt.savefig(f'{dest}/generated_{i}.jpg',pad_inches = 0)
#         plt.imshow(img)

net_ig.eval()
with torch.no_grad():
    dpi = 100
    for i in tqdm(range(args.n_sample // args.batch)):
        noise = torch.randn(args.batch, noise_dim).to(device)
        g_imgs = net_ig(noise)[0]

        fig = plt.figure(figsize=(512 / dpi, 512 / dpi), dpi=dpi)
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        fig.add_axes(ax)

        img = np.transpose(denormalized(g_imgs).squeeze().cpu().numpy(), (1, 2, 0))
        ax.imshow(img)

        plt.savefig(f'{dest}/generated_{i}.jpg', pad_inches=0, bbox_inches='tight')
        plt.close(fig)

In [None]:
args = Args()
args.multiplier = 10
args.artifacts =  "train_results/CR Scale RIS"
args.n_sample = 100
args.batch = 1
noise_dim = 256
device = torch.device('cuda:%d'%(args.cuda))

net_ig = Generator( ngf=64, nz=noise_dim, nc=3, im_size=args.im_size)#, big=args.big )
net_ig.to(device)
print("success")

epoch = 62300
ckpt = f"{args.artifacts}/models/{epoch}.pth"
checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
# Remove prefix `module`.
checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
net_ig.load_state_dict(checkpoint['g'])
#load_params(net_ig, checkpoint['g_ema'])
print('load checkpoint success, epoch %d'%epoch)
net_ig.to(device)

# del checkpoint

dest = 'eval_%d'%(epoch)

dest = os.path.join(args.artifacts,dest, 'img')
os.makedirs(dest, exist_ok=True)

# net_ig.eval()
# with torch.no_grad():
#     dpi = 100
#     for i in tqdm(range(args.n_sample//args.batch)):
#         noise = torch.randn(args.batch, noise_dim).to(device)
#         g_imgs = net_ig(noise)[0]
#         # g_imgs = resize(g_imgs,args.im_size) # resize the image using given dimension
#         # print(g_imgs.shape)
#         fig = plt.figure(figsize=(512/dpi, 512/dpi), dpi=dpi)
#         ax = plt.Axes(fig, [0., 0., 1., 1.])
#         ax.set_axis_off()
#         fig.add_axes(ax)
#         img = np.transpose(denormalized(g_imgs).squeeze().cpu(),(1,2,0))
#         plt.savefig(f'{dest}/generated_{i}.jpg',pad_inches = 0)
#         plt.imshow(img)

net_ig.eval()
with torch.no_grad():
    dpi = 100
    for i in tqdm(range(args.n_sample // args.batch)):
        noise = torch.randn(args.batch, noise_dim).to(device)
        g_imgs = net_ig(noise)[0]

        fig = plt.figure(figsize=(512 / dpi, 512 / dpi), dpi=dpi)
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        fig.add_axes(ax)

        img = np.transpose(denormalized(g_imgs).squeeze().cpu().numpy(), (1, 2, 0))
        ax.imshow(img)

        plt.savefig(f'{dest}/generated_{i}.jpg', pad_inches=0, bbox_inches='tight')
        plt.close(fig)

success
load checkpoint success, epoch 62300


100%|█████████████████████████████████████████| 100/100 [00:08<00:00, 11.72it/s]
