In [1]:
import time
import numpy as np
import os
from RANet_lib import *
from RANet_lib.RANet_lib import *
from RANet_model import RANet as Net
from RANet_model import make_layer2, MS_Block, ResBlock2
import os
import os.path as osp
from glob import glob
import pickle

import matplotlib.pyplot as plt
from torchvision import transforms
import PIL.Image as Image

from vj_davis_17_loader import Custom_DAVIS2017_dataset
from torch.utils.data import DataLoader
from vj_loss_functions import *
from vj_data_parallel_model import *

import argparse
from math import log10
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torch.cuda as cuda
import torch.multiprocessing

from configs_vj.vj_config_ranet_iou_trnsfm import *

Number of files: 30
Dirs: 30
saving data in X_test
loading files from:  ../datasets/DAVIS/ImageSets/2017/train.txt


In [2]:
class support_model(nn.Module):
    def __init__(self, ip_channel, level = 0):
        super(support_model, self).__init__()
        ch1 = min(64, ip_channel)
        self.predictor = nn.Sequential(make_layer2(ip_channel, ch1),
                make_layer2(ch1, 32),
                MS_Block(32, 16, d=[1,3,6]),
                ResBlock2(16, 8),
                nn.Conv2d(16, 1, 3, padding=1)
                )
        self.level = level
    def forward(self, feat):
        Out = []
        for idx in range(len(feat)): # Per image
            out_ = []
            for idy in range(len(feat[idx])): # Per object in image
                # RANet upsamples just after features are computer. So down sample it back
                feat_img = F.interpolate(feat[idx][idy][self.level],scale_factor=1/2,\
                                         mode='bilinear', align_corners=True)
                out_.append(self.predictor(feat_img))
            Out.append(torch.cat(out_, 1))
        return Out
class Full_training_RAnet_multi_level(nn.Module):
    '''
    Takes Loss function for segmentation, the segmentation model and 
    '''
    def __init__(self, model, loss_classifier, sup_model1, sup_model2, lamda1=0.5, lamda2=0.5):
        super(Full_training_RAnet_multi_level, self).__init__()
        self.model = model
        self.loss_classifier = loss_classifier
        self.sup_model1 = sup_model1
        self.sup_model2 = sup_model2
        self.lamda1 = lamda1 # How much to weigh the 1st level mask prediction loss
        self.lamda2 = lamda2 # How much to weigh the 2nd level mask prediction loss
        
    def forward(self, template, target, template_msk, target_msk,\
                               prev_mask=None):

        Out, feat = self.model.RANet_Multiple_forward_train_mult_lvl(template=template,target=target,\
                               template_msk=template_msk, target_msk = target_msk,prev_mask=prev_mask)
        ############ Main mask prediction #############
        prediction_single_masks = []
        target_single_masks = []
        for idx in range(len(Out)):
            max_obj = template_msk[idx,0].max().int().data.cpu().numpy()
            target_msk_images = self.model.P2masks(F.relu(target_msk[idx,0] - 1), max_obj - 1)
            for i in range(max_obj-1):
                prediction_single_masks.append(Out[idx][0,i].reshape(-1))
                target_single_masks.append(target_msk_images[i+1].reshape(-1))

        prediction_single_masks = torch.stack(prediction_single_masks)
        target_single_masks = torch.stack(target_single_masks)
        
        cls_loss = self.loss_classifier(prediction_single_masks,target_single_masks)
        cls_loss_lvl1 = cls_loss.clone()*0
        cls_loss_lvl2 = 0
        
        Out_lvl1 = self.sup_model1(feat)
        Out_lvl2 = self.sup_model2(feat)

        ########## Level 1 prediction ###############3
        w,h = Out_lvl1[0][0][0].shape[-2:]
        prediction_single_masks = []
        target_single_masks = []
        target_msk_lvl1 = F.interpolate(target_msk, size=(w,h), mode='nearest')

        for idx in range(len(Out)): # Number of images
            max_obj = template_msk[idx,0].max().int().data.cpu().numpy()
            target_msk_images = self.model.P2masks(F.relu( target_msk_lvl1[idx,0] - 1), max_obj - 1)
            for i in range(max_obj-1):
                prediction_single_masks.append(Out_lvl1[idx][0,i].reshape(-1))
                target_single_masks.append(target_msk_images[i+1].reshape(-1))

        prediction_single_masks = torch.stack(prediction_single_masks)
        target_single_masks = torch.stack(target_single_masks)

        cls_loss_lvl1 = loss_fn(prediction_single_masks,target_single_masks)

        ########## Level 2 prediction ###############3
        w,h = Out_lvl2[0][0][0].shape[-2:]
        prediction_single_masks = []
        target_single_masks = []
        target_msk_lvl1 = F.interpolate(target_msk, size=(w,h), mode='nearest')

        for idx in range(len(Out)):
            max_obj = template_msk[idx,0].max().int().data.cpu().numpy()
            target_msk_images = self.model.P2masks(F.relu( target_msk_lvl1[idx,0] - 1), max_obj - 1)
            for i in range(max_obj-1):
                prediction_single_masks.append(Out_lvl2[idx][0,i].reshape(-1))
                target_single_masks.append(target_msk_images[i+1].reshape(-1))

        prediction_single_masks = torch.stack(prediction_single_masks)
        target_single_masks = torch.stack(target_single_masks)

        cls_loss_lvl2 = loss_fn(prediction_single_masks,target_single_masks)
        
        total_loss = cls_loss + self.lamda1*cls_loss_lvl1 + self.lamda2*cls_loss_lvl2
        
        return torch.stack((total_loss, cls_loss, cls_loss_lvl1+cls_loss_lvl2)).unsqueeze(0)

