In [None]:
!pip install dgl-cu90 

In [1]:
#train.py
from tqdm import trange
import torch

from torch.utils.data import DataLoader

from logger import Logger
from modules.model_m import GeneratorFullModel, DiscriminatorFullModel

from torch.optim.lr_scheduler import MultiStepLR

from sync_batchnorm import DataParallelWithCallback

from frames_dataset import DatasetRepeater

In [2]:
#run.py 
import matplotlib

matplotlib.use('Agg')

import os, sys
import yaml
from argparse import ArgumentParser
from time import gmtime, strftime
from shutil import copy

from frames_dataset import FramesDataset

from modules.generator import OcclusionAwareGenerator
from modules.discriminator import MultiScaleDiscriminator
from modules.keypoint_detector import KPDetector

import torch

#from train import train
from reconstruction import reconstruction
from animate import animate

In [3]:
import numpy as np
from modules.stn_a import STN

# Build Model

由 run.py 改來

In [4]:
class opt:
    config = "config/mgif-256-m2.yaml"
    mode = "train"
    log_dir = 'log'
    checkpoint = None
    device_ids = "0" #"0,1,2,3"

In [5]:
fn = lambda x: list(map(int, x.split(',')))
opt.device_ids = fn(opt.device_ids)

In [6]:
with open(opt.config) as f:
    config = yaml.load(f)

if opt.checkpoint is not None:
    log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
else:
    log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
    log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())

  


In [7]:
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                    **config['model_params']['common_params'])
discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],
                                        **config['model_params']['common_params'])
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                         **config['model_params']['common_params'])

In [10]:
# 定義初始圖
"""
#-1:背景  ![](https://i.imgur.com/CnW6Uy8.gif)
0:後半  ![](https://i.imgur.com/XGRK0ul.gif)
1:外前腳 ![](https://i.imgur.com/7ngPynF.gif)
2:耳朵  ![](https://i.imgur.com/fZjnGGQ.gif)
3:頭上  ![](https://i.imgur.com/uGS8Eu0.gif)
4:內後腳  ![](https://i.imgur.com/FGIHvpb.gif)
5:鼻子  ![](https://i.imgur.com/NoDo30g.gif)
6:內側腳(?)  ![](https://i.imgur.com/mA0hRHP.gif)
7:外後腳  ![](https://i.imgur.com/AS3d08V.gif)
8:軀幹  ![](https://i.imgur.com/Yw3K9qb.gif)
9:尾巴 ![](https://i.imgur.com/DxFcCtj.gif)
"""
adjmatrix_directed = torch.zeros((10,10))+0.2 #[10, 10]
select_edge =  [(3,2), (5,2)] #頭部
select_edge += [(1,8), (6,8)] #前半
select_edge += [(4,0), (7,0), (9,0)] #後半
select_edge += [(8,2), (8,0)] #總體
select_edge += [ (e[1],e[0]) for e in select_edge] #undirect
select_edge += [ (i,i) for i in range(10)] #to self
for e in select_edge:
    adjmatrix_directed[e[0],e[1]] = 0.8

In [13]:
torch.sum(adjmatrix_directed>0.8)

tensor(0)

In [14]:
import torch
from modules.graphattn_a1 import MultiHeadAttention
kp_refiner = MultiHeadAttention(2,2,1)
kp_refiner.assign_mask_weight(adjmatrix_directed)

#try
"""
bs = 5
x = torch.rand((bs,10,2))
kp_refiner.assign_mask_weight(adjmatrix_directed)
y, a, _ = kp_refiner(x, mask_type="soft")
print( y.shape, a.shape )
print( kp_refiner.mask_weight )
"""

'\nbs = 5\nx = torch.rand((bs,10,2))\nkp_refiner.assign_mask_weight(adjmatrix_directed)\ny, a, _ = kp_refiner(x, mask_type="soft")\nprint( y.shape, a.shape )\nprint( kp_refiner.mask_weight )\n'

In [15]:
#modified # added 20200616
from modules.stn_a import STN
heatmap_stn = STN((1,64,64))


In [16]:
#kp_refiner=None
heatmap_stn=None
oval_heatmap=False if (heatmap_stn is None) else True

# Load Model

由 demo.py 改來

In [17]:
!mkdir "demo/202006271200/"
img_save_folder = "demo/202006271200/"

mkdir: cannot create directory ‘demo/202006271200/’: File exists


In [18]:
class opt2:
    config = opt.config
    checkpoint = '../../public/first-order-model/checkpoint/mgif-cpk.pth.tar'
    #checkpoint =  "demo/202006221000/mgif_cpk_newattngraph_300.tar"
    cpu = True
    relative = True
    adapt_scale = True
    find_best_frame = False

