In [1]:
import torch
import numpy as np
from lib.networks import anime_full_encoder, anime_eye_encoder
from lib.data import Data
from lib.train_history import train_history 
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
import os

In [2]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True
device

device(type='cuda', index=1)

In [3]:
lr_full = 0.0001
lr_eye = 0.0002
beta1=0.5
beta2=0.999

In [4]:
model_folder = 'model'

In [5]:
full_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

eye_transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

full_path = "../Datasets/anime/anime_face_large_dif_pos"
leye_path = "left_eye"
reye_path = "right_eye"
train_loader_leye = Data(leye_path, eye_transform)
train_loader_reye = Data(reye_path, eye_transform)
train_loader_full = Data(full_path, full_transform, shuffle=True)


In [6]:
# leye_image = Image.open("left_eye/42001-1-0.png")

# leye_image = transform(leye_image)
# leye_image = leye_image.unsqueeze(0)
# leye_image.size()

In [7]:
L1_loss = nn.L1Loss()

In [8]:
model_full = anime_full_encoder(3).to(device)
model_leye = anime_eye_encoder(3).to(device)
model_reye = anime_eye_encoder(3).to(device)
all_models = {
    'model_full.pth':model_full,
    'model_leye.pth':model_leye,
    'model_reye.pth':model_reye
}

In [9]:
optimizer_full = optim.Adam(model_full.parameters(), lr=lr_full, betas=(beta1, beta2))
optimizer_leye = optim.Adam(model_leye.parameters(), lr=lr_eye, betas=(beta1, beta2))
optimizer_reye = optim.Adam(model_reye.parameters(), lr=lr_eye, betas=(beta1, beta2))

In [11]:
train_hist = train_history(['full_img_recon_loss',
                          'leye_img_recon_loss',
                          'reye_img_recon_loss',
                          'leye_latent_loss',
                          'reye_latent_loss'                                          
                          ])

In [12]:
def save_models(models, folder):
    for k, v in models.items():
        torch.save(v.state_dict(), os.path.join(folder, k))

In [13]:
def load_models(models, folder):
    for k, v in models.items():
        v.load_state_dict(torch.load(os.path.join(folder, k)))

In [14]:
count=0

In [None]:
# epoch_count=1
# niter = 5
# niter_decay = 100
# def lambda_rule(epoch):
#     lr_l = 1.0 - max(0, epoch + epoch_count - niter) / float(niter_decay + 1)
#     return lr_l
# schedulers = [lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) for optimizer in [optimizer_full,optimizer_leye,optimizer_reye]]

In [None]:
for i in range(100000):
    leye_img,leye_img_name = train_loader_leye.next()
    full_img = train_loader_full.get(leye_img_name)
    reye_img = train_loader_reye.get(leye_img_name)

#     full_img,full_img_name = train_loader_full.next()
#     leye_img = train_loader_leye.get(full_img_name)
#     reye_img = train_loader_reye.get(full_img_name) 
    if type(full_img) == type(None) or type(reye_img) == type(None):
        continue

    leye_img = leye_img.to(device)
    reye_img = reye_img.to(device)
    full_img = full_img.to(device)

    optimizer_full.zero_grad()
    optimizer_leye.zero_grad()
    optimizer_reye.zero_grad()   

    full_recon, full_result_l, full_result_r = model_full(full_img)
    leye_recon, lresult = model_leye(leye_img)
    reye_recon, rresult = model_reye(reye_img)

    full_img_recon_loss = L1_loss(full_recon,full_img)
    leye_img_recon_loss = L1_loss(leye_recon,leye_img)
    reye_img_recon_loss = L1_loss(reye_recon,reye_img)
    recon_loss = full_img_recon_loss+leye_img_recon_loss+reye_img_recon_loss

    leye_latent_loss = L1_loss(full_result_l,lresult)
    reye_latent_loss = L1_loss(full_result_r,rresult)
    latent_loss = leye_latent_loss+reye_latent_loss

    loss = recon_loss+latent_loss
    loss.backward()

    optimizer_full.step()
    optimizer_leye.step()
    optimizer_reye.step()

    train_hist.add_params([full_img_recon_loss,leye_img_recon_loss,reye_img_recon_loss,
                           leye_latent_loss,reye_latent_loss])

#     except:
#         continue
    
    if count%100==0:
        losses = train_hist.check_current_avg()
        print(losses)
#         print('recon loss = %.5f, latent_loss = %.5f'%(recon_loss,latent_loss))
    if count%1000==0:
        save_models(all_models,model_folder)
    if count%10000==0:
        new_dir = os.path.join(model_folder,str(count))
        if not os.path.isdir(new_dir):
            os.makedirs(new_dir)
        save_models(all_models,new_dir)
    count+=1
    

{'full_img_recon_loss': tensor(0.6785), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.7456), 'reye_latent_loss': tensor(0.6819)}
{'full_img_recon_loss': tensor(0.5327), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.5331), 'reye_latent_loss': tensor(0.5411)}
{'full_img_recon_loss': tensor(0.4384), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.4164), 'reye_latent_loss': tensor(0.4431)}
{'full_img_recon_loss': tensor(0.4197), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.3419), 'reye_latent_loss': tensor(0.3619)}
{'full_img_recon_loss': tensor(0.4127), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.2813), 'reye_latent_loss': tensor(0.2969)}
{'full_img_recon_loss': tensor(0.4041), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.2267), 'reye_latent_loss': tensor(0.2349)}
{'full_img_recon_loss': tensor(0.3843), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.2107), 

{'full_img_recon_loss': tensor(0.3146), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.1193), 'reye_latent_loss': tensor(0.1201)}
{'full_img_recon_loss': tensor(0.3128), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.1203), 'reye_latent_loss': tensor(0.1196)}
{'full_img_recon_loss': tensor(0.2978), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.1155), 'reye_latent_loss': tensor(0.1206)}
{'full_img_recon_loss': tensor(0.3215), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.1216), 'reye_latent_loss': tensor(0.1225)}
{'full_img_recon_loss': tensor(0.2877), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.1206), 'reye_latent_loss': tensor(0.1173)}
{'full_img_recon_loss': tensor(0.3032), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.1143), 'reye_latent_loss': tensor(0.1235)}
{'full_img_recon_loss': tensor(0.3120), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.1165), 

{'full_img_recon_loss': tensor(0.2632), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.1102), 'reye_latent_loss': tensor(0.1134)}
{'full_img_recon_loss': tensor(0.2638), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.1129), 'reye_latent_loss': tensor(0.1154)}
{'full_img_recon_loss': tensor(0.2621), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.1111), 'reye_latent_loss': tensor(0.1093)}
{'full_img_recon_loss': tensor(0.2634), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.1112), 'reye_latent_loss': tensor(0.1119)}
{'full_img_recon_loss': tensor(0.2598), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.1102), 'reye_latent_loss': tensor(0.1081)}
{'full_img_recon_loss': tensor(0.2612), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.1102), 'reye_latent_loss': tensor(0.1090)}
{'full_img_recon_loss': tensor(0.2648), 'leye_img_recon_loss': tensor(nan), 'leye_latent_loss': tensor(0.1097), 

In [None]:
# save_models(all_models,'model')