In [1]:
#train.py
from tqdm import trange
import torch

from torch.utils.data import DataLoader

from logger import Logger
from modules.model_m import GeneratorFullModel, DiscriminatorFullModel

from torch.optim.lr_scheduler import MultiStepLR

from sync_batchnorm import DataParallelWithCallback

from frames_dataset import DatasetRepeater

#run.py 
import matplotlib

matplotlib.use('Agg')

import os, sys
import yaml
from argparse import ArgumentParser
from time import gmtime, strftime
from shutil import copy

from frames_dataset import FramesDataset

from modules.generator import OcclusionAwareGenerator
from modules.discriminator import MultiScaleDiscriminator
from modules.keypoint_detector import KPDetector

import torch

#from train import train
from reconstruction import reconstruction
from animate import animate

In [2]:
import numpy as np
from modules.graphattn_a1 import MultiHeadAttention
from visualization_a import *
from demo import make_animation
from visualization_kp_a import *

In [3]:
import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte

# Build Model

In [4]:
def load_model(opt2, generator, kp_detector, discriminator, 
               model_no=1):
    checkpoint_path = opt2.checkpoints[model_no]

    #checkpoint.seed(0)
    if opt2.cpu:
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(checkpoint_path)

    if not "dense_motion_network.kp_variance" in checkpoint['generator']:
        checkpoint['generator']["dense_motion_network.kp_variance"] = torch.tensor([0.01])
    else:
        print("kp_variance: ", checkpoint['generator']["dense_motion_network.kp_variance"])
        
    generator.load_state_dict(checkpoint['generator'])#,strict=False
    kp_detector.load_state_dict(checkpoint['kp_detector'])
    discriminator.load_state_dict(checkpoint['discriminator'])
    if 'kp_refiner' in checkpoint: 
        kp_refiner = MultiHeadAttention(2,2,1,alpha=0.3)
        kp_refiner.load_state_dict(checkpoint['kp_refiner'])
    else:
        kp_refiner = None
    return generator, kp_detector, discriminator, kp_refiner

In [5]:
def build_model(opt2, model_no=0):
    print("model_no: ", model_no)
    class opt:
        config = opt2.config
        mode = "train"
        log_dir = 'log'
        checkpoint = None
        device_ids = model_no # {0,1}

    with open(opt.config) as f:
        config = yaml.load(f)

    if opt.checkpoint is not None:
        log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
    else:
        log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
        log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())

    generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                        **config['model_params']['common_params'])
    discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],
                                            **config['model_params']['common_params'])
    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    
    img_source_folder = f"{config['dataset_params']['root_dir']}/test/"

    generator, kp_detector, discriminator, kp_refiner = \
    load_model(opt2, generator, kp_detector, discriminator,  \
               model_no=model_no)
    if kp_refiner is None:
        print("No kp_refiner")
    print("generator.dense_motion_network.kp_variance", generator.dense_motion_network.kp_variance)
    return generator, kp_detector, discriminator, kp_refiner, img_source_folder 
  

