In [2]:
import os
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
import math
import time

from KittiDataset import KittiDataset
from Model import EnDeWithPooling, EnDeConvLSTM_ws, SkipLSTMEnDe
from torchvision import transforms
from PIL import Image

In [3]:
def saveTransformedImages(imageTensor):
    to_pil = torchvision.transforms.ToPILImage()
    im = to_pil(imageTensor)
    mn, mx = np.min(im), np.max(im)
    im = (im - mn) / (mx - mn)
    print(im)
    plt.imshow(im, cmap='gray')
    plt.show()

In [4]:
def plotTrajectory(xValsGT, yValsGT, xValsPred, yValsPred, seqLen, im_path, numFrames=None):
    fig = plt.figure(figsize=(8, 8))
    plt.scatter(yValsGT, xValsGT, c='r', marker='o', label='Ground Truth')
    plt.scatter(yValsPred, xValsPred, c='g', marker='x', label='Prediction')
    axes = plt.gca()
    axes.set_xlim([1, 512])
    axes.set_ylim([1, 512])
    plt.xlabel('X-Axis')
    plt.ylabel('Y-Axis')
    plt.legend(loc='upper right')
    if numFrames == None:
        plt.title('Trajectory')
    else:
        plot_title = 'Trajectory (' + str(numFrames // 10 - 1) + "s)"
        plt.title(plot_title)
    plt.savefig(im_path)
    plt.close()

In [5]:
def heatmapAccuracy(outputMap, labelMap, thr=1.5):
    pred = np.unravel_index(outputMap.argmax(), outputMap.shape)
    gt = np.unravel_index(labelMap.argmax(), labelMap.shape)

    dist = math.sqrt((pred[0] - gt[0]) ** 2 + (pred[1] - gt[1]) ** 2)
    if dist <= thr:
        return 1, dist, (pred[0], pred[1]), (gt[0], gt[1])
    return 0, dist, (pred[0], pred[1]), (gt[0], gt[1])

In [6]:
def largest_indices(ary, n):
    """Returns the n largest indices from a numpy array."""
    flat = ary.flatten()
    indices = np.argpartition(flat, -n)[-n:]
    indices = indices[np.argsort(-flat[indices])]
    return np.unravel_index(indices, ary.shape)

In [7]:
def multiAccuracy(outputMap, labelMap, topK=5):
    pred = largest_indices(outputMap, topK)
    gt = np.unravel_index(labelMap.argmax(), labelMap.shape)
    dist_arr = []
    for i in range(len(pred[0])):
        dist = math.sqrt((pred[0][i] - gt[0]) ** 2 + (pred[1][i] - gt[1]) ** 2)
        dist_arr.append(dist)
    
    min_val = np.min(dist_arr)
    min_idx = np.argmin(dist_arr)
    within_radius = 0
    if min_val <= 4:
        within_radius = 1
    return 0, min_val, (pred[0][min_idx], pred[1][min_idx]), (gt[0], gt[1]), within_radius

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [9]:
torch.set_default_tensor_type(torch.cuda.FloatTensor)

### Set correct repo path

In [42]:
repo_path = "/home/fbd/rrc/submission/INFER-code"

In [43]:
# Best
checkpoint_path = os.path.join(repo_path, "models", "cityscapes-transfer", "checkpoint_future.tar")

checkpoint = torch.load(checkpoint_path)
model = SkipLSTMEnDe(activation="relu", initType="default", numChannels=5, imageHeight=256, imageWidth=256, batchnorm=False, softmax=False)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.cuda()
model.convlstm = model.convlstm.cuda()

In [44]:
data_dir = "/home/fbd/rrc/submission/INFER-datasets/cityscapes"
val_dir = os.path.join(data_dir, "test.csv")

val_dataset = KittiDataset(data_dir, height=256, width=256, train=False, infoPath=val_dir, augmentation=False, groundTruth=True)

### Future Prediction (Final, Validation)

### 0.8s preconditioning, 1s future

In [45]:
import csv
upsample_512 = torch.nn.Upsample(scale_factor=2, mode='bilinear')
labelTransform = transforms.Compose([
    transforms.ToTensor()
])
targetGTDir = os.path.join(data_dir, 'targetGT')
valLoss1, valLoss2, valLoss3, valLoss4, valLoss = [], [], [], [], []
futureFrames = 14
topK = 5

In [46]:
debug, prevOut, state = True, None, None
prevChannels = None
xValsGT, yValsGT, xValsPred, yValsPred = [], [], [], []
seqLoss, seqVals = [], []
seqNum, seqLen = 0, 0

start_time = time.time()
model.eval()

for i in range(len(val_dataset)):
    grid, kittiSeqNum, vehicleId, frame1, frame2, endOfSequence, offset, numFrames, augmentation = val_dataset[i]
    
    if endOfSequence is False:
        if int(offset) % 2 == 0:
            continue

    # The Last Channel is the target frame and first n - 1 are source frames
    inp = grid[:-1, :].unsqueeze(0).to(device)
    currLabel = grid[-1:, :].unsqueeze(0).to(device)
    
    if offset < futureFrames:
        prevChannels = inp

    if offset >= futureFrames:
        new_inp = inp.clone().squeeze(0)
        mn, mx = torch.min(prevOut), torch.max(prevOut)
        prevOut = (prevOut - mn) / (mx - mn)
        new_inp[0] = prevOut
        new_inp[4] = prevChannels[0, 4, :, :]
        inp = new_inp.unsqueeze(0).cuda()

    # Forward the input and obtain the result
    out = model.forward(inp, state)
    state = (model.h, model.c, model.h1, model.c1, model.h2, model.c2)
    currOutputMap = out.clone()
    newOutputMap = upsample_512(currOutputMap)
    nextTargetImg = Image.open(os.path.join(targetGTDir, str(kittiSeqNum).zfill(4), 
                                            str(frame2).zfill(6), str(vehicleId).zfill(6) + '.png'))
    
    nextTargetTensor = labelTransform(nextTargetImg).unsqueeze(0)
    
    prevOut = currOutputMap.detach().cpu().squeeze(0).squeeze(0)
    currOutputMap = currOutputMap.detach().cpu().numpy().squeeze(0).squeeze(0)
    currLabel = currLabel.detach().cpu().numpy().squeeze(0).squeeze(0)
    _, dist, predCoordinates, gtCoordinates = heatmapAccuracy(currOutputMap, currLabel)
    
    # Upsampled outputs and inputs
    currOutputMap1 = newOutputMap.detach().cpu().numpy().squeeze(0).squeeze(0)
    currLabel1 = nextTargetTensor.detach().cpu().numpy().squeeze(0).squeeze(0)
    
    _, dist1, predCoordinates1, gtCoordinates1 = heatmapAccuracy(currOutputMap1, currLabel1)
    _, dist2, predCoordinates2, gtCoordinates2, within_radius = multiAccuracy(currOutputMap1, currLabel1, topK=topK)
    
    if offset >= futureFrames:
        seqLoss.append(dist2)

    seqLen += 1
    xValsGT.append(gtCoordinates1[0])
    yValsGT.append(gtCoordinates1[1])
    xValsPred.append(predCoordinates1[0])
    yValsPred.append(predCoordinates1[1])
    
    if endOfSequence:
        seqVals.append(seqLen)
        xValsGT, yValsGT, xValsPred, yValsPred = [], [], [], []
        seqNum += 1
        state = None
        valLoss.append(np.mean(seqLoss))
        print("SeqNum: {}, KittiSeqNum: {}, VehicleNum: {}, numFrames: {}, loss: {}, len(seqLoss): {}".format(seqNum, kittiSeqNum, vehicleId, numFrames, np.mean(seqLoss), len(seqLoss)))
        seqLoss = []

end_time = time.time()

  "See the documentation of nn.Upsample for details.".format(mode))


