In [1]:
from __future__ import print_function, division
import sys
sys.path.append('core')

import argparse, configparser
import os
import cv2
import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader
# from torch.optim import Adam as AdamW
from torch.optim.adamw import AdamW
from core.onecyclelr import OneCycleLR

from tensorboardX import SummaryWriter


In [2]:

try:
    from torch.cuda.amp import GradScaler
except:
    # dummy GradScaler for PyTorch < 1.6
    class GradScaler:
        def __init__(self, enabled=False):
            pass
        def scale(self, loss):
            return loss
        def unscale_(self, optimizer):
            pass
        def step(self, optimizer):
            optimizer.step()
        def update(self):
            pass

In [3]:
# check if cuda is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [32]:
# define parameters
lr = 0.0002
# num_steps = 100000
epochs = 4
batch_size = 4
image_size = [160, 240]
gpuid = 0
iter = 12
wdecay = 0.0005
epsilon = 1e-8
clip = 1
dropout = 0
gamma = 0.8
add_noise = False

torch.set_num_threads(16)
torch.manual_seed(1234)
np.random.seed(1234)


In [33]:
# making directories for saving checkpoints and logs
if not os.path.exists('checkpoints'):
    os.mkdir('checkpoints')
if not os.path.exists('runs'):
    os.mkdir('runs')

In [34]:
# load dataset
import os
import math
import random
from glob import glob
import os.path as osp

from utils import frame_utils
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor

import sys

In [7]:
class FlowDataset(torch.utils.data.Dataset):
    def __init__(self, aug_params=None, sparse=False):
        self.augmentor = None
        self.sparse = sparse
        if aug_params is not None:
            if sparse:
                self.augmentor = SparseFlowAugmentor(**aug_params)
            else:
                self.augmentor = FlowAugmentor(**aug_params)

        self.is_test = False
        self.is_validate = False
        self.init_seed = False
        self.flow_list = []
        self.image_list = []
        self.extra_info = []

    def __getitem__(self, index):
        # print('Index is {}'.format(index))
        # sys.stdout.flush()
        if self.is_test:
            img1 = frame_utils.read_gen(self.image_list[index][0])
            img2 = frame_utils.read_gen(self.image_list[index][1])
            img1 = np.array(img1).astype(np.uint8)[..., :3]
            img2 = np.array(img2).astype(np.uint8)[..., :3]
            img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
            img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
            return img1, img2, self.extra_info[index]

        # if not self.init_seed:
        #     worker_info = torch.utils.data.get_worker_info()
        #     if worker_info is not None:
        #         torch.manual_seed(worker_info.id)
        #         np.random.seed(worker_info.id)
        #         random.seed(worker_info.id)
        #         self.init_seed = True

        index = index % len(self.image_list)
        valid = None
        h,w = 600,800
        if self.sparse:
            flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
        else:
            flow = frame_utils.read_gen(self.flow_list[index],h = h,w = w)

        img1 = frame_utils.read_gen(self.image_list[index][0])
        img2 = frame_utils.read_gen(self.image_list[index][1])

        flow = np.array(flow).astype(np.float32)
        img1 = np.array(img1).astype(np.uint8)
        img2 = np.array(img2).astype(np.uint8)

        # grayscale images
        if len(img1.shape) == 2:
            img1 = np.tile(img1[...,None], (1, 1, 3))
            img2 = np.tile(img2[...,None], (1, 1, 3))
        else:
            img1 = img1[..., :3]
            img2 = img2[..., :3]

        if self.augmentor is not None:
            if self.sparse:
                img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
            else:
                img1, img2, flow = self.augmentor(img1, img2, flow)

        img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
        img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
        flow = torch.from_numpy(flow).permute(2, 0, 1).float()

        if valid is not None:
            valid = torch.from_numpy(valid)
        else:
            valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)

        if self.is_validate:
            return img1, img2, flow, valid.float(), self.extra_info[index]
        else:
            return img1, img2, flow, valid.float()

    def getDataWithPath(self, index):
        img1, img2, flow, valid = self.__getitem__(index)

        imgPath_1 = self.image_list[index][0]
        imgPath_2 = self.image_list[index][1]

        return img1, img2, flow, valid, imgPath_1, imgPath_2

    def __rmul__(self, v):
        self.flow_list = v * self.flow_list
        self.image_list = v * self.image_list
        return self

    def __len__(self):
        return len(self.image_list)

