In [1]:
import os,time,pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms
from lib import networks,train_history,util,visualizer
import itertools
from torchvision import datasets

In [2]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True
device

device(type='cuda', index=2)

In [3]:
display_id=1
display_winsize=256
display_ncols=5
display_server='http://localhost'
display_port=8097
display_env='main'
name = 'deepfakeanime1'
checkpoints_dir='checkpoints'

vis = visualizer.Visualizer(display_id,display_winsize,display_ncols,display_server,display_port,display_env,
                 name,checkpoints_dir)

create web directory checkpoints/persona_3/web...


In [4]:
load_size = 128

batch_size = 1

lr = 0.0002
beta1 = 0.5
beta2 = 0.999

train_epoch = 100

display_freq = 400
print_freq = 400
save_latest_freq = 400
update_html_freq = 800

In [5]:
auto_encoder = networks.AutoEncoder(False)
auto_encoder = auto_encoder.to(device)

print('---------- Networks initialized -------------')
num_params = 0
for param in auto_encoder.parameters():
    num_params += param.numel()
print(str.format('model has {} number of parameters', num_params))
print('-----------------------------------------------')

initialize network with normal
initialize network with normal
initialize network with normal
initialize network with normal
---------- Networks initialized -------------
G has 45620867 number of parameters
D has 2764737 number of parameters
-----------------------------------------------


In [6]:
# src_data_path='data/real'
# tgt_data_path = "data/manga"
src_data_path='/data/persona_cyclegan/real'
tgt_data_path =  '/data/persona_cyclegan/anime'