SeqNum: 1, KittiSeqNum: 0, VehicleNum: 3, numFrames: 30, loss: 4.5682446146884175, len(seqLoss): 8
SeqNum: 2, KittiSeqNum: 0, VehicleNum: 4, numFrames: 30, loss: 4.7279944297634255, len(seqLoss): 8
SeqNum: 3, KittiSeqNum: 18, VehicleNum: 6, numFrames: 30, loss: 4.86975096479574, len(seqLoss): 8
SeqNum: 4, KittiSeqNum: 18, VehicleNum: 7, numFrames: 30, loss: 5.034616995972669, len(seqLoss): 8
SeqNum: 5, KittiSeqNum: 26, VehicleNum: 21, numFrames: 30, loss: 1.375, len(seqLoss): 8
SeqNum: 6, KittiSeqNum: 32, VehicleNum: 14, numFrames: 30, loss: 2.4705114938006396, len(seqLoss): 8
SeqNum: 7, KittiSeqNum: 42, VehicleNum: 15, numFrames: 30, loss: 4.9992640141559415, len(seqLoss): 8
SeqNum: 8, KittiSeqNum: 50, VehicleNum: 2, numFrames: 30, loss: 4.3017766952966365, len(seqLoss): 8
SeqNum: 9, KittiSeqNum: 1, VehicleNum: 25, numFrames: 30, loss: 9.064940543450463, len(seqLoss): 8
SeqNum: 10, KittiSeqNum: 10, VehicleNum: 20, numFrames: 30, loss: 0.875, len(seqLoss): 8
SeqNum: 11, KittiSeqNum: 11