In [8]:
class Carla_Dataset(FlowDataset):
    def __init__(self, aug_params=None, split='training', root='/home/sushlok/new_approach/datasets/carla', 
                 seq= [
                    "SoftRainNight",
                    "ClearNoon",
                    "CloudyNoon"
                ], 
            setup_type = [ 'camera_0', 'camera_-1','camera_1'], is_validate=False):
        super(Carla_Dataset, self).__init__(aug_params, sparse=False)
        if split == 'testing':
            self.is_test = True

        self.is_validate = is_validate
        image_dirs = []
        # datasets/vkitti/vkitti_1.3.1_rgb
        # print(seq, setup_type)
        start = 0
        for s in seq:
            for t in setup_type:
                # print(sorted(glob(osp.join(root, '%s' %(s) ,'rgb_%s/*.png' % (t)))))
                image_dirs += sorted(glob(osp.join(root, '%s' %(s) ,'rgb_%s/*.png' % (t))))
                # print(image_dirs)
                for i in range(start,len(image_dirs)-1):
                    img1 = image_dirs[i]
                    img2 = image_dirs[i+1]  
                    self.image_list += [ [img2, img1] ]
                    self.extra_info += [ [img2.split('/')[-1]] ]
                start = len(image_dirs)
            
        if split == 'training':
            for s in seq:
                for t in setup_type:
                    self.flow_list += sorted(glob(osp.join(root, '%s' %(s) ,'flow_%s/flow_npz/*.npz' % (t))))

