# Running prediction models

In [11]:
import matplotlib.pyplot as plt
import albumentations as A
import cv2
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

## Download Datasets

In [None]:
# Download synthetic test dataset
# !gdown 1-AtP2n5N0J7XTRzPlOz95eomtsLbuyvR -O synthetic_data.zip
# !unzip -d synthetic_data/ synthetic_data.zip # unziping test data

# Download natural test dataset
!gdown 1-0aDjjgh0RCRlWFYmNxIApk-wJ6p8mq_ -O natural_data.zip
!unzip -d natural_data/ natural_data.zip # unziping test data

### Download models

In [None]:
# Download model 1A
!gdown 1T-NFKCLVBlSkbuR-eai4_fejw7AtslBK
# Download model 1B
!gdown 1MOeTTWnkc-vYj5Q16ubQIx0yZSRlA8Pi
# Download model 1C

# Download model 1A

# Download model 2B

# Download model 2C

## Dataset Class

In [4]:
class CustomDataset(Dataset):
    """Our custom dataset."""
    def __init__(self,log_file,root_dir,check=False,transform=None):
        """
        Args: 
		    log_file (string): path to txt file with all logged sample ids. 
		    root_dir (string): Directory with all the image frames.
            check (Bool, optional): Also return 7th frames for sanity check.
		    transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.sample_names=open(log_file).read().splitlines()
        self.root_dir=root_dir
        self.check = check
        self.transform=transform
    def __len__(self):
        return len(self.sample_names)
    def __getitem__(self,idx):
        if torch.is_tensor(idx):
            idx=idx.tolist()
        sample_id = self.sample_names[idx]
        image0_path=os.path.join(self.root_dir,"frame0",sample_id)
        image3_path=os.path.join(self.root_dir,"frame3",sample_id)
        coords_path=os.path.join(self.root_dir,"coords",sample_id)
        vis_path=os.path.join(self.root_dir,"vis",sample_id)
	
        image0=np.load(image0_path+'.npy')
        image3=np.load(image3_path+'.npy')
        image0 = cv2.resize(image0, dsize=(640, 360))
        image3 = cv2.resize(image3, dsize=(640, 360))


        coords= np.load(coords_path+'.npy')
        vis=np.load(vis_path+'.npy')
        

        if self.check:
            image7_path=os.path.join(self.root_dir,"frame7",sample_id)
            image7=np.load(image7_path+'.npy')
            image7 = cv2.resize(image7, dsize=(640, 360))
            sample={'id':sample_id, 'image0':image0, 'image3':image3, 'image7':image7, 'coords':coords, 'vis':vis, 'shift':np.array((0,0))}
        else:
          sample={'id':sample_id, 'image0':image0, 'image3':image3, 'coords':coords, 'vis':vis, 'shift':np.array((0,0))}
             
        if self.transform:
            sample = self.transform(sample)
        return sample

class ToTensor(object):
    """
    Convert ndarrays in sample to Tensors.
    - swap color axis because
    - numpy image: H x W x C
    - torch image: C x H x W
    """
    def __call__(self, sample):
        sample_id, image0, image3, coords, vis, shift = sample['id'], sample['image0'], sample['image3'], sample['coords'], sample['vis'], sample['shift']
        image0 = image0.transpose((2, 0, 1))
        image3 = image3.transpose((2, 0, 1))


        if len(sample)==7:
            image7 = sample['image7']
            image7 = image7.transpose((2, 0, 1))
            return {'id': sample_id,
                    'image0': torch.from_numpy(image0),
                    'image3': torch.from_numpy(image3),
                    'image7': torch.from_numpy(image7),
                    'coords': torch.from_numpy(coords),
		                'vis': torch.from_numpy(vis),
                    'shift': torch.from_numpy(shift)}
        else:
            return {'id': sample_id,
                    'image0': torch.from_numpy(image0),
                    'image3': torch.from_numpy(image3),
                    'coords': torch.from_numpy(coords),
		                'vis': torch.from_numpy(vis),
                    'shift': torch.from_numpy(shift)}

class AugmentData(object):
    """
    Augment data with ColorJitter, Gaussian Noise and To Gray transformations.
    """
    def __call__(self, sample):
        sample_id, image0, image3, coords, vis, shift = sample['id'], sample['image0'], sample['image3'], sample['coords'], sample['vis'], sample['shift']
        
        trans = A.Compose(
            [
             A.augmentations.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, always_apply=False, p=0.2),
             A.augmentations.transforms.GaussNoise (var_limit=.05, mean=0, per_channel=True, always_apply=False, p=0.2),
             A.augmentations.transforms.ToGray(p=0.1)
             ],
            additional_targets={'image3': 'image'}
            )

        image0 = image0.astype(np.float32)
        image3 = image3.astype(np.float32)

        transformed = trans(image=image0, image3=image3)
        n_image0 = transformed['image']
        n_image3 = transformed['image3']
        
        if len(sample)==7:
            image7 = sample['image7']
            return {'id': sample_id,
                    'image0': n_image0,
                    'image3': n_image3,
                    'image7': image7,
                    'coords': coords,
		            'vis': vis,
                    'shift': shift}
        else:
            return {'id': sample_id,
                    'image0': n_image0,
                    'image3': n_image3,
                    'coords': coords,
		            'vis': vis,
                    'shift': shift}

class ShiftData(object):
    """
    Shifting data along random x, y axis.
    """
    def __call__(self, sample):
        sample_id, image0, image3, coords, vis, shift = sample['id'], sample['image0'], sample['image3'], sample['coords'], sample['vis'], sample['shift']

        # randomly sample these values and pass to Affine trans (x in range (-5,5), y in range (-10,10)
        random_x = np.random.randint(-5,5)
        random_y = np.random.randint(-10,10)

        coords[:, :, 0] =  coords[:, :, 0 ]+random_y
        coords[:, :, 1] =  coords[:, :, 1 ]+random_x


        trans = A.Compose(
            [
              A.augmentations.geometric.transforms.Affine(scale=1, translate_percent=None, translate_px={'x':random_x, 'y':random_y}, shear=None, fit_output=False, keep_ratio=False, always_apply=False, p=1)             
            ],
            additional_targets={'image3': 'image'}
            )

        image0 = image0.astype(np.float32)
        image3 = image3.astype(np.float32)

        transformed = trans(image=image0, image3=image3)
        n_image0 = transformed['image']
        n_image3 = transformed['image3']

        shift = (random_x, random_y)
 
        if len(sample)==7:
            image7 = sample['image7']
            return {'id': sample_id,
                    'image0': n_image0,
                    'image3': n_image3,
                    'image7': image7,
                    'coords': coords,
		            'vis': vis,
                    'shift': np.array(shift)}
        else:
            return {'id': sample_id,
                    'image0': n_image0,
                    'image3': n_image3,
                    'coords': coords,
		            'vis': vis,
                    'shift': np.array(shift)}

## Helper functions

In [20]:
def run_model(device, model, dataloader):
  model.eval()
  
  for i_batch, sample_batched in enumerate(dataloader):
    batch_size = len(sample_batched['id'])

    input1 = sample_batched['image0']
    input2 = sample_batched['image3']
    last_imgs = sample_batched['image7']
    inputs = torch.cat((input1, input2), dim=1) 
    inputs = inputs.to(device).float() # torch.Size([B, 6, H, W])

    with torch.no_grad():
          # Get model outputs 
          outputs = model(inputs) # torch.Size([B, 3, H, W]) same as inputs shape
          outputs = outputs['out'].to(device)
          outputs_coords = outputs[:, :2, :, :]
          outputs_vis = outputs[:, 2, :, :]
          outputs_coords = torch.permute(outputs_coords, (0, 2, 3, 1))
          outputs_vis = torch.where(outputs_vis > 0, 1.0, 0.0)
          outputs_vis = outputs_vis.view(batch_size,1,-1)
          outputs_coords = outputs_coords.view(batch_size,1,-1,2) # torch.Size([B, 1, 64*64, 2])

          reconstruct_batch(outputs_coords.cpu().numpy(), outputs_vis.cpu().numpy(), input1, last_imgs)

In [31]:
def reconstruct_batch(coords, vis, inputs, gts):
  for i in range(coords.shape[0]):
    input=inputs[i].permute(1, 2, 0).cpu().numpy() # H, W, C
    gt=gts[i].permute(1, 2, 0).cpu().numpy() # H, W, C

    prediction = frame_reconstruction(input, coords[i], vis[i])

    plt.imshow(input)
    plt.title("First Frame")
    plt.show()
    plt.imshow(gt)
    plt.title("Last Frame")
    plt.show()
    plt.imshow(prediction)
    plt.title("Reconstructed Frame")
    plt.show()

In [22]:
def frame_reconstruction(img0, coords, vis):
  h = np.arange(0, 360) # Take linspace of H of the image
  w = np.arange(0, 640) # Take linspace of W of the image
  w, h = np.meshgrid(w, h)
  original_x = h.flatten()
  original_y = w.flatten()

  vis = vis.squeeze() > 0
  original_x = original_x[vis]
  original_y = original_y[vis]

  coords = coords.squeeze()
  
  coords_x = coords[:, 1]
  coords_y = coords[:, 0]
  coords_x = coords_x[vis]
  coords_y = coords_y[vis]


  coords_x = original_x+coords_x
  coords_y = original_y+coords_y

  
  reconstructed_NN = interpolate(img0, original_x, original_y, coords_x, coords_y)

  return reconstructed_NN

In [32]:
from scipy.interpolate import NearestNDInterpolator
def interpolate(img0, original_x, original_y, coords_x, coords_y):
  z = img0
  z_R = []
  z_G = []
  z_B = []

  for point in list(zip(original_x, original_y)):
    x, y = point
    x = int(x)
    y = int(y)
    z_R.append(img0[x, y, 0]) #(360,640,3)
    z_G.append(img0[x, y, 1])
    z_B.append(img0[x, y, 2])
  

  X = np.arange(0, 360) # Take linspace of H of the image
  Y = np.arange(0, 640) # Take linspace of W of the image
  X, Y = np.meshgrid(X, Y)  # 2D grid for interpolation

  points = list(zip(coords_x, coords_y))

  interp_R_NN = NearestNDInterpolator(points, z_R) # predicted point coord -> R-intensity
  interp_G_NN = NearestNDInterpolator(points, z_G) # predicted point coord -> G-intensity
  interp_B_NN = NearestNDInterpolator(points, z_B) # predicted point coord -> B-intensity


  Z_R_NN = interp_R_NN(X, Y) 
  Z_G_NN = interp_G_NN(X, Y)
  Z_B_NN = interp_B_NN(X, Y)


  concateneted_NN = np.stack([Z_R_NN, Z_G_NN, Z_B_NN])

  img_NN = concateneted_NN

  reconstruction_NN = img_NN.transpose(2, 1, 0)

  return reconstruction_NN

## Run Model

In [None]:
# Choose Dataset and Model Path
DATASET_PATH = "natural_data/"
MODEL_PATH = "model2A.pt"

model = torch.load(MODEL_PATH)

dataset = CustomDataset(log_file=DATASET_PATH+'/sample_ids.txt',
                        root_dir=DATASET_PATH,
                        check=True,
                        transform=transforms.Compose([ToTensor()]))

dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
device = 'cuda'
run_model(device, model, dataloader)