In [47]:
print("Avg Loss: {}".format(np.mean(valLoss)))
print("Avg Loss in m: {}".format(np.mean(valLoss) * 0.25))
print("Num Seq: {}".format(len(valLoss)))

Avg Loss: 3.674949260378043
Avg Loss in m: 0.9187373150945107
Num Seq: 26


### Kitti Future Prediction

In [53]:
checkpoint_path = os.path.join(repo_path, "models", "cityscapes-transfer", "checkpoint_future.tar")

checkpoint = torch.load(checkpoint_path)
model = SkipLSTMEnDe(activation="relu", initType="default", numChannels=5, imageHeight=256, imageWidth=256, batchnorm=False, softmax=False)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.cuda()
model.convlstm = model.convlstm.cuda()

In [55]:
data_dir = "/home/fbd/rrc/submission/INFER-datasets/kitti"
val_dir = os.path.join(data_dir, "val.csv")
val_dataset = KittiDataset(data_dir, height=256, width=256, train=False, infoPath=val_dir, augmentation=False, groundTruth=True)

In [56]:
upsample_512 = torch.nn.Upsample(scale_factor=2, mode='bilinear')
labelTransform = transforms.Compose([
    transforms.ToTensor()
])
targetGTDir = os.path.join(data_dir, 'targetGT')
valLoss1, valLoss2, valLoss3, valLoss4, valLoss, valLoss8, valLoss9 = [], [], [], [], [], [], []
futureFrames = 8
topK = 5

In [57]:
debug, prevOut, state = True, None, None
xValsGT, yValsGT, xValsPred, yValsPred = [], [], [], []
prevChannels = None
seqLoss, seqVals = [], []
seqNum, seqLen = 0, 0

start_time = time.time()
model.eval()