In [9]:
aug_params = {'crop_size': image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
train_dataset = Carla_Dataset(aug_params, split='training', seq= ["ClearNoon"], setup_type = ['camera_-10', 'camera_-9', 'camera_-7', 'camera_-6']) # , 'camera_1','camera_2','camera_3', 'camera_4'
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

In [10]:
validation_datset = Carla_Dataset(aug_params, split='training', seq= ["ClearNoon"], setup_type = ['camera_-8'])
valid_loader = torch.utils.data.DataLoader(validation_datset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)

In [11]:
print(len(train_loader))

1000


In [12]:
# visualize data in dataloader using tensorboard
# from torch.utils.tensorboard import SummaryWriter
# writer = SummaryWriter()
# for i, (img1, img2, flow, valid) in enumerate(train_loader):
#     # print(img1.shape, img2.shape, flow.shape, valid.shape)
#     writer.add_images('img1', img1, i)
#     writer.add_images('img2', img2, i)
#     # writer.add_images('flow', flow, i)
#     # writer.add_images('valid', valid, i)
# writer.close()
    

In [13]:
# model
from core.update import BasicUpdateBlock
from core.extractor import BasicEncoder, BasicConvEncoder, Non_uniform_Encoder
from core.corr import CorrBlock, AlternateCorrBlock
from utils.utils import bilinear_sampler, coords_grid, upflow8
from core.swin_transformer import POLAUpdate, MixAxialPOLAUpdate

In [14]:
autocast = torch.cuda.amp.autocast

In [15]:
class GMFlowNetModel(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args

        self.hidden_dim = hdim = 128
        self.context_dim = cdim = 128
        args.corr_levels = 4
        args.corr_radius = 4
        args.dropout = 0.0
        args.use_mix_attn = False
        args.mixed_precision = True
        if not hasattr(self.args, 'dropout'):
            self.args.dropout = 0

        if not hasattr(self.args, 'alternate_corr'):
            self.args.alternate_corr = False

        # feature network, context network, and update block
        if self.args.use_mix_attn:
            self.fnet = nn.Sequential(
                            # BasicConvEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout),
                            Non_uniform_Encoder(output_dim=256, norm_fn='instance', dropout=args.dropout),
                            MixAxialPOLAUpdate(embed_dim=256, depth=6, num_head=8, window_size=7)
                        )
        else:
            self.fnet = nn.Sequential(
                Non_uniform_Encoder(output_dim=256, norm_fn='instance', dropout=args.dropout),
                # BasicConvEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout),
                POLAUpdate(embed_dim=256, depth=6, num_head=8, window_size=7, neig_win_num=1)
            )

        self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
        self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim, input_dim=cdim)

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    def initialize_flow(self, img):
        """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
        N, C, H, W = img.shape
        coords0 = coords_grid(N, H//8, W//8).to(img.device)
        coords1 = coords_grid(N, H//8, W//8).to(img.device)

        # optical flow computed as difference: flow = coords1 - coords0
        return coords0, coords1

    def upsample_flow(self, flow, mask):
        """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
        N, _, H, W = flow.shape
        mask = mask.view(N, 1, 9, 8, 8, H, W)
        mask = torch.softmax(mask, dim=2)

        up_flow = F.unfold(8 * flow, [3,3], padding=1)
        up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)

        up_flow = torch.sum(mask * up_flow, dim=2)
        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
        return up_flow.reshape(N, 2, 8*H, 8*W)

    def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
        """ Estimate optical flow between pair of frames """

        image1 = 2 * (image1 / 255.0) - 1.0
        image2 = 2 * (image2 / 255.0) - 1.0

        image1 = image1.contiguous()
        image2 = image2.contiguous()

        hdim = self.hidden_dim
        cdim = self.context_dim

        # run the feature network
        with autocast(enabled=self.args.mixed_precision):
            fmaps, cache = self.fnet([image1, image2])
        fmap1,fmap2 = fmaps
        fmap1 = fmap1.float()
        fmap2 = fmap2.float()

        # # Self-attention update
        # fmap1 = self.transEncoder(fmap1)
        # fmap2 = self.transEncoder(fmap2)
        # print("feature_volume:")
        # # print(fmap1, fmap2)
        # print(fmap1.shape)
        # print("------------------------------------------------")
        if self.args.alternate_corr:
            corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
        else:
            corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)

        # print("corr_fn:")
        # # print(corr_fn.corrMap)
        # print("------------------------------------------------")
        # print(corr_fn.corrMap.shape)
        # run the context network
        with autocast(enabled=self.args.mixed_precision):
            cnet = self.cnet(image1)
            net, inp = torch.split(cnet, [hdim, cdim], dim=1)
            net = torch.tanh(net)
            inp = torch.relu(inp)

        coords0, coords1 = self.initialize_flow(image1)

        # Correlation as initialization
        N, fC, fH, fW = fmap1.shape
        corrMap = corr_fn.corrMap

        #_, coords_index = torch.max(corrMap, dim=-1) # no gradient here
        softCorrMap = F.softmax(corrMap, dim=2) * F.softmax(corrMap, dim=1) # (N, fH*fW, fH*fW)

        # print("softCorrMap:")
        # # print(softCorrMap)
        # print("------------------------------------------------")
        # print(softCorrMap.shape)
        if flow_init is not None:
            coords1 = coords1 + flow_init
        else:
            # print('matching as init')
            # mutual match selection
            match12, match_idx12 = softCorrMap.max(dim=2) # (N, fH*fW)
            match21, match_idx21 = softCorrMap.max(dim=1)

            for b_idx in range(N):
                match21_b = match21[b_idx,:]
                match_idx12_b = match_idx12[b_idx,:]
                match21[b_idx,:] = match21_b[match_idx12_b]

            matched = (match12 - match21) == 0  # (N, fH*fW)
            coords_index = torch.arange(fH*fW).unsqueeze(0).repeat(N,1).to(softCorrMap.device)
            coords_index[matched] = match_idx12[matched]

            # matched coords
            coords_index = coords_index.reshape(N, fH, fW)
            coords_x = coords_index % fW
            coords_y = coords_index // fW

            coords_xy = torch.stack([coords_x, coords_y], dim=1).float()
            coords1 = coords_xy
        # print('coords1:')
        # # print(coords1)
        # print(coords1.shape)
        # print("------------------------------------------------")
        
        # Iterative update
        flow_predictions = []
        # print("iter:",iter)
        for itr in range(iters):
            # print(itr)
            coords1 = coords1.detach()
            corr = corr_fn(coords1) # index correlation volume

            flow = coords1 - coords0
            # print("flow:")
            # print(flow)
            # print(flow.shape)
            # print("------------------------------------------------")
            with autocast(enabled=self.args.mixed_precision):
                net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)

            # F(t+1) = F(t) + \Delta(t)
            
            coords1 = coords1 + delta_flow
            # print("coords1:")
            # # print(coords1)
            # print(coords1.shape)
            # print("------------------------------------------------")
            
            # upsample predictions
            if up_mask is None:
                flow_up = upflow8(coords1 - coords0)
            else:
                flow_up = self.upsample_flow(coords1 - coords0, up_mask)
            # print(flow_up.shape)
            flow_predictions.append(flow_up)

        if test_mode:
            return coords1 - coords0, flow_up

        return flow_predictions, softCorrMap