In [19]:
# load model
checkpoint_path = opt2.checkpoint

#checkpoint.seed(0)
if opt2.cpu:
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
else:
    checkpoint = torch.load(checkpoint_path)

generator.load_state_dict(checkpoint['generator'], strict=False)
kp_detector.load_state_dict(checkpoint['kp_detector'])
discriminator.load_state_dict(checkpoint['discriminator'])
if 'kp_refiner' in checkpoint: 
    kp_refiner.load_state_dict(checkpoint['kp_refiner'])

# Fix Orig Model

In [20]:
def freeze_model(model):
    model.eval()
    for params in model.parameters():
        params.requires_grad = False
        
def unfreeze_model(model):
    model.train()
    for params in model.parameters():
        params.requires_grad = True

In [21]:
#只訓練dense_motion_network
freeze_generator = True
freeze_discriminator = False
freeze_kp_detector = True
freeze_model(kp_detector)
freeze_model(generator)
unfreeze_model(generator.dense_motion_network)

# Train

由 run.py 改來

In [22]:
generator.train()
discriminator.train()
kp_detector.train()
if not heatmap_stn is None: heatmap_stn.train()
if not kp_refiner is None: kp_refiner.train()
if torch.cuda.is_available():
    generator.to(opt.device_ids[0])
    discriminator.to(opt.device_ids[0])
    kp_detector.to(opt.device_ids[0])
    if not heatmap_stn is None:heatmap_stn.to(opt.device_ids[0])
    if not kp_refiner  is None:kp_refiner.to(opt.device_ids[0])

In [23]:
!rm -rf ../../public/first-order-model/moving-gif/train/.ipynb_checkpoints
!rm -rf ../../public/first-order-model/moving-gif/test/.ipynb_checkpoints
dataset = FramesDataset(is_train=(opt.mode == 'train'), **config['dataset_params'])

Use predefined train-test split.


In [24]:
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
    copy(opt.config, log_dir)

由 train.py 改來

In [25]:
checkpoint = opt.checkpoint
device_ids = opt.device_ids

In [26]:
train_params = config['train_params']

optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999))
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator']*2, betas=(0.5, 0.999))
optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999))

In [27]:
start_epoch = 0

In [28]:
scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1,
                                  last_epoch=start_epoch - 1)
scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1,
                                      last_epoch=start_epoch - 1)
scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1,
                                    last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0))

In [29]:
# added 20200616
if not kp_refiner is None:
    optimizer_kp_refiner = torch.optim.Adam(kp_refiner.parameters(), lr=train_params['lr_kp_refiner'], betas=(0.5, 0.999))
    scheduler_kp_refiner = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1,
                                        last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0))

if not heatmap_stn is None:
    optimizer_heatmap_stn = torch.optim.Adam(heatmap_stn.parameters(), lr=train_params['lr_heatmap_stn'], betas=(0.5, 0.999))
    scheduler_heatmap_stn = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1,
                                        last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0))

In [30]:
if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
    dataset = DatasetRepeater(dataset, train_params['num_repeats'])
dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=6, drop_last=True)

In [31]:
generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params)
discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)

In [32]:
if torch.cuda.is_available():
    generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)
    discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)

In [33]:
kp_refiner

MultiHeadAttention(
  in_features=2, head_num=1, bias=True, activation=<function relu at 0x7f54158329e0>
  (linear_q): Linear(in_features=2, out_features=2, bias=True)
  (linear_k): Linear(in_features=2, out_features=2, bias=True)
  (linear_v): Linear(in_features=2, out_features=2, bias=True)
  (linear_o): Linear(in_features=2, out_features=2, bias=True)
)

In [34]:
print("has kp_refiner:{}, \nhas heatmap_stn:{}, oval_heatmap:{}" \
      .format(not kp_refiner is None, not heatmap_stn is None, oval_heatmap))
