In [3]:


import _init_paths
from PIL import Image
import numpy as np
import cv2
import random
import copy

import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import torch.nn.functional as F

from dall_e  import map_pixels, unmap_pixels, load_model
from IPython.display import display, display_markdown
from vcl.models.vqvae import *
from vcl.models.trackers import *
from vcl.utils import *
from mmcv.runner import get_dist_info, init_dist, load_checkpoint

import matplotlib.pyplot as plt
%matplotlib inline


target_image_size = 256
step = 3

output_dir = '/home/lr/project/vcl_output/vis_correspondence_ablations_local_step_3'

if os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)

samples = []
list_path = '/home/lr/dataset/YouTube-VOS/2018/train/youtube2018_train.json'
video_dir = '/home/lr/dataset/YouTube-VOS/2018/train/JPEGImages_s256'
mask_dir = '/home/lr/dataset/YouTube-VOS/2018/train/Annotations_s256'
data = mmcv.load(list_path)
for vname, frames in data.items():
    sample = dict()
    sample['frames_path'] = []
    sample['masks_path'] = []
    for frame in frames:
        sample['frames_path'].append(osp.join(video_dir, vname, frame))
        sample['masks_path'].append(osp.join(mask_dir, vname, frame.replace('jpg','png')))
        
    sample['num_frames'] = len(sample['frames_path'])

    samples.append(sample)
        
visualizer = Correspondence_Visualizer(mode='pair', show_mode='none', radius=6, blend_color='jet')


def build_model(strides=(1,2,2,1), depth=18, pretrained=None, model_pre=None, pool_type='none'):
    # final model
    model = VanillaTracker(
        backbone=dict(type='ResNet', depth=depth, strides=strides, out_indices=(2, ), pool_type=pool_type, pretrained=model_pre),
        test_cfg=dict(),
        train_cfg=dict()
        )

    if pretrained is not None:
        params = torch.load(pretrained)
        state_dict = params['state_dict']
        state_dict = { k:v for k,v in state_dict.items() if k.find('backbone') != -1}
        model.load_state_dict(state_dict, strict=False)
        
    model_name = 'final'

    model = model.cuda()
    model.eval()
    return model
    

def main(x1, x2, model, model_name='final'):
    model.eval()
    enc1 = model.backbone(x1.cuda())
    enc2 = model.backbone(x2.cuda())
    plt, f1, f2, result, query = visualizer.visualize([frame1, frame2], [enc1, enc2], sample_idx, return_all=True)

    
    os.makedirs(os.path.join(output_dir,model_name, video_name), exist_ok=True)
    plt.savefig(os.path.join(output_dir,model_name,video_name, video_name+'.png'))
    cv2.imwrite(os.path.join(output_dir,model_name,video_name, 'target.jpg'), f1)
    cv2.imwrite(os.path.join(output_dir,model_name,video_name, 'ref.jpg'), f2)
    cv2.imwrite(os.path.join(output_dir,model_name,video_name, 'result.jpg'), result)
    cv2.imwrite(os.path.join(output_dir,model_name,video_name, 'query.jpg'), query)
    
    
    out = np.concatenate([query, result], 1)
    return out
    

In [4]:
## sample frame
import os.path as osp
import glob

vname = None

model1 = build_model(model_pre='/home/lr/models/ssl/image_based/moco_v2_res18_ep200_lab.pth')
model2 = build_model(strides=(1,2,1,1), depth=50, model_pre='/home/lr/models/ssl/image_based/detco_200ep_AA.pth', pool_type='max')
model3 = build_model(pretrained='/home/lr/expdir/VCL/group_vqvae_tracker/mast_d4_l2_base_5/epoch_1600.pth')
model4 = build_model(pretrained='/home/lr/expdir/VCL/group_vqvae_tracker/mast_d4_l2_pyramid_dis_16/epoch_3200.pth')
model5 = build_model(depth=50, pretrained='/home/lr/expdir/VCL/group_vqvae_tracker/final_framework_v2_9/epoch_1600.pth')

models = [ model1, model2, model3, model4, model5 ]


for i, sample in enumerate(samples):

    num_frames = sample['num_frames']
    frame_idx = random.randint(0, num_frames-step-1)

    video_name = sample['frames_path'][0].split('/')[-2]


    frame1 = cv2.imread(sample['frames_path'][frame_idx])[:,:,::-1]

    frame2 = cv2.imread(sample['frames_path'][min(frame_idx+step, len(sample['frames_path'])-1)])[:,:,::-1]
    
    frame1 = cv2.resize(frame1, (256,256))
    frame2 = cv2.resize(frame2, (256,256))
    
    print(sample['masks_path'][frame_idx])
    mask = mmcv.imread(sample['masks_path'][frame_idx], flag='unchanged', backend='pillow')
    mask = (mask > 0).astype(np.uint8)
    
    m = mmcv.imresize(mask, (32,32), interpolation='nearest').reshape(-1)
    
    idxs = np.nonzero(m)[0].tolist()
    
    if len(idxs) > 0:
        sample_idx = random.choice(idxs)
    else:
        sample_idx = random.randint(0, 32 ** 2 -1)
    # print(idxs, sample_idx)

    print('sample frames from {}'.format(sample['frames_path'][frame_idx]), i)

    x1_rgb = preprocess_(frame1, mode='rgb')
    x2_rgb = preprocess_(frame2, mode='rgb')

    x1_lab = preprocess_(frame1, mode='lab')
    x2_lab = preprocess_(frame2, mode='lab')
    
    outs = []
    model_names = ['moco_v2_lab_res18', 'detco', 'rec', 'temporal', 'final']
    
    for idx, model_name in enumerate(model_names):
        if model_name is 'detco':
            out = main(x1_rgb, x2_rgb, models[idx], model_name=model_name)
        else:
            out = main(x1_lab, x2_lab, models[idx], model_name=model_name)
            
        outs.append(out)
        result = np.concatenate(outs, 0)
        cv2.imwrite(osp.join(output_dir, video_name)+'.jpg', result)

    

2022-05-16 17:25:56,784 - vcl - INFO - Loading /home/lr/models/ssl/image_based/moco_v2_res18_ep200_lab.pth as torchvision
2022-05-16 17:25:56,801 - vcl - INFO - These parameters in pretrained checkpoint are not loaded: {'fc.0.weight', 'fc.2.weight', 'fc.2.bias', 'fc.0.bias'}


load checkpoint from local path: /home/lr/models/ssl/image_based/moco_v2_res18_ep200_lab.pth


2022-05-16 17:25:57,081 - vcl - INFO - Loading /home/lr/models/ssl/image_based/detco_200ep_AA.pth as torchvision


load checkpoint from local path: /home/lr/models/ssl/image_based/detco_200ep_AA.pth


  if model_name is 'detco':


/home/lr/dataset/YouTube-VOS/2018/train/Annotations_s256/5da6b2dc5d/00055.png
sample frames from /home/lr/dataset/YouTube-VOS/2018/train/JPEGImages_s256/5da6b2dc5d/00055.jpg 0
/home/lr/dataset/YouTube-VOS/2018/train/Annotations_s256/8aad0591eb/00105.png
sample frames from /home/lr/dataset/YouTube-VOS/2018/train/JPEGImages_s256/8aad0591eb/00105.jpg 1
/home/lr/dataset/YouTube-VOS/2018/train/Annotations_s256/e8962324e3/00110.png
sample frames from /home/lr/dataset/YouTube-VOS/2018/train/JPEGImages_s256/e8962324e3/00110.jpg 2


KeyboardInterrupt: 

<Figure size 432x288 with 0 Axes>