for i in range(len(val_dataset)):
    grid, kittiSeqNum, vehicleId, frame1, frame2, endOfSequence, offset, numFrames, augmentation = val_dataset[i]
    
    # The Last Channel is the target frame and first n - 1 are source frames
    inp = grid[:-1, :].unsqueeze(0).to(device)
    currLabel = grid[-1:, :].unsqueeze(0).to(device)
    
    if offset < futureFrames:
        prevChannels = inp    

    if offset >= futureFrames:
        new_inp = inp.clone().squeeze(0)
        mn, mx = torch.min(prevOut), torch.max(prevOut)
        prevOut = (prevOut - mn) / (mx - mn)
        new_inp[0] = prevOut
        new_inp[4] = prevChannels[0, 4, :, :]        
        inp = new_inp.unsqueeze(0).cuda()

    # Forward the input and obtain the result
    out = model.forward(inp, state)
    state = (model.h, model.c, model.h1, model.c1, model.h2, model.c2)
    currOutputMap = out.clone()
    newOutputMap = upsample_512(currOutputMap)
    nextTargetImg = Image.open(os.path.join(targetGTDir, str(kittiSeqNum).zfill(4), 
                                            str(frame2).zfill(6), str(vehicleId).zfill(6) + '.png'))
    
    nextTargetTensor = labelTransform(nextTargetImg).unsqueeze(0)
    
    prevOut = currOutputMap.detach().cpu().squeeze(0).squeeze(0)
    currOutputMap = currOutputMap.detach().cpu().numpy().squeeze(0).squeeze(0)
    currLabel = currLabel.detach().cpu().numpy().squeeze(0).squeeze(0)
    
    # Upsampled outputs and inputs
    currOutputMap1 = newOutputMap.detach().cpu().numpy().squeeze(0).squeeze(0)
    currLabel1 = nextTargetTensor.detach().cpu().numpy().squeeze(0).squeeze(0)
    
    _, dist2, predCoordinates2, gtCoordinates2, within_radius = multiAccuracy(currOutputMap1, currLabel1, topK=topK)    
    
    if offset >= futureFrames:
        seqLoss.append(dist2)

    seqLen += 1
    xValsGT.append(gtCoordinates1[0])
    yValsGT.append(gtCoordinates1[1])
    xValsPred.append(predCoordinates1[0])
    yValsPred.append(predCoordinates1[1])
    
    if endOfSequence:
        seqVals.append(seqLen)
        xValsGT, yValsGT, xValsPred, yValsPred = [], [], [], []
        seqNum +=1
        state = None
        valLoss.append(np.mean(seqLoss))
        valLoss8.append(np.mean(seqLoss[:8]))
        valLoss9.append(np.mean(seqLoss[:9]))        
        valLoss1.append(np.mean(seqLoss[:10]))
        print("SeqNum: {}, KittiSeqNum: {}, VehicleNum: {}, numFrames: {}, loss: {}, len(seqLoss): {}".format(seqNum, kittiSeqNum, vehicleId, numFrames, np.mean(seqLoss), len(seqLoss)))
        seqLoss = []

end_time = time.time()

  "See the documentation of nn.Upsample for details.".format(mode))


SeqNum: 1, KittiSeqNum: 0, VehicleNum: 0, numFrames: 21, loss: 17.120213422955146, len(seqLoss): 12
SeqNum: 2, KittiSeqNum: 4, VehicleNum: 2, numFrames: 21, loss: 0.985702260395516, len(seqLoss): 12
SeqNum: 3, KittiSeqNum: 4, VehicleNum: 7, numFrames: 21, loss: 8.89134361076261, len(seqLoss): 12
SeqNum: 4, KittiSeqNum: 4, VehicleNum: 8, numFrames: 21, loss: 7.216872177929858, len(seqLoss): 12
SeqNum: 5, KittiSeqNum: 8, VehicleNum: 8, numFrames: 21, loss: 0.4166666666666667, len(seqLoss): 12
SeqNum: 6, KittiSeqNum: 8, VehicleNum: 17, numFrames: 21, loss: 4.00455705869119, len(seqLoss): 12
SeqNum: 7, KittiSeqNum: 16, VehicleNum: 1, numFrames: 21, loss: 0.0, len(seqLoss): 12
SeqNum: 8, KittiSeqNum: 16, VehicleNum: 3, numFrames: 21, loss: 0.0, len(seqLoss): 12
SeqNum: 9, KittiSeqNum: 20, VehicleNum: 0, numFrames: 21, loss: 0.0, len(seqLoss): 12
SeqNum: 10, KittiSeqNum: 20, VehicleNum: 0, numFrames: 21, loss: 3.75, len(seqLoss): 12
SeqNum: 11, KittiSeqNum: 20, VehicleNum: 4, numFrames: 21, 

In [58]:
print("0.8s: {}".format(np.mean(valLoss8) * 0.25))
print("0.9s: {}".format(np.mean(valLoss9) * 0.25))
print("1s: {}".format(np.mean(valLoss1) * 0.25))

0.8s: 0.4919861827300455
0.9s: 0.537135809450425
1s: 0.5770326002443692
