In [2]:
import os
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
import math
import time

from BaselineDataset import BaselineDataset
from Model import EnDeWithPooling, EnDeConvLSTM_ws, SkipLSTMEnDe
from torchvision import transforms
from PIL import Image

In [3]:
def saveTransformedImages(imageTensor):
    to_pil = torchvision.transforms.ToPILImage()
    im = to_pil(imageTensor)
    mn, mx = np.min(im), np.max(im)
    im = (im - mn) / (mx - mn)
    print(im)
    plt.imshow(im, cmap='gray')
    plt.show()

In [4]:
def plotTrajectory(xValsGT, yValsGT, xValsPred, yValsPred, xValsPredMulti, yValsPredMulti, seqLen, im_path, numFrames=None):
    fig = plt.figure(figsize=(8, 8))
    plt.plot(yValsGT, xValsGT, c='r', label='Ground Truth')
    plt.plot(yValsPred, xValsPred, c='g', label='Prediction')
    plt.plot(yValsPredMulti, xValsPredMulti, c='b', label='Multimodal Prediction', alpha=0.8)
    axes = plt.gca()
    axes.set_xlim([1, 512])
    axes.set_ylim([1, 512])
    plt.xlabel('X-Axis')
    plt.ylabel('Y-Axis')
    plt.legend(loc='upper right')
    if numFrames == None:
        plt.title('Trajectory')
    else:
        plot_title = 'Trajectory (' + str(numFrames // 10 - 2) + "s)"
        plt.title(plot_title)
    plt.savefig(im_path)
    plt.close()

In [5]:
def heatmapAccuracy(outputMap, labelMap, thr=1.5):
    pred = np.unravel_index(outputMap.argmax(), outputMap.shape)
    gt = np.unravel_index(labelMap.argmax(), labelMap.shape)

    dist = math.sqrt((pred[0] - gt[0]) ** 2 + (pred[1] - gt[1]) ** 2)
    if dist <= thr:
        return 1, dist, (pred[0], pred[1]), (gt[0], gt[1])
    return 0, dist, (pred[0], pred[1]), (gt[0], gt[1])

In [6]:
def largest_indices(ary, n):
    """Returns the n largest indices from a numpy array."""
    flat = ary.flatten()
    indices = np.argpartition(flat, -n)[-n:]
    indices = indices[np.argsort(-flat[indices])]
    return np.unravel_index(indices, ary.shape)

In [7]:
def multiAccuracy(outputMap, labelMap, topK=5):
    pred = largest_indices(outputMap, topK)
    gt = np.unravel_index(labelMap.argmax(), labelMap.shape)
    dist_arr = []
    for i in range(len(pred[0])):
        dist = math.sqrt((pred[0][i] - gt[0]) ** 2 + (pred[1][i] - gt[1]) ** 2)
        dist_arr.append(dist)
    
    min_val = np.min(dist_arr)
    min_idx = np.argmin(dist_arr)
    within_radius = 0
    if min_val <= 4:
        within_radius = 1
    return 0, min_val, (pred[0][min_idx], pred[1][min_idx]), (gt[0], gt[1]), within_radius

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [9]:
torch.set_default_tensor_type(torch.cuda.FloatTensor)

### Set Cross Validation No

In [30]:
cv_num = 1
repo_dir = "/home/fbd/rrc/submission/INFER-code"

In [31]:
# checkpoint_path = os.path.join(repo_dir, "models", "baseline", "cv-0", "checkpoint_future.tar")
checkpoint_path = os.path.join(repo_dir, "models", "baseline", "cv-1", "checkpoint_future.tar")
# checkpoint_path = os.path.join(repo_dir, "models", "baseline", "cv-2", "checkpoint_future.tar")
# checkpoint_path = os.path.join(repo_dir, "models", "baseline", "cv-3", "checkpoint_future.tar")
# checkpoint_path = os.path.join(repo_dir, "models", "baseline", "cv-4", "checkpoint_future.tar")

In [32]:
checkpoint = torch.load(checkpoint_path)
model = SkipLSTMEnDe(activation="relu", initType="default", numChannels=4, imageHeight=256, imageWidth=256, batchnorm=False, softmax=False)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.cuda()
model.convlstm = model.convlstm.cuda()

In [33]:
data_dir = "/home/fbd/rrc/submission/INFER-datasets/kitti"
val_dir = os.path.join(data_dir, "final-validation", "test" + str(cv_num) + ".csv")
val_dataset = BaselineDataset(data_dir, height=256, width=256, train=False, infoPath=val_dir, augmentation=False, groundTruth=True)

### Future Prediction (Final, Validation)

In [34]:
upsample_512 = torch.nn.Upsample(scale_factor=2, mode='bilinear')
labelTransform = transforms.Compose([
    transforms.ToTensor()
])
targetGTDir = os.path.join(data_dir, 'targetGT')
valLoss1, valLoss2, valLoss3, valLoss4, valLoss, valLoss8 = [], [], [], [], [], []
topK = 5
totalPreds = 0
hitPreds = 0

In [35]:
debug, prevOut, state = True, None, None
xValsGT, yValsGT, xValsPred, yValsPred, xValsPredMulti, yValsPredMulti = [], [], [], [], [], []
seqLoss, seqVals = [], []
seqNum, seqLen = 0, 0

start_time = time.time()
model.eval()

for i in range(len(val_dataset)):
    grid, kittiSeqNum, vehicleId, frame1, frame2, endOfSequence, offset, numFrames, augmentation = val_dataset[i]
    
    # The Last Channel is the target frame and first n - 1 are source frames
    inp = grid[:-1, :].unsqueeze(0).to(device)
    currLabel = grid[-1:, :].unsqueeze(0).to(device)
    
    if numFrames < 60:
        continue

    if offset >= 20:
        new_inp = inp.clone().squeeze(0)
        mn, mx = torch.min(prevOut), torch.max(prevOut)
        prevOut = (prevOut - mn) / (mx - mn)
        new_inp[0] = prevOut
        inp = new_inp.unsqueeze(0).cuda()

    # Forward the input and obtain the result
    out = model.forward(inp, state)
    state = (model.h, model.c, model.h1, model.c1, model.h2, model.c2)
    currOutputMap = out.clone()
    newOutputMap = upsample_512(currOutputMap)
    nextTargetImg = Image.open(os.path.join(targetGTDir, str(kittiSeqNum).zfill(4), 
                                            str(frame2).zfill(6), str(vehicleId).zfill(6) + '.png'))
    
    nextTargetTensor = labelTransform(nextTargetImg).unsqueeze(0)
    
    prevOut = currOutputMap.detach().cpu().squeeze(0).squeeze(0)
    currOutputMap = currOutputMap.detach().cpu().numpy().squeeze(0).squeeze(0)
    currLabel = currLabel.detach().cpu().numpy().squeeze(0).squeeze(0)
    
    # Upsampled outputs and inputs
    currOutputMap1 = newOutputMap.detach().cpu().numpy().squeeze(0).squeeze(0)
    currLabel1 = nextTargetTensor.detach().cpu().numpy().squeeze(0).squeeze(0)
    
    _, dist1, predCoordinates1, gtCoordinates1 = heatmapAccuracy(currOutputMap1, currLabel1)
    _, dist2, predCoordinates2, gtCoordinates2, within_radius = multiAccuracy(currOutputMap1, currLabel1, topK=topK)
    
    if offset >= 20:
        seqLoss.append(dist2)
        totalPreds += 1
        if within_radius == 1:
            hitPreds += 1

    seqLen += 1
    xValsGT.append(gtCoordinates1[0])
    yValsGT.append(gtCoordinates1[1])
    xValsPred.append(predCoordinates1[0])
    yValsPred.append(predCoordinates1[1])
    xValsPredMulti.append(predCoordinates2[0])
    yValsPredMulti.append(predCoordinates2[1])
    
    if endOfSequence:
        seqVals.append(seqLen)
        xValsGT, yValsGT, xValsPred, yValsPred, xValsPredMulti, yValsPredMulti = [], [], [], [], [], []
        seqNum +=1
        state = None
        valLoss.append(np.mean(seqLoss))
        valLoss8.append(np.mean(seqLoss[:8]))
        valLoss1.append(np.mean(seqLoss[:10]))
        valLoss2.append(np.mean(seqLoss[:20]))
        valLoss3.append(np.mean(seqLoss[:30]))
        valLoss4.append(np.mean(seqLoss[:40]))
        print("SeqNum: {}, KittiSeqNum: {}, VehicleNum: {}, numFrames: {}, loss: {}, len(seqLoss): {}".format(seqNum, kittiSeqNum, vehicleId, numFrames, np.mean(seqLoss), len(seqLoss)))
        seqLoss = []

end_time = time.time()

  "See the documentation of nn.Upsample for details.".format(mode))


