In [1]:
import sys
from pathlib import Path
import shutil
from tqdm import tqdm

import numpy as np
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from core.datasets import SampleFaceDataset, MergeDataset, UniformYawBatchSampler
from core.imagelib import DSSIM, gaussian_blur, style_loss, total_variation_loss
from core.models import DF, LIAE, LIAE_multi, UNetPatchDiscriminator
from core.models import contextual_loss as cl
from core.options import read_yaml, write_yaml, DictAsMember
from core.loglib import write_losses, write_images, save_weights, load_weights

In [2]:
project = "proj_temp"
src  = "jeou"
dst  = "ryu"

model = f"finetune-{src}-{dst}_newmodel2"
#saved_model = "df-wf-288_torch"
configs = "stage01.yaml"
finetune_start = False

gpu_idxs = [1,2,3]
device = f"cuda:{gpu_idxs[0]}"
parallel = len(gpu_idxs) > 1

In [3]:
src_path = Path(f"../data-source/{project}/{src}")
dst_path = Path(f"../data-target/{project}/{dst}")

model_path  = Path(f"../workspace/{project}/{model}")
#savedmodel_path = Path(f"../workspace/saved_models/{saved_model}")
if not model_path.is_dir():
    model_path.mkdir(parents=True, exist_ok=True)
#     shutil.copytree(savedmodel_path, model_path)
#     finetune_start = True

debug_path = model_path.joinpath("history")
debug_path.mkdir(parents=True, exist_ok=True)
backup_path = model_path.joinpath("autobackups")
backup_path.mkdir(parents=True, exist_ok=True)

config_path = Path(f"../workspace/saved_configs/{configs}")
config = read_yaml(config_path)
data_dict  = config.data_params
#model_dict = read_yaml(model_path.joinpath("model_opt.yaml"))
optim_dict = config.optim_params
train_dict = config.train_params
if parallel:
    train_dict.batch_size = train_dict.batch_size * len(gpu_idxs)

In [4]:
model_dict = DictAsMember({
    "resolution"  : 224, 
    "face_type"   : "whole_face",
    "model_type"  : "liae-u",
    "ae_dims"     : 256,
    "e_dims"      : 80,
    "d_dims"      : 80,
    "d_mask_dims" : 22,
    "likeness"    : True,
    "double_res"  : False
})

In [5]:
src_dataset = SampleFaceDataset(src_path, 
            model_dict.resolution, model_dict.face_type, 
            data_dict.random_warp, data_dict.random_flip)
dst_dataset = SampleFaceDataset(dst_path, 
            model_dict.resolution, model_dict.face_type, 
            data_dict.random_warp, data_dict.random_flip)
mrg_dataset = MergeDataset(src_dataset, dst_dataset)

if data_dict.uniform_yaw:
    src_yaw_dist = src_dataset.get_yaw_dist()
    src_sampler = UniformYawBatchSampler(src_yaw_dist, batch_size=train_dict.batch_size)
    train_loader = DataLoader(mrg_dataset, batch_sampler=src_sampler)
else:
    train_loader = DataLoader(mrg_dataset, batch_size=train_dict.batch_size)

100%|██████████| 3493/3493 [00:13<00:00, 262.41it/s]
100%|██████████| 1321/1321 [00:05<00:00, 258.40it/s]


In [6]:
# if model_dict.model_type.startswith('df'):
#     model = DF(model_dict.resolution, 
#                model_dict.ae_dims, model_dict.e_dims, 
#                model_dict.d_dims, model_dict.d_mask_dims, 
#                likeness=model_dict.likeness, double_res=model_dict.double_res).to(device)
# else:
#     model = LIAE(model_dict.resolution, 
#                model_dict.ae_dims, model_dict.e_dims, 
#                model_dict.d_dims, model_dict.d_mask_dims, 
#                likeness=model_dict.likeness, double_res=model_dict.double_res).to(device)
    
# model, log_history = load_weights(model_path, model, finetune_start=finetune_start)
# start_iters = log_history["current_iters"]

model = LIAE_multi(model_dict.resolution, 
               model_dict.ae_dims, model_dict.e_dims, 
               model_dict.d_dims, model_dict.d_mask_dims, 
               likeness=model_dict.likeness, double_res=model_dict.double_res).to(device)
start_iters = 0