In [3]:
dataset='17train'
inSize1 = 480
inSize2 = 864
root = '../datasets/DAVIS'
img_mode = '480p'
img_shape = (inSize1,inSize2)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
gpus = [i for i in range(torch.cuda.device_count())]
print('using GPUs ID: {}'.format(gpus))
torch.multiprocessing.set_sharing_strategy('file_system')

trnsfm_crop = RandomCrop(min_size=0.75, prob=0.3)
trnsfm_rotate=RandomRotate(max_angle=np.pi/6, prob=0.3)
trnsfm_flip=RandomFlip(vprob=0.2, hprob=0.2)
#transforms.Compose([trnsfm_crop,trnsfm_rotate,trnsfm_flip])

trnsfm_piecewise = transforms.Compose([trnsfm_crop,trnsfm_rotate]) 
trnsfm_common = transforms.Compose([trnsfm_flip])
get_prev_mask = True
use_std_template = True

img_dataset = Custom_DAVIS2017_dataset(root=root, img_shape=img_shape, img_mode=img_mode,\
                            get_prev_mask=get_prev_mask, use_std_template=use_std_template,\
                            trnsfm_common = None, trnsfm_piecewise=None,loader_type='train')

img_loader = DataLoader(dataset=img_dataset, num_workers=0,\
                        batch_size=batch_size*len(gpus), shuffle=False, pin_memory=True)

print('===> Building model')
############## Choose a model #################
# params='RANet_video_multi.pth'
# params='RANet_encoder_retrain_epoch1.pth'
# params='RANet_multi_basic_train_epoch1.pth'
params = 'RANet_video_multi.pth'
model = Net(pretrained=False, type=net_type)
model.set_type(net_type)
# model.cuda()
checkpoint_load('../models/' + params, model)

optimizer_model = torch.optim.Adam(model.parameters(), lr=0)

####### Support models #########
sup_model1= support_model(ip_channel=128, level = 0)
sup_model2= support_model(ip_channel=32, level = 1)

try:
    sup_model1.load_state_dict(torch.load('../models/sup_model1.pth'))
    sup_model2.load_state_dict(torch.load('../models/sup_model2.pth'))
except:
    print("Creating support models from scratch")

optimizer_sup_model1 = torch.optim.Adam(sup_model1.parameters(), lr=1e-04)
optimizer_sup_model2 = torch.optim.Adam(sup_model2.parameters(), lr=1e-04)
print("Support modules ready")
############## Data parallel Model ###############
full_model = Full_training_RAnet_multi_level(model, loss_classifier=loss_fn,\
                    sup_model1=sup_model1, sup_model2=sup_model2, lamda1=lamda1, lamda2=lamda2)
# full_model = full_model.cuda()

