In [None]:
import sys
sys.path.insert(0, '/kaggle/input/l5kit-may31/l5kit/')

In [None]:
# from IPython.core.debugger import set_trace

In [None]:
import numpy as np
import os
import psutil
import torch

from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision.models.resnet import resnet50
from tqdm.notebook import tqdm
from typing import Dict
from pprint import pprint

from l5kit.configs import load_config_data
from l5kit.data import LocalDataManager, ChunkedDataset
from l5kit.dataset import AgentDataset, EgoDataset
from l5kit.evaluation import write_pred_csv
from l5kit.geometry import transform_points
from l5kit.rasterization import build_rasterizer

In [None]:
INPUT_DIR = '/kaggle/input/lyft-motion-prediction-autonomous-vehicles'
WEIGHTS_FILE = '/kaggle/input/cs535-resnet50-training/cs535_resnet50.pth'

In [None]:
# set env variable for data
os.environ["L5KIT_DATA_FOLDER"] = INPUT_DIR
dm = LocalDataManager(None)

In [None]:
cfg = load_config_data("/kaggle/input/l5kit-may31/examples/agent_motion_prediction/agent_motion_config.yaml")
# pprint(cfg)

In [None]:
cfg['model_params']['history_num_frames'] = 10
cfg['val_data_loader']['batch_size'] = 12
cfg['val_data_loader']['num_workers'] = 4
cfg['val_data_loader']['key'] = 'scenes/test.zarr'

## Init test dataset

In [None]:
# ===== INIT DATASET
test_cfg = cfg["val_data_loader"]

# Rasterizer
rasterizer = build_rasterizer(cfg, dm)

# Test dataset/dataloader
test_zarr = ChunkedDataset(dm.require(test_cfg["key"])).open()
test_mask = np.load(f"{INPUT_DIR}/scenes/mask.npz")["arr_0"]
test_dataset = AgentDataset(cfg, test_zarr, rasterizer, agents_mask=test_mask)
# test_dataset, _ = random_split(test_dataset, [100, 71122-100])
test_dataloader = DataLoader(test_dataset,
                             shuffle=test_cfg["shuffle"],
                             batch_size=test_cfg["batch_size"],
                             num_workers=test_cfg["num_workers"])


print(test_dataloader)
print(len(test_dataset))
print(len(test_dataloader))

## Build model

In [None]:
def build_model(cfg: Dict) -> torch.nn.Module:
    # load pre-trained Conv2D model
    model = resnet50(pretrained=False)

    # change input channels number to match the rasterizer's output
    num_history_channels = (cfg["model_params"]["history_num_frames"] + 1) * 2
    num_in_channels = 3 + num_history_channels
    model.conv1 = nn.Conv2d(
        num_in_channels,
        model.conv1.out_channels,
        kernel_size=model.conv1.kernel_size,
        stride=model.conv1.stride,
        padding=model.conv1.padding,
        bias=False,
    )
    # change output size to (X, Y) * number of future states
    num_targets = 2 * cfg["model_params"]["future_num_frames"]
    model.fc = nn.Linear(in_features=2048, out_features=num_targets)

    return model

In [None]:
# ==== INIT MODEL
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = build_model(cfg).to(device)
model.load_state_dict(torch.load(WEIGHTS_FILE, map_location=device))
# optimizer = optim.Adam(model.parameters(), lr=1e-3)
# criterion = nn.MSELoss(reduction="none")

## Inference loop

In [None]:
model.eval()

future_coords_offsets_pd = []
timestamps = []
agent_ids = []

with torch.no_grad():
    dataiter = iter(test_dataloader)
    
    pbar = tqdm(dataiter)
    for data in pbar:

        inputs = data["image"].to(device)
        target_availabilities = data["target_availabilities"].unsqueeze(-1).to(device)
        targets = data["target_positions"].to(device)
        outputs = model(inputs).reshape(targets.shape)
        
        # convert agent coordinates into world offsets
        agents_coords = outputs.cpu().numpy().copy()
        world_from_agents = data["world_from_agent"].numpy()
        centroids = data["centroid"].numpy()
        coords_offset = transform_points(agents_coords, world_from_agents) - centroids[:, None, :2]
        
        future_coords_offsets_pd.append(coords_offset)
        timestamps.append(data["timestamp"].numpy().copy())
        agent_ids.append(data["track_id"].numpy().copy())
        
        pbar.set_description(f'RAM used: {psutil.virtual_memory().percent}%')

## Get submission file

In [None]:
write_pred_csv('submission.csv',
               timestamps=np.concatenate(timestamps),
               track_ids=np.concatenate(agent_ids),
               coords=np.concatenate(future_coords_offsets_pd))