SeqNum: 1, KittiSeqNum: 8, VehicleNum: 8, numFrames: 61, loss: 5.921452857307328, len(seqLoss): 40
SeqNum: 2, KittiSeqNum: 8, VehicleNum: 13, numFrames: 61, loss: 6.488410788039642, len(seqLoss): 40
SeqNum: 3, KittiSeqNum: 20, VehicleNum: 3, numFrames: 61, loss: 2.3949155550700865, len(seqLoss): 40
SeqNum: 4, KittiSeqNum: 20, VehicleNum: 5, numFrames: 61, loss: 2.9852286300231343, len(seqLoss): 40
SeqNum: 5, KittiSeqNum: 20, VehicleNum: 12, numFrames: 61, loss: 2.581867413257071, len(seqLoss): 40
SeqNum: 6, KittiSeqNum: 20, VehicleNum: 12, numFrames: 61, loss: 1.5740997509714414, len(seqLoss): 40
SeqNum: 7, KittiSeqNum: 20, VehicleNum: 12, numFrames: 61, loss: 19.722267741527773, len(seqLoss): 40
SeqNum: 8, KittiSeqNum: 20, VehicleNum: 16, numFrames: 61, loss: 4.764257650269385, len(seqLoss): 40
SeqNum: 9, KittiSeqNum: 20, VehicleNum: 122, numFrames: 61, loss: 18.313226648084502, len(seqLoss): 40
SeqNum: 10, KittiSeqNum: 84, VehicleNum: 14, numFrames: 61, loss: 2.185481176663386, len(s

In [36]:
print("1s: {}, 2s: {}, 3s: {}, 4s: {}".format(np.mean(valLoss1), np.mean(valLoss2), np.mean(valLoss3), np.mean(valLoss4)))

1s: 1.6652890070591397, 2s: 2.806251491477541, 3s: 3.7122240229439227, 4s: 4.55494582261985


In [31]:
print("HitPreds: {}, TotalPreds: {}, Hit Rate: {}".format(hitPreds, totalPreds, hitPreds / totalPreds))

HitPreds: 513, TotalPreds: 1120, Hit Rate: 0.45803571428571427