with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger:
    for epoch in trange(start_epoch, train_params['num_epochs']):#每一回合
        if epoch>0: break
        for j, x in enumerate(dataloader):
            if j > 300: break
            #print(0)
            losses_generator, generated, (adj_source, adj_weights_source, adj_driving, adj_weights_driving) \
                = generator_full(x, kp_refiner=kp_refiner, 
                                 )
            #loss_adj = torch.mean(masks_source * preserve_non_topk(masks_source, k=9) + \
            #                      masks_driving * preserve_non_topk(masks_driving, k=9))
            #print(1)
            loss_values = [val.mean() for val in losses_generator.values()]
            loss = sum(loss_values)#+100*loss_adj
            
            #print(generator.dense_motion_network.kp_variance)

            loss.backward()
            if not freeze_generator:
                optimizer_generator.step()
                optimizer_generator.zero_grad()
            if not heatmap_stn is None:
                optimizer_heatmap_stn.step()
                optimizer_heatmap_stn.zero_grad()  
            if not kp_refiner is None:
                optimizer_kp_refiner.step()
                optimizer_kp_refiner.zero_grad()
            if not freeze_kp_detector:
                optimizer_kp_detector.step()
                optimizer_kp_detector.zero_grad()
            
            #print
            if j%60 == 0:
                for i in range(11):  
                    r_ = torch.sum(adj_weights_source>=(i/10))/np.product(adj_weights_source.shape)
                    print("    #edge_weights(>={}): {:.2f} %".format(i/10, r_)  )
            if j%30 == 0:
                if not kp_refiner is None:
                    adj = np.array(adj_source.cpu())
                    #G = kp_refiner.get_graph_by_adjmatrix(adj_source, draw=True)
                    #draw_graph(G,save_pth=f'nx_fig/epo{epoch}-{j}-nx.png') #圖片存檔
                    #np.save(f'nx_fig/epo{epoch}-{j}-adj.npy', adj  )
                    print("[{}-{}]loss: {:.4f}, adj #0: {}, adj #1:{}" \
                          .format(epoch,j,loss.cpu().item(), np.sum(adj==0), np.sum(adj==1)))
                else:
                    print("[{}-{}]loss: {:.4f}" \
                      .format(epoch,j,loss.cpu().item()))
    
            if train_params['loss_weights']['generator_gan'] != 0:
                optimizer_discriminator.zero_grad()
                losses_discriminator = discriminator_full(x, generated)
                loss_values = [val.mean() for val in losses_discriminator.values()]
                loss = sum(loss_values)

                loss.backward()
                if not freeze_discriminator:
                    optimizer_discriminator.step()
                    optimizer_discriminator.zero_grad()

            else:
                losses_discriminator = {}
 
            
            losses_generator.update(losses_discriminator)
            losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
            logger.log_iter(losses=losses)

        
        scheduler_generator.step()
        scheduler_discriminator.step()
        scheduler_kp_detector.step()
        if not kp_refiner is None: 
            scheduler_kp_refiner.step()
        if not heatmap_stn is None: 
            scheduler_heatmap_stn.step()

        logger.log_epoch(epoch, {'generator': generator,
                                 'discriminator': discriminator,
                                 'kp_detector': kp_detector,
                                 'optimizer_generator': optimizer_generator,
                                 'optimizer_discriminator': optimizer_discriminator,
                                 'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated)

  0%|          | 0/100 [00:00<?, ?it/s]

has kp_refiner:True, 
has heatmap_stn:False, oval_heatmap:False


  "See the documentation of nn.Upsample for details.".format(mode))


    #edge_weights(>=0.0): 1.00 %
    #edge_weights(>=0.1): 1.00 %
    #edge_weights(>=0.2): 1.00 %
    #edge_weights(>=0.3): 0.28 %
    #edge_weights(>=0.4): 0.28 %
    #edge_weights(>=0.5): 0.28 %
    #edge_weights(>=0.6): 0.28 %
    #edge_weights(>=0.7): 0.28 %
    #edge_weights(>=0.8): 0.28 %
    #edge_weights(>=0.9): 0.00 %
    #edge_weights(>=1.0): 0.00 %
[0-0]loss: 136.2640, adj #0: 72, adj #1:28
[0-30]loss: 143.2298, adj #0: 72, adj #1:28
    #edge_weights(>=0.0): 1.00 %
    #edge_weights(>=0.1): 1.00 %
    #edge_weights(>=0.2): 0.65 %
    #edge_weights(>=0.3): 0.28 %
    #edge_weights(>=0.4): 0.28 %
    #edge_weights(>=0.5): 0.28 %
    #edge_weights(>=0.6): 0.28 %
    #edge_weights(>=0.7): 0.28 %
    #edge_weights(>=0.8): 0.13 %
    #edge_weights(>=0.9): 0.00 %
    #edge_weights(>=1.0): 0.00 %
[0-60]loss: 92.6280, adj #0: 72, adj #1:28
[0-90]loss: 120.4336, adj #0: 72, adj #1:28
    #edge_weights(>=0.0): 1.00 %
    #edge_weights(>=0.1): 1.00 %
    #edge_weights(>=0.2): 0.68 %
 

  1%|          | 1/100 [00:50<1:22:35, 50.06s/it]


In [35]:
#save model
op_checkpoint = {
    'generator' : generator.state_dict(),
    'kp_detector': kp_detector.state_dict(),
    'discriminator': discriminator.state_dict(),
    'kp_refiner':kp_refiner.state_dict(),
    
}
torch.save(op_checkpoint, 
           img_save_folder+f"mgif_cpk_newattngraph_300(only_dm).tar")