In [7]:
# data_loader
transform = transforms.Compose([
        transforms.Resize((load_size, load_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.1,0.1,0.1,0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

train_loader_A = torch.utils.data.DataLoader(datasets.ImageFolder(src_data_path, transform), batch_size=batch_size, shuffle=True, drop_last=True)
train_loader_B = torch.utils.data.DataLoader(datasets.ImageFolder(tgt_data_path, transform), batch_size=batch_size, shuffle=True, drop_last=True)


In [11]:
# loss
L1_loss = nn.L1Loss().to(device)

In [9]:
optimizer = optim.Adam(auto_encoder.parameters(), lr=lr, betas=(beta1, beta2))

In [10]:
epoch_count=1
niter = 3
niter_decay = 100
def lambda_rule(epoch):
    lr_l = 1.0 - max(0, epoch + epoch_count - niter) / float(niter_decay + 1)
    return lr_l
schedulers = [lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) for optimizer in [G_optimizer,D_A_optimizer,D_B_optimizer]]

In [12]:
train_hist = train_history.train_history(['real_loss',
                                          'anime_loss'
                                          ])

In [13]:
starting_epoch=0

In [None]:
## print('training start!')
for epoch in range(train_epoch):
    epoch_start_time = time.time()

    for scheduler in schedulers:
        scheduler.step()
        print('learning rate = %.7f'%G_optimizer.param_groups[0]['lr'])
    data_size = min(len(train_loader_A), len(train_loader_B))
    
    for i,((real,_),(anime,_)) in enumerate(zip(train_loader_A, train_loader_B)):
        iter_start_time = time.time()
        # input image data
        real = real.to(device)
        anime = anime.to(device)

        # Train generator G
        optimizer.zero_grad()
        recon_real = auto_encoder(real, 'real')
        recon_anime = auto_encoder(anime, 'anime')
        loss_real = L1_loss(recon_real, real)
        loss_anime = L1_loss(recon_anime, anime)
        loss = loss_real + loss_anime
        loss.backward()
        optimizer.step()
        
        if i % display_freq == 0:
            with torch.no_grad():
                gen_anime = auto_encoder(real, 'anime')
            save_result = i % update_html_freq == 0
            vis.display_current_results([real,recon_real,anime,recon_anime,gen_anime],starting_epoch+epoch, i,save_result)

        train_hist.add_params([loss_real,loss_anime])

        if i % print_freq == 0:
            t_data = iter_start_time - epoch_start_time
            losses = train_hist.check_current_avg()
            t = (time.time() - iter_start_time) / batch_size
            vis.print_current_losses(starting_epoch+epoch, i, losses, t, t_data)
            if display_id > 0:
                vis.plot_current_losses(starting_epoch+epoch, float(i) / data_size, losses)
        
        if i % save_latest_freq == 0:
            torch.save(auto_encoder.state_dict(), os.path.join(checkpoints_dir,name, 'auto_encoder.pkl'))
            train_hist.save_train(os.path.join(checkpoints_dir,name,  'train_hist.pkl'))
            
    if (epoch+starting_epoch)%5 == 0:
        torch.save(auto_encoder.state_dict(), os.path.join(checkpoints_dir,name, str(epoch+starting_epoch)+'auto_encoder.pkl'))


learning rate = 0.0002000
learning rate = 0.0002000
learning rate = 0.0002000
(epoch: 0, iters: 0, time: 1.446, data: 0.039) G_gan_loss: 3.985 G_cycle_loss: 16.153 D_A_fake_loss: 2.067 D_A_real_loss: 2.005 D_B_fake_loss: 1.635 D_B_real_loss: 2.375 
(epoch: 0, iters: 400, time: 0.688, data: 87.727) G_gan_loss: 1.073 G_cycle_loss: 6.795 D_A_fake_loss: 0.488 D_A_real_loss: 0.420 D_B_fake_loss: 0.415 D_B_real_loss: 0.383 
(epoch: 0, iters: 800, time: 0.728, data: 176.547) G_gan_loss: 0.802 G_cycle_loss: 5.340 D_A_fake_loss: 0.251 D_A_real_loss: 0.272 D_B_fake_loss: 0.233 D_B_real_loss: 0.232 
(epoch: 0, iters: 1200, time: 0.612, data: 265.380) G_gan_loss: 0.799 G_cycle_loss: 5.067 D_A_fake_loss: 0.239 D_A_real_loss: 0.269 D_B_fake_loss: 0.244 D_B_real_loss: 0.232 
(epoch: 0, iters: 1600, time: 0.705, data: 353.813) G_gan_loss: 0.796 G_cycle_loss: 4.780 D_A_fake_loss: 0.238 D_A_real_loss: 0.263 D_B_fake_loss: 0.235 D_B_real_loss: 0.234 
(epoch: 0, iters: 2000, time: 0.666, data: 442.605) G_

(epoch: 0, iters: 18800, time: 0.658, data: 4166.151) G_gan_loss: 1.094 G_cycle_loss: 3.320 D_A_fake_loss: 0.179 D_A_real_loss: 0.198 D_B_fake_loss: 0.158 D_B_real_loss: 0.158 
(epoch: 0, iters: 19200, time: 0.756, data: 4254.520) G_gan_loss: 1.059 G_cycle_loss: 3.278 D_A_fake_loss: 0.177 D_A_real_loss: 0.198 D_B_fake_loss: 0.155 D_B_real_loss: 0.163 
(epoch: 0, iters: 19600, time: 0.648, data: 4342.943) G_gan_loss: 1.120 G_cycle_loss: 3.294 D_A_fake_loss: 0.186 D_A_real_loss: 0.199 D_B_fake_loss: 0.147 D_B_real_loss: 0.154 
(epoch: 0, iters: 20000, time: 0.729, data: 4431.548) G_gan_loss: 1.110 G_cycle_loss: 3.340 D_A_fake_loss: 0.185 D_A_real_loss: 0.197 D_B_fake_loss: 0.148 D_B_real_loss: 0.164 
(epoch: 0, iters: 20400, time: 0.665, data: 4520.066) G_gan_loss: 1.127 G_cycle_loss: 3.367 D_A_fake_loss: 0.180 D_A_real_loss: 0.192 D_B_fake_loss: 0.141 D_B_real_loss: 0.154 
learning rate = 0.0002000
learning rate = 0.0002000
learning rate = 0.0002000
(epoch: 1, iters: 0, time: 0.439, dat