# parallel_model_parameters = []
# for name, param in full_model.named_parameters():
#     if ('model' not in name):
#         parallel_model_parameters.append(param)

# optimizer_parallel_model = torch.optim.Adam(parallel_model_parameters, lr=0)
full_model = nn.DataParallel(full_model, device_ids=gpus).cuda()
print("memory usage :", cuda.memory_allocated(0) /(1024*1024))


using GPUs ID: [0]
loading files from:  ../datasets/DAVIS/ImageSets/2017/train.txt
===> Building model
Multi-object mode
=> Loaded checkpoint '../models/RANet_video_multi.pth'
Support modules ready
memory usage : 246.8291015625


In [5]:
train_model = False
train_sup_model = True
epoch = 1
max_memory = cuda.memory_allocated(0) /(1024*1024)
loss_per_epoch = []

In [6]:
for epoch in range(20):
    start_time = time.perf_counter()
    loss_per_batch = []
    model_train_time = 0
    if (train_model):
        model.train()
    else:
        model.eval()
    if (train_sup_model):
        sup_model1.train()
        sup_model2.train()
    else:
        sup_model1.eval()
        sup_model2.eval()
    
    for iteration, batch in enumerate(img_loader, 1):
        template,template_mask, target,target_mask, prev_mask = batch

        optimizer_model.zero_grad()
        optimizer_sup_model1.zero_grad()
        optimizer_sup_model2.zero_grad()

        start_time_model = time.perf_counter()
        loss = full_model(template=template, target=target, template_msk=template_mask, target_msk=target_mask,\
                              prev_mask=prev_mask)
        total_loss, cls_loss, cls_loss_lvls = loss.mean(dim=0)
        total_loss= total_loss.mean()

        if (cuda.memory_allocated(0) /(1024*1024) > max_memory):
            print("New max memory!:", cuda.memory_allocated(0) /(1024*1024), "iteration:", iteration)
            max_memory = cuda.memory_allocated(0) /(1024*1024)
        total_loss.backward()
        if train_model and epoch > -1:
            optimizer_model.step()
        if train_sup_model and epoch > -1:
            optimizer_sup_model1.step()
            optimizer_sup_model2.step()
            
        loss_per_batch.append([total_loss.item(), cls_loss.item(), cls_loss_lvls.item()])
        del total_loss, cls_loss, cls_loss_lvls, loss, template,template_mask, target,target_mask, prev_mask
        model_train_time += time.perf_counter() - start_time_model
        
    if (train_sup_model):
        torch.save(sup_model1.state_dict(), '../models/sup_model1.pth')
        torch.save(sup_model2.state_dict(), '../models/sup_model2.pth')
    
    loss_per_batch = np.array(loss_per_batch)
    loss_per_epoch.append(np.mean(loss_per_batch, axis=0))
    memory = cuda.memory_allocated(0) /(1024*1024)
    end_time = time.perf_counter()
    
    print("epoch:", epoch, "level loss:",loss_per_epoch[-1][2], "Time for mini batch:", end_time - start_time,\
      "time spend on model running:",model_train_time, "memory used",memory)


  "See the documentation of nn.Upsample for details.".format(mode))


New max memory!: 3398.55810546875 iteration: 1
New max memory!: 4358.0888671875 iteration: 2
New max memory!: 5052.8505859375 iteration: 4
New max memory!: 5053.1005859375 iteration: 9
New max memory!: 6442.375 iteration: 16
New max memory!: 6443.625 iteration: 19
New max memory!: 8525.4091796875 iteration: 29
epoch: 0 level loss: 0.3402410996456941 Time for mini batch: 76.62341135973111 time spend on model running: 34.61787291103974 memory used 511.59814453125
epoch: 1 level loss: 0.3674257849653562 Time for mini batch: 70.62205483717844 time spend on model running: 34.70563895441592 memory used 511.59814453125
epoch: 2 level loss: 0.3774564017852147 Time for mini batch: 70.2120343260467 time spend on model running: 34.72170540271327 memory used 511.59814453125
epoch: 3 level loss: 0.3975088447332382 Time for mini batch: 70.47939829900861 time spend on model running: 34.734846849925816 memory used 511.59814453125
epoch: 4 level loss: 0.3860672796765963 Time for mini batch: 69.66295719

In [6]:
max_memory

8533.2841796875