In [1]:
import os
from os import path
import time
from argparse import ArgumentParser

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image

from model.eval_network import STCN
from dataset.davis_test_dataset import DAVISTestDataset
from util.tensor_util import unpad
from inference_core import InferenceCore

from progressbar import progressbar
from easydict import EasyDict




"""
Arguments loading
"""
args = EasyDict(dict(
    model='saves/stcn.pth',
    davis_path='../DAVIS/2017',
    top=20,
    split='val',
    amp=False,
    mem_every=5,
    include_last=True,
))

davis_path = args.davis_path
# out_path = args.output

# Simple setup
# os.makedirs(out_path, exist_ok=True)
palette = Image.open(path.expanduser(davis_path + '/trainval/Annotations/480p/blackswan/00000.png')).getpalette()

torch.autograd.set_grad_enabled(False)

# Setup Dataset
if args.split == 'val':
    test_dataset = DAVISTestDataset(davis_path+'/trainval', imset='2017/val.txt')
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
elif args.split == 'testdev':
    test_dataset = DAVISTestDataset(davis_path+'/test-dev', imset='2017/test-dev.txt')
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
else:
    raise NotImplementedError

# Load our checkpoint
top_k = args.top
prop_model = STCN().cuda().eval()

# Performs input mapping such that stage 0 model can be loaded
prop_saved = torch.load(args.model)
for k in list(prop_saved.keys()):
    if k == 'value_encoder.conv1.weight':
        if prop_saved[k].shape[1] == 4:
            pads = torch.zeros((64,1,7,7), device=prop_saved[k].device)
            prop_saved[k] = torch.cat([prop_saved[k], pads], 1)
prop_model.load_state_dict(prop_saved)

total_process_time = 0
total_frames = 0


In [2]:
from metrics import db_eval_iou,db_eval_boundary
def metric_frames(pred_msks,gt_msks):
    assert pred_msks.shape == gt_msks.shape
    k = gt_msks.shape[1]
    J = [
        np.mean([
            db_eval_iou(pred_object,gt_object)
            for pred_object,gt_object in zip(pred_frame,gt_frame)
        ])
        for pred_frame,gt_frame in zip(pred_msks,gt_msks)
    ]
    F = [
        np.mean([
            db_eval_boundary(pred_object,gt_object)
            for pred_object,gt_object in zip(pred_frame,gt_frame)
        ])
        for pred_frame,gt_frame in zip(pred_msks,gt_msks)
    ]
    return J,F




In [3]:
def infer_video(data,index):
    with torch.cuda.amp.autocast(enabled=args.amp):

        rgb = data['rgb'][:,index].cuda()
        msk = data['gt'][0][:,index].cuda()
        info = data['info']
        name = info['name'][0]
        k = len(info['labels'][0])
        size = info['size_480p']

        torch.cuda.synchronize()

        processor = InferenceCore(prop_model, rgb, k, top_k=top_k, 
                        mem_every=args.mem_every, include_last=args.include_last)
        processor.interact(msk[:,0], 0, rgb.shape[1])

        # Do unpad -> upsample to original size 
        out_masks = torch.zeros((processor.t, 1, *size), dtype=torch.uint8, device='cuda')
        for ti in range(processor.t):
            prob = unpad(processor.prob[:,ti], processor.pad)
            prob = F.interpolate(prob, size, mode='bilinear', align_corners=False)
            out_masks[ti] = torch.argmax(prob, dim=0)
        
        out_masks = (out_masks.detach().cpu().numpy()[:,0]).astype(np.uint8)
        torch.cuda.synchronize()

        # compute metrics
        gt_msk = msk.cpu().numpy().squeeze(axis=2).astype(np.uint8)
        pred_msk = np.array([out_masks == i for i in range(1,1+k)]).astype(np.uint8)
        gt_msk = gt_msk.swapaxes(0,1)
        pred_msk = pred_msk.swapaxes(0,1)

        del rgb
        del msk
        del processor
        return metric_frames(pred_msk,gt_msk)


In [4]:
for data in test_loader:
    if data['info']['name'][0] == 'breakdance':
        break

In [4]:
from tqdm import tqdm
import pickle

In [7]:
# 任意两帧对比
def infer_one_frame(data,result):
    name = data['info']['name'][0]
    frames = data['info']['num_frames'].numpy()[0]
    result[name] = np.zeros((2,frames,frames))
    for i in tqdm(range(frames)):
        for j in range(frames):
            index = [i,j]
            (_,J),(_,F) = infer_video(data,index)
            result[name][0,i,j] = J
            result[name][1,i,j] = F

In [8]:
result = dict()
for data in test_loader:
    infer_one_frame(data,result)
with open('all_two.pkl','wb') as f:
    pickle.dump(result,f)

100%|██████████| 69/69 [22:44<00:00, 19.78s/it]
100%|██████████| 50/50 [07:10<00:00,  8.60s/it]
100%|██████████| 80/80 [28:01<00:00, 21.02s/it]
100%|██████████| 84/84 [20:12<00:00, 14.44s/it]
100%|██████████| 90/90 [23:15<00:00, 15.51s/it]
100%|██████████| 75/75 [16:11<00:00, 12.95s/it]
100%|██████████| 40/40 [04:36<00:00,  6.91s/it]
 70%|███████   | 73/104 [22:01<09:21, 18.10s/it]


KeyboardInterrupt: 

In [10]:
with open('all_two.pkl','wb') as f:
    pickle.dump(result,f)

In [83]:
def infer_dataset(data,result):
    name = data['info']['name'][0]
    result[name] = []
    frames = data['info']['num_frames'].numpy()[0]
    for start in tqdm(range(frames-1)):
        index = [0] + list(range(frames-1,-1,-1))[start:start+20]
        result[name].append([index,infer_video(data,index)])
    for start in tqdm(range(frames-1)):
        index = [0] + list(range(frames))[start:start+20]
        result[name].append([index,infer_video(data,index)])


In [84]:
result = dict()
infer_dataset(data,result)
with open('breakdance_first.pkl','wb') as f:
    pickle.dump(result,f)

100%|██████████| 83/83 [01:39<00:00,  1.20s/it]
100%|██████████| 83/83 [01:42<00:00,  1.23s/it]