In [6]:
def inference(source_name, driving_name, opt2,
              generator, kp_detector, discriminator, kp_refiner, img_source_folder, model_no,
              gen_anim=False, gen_anim_kp=False, graph_with_weight=True, gen_anim_kp_gif=True):
    
        
    generator.eval()
    discriminator.eval()
    kp_detector.eval()
    if not kp_refiner is None: kp_refiner.eval()
        
    if opt2.cpu:
        generator.to('cpu')
        discriminator.to('cpu')
        kp_detector.to('cpu')
        if not kp_refiner is None: kp_refiner.to('cpu')
        
    source_video_pth = f"{img_source_folder}{source_name}.gif"
    source_img_pth = f"{img_save_folder}source_image_{source_name}.png"
    driving_video_pth = f"{img_source_folder}{driving_name}.gif"
    save_name = f"_sc{source_name}_dr{driving_name}_m{model_no}"
    result_video_save_pth = f"{img_save_folder}result{save_name}.gif"    

    get_one_frame_in_gif_file(source_video_pth, 
                              source_img_pth, 
                              i=0)    

    class opt3:
        source_image = source_img_pth
        driving_video = driving_video_pth
        result_video = result_video_save_pth

    source_image = imageio.imread(opt3.source_image)
    reader = imageio.get_reader(opt3.driving_video)
    fps = 15  #reader.get_meta_data()['fps']
    driving_video = []
    try:
        for im in reader:
            driving_video.append(im)
    except RuntimeError:
        pass
    reader.close()

    source_image = resize(source_image, (256, 256))[..., :3]
    driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
    
    print("make prediction...")
    predictions, masks, heatmaps, sparse_deformeds, kp_source_list, \
    kp_driving_list, kp_norm_list, adj_source, adj_driving, adj_weights_source, adj_weights_driving = \
        make_animation_split_kp(source_image, driving_video, generator, kp_detector, relative=opt2.relative, \
                                adapt_movement_scale=opt2.adapt_scale, cpu=opt2.cpu, stn=None, \
                                oval_heatmap=False, kp_refiner=kp_refiner)
    
        
    #plot result
    if gen_anim:
        #plot result 動畫
        print("save result animation...")
        imageio.mimsave(opt3.result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps)
        print(f"save [{opt3.result_video}]!")
        
    
    if gen_anim_kp:
        #plot heatmap all frame 動畫
        print("save result animation with kp...")
        draw_heatmap_all_frame_w_graph_on_fig(
            kp_detector, save_name, source_image, driving_video, predictions, masks, heatmaps, \
            sparse_deformeds, kp_source_list, kp_driving_list, \
            kp_norm_list, adj_source, adj_driving, adj_weights_source, adj_weights_driving, \
            save_pth=img_save_folder, fps=15, 
            graph_with_weight=graph_with_weight,
            gen_kp_gif=gen_anim_kp_gif,
        )
        
    return predictions, adj_weights_driving

In [7]:
class opt2:
    config = "config/mgif-256-m2.yaml"
    checkpoints = [
        '../../public/first-order-model/checkpoint/mgif-cpk.pth.tar',
        'demo/202006221000/mgif_cpk_newattngraph_300.tar'
    ]
    cpu=True
    relative = True
    adapt_scale = True
    find_best_frame = False

In [8]:
generator, kp_detector, discriminator, kp_refiner, img_source_folder  = build_model(opt2, model_no=0)
generator_m, kp_detector_m, discriminator_m, kp_refiner_m, _ = build_model(opt2, model_no=1)

model_no:  0


  # This is added back by InteractiveShellApp.init_path()


No kp_refiner
generator.dense_motion_network.kp_variance Parameter containing:
tensor(0.0100, requires_grad=True)
model_no:  1
kp_variance:  tensor([0.0116])
generator.dense_motion_network.kp_variance Parameter containing:
tensor(0.0116, requires_grad=True)


In [9]:
source_name = "00002"
driving_name = "00001"
!mkdir "demo/202006271600/"
img_save_folder = "demo/202006271600/"

mkdir: cannot create directory ‘demo/202006271600/’: File exists


In [None]:
import time
for i in range(0,25):
    st = time.time()
    source_name, driving_name = str(i+1).zfill(5), str(i).zfill(5)
    print("*"*64+ f"\nsource_name: {source_name}, driving_name:{driving_name}"  )
    
    predictions, adj_weights_driving = \
    inference(source_name, driving_name, opt2,
              generator, kp_detector, discriminator, kp_refiner, img_source_folder, 
              model_no=0, gen_anim=True, gen_anim_kp=True, graph_with_weight=False, gen_anim_kp_gif=True)
    
    #if i<9: continue
    predictions, adj_weights_driving = \
    inference(source_name, driving_name, opt2,
              generator_m, kp_detector_m, discriminator_m, kp_refiner_m, img_source_folder, 
              model_no=1, gen_anim=True, gen_anim_kp=True, graph_with_weight=True, gen_anim_kp_gif=True)
    
    print("time cost: {:.2f}".format(time.time()-st))

****************************************************************
source_name: 00001, driving_name:00000
Save file [demo/202006271600/source_image_00001.png]
make prediction...