In [29]:
# define args 
args = argparse.Namespace()
model = nn.DataParallel(GMFlowNetModel(args), device_ids=[gpuid])
# model = GMFlowNetModel(args)

In [30]:
# load model
# model.load_state_dict(torch.load("pretrained_models/new_model.pth"), strict=True)
model.load_state_dict(torch.load("checkpoints/carla_new_model.pth"), strict=True)


<All keys matched successfully>

In [31]:
# print layers and parameters names and set model parameters to non-require grad
for name, param in model.named_parameters():
    if(name.split('.')[1] != 'fnet'):
        param.requires_grad = False

In [37]:
# visualize the model using tensorboard
# test_image = 255*torch.randn(1, 3, 64, 64)
# writer.add_graph(model, (test_image,test_image.transpose(2,3)))
# writer.close()
    

In [38]:
# defining optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wdecay, eps=epsilon)

In [39]:
# define scheduler
from core.onecyclelr import OneCycleLR
scheduler = OneCycleLR(optimizer, lr, steps_per_epoch=len(train_loader), epochs=epochs,
        pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')

In [40]:
scaler = GradScaler(enabled=args.mixed_precision)

In [41]:
from core.loss import compute_supervision_coarse, compute_coarse_loss, backwarp

In [42]:
MAX_FLOW = 400
SUM_FREQ = 100
VAL_FREQ = 5000

In [43]:
# define loss function
def sequence_loss(train_outputs, image1, image2, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW, use_matching_loss=False):
    """ Loss function defined over sequence of flow predictions """
    flow_preds, softCorrMap = train_outputs

    # original RAFT loss
    n_predictions = len(flow_preds)
    flow_loss = 0.0

    # exclude invalid pixels and extremely large displacements
    mag = torch.sum(flow_gt**2, dim=1).sqrt()
    valid = (valid >= 0.5) & (mag < max_flow)

    for i in range(n_predictions):
        i_weight = gamma**(n_predictions - i - 1)
        i_loss = (flow_preds[i] - flow_gt).abs()
        flow_loss += i_weight * (valid[:, None].float()  * i_loss).mean()

    epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
    epe = epe.view(-1)[valid.view(-1)]

    metrics = {
        'epe': epe.mean().item(),
        '1px': (epe < 1).float().mean().item(),
        '3px': (epe < 3).float().mean().item(),
        '5px': (epe < 5).float().mean().item(),
    }

    if use_matching_loss:
        # enable global matching loss. Try to use it in late stages of the trianing
        img_2back1 = backwarp(image2, flow_gt)
        occlusionMap = (image1 - img_2back1).mean(1, keepdims=True) #(N, H, W)
        occlusionMap = torch.abs(occlusionMap) > 20
        occlusionMap = occlusionMap.float()

        conf_matrix_gt = compute_supervision_coarse(flow_gt, occlusionMap, 8) # 8 from RAFT downsample

        matchLossCfg = configparser.ConfigParser()
        matchLossCfg.POS_WEIGHT = 1
        matchLossCfg.NEG_WEIGHT = 1
        matchLossCfg.FOCAL_ALPHA = 0.25
        matchLossCfg.FOCAL_GAMMA = 2.0
        matchLossCfg.COARSE_TYPE = 'cross_entropy'
        match_loss = compute_coarse_loss(softCorrMap, conf_matrix_gt, matchLossCfg)

        flow_loss = flow_loss + 0.01*match_loss

    return flow_loss, metrics

In [44]:
# setting model to train mode
model.cuda()
model.train()

DataParallel(
  (module): GMFlowNetModel(
    (fnet): Sequential(
      (0): Non_uniform_Encoder(
        (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (norm2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (norm3): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (conv0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (residual_layer1): Diff_ResidualBlock(
          (conv0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU(inplace=True)
          (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine

In [45]:
def weight_histograms(writer, step, model):
  for name, param in model.named_parameters():
    writer.add_histogram()

In [46]:
add_noise = True
use_mix_attn = True
writer = SummaryWriter("runs/test_run3")
for epoch in range(epochs):
    # weight_histograms(writer, epoch, model)
    for i_batch, data_blob in enumerate(train_loader):
        optimizer.zero_grad()
        image1, image2, flow, valid = [x.cuda() for x in data_blob]

        if add_noise:
            stdv = np.random.uniform(0.0, 5.0)
            image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
            image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)

        flow_predictions = model(image1, image2, iters=iter)

        loss, metrics = sequence_loss(flow_predictions, image1, image2, flow, valid, gamma=gamma, use_matching_loss=use_mix_attn)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        scaler.step(optimizer)
        scheduler.step()
        scaler.update()
        print("Epoch: {}, Iter: {}, Loss: {}".format(epoch, i_batch, loss.item()))
        # add loss and metric on tensorboard scalar
        writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + i_batch)
        for k, v in metrics.items():
            writer.add_scalar('metrics/' + k, v, epoch * len(train_loader) + i_batch)
PATH = 'checkpoints/carla_new_model.pth'
torch.save(model.state_dict(), PATH)
writer.close()

Epoch: 0, Iter: 0, Loss: 8.891362190246582
Epoch: 0, Iter: 1, Loss: 3.609867811203003
Epoch: 0, Iter: 2, Loss: 6.381996154785156
Epoch: 0, Iter: 3, Loss: 6.591984272003174
Epoch: 0, Iter: 4, Loss: 6.784656524658203
Epoch: 0, Iter: 5, Loss: 5.969427585601807
Epoch: 0, Iter: 6, Loss: 9.13219165802002
Epoch: 0, Iter: 7, Loss: 3.9313809871673584
Epoch: 0, Iter: 8, Loss: 6.451406955718994
Epoch: 0, Iter: 9, Loss: 4.644724369049072
Epoch: 0, Iter: 10, Loss: 5.081079006195068
Epoch: 0, Iter: 11, Loss: 2.681011438369751
Epoch: 0, Iter: 12, Loss: 4.05045747756958
Epoch: 0, Iter: 13, Loss: 5.848642349243164
Epoch: 0, Iter: 14, Loss: 2.9176547527313232
Epoch: 0, Iter: 15, Loss: 10.234559059143066
Epoch: 0, Iter: 16, Loss: 5.062490463256836
Epoch: 0, Iter: 17, Loss: 7.866201877593994
Epoch: 0, Iter: 18, Loss: 8.815914154052734
Epoch: 0, Iter: 19, Loss: 5.644184589385986
Epoch: 0, Iter: 20, Loss: 10.839179039001465
Epoch: 0, Iter: 21, Loss: 5.240859508514404
Epoch: 0, Iter: 22, Loss: 6.305705547332