In [20]:
import argparse
import random
import math

from tqdm import tqdm
import numpy as np
from PIL import Image

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable, grad
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

In [9]:
#!pip install import_ipynb

Collecting import_ipynb
  Downloading import-ipynb-0.1.3.tar.gz (4.0 kB)
Building wheels for collected packages: import-ipynb
  Building wheel for import-ipynb (setup.py): started
  Building wheel for import-ipynb (setup.py): finished with status 'done'
  Created wheel for import-ipynb: filename=import_ipynb-0.1.3-py3-none-any.whl size=2975 sha256=07e86dd1ace146484249e4ab1ab7e6a868b8502f01e1b455513a90a184079edc
  Stored in directory: c:\users\shane\appdata\local\pip\cache\wheels\06\7e\ad\1cb03e935234186825cefc7e2c8f3451b4f654b5bc72232a7b
Successfully built import-ipynb
Installing collected packages: import-ipynb
Successfully installed import-ipynb-0.1.3


In [42]:
from Dataset import *
from Model2 import StyledGenerator, Discriminator

In [43]:
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag


def accumulate(model1, model2, decay=0.999):
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())

    for k in par1.keys():
        par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)

In [44]:
def sample_data(dataset, batch_size, image_size=4):
    '''
    dataset.resolution = image_size
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=1, drop_last=True)
    
    '''
    transform = transforms.Compose([
            transforms.Resize(image_size),          # Resize to the same size
            transforms.CenterCrop(image_size),      # Crop to get square area
            transforms.RandomHorizontalFlip(),      # Increase number of samples
            transforms.ToTensor(),            
            transforms.Normalize((0.5, 0.5, 0.5),
                                 (0.5, 0.5, 0.5))])

    dataset.transform = transform
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=1,drop_last=True)
    return loader

In [45]:
def reset_lr(optimizer, lr):
    for group in optimizer.param_groups:
        mult = group.get('mult', 1)
        group['lr'] = lr * mult

In [34]:
'''
def imshow(tensor, i):
    grid = tensor[0]
    grid.clamp_(-1, 1).add_(1).div_(2)
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
    ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    img = Image.fromarray(ndarr)
    img.save(f'{save_folder_path}sample-iter{i}.png')
    plt.imshow(img)
    plt.show()
    
'''

"\ndef imshow(tensor, i):\n    grid = tensor[0]\n    grid.clamp_(-1, 1).add_(1).div_(2)\n    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer\n    ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()\n    img = Image.fromarray(ndarr)\n    img.save(f'{save_folder_path}sample-iter{i}.png')\n    plt.imshow(img)\n    plt.show()\n    \n"

In [39]:
"""
Settings:

init size:default=8
step:Train step 1,2 3,....
resolution:resolution of generated pictures

"""
init_size=8
step = int(math.log2(init_size)) - 2
resolution = 4 * 2 ** step
batch_default=32
n_gpu             = 1
device            = torch.device('cuda:0')

learning_rate     = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
batch_size_1gpu   = {4: 128, 8: 128, 16: 64, 32: 32, 64: 16, 128: 16}
mini_batch_size_1 = 8
batch_size        = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
mini_batch_size   = 8
batch_size_4gpus  = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}
mini_batch_size_4 = 16
batch_size_8gpus  = {4: 512, 8: 256, 16: 128, 32: 64}
mini_batch_size_8 = 32
n_fc              = 8
dim_latent        = 512
dim_input         = 4
n_sample          = 120000
DGR               = 1
n_show_loss       = 500



max_size          =1024 #Max image size 
max_step          = int(math.log2(args.max_size)) - 2 # Maximum step (8 for 1024^2)

style_mixing      = [] # Waiting to implement
image_folder_path = '/content/drive/MyDrive/Dataset/keras_png_slices_data'
save_folder_path  = '/content/drive/MyDrive/SG_siskon/results/'

low_steps         = [0, 1, 2]
# style_mixing    += low_steps
mid_steps         = [3, 4, 5]
# style_mixing    += mid_steps
hig_steps         = [6, 7, 8]
# style_mixing    += hig_steps

# Used to continue training from last checkpoint
startpoint        = 0
used_sample       = 0
alpha             = 0

# Mode: Evaluate? Train?
is_train          = True

# How to start training?
# True for start from saved model
# False for retrain from the very beginning
is_continue       = True
d_losses          = [float('inf')]
g_losses          = [float('inf')]
inputs, outputs = [], []


In [36]:
'''
Train function

'''


def train(generator, discriminator, g_optim, d_optim, dataset, step, startpoint=0, used_sample=0,
         d_losses = [], g_losses = [], alpha=0):
    
    loader=sample_data(dataset, batch_size.get(resolution, mini_batch_size), resolution)
    resolution  = 4 * 2 ** step
    
    data_loader = iter(loader)
    
    reset_lr(g_optim, learning_rate.get(resolution, 0.001))
    reset_lr(d_optim, learning_rate.get(resolution, 0.001))
    
    progress_bar = tqdm(range(1000000)
    
    
    requires_grad(generator, False)
    requires_grad(discriminator, True)

    disc_loss_val = 0
    gen_loss_val = 0
    grad_loss_val = 0
                        

    alpha = 0
    used_sample = 0

    #max_step = int(math.log2(args.max_size)) - 2
    final_progress = False
    
    
    
    # Train
                        
    for i in progress_bar:
            discriminator.zero_grad()
            #alpha = min(1, alpha + batch_size.get(resolution, mini_batch_size) / (n_sample * 2))
            

            alpha = min(1, 1 / args.phase * (used_sample + 1))

            if (resolution == args.init_size and args.ckpt is None) or final_progress:
                alpha = 1

            if used_sample > args.phase * 2:
                used_sample = 0
                step += 1

                if step > max_step:
                    step = max_step
                    final_progress = True
                    ckpt_step = step + 1

                else:
                    alpha = 0
                    ckpt_step = step

                resolution = 4 * 2 ** step

                loader = sample_data(
                    dataset, args.batch.get(resolution, args.batch_default), resolution
                )
                data_loader = iter(loader)

                torch.save(
                    {
                        'generator': generator.module.state_dict(),
                        'discriminator': discriminator.module.state_dict(),
                        'g_optimizer': g_optimizer.state_dict(),
                        'd_optimizer': d_optimizer.state_dict(),
                        'g_running': g_running.state_dict(),
                    },
                    f'checkpoint/train_step-{ckpt_step}.model',
                )

    