In [7]:
mseloss = nn.MSELoss()
dssimloss = DSSIM(filter_size = int(model_dict.resolution // 11.6)).to(device)
clloss = cl.ContextualLoss(use_vgg=True, vgg_layer='relu5_4').to(device)
l1loss = nn.L1Loss().to(device)
#kldivloss = nn.KLDivLoss()

if optim_dict.use_gan:
    discriminator = UNetPatchDiscriminator(
        optim_dict.gan_patch_size, 3, optim_dict.gan_dims
    ).to(device)
    D_Loss = nn.BCEWithLogitsLoss()

if parallel:
    model = nn.DataParallel(model, device_ids=gpu_idxs)
    mseloss = nn.DataParallel(mseloss, device_ids=gpu_idxs)
    dssimloss = nn.DataParallel(dssimloss, device_ids=gpu_idxs)
    clloss = nn.DataParallel(clloss, device_ids=gpu_idxs)
    l1loss = nn.DataParallel(l1loss, device_ids=gpu_idxs)
    if optim_dict.use_gan:
        discriminator = nn.DataParallel(discriminator, device_ids=gpu_idxs)

In [None]:
model.decoder.load_state_dict(
    torch.load("../workspace/saved_models/liaeu-wf-224_torch/new_SAEHD_decoder.pkl")
)
model.decoder.requires_grad = False

In [8]:
model_opt = optim.Adam(model.parameters(), lr=optim_dict.learning_rate)

if optim_dict.lr_schedule:
    scheduler = optim.lr_scheduler.StepLR(model_opt, 600, gamma=0.1)
    
if optim_dict.use_gan:
    disc_opt = optim.Adam(discriminator.parameters(), lr=optim_dict.learning_rate)
    if optim_dict.lr_schedule:
        disc_scheduler = optim.lr_scheduler.StepLR(disc_opt, 600, gamma=0.1)

In [11]:
def customloss_step(src, dst, result, optim_dict):
    if optim_dict.masked_training:
        target_src_opt = src["target"] * src["blur_mask"]
        target_dst_opt = dst["target"] * dst["blur_mask"]
        prd_src_src_opt = result["prd_src_src"] * src["blur_mask"]
        prd_dst_dst_opt = result["prd_dst_dst"] * dst["blur_mask"]
        prd_src_dst_opt = result["prd_src_dst"] * dst["blur_mask"]
        prd_src_res_opt = result["prd_src_res"] * src["blur_mask"]
        prd_res_dst_opt = result["prd_res_dst"] * dst["blur_mask"]
    else:
        target_src_opt = src["target"]
        target_dst_opt = dst["target"]
        prd_src_src_opt = result["prd_src_src"]
        prd_dst_dst_opt = result["prd_dst_dst"]
        prd_src_dst_opt = result["prd_src_dst"] * dst["blur_mask"]
        prd_src_res_opt = result["prd_src_res"] * src["blur_mask"]
        prd_res_dst_opt = result["prd_res_dst"] * dst["blur_mask"]
        
    src_dssim_val1 = dssimloss(target_src_opt, prd_src_src_opt)
    src_mse_val    = mseloss(target_src_opt,  prd_src_src_opt)
    src_mask_val   = mseloss(src["mask"], result["prd_src_srcm"])
    src_lossval    = src_dssim_val1 + src_mse_val + src_mask_val

    dst_dssim_val1 = dssimloss(target_dst_opt, prd_dst_dst_opt)
    dst_mse_val    = mseloss(target_dst_opt,  prd_dst_dst_opt)
    dst_mask_val   = mseloss(dst["mask"], result["prd_dst_dstm"])
    dst_lossval    = dst_dssim_val1 + dst_mse_val + dst_mask_val          
    
    res_cl_val1    = clloss(target_dst_opt, prd_src_dst_opt)
    res_mask_val   = mseloss(dst["mask"], result["prd_src_dstm"])
    res_lossval    = res_cl_val1 + dst_mask_val

    halfline    = result["src_code"].shape[1] // 2
    src_id_code = result["src_code"][:, halfline:, :, :]
    prd_id_code = result["res_id_code"]
    l1_lossval  = l1loss(src_id_code, prd_id_code)
    
    tot_lossval = src_lossval + dst_lossval + res_lossval + l1_lossval

    losses = {
        "Total_Loss"    : tot_lossval.mean(),
        "Src_Loss"      : src_lossval.mean(), 
        "Dst_Loss"      : dst_lossval.mean(),
        "Res_loss"      : res_lossval.mean(),
        "L1_Loss"       : l1_lossval.mean()
    }
 
    images = {
        "Target SRC"      : target_src_opt, 
        "Target DST"      : target_dst_opt, 
        "Predict SRC"     : prd_src_src_opt, 
        "Predict DST"     : prd_dst_dst_opt, 
        "Predict SRC-DST" : result["prd_src_dst"], 
        "Predict SRC-RES" : result["prd_src_res"],
        "Predict RES-DST" : result["prd_res_dst"]
    }
    return tot_lossval, losses, images

In [None]:
writer = SummaryWriter(log_dir=str(model_path))
curr_iters = 0

pbar = tqdm(total=train_dict.target_iter)

while True:
    if data_dict.uniform_yaw:
        src_yaw_dist = src_dataset.get_yaw_dist()
        src_sampler = UniformYawBatchSampler(src_yaw_dist, batch_size=train_dict.batch_size)
        train_loader = DataLoader(mrg_dataset, batch_sampler=src_sampler)

    for iter, data in enumerate(train_loader):
        curr_iters += train_dict.batch_size
        #log_history["current_iter"] = start_iters+curr_iters
        
        for key, item in data.items():
            for key2, item2 in item.items():
                if key2 != "filename":
                    data[key][key2] = item2.to(device)

        model_opt.zero_grad()
        if optim_dict.use_gan:
            disc_opt.zero_grad()

        result = model(data["src"]["warped"], data["dst"]["warped"])
        tot_lossval, losses, images = customloss_step(data["src"], data["dst"], result, optim_dict)
        tot_lossval.sum().backward()
        model_opt.step()
        if optim_dict.lr_schedule:
            scheduler.step()
            
        if optim_dict.use_gan:
            losses["Disc_Loss"].backward()
            disc_opt.step()
            if optim_dict.lr_schedule:
                disc_scheduler.step()

        pbar.update(train_dict.batch_size)
        pbar.set_postfix(src=f"{losses['Src_Loss']:.5f}", dst=f"{losses['Dst_Loss']:.5f}")

        if curr_iters % train_dict.debug_iter == 0:      
            write_losses(writer, losses, start_iters+curr_iters)

            #log_history["src_loss_history"][curr_iters] = losses['Src_Loss'].item()
            #log_history["dst_loss_history"][curr_iters] = losses['Dst_Loss'].item()

        if curr_iters % train_dict.preview_iter == 0:
            write_images(writer, images, debug_path, start_iters+curr_iters)

        if curr_iters % train_dict.save_iter == 0:
            save_weights(model_path, model.module if parallel else model, 
                         config=config)

        if curr_iters % train_dict.backup_iter == 0:
            save_weights(backup_path.joinpath(f"{curr_iters:05d}"), model.module if parallel else model, 
                         config=config)
            
        if curr_iters >= train_dict.target_iter:
            save_weights(model_path, model.module if parallel else model, 
                         config=config)
            save_weights(backup_path.joinpath(f"{curr_iters:05d}"), model.module if parallel else model, 
                         config=config)
            sys.exit(0)


pbar.close()


  0%|          | 0/12000 [00:00<?, ?it/s][A
  0%|          | 12/12000 [00:05<1:27:05,  2.29it/s][A
  0%|          | 12/12000 [00:05<1:27:05,  2.29it/s, dst=0.35006, src=0.33851][A
  0%|          | 24/12000 [00:08<1:17:40,  2.57it/s, dst=0.35006, src=0.33851][A
  0%|          | 24/12000 [00:08<1:17:40,  2.57it/s, dst=0.39394, src=0.36464][A
  0%|          | 36/12000 [00:11<1:10:30,  2.83it/s, dst=0.39394, src=0.36464][A
  0%|          | 36/12000 [00:11<1:10:30,  2.83it/s, dst=0.36501, src=0.36361][A
  0%|          | 48/12000 [00:15<1:05:51,  3.02it/s, dst=0.36501, src=0.36361][A
  0%|          | 48/12000 [00:15<1:05:51,  3.02it/s, dst=0.40570, src=0.41067][A
  0%|          | 60/12000 [00:18<1:02:03,  3.21it/s, dst=0.40570, src=0.41067][A
  0%|          | 60/12000 [00:18<1:02:03,  3.21it/s, dst=0.38611, src=0.35862][A
  1%|          | 72/12000 [00:21<59:48,  3.32it/s, dst=0.38611, src=0.35862]  [A
  1%|          | 72/12000 [00:21<59:48,  3.32it/s, dst=0.33664, src=0.34145][

  5%|▌         | 612/12000 [03:07<1:30:32,  2.10it/s, dst=0.17711, src=0.17807][A
  5%|▌         | 612/12000 [03:07<1:30:32,  2.10it/s, dst=0.18386, src=0.21301][A
  5%|▌         | 624/12000 [03:10<1:18:38,  2.41it/s, dst=0.18386, src=0.21301][A
  5%|▌         | 624/12000 [03:10<1:18:38,  2.41it/s, dst=0.16311, src=0.20382][A
  5%|▌         | 636/12000 [03:13<1:09:57,  2.71it/s, dst=0.16311, src=0.20382][A
  5%|▌         | 636/12000 [03:13<1:09:57,  2.71it/s, dst=0.17300, src=0.23486][A
  5%|▌         | 648/12000 [03:16<1:03:28,  2.98it/s, dst=0.17300, src=0.23486][A
  5%|▌         | 648/12000 [03:16<1:03:28,  2.98it/s, dst=0.22232, src=0.23223][A
  6%|▌         | 660/12000 [03:19<59:29,  3.18it/s, dst=0.22232, src=0.23223]  [A
  6%|▌         | 660/12000 [03:19<59:29,  3.18it/s, dst=0.19489, src=0.20607][A
  6%|▌         | 672/12000 [03:22<56:20,  3.35it/s, dst=0.19489, src=0.20607][A
  6%|▌         | 672/12000 [03:22<56:20,  3.35it/s, dst=0.18769, src=0.25040][A
  6%|▌    

 10%|█         | 1212/12000 [06:10<1:28:16,  2.04it/s, dst=0.15942, src=0.16304][A
 10%|█         | 1212/12000 [06:10<1:28:16,  2.04it/s, dst=0.16479, src=0.15396][A
 10%|█         | 1224/12000 [06:13<1:15:55,  2.37it/s, dst=0.16479, src=0.15396][A
 10%|█         | 1224/12000 [06:13<1:15:55,  2.37it/s, dst=0.12670, src=0.14421][A
 10%|█         | 1236/12000 [06:16<1:07:27,  2.66it/s, dst=0.12670, src=0.14421][A
 10%|█         | 1236/12000 [06:16<1:07:27,  2.66it/s, dst=0.12098, src=0.13848][A
 10%|█         | 1248/12000 [06:20<1:01:08,  2.93it/s, dst=0.12098, src=0.13848][A
 10%|█         | 1248/12000 [06:20<1:01:08,  2.93it/s, dst=0.13651, src=0.15987][A
 10%|█         | 1260/12000 [06:23<56:28,  3.17it/s, dst=0.13651, src=0.15987]  [A
 10%|█         | 1260/12000 [06:23<56:28,  3.17it/s, dst=0.12825, src=0.15083][A
 11%|█         | 1272/12000 [06:26<53:36,  3.34it/s, dst=0.12825, src=0.15083][A
 11%|█         | 1272/12000 [06:26<53:36,  3.34it/s, dst=0.10481, src=0.17296][

In [None]:
def ganloss_step(src, result, optim_dict):
    if optim_dict.masked_training:
        target_src_opt       = src["target"] * src["blur_mask"]
        prd_src_src_opt      = result["prd_src_src"] * src["blur_mask"]
        target_src_anti_opt  = src["target"] * (1.0 - src["blur_mask"])
        prd_src_src_anti_opt = result["prd_src_src"] * (1.0 - src["blur_mask"])
    else:
        target_src_opt  = src["target"]
        prd_src_src_opt = result["prd_src_src"]
        
    prd_src_src_d1, prd_src_src_d2 = discriminator(prd_src_src_opt.detach())
    tgt_src_d1, tgt_src_d2 = discriminator(target_src_opt)
    
    prd_src_src_d1_ones  = torch.ones_like(prd_src_src_d1)
    prd_src_src_d1_zeros = torch.zeros_like(prd_src_src_d1)
    prd_src_src_d2_ones  = torch.ones_like(prd_src_src_d2)
    prd_src_src_d2_zeros = torch.zeros_like(prd_src_src_d2)

    tgt_src_d1_ones  = torch.ones_like(tgt_src_d1)
    tgt_src_d2_ones  = torch.ones_like(tgt_src_d2)

    disc_loss_1 = D_Loss(tgt_src_d1_ones, tgt_src_d1) + D_Loss(prd_src_src_d1_zeros, prd_src_src_d1)
    disc_loss_2 = D_Loss(tgt_src_d2_ones, tgt_src_d2) + D_Loss(prd_src_src_d2_zeros, prd_src_src_d2)
    disc_loss = (disc_loss_1 + disc_loss_2) * 0.5

    prd_src_src_d1, prd_src_src_d2 = discriminator(prd_src_src_opt)

    g_loss_1 = D_Loss(prd_src_src_d1_ones, prd_src_src_d1)
    g_loss_2 = D_Loss(prd_src_src_d2_ones, prd_src_src_d2)
    g_loss = (g_loss_1 + g_loss_2)
    
    if optim_dict.masked_training:
        g_loss += total_variation_loss(result["prd_src_src"])
        g_loss += 0.02 * torch.square(prd_src_src_anti_opt-target_src_anti_opt).mean(axis=[0,1,2,3])
    
    return disc_loss, g_loss

def loss_step(src, dst, result, optim_dict, radius=model_dict.resolution):
    if optim_dict.masked_training:
        target_src_opt = src["target"] * src["blur_mask"]
        target_dst_opt = dst["target"] * dst["blur_mask"]
        prd_src_src_opt = result["prd_src_src"] * src["blur_mask"]
        prd_dst_dst_opt = result["prd_dst_dst"] * dst["blur_mask"]
    else:
        target_src_opt = src["target"]
        target_dst_opt = dst["target"]
        prd_src_src_opt = result["prd_src_src"]
        prd_dst_dst_opt = result["prd_dst_dst"]
        
    src_dssim_val1 = dssimloss(target_src_opt, prd_src_src_opt)
    src_mse_val    = mseloss(target_src_opt,  prd_src_src_opt)
    src_mask_val   = mseloss(src["mask"], result["prd_src_srcm"])
    src_lossval    = src_dssim_val1 + src_mse_val + src_mask_val

    dst_dssim_val1 = dssimloss(target_dst_opt, prd_dst_dst_opt)
    dst_mse_val    = mseloss(target_dst_opt,  prd_dst_dst_opt)
    dst_mask_val   = mseloss(dst["mask"], result["prd_dst_dstm"])
    dst_lossval    = dst_dssim_val1 + dst_mse_val + dst_mask_val   
    
    if optim_dict.eyes_mouth_prio:
        target_src_eye_opt = src["target"] * src["eyemask"]
        target_dst_eye_opt = dst["target"] * dst["eyemask"]
        prd_src_src_eye_opt = result["prd_src_src"] * src["eyemask"]
        prd_dst_dst_eye_opt = result["prd_dst_dst"] * dst["eyemask"]
        
        src_eye_val = F.l1_loss(prd_src_src_eye_opt, target_src_eye_opt) * 30
        dst_eye_val = F.l1_loss(prd_dst_dst_eye_opt, target_dst_eye_opt) * 30
        src_lossval += src_eye_val
        dst_lossval += dst_eye_val
        
    tot_lossval    = src_lossval + dst_lossval

    if optim_dict.true_style_pow > 0.0:
        dst_styleblur_mask   = gaussian_blur(dst["mask"], (radius // 32), use_cpu=False)
        target_dst_styleopt  = dst["target"] * dst_styleblur_mask
        prd_dst_dst_styleopt = result["prd_dst_dst"] * dst_styleblur_mask
        tot_stylelossval = style_loss(
            prd_dst_dst_styleopt, target_dst_styleopt, (radius // 16)
        ).sum()
        tot_stylelossval = optim_dict.true_style_pow * tot_stylelossval
        
        tot_lossval += tot_stylelossval
        
    if optim_dict.use_gan:
        disc_loss, g_loss = ganloss_step(src, result, optim_dict)
        disc_loss = optim_dict.gan_pow * disc_loss
        g_loss = optim_dict.gan_pow * g_loss
        tot_lossval += g_loss

    losses = {
        "Total_Loss"    : tot_lossval.mean(),
        "Src_Loss"      : src_lossval.mean(), 
        "Dst_Loss"      : dst_lossval.mean()
    }
    if optim_dict.eyes_mouth_prio:
        losses["Src_Loss_Eyes"] = src_eye_val.mean()
        losses["Dst_Loss_Eyes"] = dst_eye_val.mean()
    if optim_dict.true_style_pow > 0.0:
        losses["Style_Loss"] = tot_stylelossval.mean()
    if optim_dict.use_gan:
        losses["GAN_Loss"] = g_loss
        losses["Disc_Loss"] = disc_loss
    
    images = {
        "Target SRC"      : target_src_opt, 
        "Target DST"      : target_dst_opt, 
        "Predict SRC"     : prd_src_src_opt, 
        "Predict DST"     : prd_dst_dst_opt, 
        "Predict SRC-DST" : result["prd_src_dst"]
    }
    return tot_lossval, losses, images