# Required library

In [1]:
import pandas as pd, numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from pathlib import Path
%matplotlib inline

In [2]:
import itertools as it
from pathlib import Path

pd.options.display.max_columns=305

In [3]:
import lightgbm  as lgb
import gc
from sklearn.model_selection import train_test_split
import time
from tqdm.notebook import tqdm

import re,json

# Reading the data

In [4]:
DATA_ROOT = Path("../input/lyft-train-as-parquet/train")

In [5]:
def get_scene_path(scene):
    meta = "meta_{}_{}.json".format(*re.search( r"scenes_(\d+)_(\d+)", scene.stem).groups())
    with open(DATA_ROOT/meta) as f:
        meta = json.load(f)
    frame = DATA_ROOT/meta["frames"]["results"]["filename"]
    agent = DATA_ROOT/meta["agents"]["results"]["filename"]
    return (scene, frame, agent)

In [6]:
SCENES = np.array(list(DATA_ROOT.glob("scenes_*.parquet.snappy")))
SCENES = SCENES[np.random.permutation(len(SCENES))]
print("NB SCENES:", len(SCENES))
scene = SCENES[0]
scene

NB SCENES: 33


PosixPath('../input/lyft-train-as-parquet/train/scenes_000500_001000_500.parquet.snappy')

In [7]:
get_scene_path(scene)

(PosixPath('../input/lyft-train-as-parquet/train/scenes_000500_001000_500.parquet.snappy'),
 PosixPath('../input/lyft-train-as-parquet/train/frames_00124167_00248301_124134.parquet.snappy'),
 PosixPath('../input/lyft-train-as-parquet/train/agents_0010019001_0020042294_10023293.parquet.snappy'))

In [8]:
reader = pd.read_parquet

In [9]:
def merge(scenes, frames, agents, shift, verbose=False):
    df = scenes.merge(frames, on = "scene_db_id")
    df = df.merge(agents, on="frame_db_id")

    shift_cols = [
            "centroid_x", "centroid_y",
            "yaw",
            "velocity_x","velocity_y",
            "nagents","nlights",
            'extent_x', 'extent_y','extent_z',
            'label_probabilities_PERCEPTION_LABEL_UNKNOWN',
            'label_probabilities_PERCEPTION_LABEL_CAR', 
            'label_probabilities_PERCEPTION_LABEL_CYCLIST', 
            'label_probabilities_PERCEPTION_LABEL_PEDESTRIAN',
    ]
    new_shift_cols = ["centroid_xs", "centroid_ys"] + shift_cols[2:]

    df[new_shift_cols] = df.groupby(["scene_db_id", "track_id"])[shift_cols].shift(shift)
    nulls = df[["centroid_xs", "centroid_ys", "yaw", "velocity_x","velocity_y"]].isnull().any(1)
    shape0 =  df.shape
    df = df[~nulls]
    
    if verbose:
        print("SHAPE0:", shape0)
        print("nulls ratio:", nulls.sum()/shape0[0])
    
    return df

In [10]:
def read_all(shift=1, max_len=1e5):
    """
    Read parquet files from the SCENES list until the df's size is greater than `max_len`.
    
    If you want better accuracy, you need to increase `max_len`.
    With `max_len=12e6` I got a score of 200.xxx.
    But note that training time increases as max_len increases.
    """
    dfs = None
    scenes = SCENES[np.random.permutation(len(SCENES))]
    for scene in scenes:
        SCENES_FILE,FRAMES_FILE,AGENTS_FILE = get_scene_path(scene)
        
        scenes = reader(SCENES_FILE)
        
        frames = reader(FRAMES_FILE)
        frames["nagents"] = frames["agent_index_interval_end"] - frames["agent_index_interval_start"]
        frames["nlights"] = frames["traffic_light_faces_index_interval_end"
                                  ] - frames["traffic_light_faces_index_interval_start"]
    
        agents = reader(AGENTS_FILE)
        agents.rename(columns = {"agent_id": "agent_db_id"}, inplace=True)
        
        df = merge(scenes, frames, agents, shift=shift)
        
        dfs = df if dfs is None else dfs.append(df)
        dfs.reset_index(inplace=True, drop=True)
        
        if len(dfs) > max_len:
            break
    
    return dfs

# LGBM trainer

In [11]:
def lgbm_trainer(shift=1, root=None, params=None):
    t0 = time.strftime("%Y%m%d%H%M%S")
    T00 = time.time()
    root = "model_{}".format(t0) if root is None else str(root)
    params = PARAMS if params is None else params
    
    df = read_all(shift=shift)
    print("df.shape:", df.shape)
    
    df_centroid = df[["centroid_x", "centroid_y"]]
    df = df[TRAIN_COLS]
    gc.collect()
    
    train_index, test_index = train_test_split(df.index.values.reshape((-1,1)),
                                           df.index.values, test_size = .20, random_state=177)[2:]
    print("\n")
    for suffix in ["x", "y"]:
        print("--> {}".format(suffix.upper()))
        target_name = "centroid_" + suffix
        target = (df_centroid[target_name] - df[target_name+"s"])
    
        train_data = lgb.Dataset(df.loc[train_index], label= target.loc[train_index])
        test_data = lgb.Dataset(df.loc[test_index], label= target.loc[test_index])

        clf = lgb.train(params,
                        train_data,
                        valid_sets = [train_data, test_data],
                        early_stopping_rounds=60, 
                        verbose_eval= 40
                       )
        
        clf.save_model("models/{}/lgbm_{}_shift_{:02d}.bin".format(root, suffix, shift))
        print('\n')
    print("elapsed: {:.5f} min".format((time.time()-T00)/60))

In [12]:
def get_time(format_="%Y-%m-%d %H:%M:%S"):
    return time.strftime(format_)

In [13]:
def train_50_shifts(root=None):
    root = root or  "model_{}".format(time.strftime("%Y%m%d%H%M%S"))
    Path("models").joinpath(root).mkdir(exist_ok=True, parents=True)
    params = PARAMS.copy()
    for shift in tqdm(list(range(50, 0, -1))):
        print('\n ******************* SHIFT {:02d} {} ***********\n'.format(shift,get_time()))
        
        if not (shift-1)%5:
            params["num_iterations"] = max(100, params["num_iterations"] - 20)
            params["num_leaves"] = max(31, params["num_leaves"] - 10)
        
        meta = {
            "TRAIN_COLS": TRAIN_COLS,
            "params": params,
            "shift": shift,
            "start": get_time(),
            "end": None
        }
        
            
        lgbm_trainer(root=root, shift=shift, params=params)
        meta["end"] = get_time()
        with open("models/{}/meta_shift_{:02d}.json".format(root, shift), "w") as f:
            json.dump(meta, f, indent=2)

# Params

**I fixed this phase.**

In [14]:
PARAMS = {
         'objective':'regression',
         'boosting': 'gbdt',
         'feature_fraction': 0.5 ,
         'scale_pos_weight' : 1/40., 
         'num_iterations' : 200,
         'learning_rate' :  0.7,
         'max_depth': 41,
         'min_data_in_leaf': 64,
         'num_leaves': 128,
         'bagging_freq' : 1,
         'bagging_fraction' : 0.8,
         'tree_learner': 'voting',
         'boost_from_average': True,
         'verbosity' : 0,
         'num_threads': 2,
         'metric' : ['mse'],
         'metric': [ "l1", "rmse"],
         "verbosity": 1,
         'reg_alpha': 0.1,
         'reg_lambda': 0.3
        }

In [15]:
# Uncomment the columns if you want more

TRAIN_COLS = [
     'ego_translation_x', 
     'ego_translation_y', 
     'ego_translation_z', 
     'ego_rotation_xx', 
     'ego_rotation_xy', 
     'ego_rotation_xz', 
     'ego_rotation_yx', 
     'ego_rotation_yy', 
     'ego_rotation_yz', 
     'ego_rotation_zx', 
     'ego_rotation_zy', 
     'ego_rotation_zz', 
    'extent_x', 
    'extent_y', 
    'extent_z', 
    'velocity_x', 
    'velocity_y', 
    'label_probabilities_PERCEPTION_LABEL_UNKNOWN', 
    'label_probabilities_PERCEPTION_LABEL_CAR', 
    'label_probabilities_PERCEPTION_LABEL_CYCLIST', 
    'label_probabilities_PERCEPTION_LABEL_PEDESTRIAN', 
    'yaw', 
    'nagents', 
    'nlights', 
    'centroid_xs', 
    'centroid_ys',
]

In [16]:
print("len(TRAIN_COLS):", len(TRAIN_COLS))

len(TRAIN_COLS): 26


# Training

In [17]:
%%time

# Train 50x2 lgbm models (50 time dimensions X 2 space dimensions)
# Save it as lgbm_{x or y}_shift_{i:02d}
# Each model has it's own meta_shift_{i:02d} which contains the model's params
# You can juts feed the ouputs as inputs to the inference kernel
# The inference kernel is at https://www.kaggle.com/kneroma/lgbm-on-lyft-tabular-data-inference

train_50_shifts("lyft_lgbm_model")

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))


 ******************* SHIFT 50 2020-09-10 00:53:26 ***********

df.shape: (2471188, 42)


--> X




Training until validation scores don't improve for 60 rounds
[40]	training's l1: 1.07943	training's rmse: 2.90353	valid_1's l1: 1.11335	valid_1's rmse: 3.01955
[80]	training's l1: 0.985603	training's rmse: 2.61322	valid_1's l1: 1.03978	valid_1's rmse: 2.79593
[120]	training's l1: 0.934208	training's rmse: 2.41761	valid_1's l1: 1.00404	valid_1's rmse: 2.65498
[160]	training's l1: 0.895033	training's rmse: 2.28991	valid_1's l1: 0.978148	valid_1's rmse: 2.561
[200]	training's l1: 0.870624	training's rmse: 2.18883	valid_1's l1: 0.965831	valid_1's rmse: 2.50831
Did not meet early stopping. Best iteration is:
[200]	training's l1: 0.870624	training's rmse: 2.18883	valid_1's l1: 0.965831	valid_1's rmse: 2.50831


--> Y
Training until validation scores don't improve for 60 rounds
[40]	training's l1: 1.22952	training's rmse: 3.32966	valid_1's l1: 1.27323	valid_1's rmse: 3.46945
[80]	training's l1: 1.11372	training's rmse: 2.92384	valid_1's l1: 1.18006	valid_1's rmse: 3.1379
[120]	training's l1: 

[40]	training's l1: 1.07837	training's rmse: 2.95606	valid_1's l1: 1.11205	valid_1's rmse: 3.08554
[80]	training's l1: 0.975647	training's rmse: 2.64039	valid_1's l1: 1.02589	valid_1's rmse: 2.83151
[120]	training's l1: 0.926137	training's rmse: 2.45303	valid_1's l1: 0.990443	valid_1's rmse: 2.6957
[160]	training's l1: 0.890204	training's rmse: 2.33053	valid_1's l1: 0.965535	valid_1's rmse: 2.60766
Did not meet early stopping. Best iteration is:
[180]	training's l1: 0.876156	training's rmse: 2.27435	valid_1's l1: 0.957669	valid_1's rmse: 2.57028


elapsed: 1.94163 min

 ******************* SHIFT 44 2020-09-10 01:04:55 ***********

df.shape: (2575502, 42)


--> X
Training until validation scores don't improve for 60 rounds
[40]	training's l1: 0.924798	training's rmse: 2.48672	valid_1's l1: 0.960232	valid_1's rmse: 2.62063
[80]	training's l1: 0.858858	training's rmse: 2.25912	valid_1's l1: 0.90857	valid_1's rmse: 2.45073
[120]	training's l1: 0.818206	training's rmse: 2.0883	valid_1's l1:

[80]	training's l1: 0.841421	training's rmse: 2.28047	valid_1's l1: 0.881937	valid_1's rmse: 2.44195
[120]	training's l1: 0.801813	training's rmse: 2.15477	valid_1's l1: 0.852689	valid_1's rmse: 2.34936
[160]	training's l1: 0.776653	training's rmse: 2.06096	valid_1's l1: 0.835878	valid_1's rmse: 2.28387
Did not meet early stopping. Best iteration is:
[160]	training's l1: 0.776653	training's rmse: 2.06096	valid_1's l1: 0.835878	valid_1's rmse: 2.28387


elapsed: 2.00380 min

 ******************* SHIFT 38 2020-09-10 01:16:48 ***********

df.shape: (3213335, 42)


--> X
Training until validation scores don't improve for 60 rounds
[40]	training's l1: 0.706112	training's rmse: 1.98216	valid_1's l1: 0.725745	valid_1's rmse: 2.06991
[80]	training's l1: 0.651132	training's rmse: 1.80335	valid_1's l1: 0.681357	valid_1's rmse: 1.93453
[120]	training's l1: 0.622435	training's rmse: 1.70492	valid_1's l1: 0.661221	valid_1's rmse: 1.86741
[160]	training's l1: 0.602493	training's rmse: 1.62815	valid_

Did not meet early stopping. Best iteration is:
[140]	training's l1: 0.571142	training's rmse: 1.52072	valid_1's l1: 0.605089	valid_1's rmse: 1.66715


--> Y
Training until validation scores don't improve for 60 rounds
[40]	training's l1: 0.741816	training's rmse: 2.07502	valid_1's l1: 0.761184	valid_1's rmse: 2.16063
[80]	training's l1: 0.687281	training's rmse: 1.88116	valid_1's l1: 0.714152	valid_1's rmse: 1.9969
[120]	training's l1: 0.657753	training's rmse: 1.77633	valid_1's l1: 0.692263	valid_1's rmse: 1.9188
Did not meet early stopping. Best iteration is:
[140]	training's l1: 0.647908	training's rmse: 1.74177	valid_1's l1: 0.684998	valid_1's rmse: 1.89303


elapsed: 1.99705 min

 ******************* SHIFT 31 2020-09-10 01:30:24 ***********

df.shape: (3524461, 42)


--> X
Training until validation scores don't improve for 60 rounds
[40]	training's l1: 0.608243	training's rmse: 1.69922	valid_1's l1: 0.618185	valid_1's rmse: 1.73927
[80]	training's l1: 0.570354	training's rmse: 1.

Did not meet early stopping. Best iteration is:
[100]	training's l1: 0.519616	training's rmse: 1.41431	valid_1's l1: 0.534958	valid_1's rmse: 1.48254


elapsed: 1.85467 min

 ******************* SHIFT 24 2020-09-10 01:43:18 ***********

df.shape: (4139582, 42)


--> X
Training until validation scores don't improve for 60 rounds
[40]	training's l1: 0.486322	training's rmse: 1.32617	valid_1's l1: 0.494493	valid_1's rmse: 1.37082
[80]	training's l1: 0.462866	training's rmse: 1.24498	valid_1's l1: 0.475256	valid_1's rmse: 1.30592
Did not meet early stopping. Best iteration is:
[100]	training's l1: 0.454841	training's rmse: 1.22395	valid_1's l1: 0.469007	valid_1's rmse: 1.29098


--> Y
Training until validation scores don't improve for 60 rounds
[40]	training's l1: 0.568633	training's rmse: 1.57282	valid_1's l1: 0.580051	valid_1's rmse: 1.6241
[80]	training's l1: 0.533855	training's rmse: 1.44473	valid_1's l1: 0.549677	valid_1's rmse: 1.51474
Did not meet early stopping. Best iteration is:


Training until validation scores don't improve for 60 rounds
[40]	training's l1: 0.387699	training's rmse: 1.05894	valid_1's l1: 0.390721	valid_1's rmse: 1.07406
[80]	training's l1: 0.36745	training's rmse: 0.989396	valid_1's l1: 0.372358	valid_1's rmse: 1.0129
Did not meet early stopping. Best iteration is:
[100]	training's l1: 0.361787	training's rmse: 0.969595	valid_1's l1: 0.367775	valid_1's rmse: 0.996887


elapsed: 2.10880 min

 ******************* SHIFT 15 2020-09-10 02:01:01 ***********

df.shape: (5115979, 42)


--> X
Training until validation scores don't improve for 60 rounds
[40]	training's l1: 0.32214	training's rmse: 0.853925	valid_1's l1: 0.326127	valid_1's rmse: 0.876814
[80]	training's l1: 0.312226	training's rmse: 0.810789	valid_1's l1: 0.317966	valid_1's rmse: 0.840881
Did not meet early stopping. Best iteration is:
[100]	training's l1: 0.307885	training's rmse: 0.792572	valid_1's l1: 0.314299	valid_1's rmse: 0.825564


--> Y
Training until validation scores don't im

df.shape: (6833344, 42)


--> X
Training until validation scores don't improve for 60 rounds
[40]	training's l1: 0.196028	training's rmse: 0.480523	valid_1's l1: 0.197984	valid_1's rmse: 0.488858
[80]	training's l1: 0.192027	training's rmse: 0.46142	valid_1's l1: 0.194649	valid_1's rmse: 0.473037
Did not meet early stopping. Best iteration is:
[100]	training's l1: 0.191004	training's rmse: 0.455379	valid_1's l1: 0.193918	valid_1's rmse: 0.468475


--> Y
Training until validation scores don't improve for 60 rounds
[40]	training's l1: 0.220376	training's rmse: 0.552069	valid_1's l1: 0.222662	valid_1's rmse: 0.561401
[80]	training's l1: 0.211409	training's rmse: 0.521152	valid_1's l1: 0.214354	valid_1's rmse: 0.533527
Did not meet early stopping. Best iteration is:
[100]	training's l1: 0.209761	training's rmse: 0.513909	valid_1's l1: 0.213026	valid_1's rmse: 0.527639


elapsed: 2.65637 min

 ******************* SHIFT 06 2020-09-10 02:22:18 ***********

df.shape: (7129324, 42)


--> X
Trai

# Loading the test set as CSV

In [18]:
# Here, I'm gonna load the test, it contains `71122` rows as expected
df = pd.read_csv("../input/lyft-test-set-as-csv/Lyft_test_set.csv")
print("df.shape:", df.shape)
df.head(10)

df.shape: (71122, 45)


Unnamed: 0,frame_index_interval_start,frame_index_interval_end,host,start_time,end_time,scene_db_id,timestamp,agent_index_interval_start,agent_index_interval_end,traffic_light_faces_index_interval_start,traffic_light_faces_index_interval_end,ego_translation_x,ego_translation_y,ego_translation_z,ego_rotation_xx,ego_rotation_xy,ego_rotation_xz,ego_rotation_yx,ego_rotation_yy,ego_rotation_yz,ego_rotation_zx,ego_rotation_zy,ego_rotation_zz,frame_db_id,frame_rank,nagents,nlights,centroid_x,centroid_y,extent_x,extent_y,extent_z,yaw,velocity_x,velocity_y,track_id,label_probabilities_PERCEPTION_LABEL_UNKNOWN,label_probabilities_PERCEPTION_LABEL_CAR,label_probabilities_PERCEPTION_LABEL_CYCLIST,label_probabilities_PERCEPTION_LABEL_PEDESTRIAN,agent_db_id,nframes,centroid_xs,centroid_ys,yaw_s
0,0,100,host-a118,1578605997692823296,1578606022692823296,0,1578606007801600134,8302,8385,48,54,536.820496,-2387.457275,288.328339,0.440643,-0.897361,0.024005,0.897383,0.441029,0.014044,-0.02319,0.015353,0.999613,99,99,83,6,540.265198,-2399.707275,4.566169,2.11553,1.517801,2.021556,-0.653267,1.878533,2,0.0,1.0,0.0,0.0,0,1,540.265198,-2399.707275,2.021556
1,100,200,host-a118,1578606022692823296,1578606047692823296,1,1578606032802467516,16433,16522,609,615,624.529236,-2269.654053,285.190308,0.549425,-0.834528,0.041163,0.83552,0.548374,-0.03454,0.006252,0.053369,0.998555,199,99,89,6,611.142456,-2279.240479,4.403694,1.703298,1.464434,-2.121913,-6.067945,-9.134871,4,0.0,1.0,0.0,0.0,1,1,611.142456,-2279.240479,-2.121913
2,100,200,host-a118,1578606022692823296,1578606047692823296,1,1578606032802467516,16433,16522,609,615,624.529236,-2269.654053,285.190308,0.549425,-0.834528,0.041163,0.83552,0.548374,-0.03454,0.006252,0.053369,0.998555,199,99,89,6,618.16571,-2279.276123,3.317446,2.120635,1.576786,0.99017,1.571839,2.408892,5,0.0,1.0,0.0,0.0,2,1,618.16571,-2279.276123,0.99017
3,100,200,host-a118,1578606022692823296,1578606047692823296,1,1578606032802467516,16433,16522,609,615,624.529236,-2269.654053,285.190308,0.549425,-0.834528,0.041163,0.83552,0.548374,-0.03454,0.006252,0.053369,0.998555,199,99,89,6,640.655945,-2233.32373,3.985592,2.013386,1.44821,-2.153157,-3.022248,-4.823001,81,0.0,1.0,0.0,0.0,3,1,640.655945,-2233.32373,-2.153157
4,100,200,host-a118,1578606022692823296,1578606047692823296,1,1578606032802467516,16433,16522,609,615,624.529236,-2269.654053,285.190308,0.549425,-0.834528,0.041163,0.83552,0.548374,-0.03454,0.006252,0.053369,0.998555,199,99,89,6,597.227966,-2236.942627,4.591854,2.620991,2.097353,-0.571951,-0.19607,-0.110046,130,0.0,1.0,0.0,0.0,4,1,597.227966,-2236.942627,-0.571951
5,200,300,host-a118,1578606047692823296,1578606072692823296,2,1578606057802432986,24831,24926,615,615,780.240417,-2032.252686,293.68866,0.536894,-0.839494,0.083634,0.840752,0.540624,0.029371,-0.069871,0.054546,0.996064,299,99,95,0,765.511963,-2054.821533,2.404521,2.111213,1.467414,1.007352,7.138762,10.837037,1,0.0,1.0,0.0,0.0,5,1,765.511963,-2054.821533,1.007352
6,300,400,host-a118,1578606072692823296,1578606097692823296,3,1578606082801997166,31374,31419,615,615,963.404419,-1752.821533,269.857697,0.538929,-0.840125,0.061198,0.841174,0.540593,0.01361,-0.044517,0.044143,0.998033,399,99,45,0,950.780701,-1772.435181,2.932318,2.122027,1.70645,1.000485,7.356675,11.042892,1,0.0,1.0,0.0,0.0,6,1,950.780701,-1772.435181,1.000485
7,300,400,host-a118,1578606072692823296,1578606097692823296,3,1578606082801997166,31374,31419,615,615,963.404419,-1752.821533,269.857697,0.538929,-0.840125,0.061198,0.841174,0.540593,0.01361,-0.044517,0.044143,0.998033,399,99,45,0,973.369202,-1748.802002,5.550737,2.02397,2.039209,2.27994,-0.055759,-0.040675,707,0.0,1.0,0.0,0.0,7,1,973.369202,-1748.802002,2.27994
8,400,500,host-a118,1578606097692823296,1578606122692823296,4,1578606107802708166,38282,38341,615,615,938.855713,-1526.349243,267.135193,-0.849012,-0.52823,0.012313,0.528097,-0.847585,0.052084,-0.017076,0.050723,0.998567,499,99,59,0,950.528564,-1533.730347,3.302427,1.962232,1.541971,2.580472,-8.717098,5.503423,1,0.0,1.0,0.0,0.0,8,1,950.528564,-1533.730347,2.580472
9,500,600,host-a118,1578606122692823296,1578606147692823296,5,1578606132802907546,46257,46301,1506,1515,729.780579,-1393.805176,267.060516,-0.848714,-0.528826,0.005275,0.528796,-0.848435,0.023091,-0.007736,0.022387,0.999719,599,99,44,9,751.761108,-1407.906738,2.567837,2.094555,1.429697,2.581488,-9.093983,5.740898,1,0.0,1.0,0.0,0.0,9,1,751.761108,-1407.906738,2.581488


# Loading the LGBM models

In [19]:
def get_model_name(filename):
    return re.search("^(lgbm_[x,y]_shift_\d+)", filename).group(1)

In [20]:
def get_models(path):
    models = {}
    path = Path(path)
    for model in path.glob("lgbm*"):
        model_name = get_model_name(model.stem)
        shift = int(model_name.split("shift_")[1])
        meta = path.joinpath("meta_shift_{:02d}.json".format(shift))
        with meta.open() as f:
            train_cols = json.load(f)["TRAIN_COLS"]
        models[model_name] = {"model": model.as_posix(), "train_cols": train_cols}
    return models

In [21]:
models = get_models("./models/lyft_lgbm_model")
len(models)

100

In [22]:
next(iter(models.items()))

('lgbm_y_shift_16',
 {'model': 'models/lyft_lgbm_model/lgbm_y_shift_16.bin',
  'train_cols': ['ego_translation_x',
   'ego_translation_y',
   'ego_translation_z',
   'ego_rotation_xx',
   'ego_rotation_xy',
   'ego_rotation_xz',
   'ego_rotation_yx',
   'ego_rotation_yy',
   'ego_rotation_yz',
   'ego_rotation_zx',
   'ego_rotation_zy',
   'ego_rotation_zz',
   'extent_x',
   'extent_y',
   'extent_z',
   'velocity_x',
   'velocity_y',
   'label_probabilities_PERCEPTION_LABEL_UNKNOWN',
   'label_probabilities_PERCEPTION_LABEL_CAR',
   'label_probabilities_PERCEPTION_LABEL_CYCLIST',
   'label_probabilities_PERCEPTION_LABEL_PEDESTRIAN',
   'yaw',
   'nagents',
   'nlights',
   'centroid_xs',
   'centroid_ys']})

## Make prediction for the test set

In [23]:
def make_colnames():
    xcols = ["coord_x{}{}".format(step, rank) for step in range(3) for rank in range(50)]
    ycols = ["coord_y{}{}".format(step, rank) for step in range(3) for rank in range(50)]
    cols = ["timestamp", "track_id"] + ["conf_0", "conf_1", "conf_2"] + list(it.chain(*zip(xcols, ycols)))
    return cols

In [24]:
def predict(models, df):
    sub = np.empty((len(df), 305))
    sub.fill(np.nan)
    sub = pd.DataFrame(sub, columns = make_colnames())
    sub[["timestamp", "track_id"]] = df[["timestamp", "track_id"]]
    sub["conf_0"] = 1.0
    
    for shift in range(1, 51):
        for suffix in ["x", "y"]:
            model_info = models["lgbm_{}_shift_{:02d}".format(suffix, shift)]
                
            model = lgb.Booster(model_file= model_info["model"])
            pred = model.predict(df[model_info["train_cols"]])
            
            sub["coord_{}0{}".format(suffix, shift-1)] = pred

        if not shift%10:
            print("shift: {}".format(shift))
    
    sub.fillna(0., inplace=True)
    
    return sub

In [25]:
sub = predict(models, df)

shift: 10
shift: 20
shift: 30
shift: 40
shift: 50


In [26]:
sub.iloc[:50, :105]

Unnamed: 0,timestamp,track_id,conf_0,conf_1,conf_2,coord_x00,coord_y00,coord_x01,coord_y01,coord_x02,coord_y02,coord_x03,coord_y03,coord_x04,coord_y04,coord_x05,coord_y05,coord_x06,coord_y06,coord_x07,coord_y07,coord_x08,coord_y08,coord_x09,coord_y09,coord_x010,coord_y010,coord_x011,coord_y011,coord_x012,coord_y012,coord_x013,coord_y013,coord_x014,coord_y014,coord_x015,coord_y015,coord_x016,coord_y016,coord_x017,coord_y017,coord_x018,coord_y018,coord_x019,coord_y019,coord_x020,coord_y020,coord_x021,coord_y021,coord_x022,coord_y022,coord_x023,coord_y023,coord_x024,coord_y024,coord_x025,coord_y025,coord_x026,coord_y026,coord_x027,coord_y027,coord_x028,coord_y028,coord_x029,coord_y029,coord_x030,coord_y030,coord_x031,coord_y031,coord_x032,coord_y032,coord_x033,coord_y033,coord_x034,coord_y034,coord_x035,coord_y035,coord_x036,coord_y036,coord_x037,coord_y037,coord_x038,coord_y038,coord_x039,coord_y039,coord_x040,coord_y040,coord_x041,coord_y041,coord_x042,coord_y042,coord_x043,coord_y043,coord_x044,coord_y044,coord_x045,coord_y045,coord_x046,coord_y046,coord_x047,coord_y047,coord_x048,coord_y048,coord_x049,coord_y049
0,1578606007801600134,2,1.0,0.0,0.0,-0.048895,0.179803,-0.04567,0.377115,-0.144884,0.515959,-0.065796,0.748473,-0.19598,0.739119,-0.373246,0.928482,-0.306198,1.38315,-0.286839,2.305614,-0.362774,1.866656,-0.220183,1.432894,-0.359163,1.811864,-0.554155,2.355518,-0.550837,2.28819,-0.328491,2.73084,0.354419,2.822388,-0.646804,3.234439,-1.738817,4.780672,-1.394794,3.591307,-1.008364,5.347291,-0.989527,2.209443,0.187658,3.373502,6.102296,2.852093,-1.764036,2.984548,3.058961,3.4637,-0.827757,3.94188,-2.393538,2.086826,1.850307,2.830377,2.408549,6.25111,-1.260731,7.029381,-0.491816,8.195442,-0.797193,7.45278,1.817005,7.988579,-0.648016,5.52593,3.179374,5.762448,1.020652,10.172624,1.060045,7.054455,9.059905,4.724983,12.980371,6.848873,1.979442,7.644498,6.351219,7.972337,-0.329545,2.410224,4.462005,11.283508,2.842552,5.895009,7.557912,9.672128,3.381531,19.037956,2.893692,8.829442,13.438974,6.199169,8.49032,-0.703562,2.064135,16.18582,-0.977369,8.780082
1,1578606032802467516,4,1.0,0.0,0.0,-0.559025,-0.857375,-1.159505,-2.110268,-1.700762,-2.918715,-2.196199,-4.092315,-2.825352,-4.597891,-3.576659,-7.405145,-4.249512,-5.925467,-4.49686,-6.503882,-4.77579,-9.280376,-6.081302,-10.384481,-6.236717,-9.952437,-6.560545,-9.297873,-7.608847,-14.348619,-7.862379,-13.522077,-6.714513,-14.877366,-7.175815,-14.171018,-7.123585,-22.87233,-7.455385,-20.299564,-10.350966,-17.17858,-10.720386,-23.931099,-11.133752,-18.604854,-10.794114,-20.655236,-12.14693,-24.623833,-10.824507,-25.649701,-12.227751,-21.167439,-12.746836,-22.95176,-14.814852,-19.821964,-9.055152,-19.225835,-15.949231,-26.951203,-17.241254,-34.361443,-15.366002,-17.342,-16.607844,-28.909818,-18.099169,-36.152203,-13.420998,-38.131996,-21.026717,-47.43775,-12.211731,-27.532714,-13.827487,-26.175475,-16.998927,-26.518629,-9.436468,-38.918101,-17.043013,-44.908545,-15.52663,-28.522702,-9.374104,-45.352994,-17.359855,-20.944808,-18.081517,-41.589194,-6.641307,-35.227119,-17.728567,-37.454451,2.645494,-54.502114,-27.49616,-36.88193,-12.3264,-30.117283,-11.458929,-23.084154
2,1578606032802467516,5,1.0,0.0,0.0,0.122341,0.216781,0.20818,0.404829,0.403633,0.783736,0.72406,0.669929,0.551755,1.060938,0.920436,1.576718,0.981454,1.48815,0.969786,1.846385,1.258034,1.020465,1.211023,2.171476,1.789011,2.230304,1.397672,3.129937,1.770164,2.059197,3.048321,2.022437,1.83162,2.084304,3.194683,2.300816,2.940655,4.131199,3.037212,5.028245,3.497013,2.388504,2.620856,2.559893,4.686587,6.345338,4.504562,1.86502,3.928072,5.593215,3.774285,5.80106,2.918834,3.107344,3.182279,5.506003,2.315717,5.620843,3.279808,6.429001,2.352812,5.809637,3.532318,4.298925,5.279031,5.793444,5.295102,7.068222,4.432748,3.057317,8.558746,3.404915,5.077504,5.35855,5.244388,3.197626,10.228386,5.234457,3.955007,7.012596,2.123226,5.73259,3.697463,16.097822,0.186766,8.198029,3.949363,6.457021,5.222399,16.767973,0.451438,13.09751,4.132546,5.271446,8.288266,-0.950029,3.350661,5.69348,2.303098,5.62166,12.484113,36.139691,4.989013,-2.726769
3,1578606032802467516,81,1.0,0.0,0.0,-0.29128,-0.570898,-0.623835,-0.977753,-0.821743,-1.28539,-1.076496,-2.038424,-1.449349,-1.851331,-1.314473,-3.832267,-2.068941,-3.438714,-2.151094,-4.165846,-2.322308,-3.467725,-2.593183,-4.185826,-2.501591,-3.850775,-2.534167,-3.148353,-3.764109,-5.70715,-2.63193,-6.394526,-3.30659,-4.471523,-3.806168,-5.679148,-4.115186,-7.513892,-2.205181,-5.469051,-2.394459,-7.532274,-3.065418,-7.443585,-4.544061,-9.588273,-4.674595,-6.067659,-4.85891,-10.0432,-1.662201,-10.418001,-1.466,-6.176119,-5.051021,-11.874823,-8.227838,-10.852094,-0.501815,-3.275828,0.446262,-7.613092,-7.027772,-5.908853,-7.810498,-7.516358,-3.682968,-15.531998,-6.639178,-15.363721,-2.987972,-8.438939,-3.53752,-13.625877,-5.71675,-11.557963,-1.521655,-3.75394,3.3552,-3.591871,-5.349416,-13.181064,1.627762,-0.773298,-5.246123,-10.845508,-1.698526,-23.805406,-1.297945,-0.493956,0.226942,-9.368753,-2.980547,-14.702935,5.230012,-23.522101,-5.613129,-9.065028,-7.695289,-12.539351,0.594069,2.72312,13.167135,6.305532
4,1578606032802467516,130,1.0,0.0,0.0,0.005855,0.006238,0.013019,-0.000842,0.027652,0.02091,0.026302,0.046529,0.013168,-0.016068,0.094944,0.027748,0.025978,-0.196384,0.00483,0.047663,0.377515,0.086367,0.14257,0.087209,0.128125,-0.001243,-0.037274,-0.019255,1.890705,-0.238282,0.217228,-0.625712,0.245979,-0.008668,-0.151698,-0.251578,0.065712,1.664764,0.320609,0.414878,0.234629,0.108891,3.131761,-0.109296,1.059539,0.102701,-0.165428,-1.313488,-0.277174,-0.183782,-0.156079,0.384843,0.480236,3.21799,0.111543,-6.859628,-0.960128,-1.186441,0.09389,-1.551257,0.541307,0.568208,0.512416,-2.100219,-0.181047,0.382011,-2.789418,1.242455,-3.274055,-0.722798,-0.097156,1.181872,-1.745813,-4.569001,3.124139,2.443337,9.400385,24.423621,-1.413519,1.63306,-0.314316,4.706463,0.025402,0.95466,-0.245963,-0.515337,2.481488,-1.850436,-0.470462,0.675025,-1.44003,-2.333931,-0.667884,-0.239008,3.544885,2.031381,2.77747,-0.941926,1.887734,1.098844,4.547872,7.882355,0.616521,-1.159526
5,1578606057802432986,1,1.0,0.0,0.0,0.685172,1.08702,1.451207,2.136413,2.066756,3.17066,2.850837,4.179632,3.480826,5.343474,4.338201,7.012667,5.032257,8.071541,5.626432,8.305968,6.300646,9.08349,6.823949,11.117105,8.017012,11.900879,8.478852,13.253293,10.244643,15.142612,9.450481,12.569777,10.449458,16.222447,11.843123,18.18587,12.157187,18.735992,12.789302,19.67828,13.560249,18.860032,14.560822,22.05875,15.163713,22.466379,17.074557,24.505877,16.266387,23.762301,17.254392,27.706612,17.596574,27.408821,18.528798,29.677111,18.430761,27.659706,20.704213,27.946345,19.802769,31.829615,21.421697,33.129776,23.067866,34.557657,24.433246,38.561309,24.08545,38.316322,26.256986,36.266147,25.646705,38.45713,26.804205,42.683258,26.987257,40.722269,25.960909,42.744812,28.332236,43.995546,30.68871,44.010864,28.952267,45.118631,29.501521,47.296125,31.350948,51.202853,31.28337,48.143758,32.498159,48.37226,33.959443,50.191503,36.097492,52.736075,35.778007,49.43804,38.058017,56.384211,36.654179,58.394013
6,1578606082801997166,1,1.0,0.0,0.0,0.76478,1.090124,1.451207,2.472431,2.227972,3.126823,2.908345,4.350925,3.802306,5.685184,4.304634,6.109752,5.180475,7.804081,5.88933,9.013562,6.536467,9.571542,6.971075,10.857978,7.967629,12.713835,8.961798,13.754778,9.381886,16.017237,10.197285,15.286583,10.910009,18.427836,11.539618,17.330828,12.384454,18.553633,13.081989,24.458749,13.566873,19.440531,14.567851,22.701278,14.165462,22.413628,16.56541,24.313901,16.074244,22.21137,17.562061,27.556688,18.120245,27.729466,18.920059,28.83633,19.762089,29.099082,21.63312,30.510003,24.472182,32.532259,22.516102,33.757265,22.384278,34.620017,23.830706,37.103776,26.092194,35.079423,25.57171,39.042148,26.053182,36.846143,28.060315,39.369048,28.49598,31.056588,24.335421,41.324971,28.469773,45.271915,28.905234,42.958826,32.328393,44.864559,30.874786,42.505445,30.265656,49.204556,30.877892,46.134928,31.631272,48.744176,37.386532,55.158708,36.158416,46.574825,35.31161,48.937572,34.354631,50.751105,36.294581,57.01709
7,1578606082801997166,707,1.0,0.0,0.0,-0.001034,0.056029,-0.001485,0.614208,0.004965,0.130047,-0.020829,0.579237,-0.001103,0.20272,0.024493,0.155795,-0.905594,0.203741,-0.201275,0.812225,0.026989,-0.483712,0.075984,0.539007,-0.294516,0.466386,0.09901,1.431776,0.195963,2.337538,0.074655,0.241772,-0.692015,-1.021136,3.431986,0.566798,-0.063579,0.20712,-0.2832,1.751893,0.592066,-0.22819,1.041103,0.52275,0.444472,0.433934,1.176579,0.767137,5.306924,-0.03274,0.599298,1.831277,-0.727705,1.947498,0.124232,4.356371,-2.155862,2.551245,3.168683,3.760351,5.98311,1.527561,1.418219,5.280186,-4.751219,5.492207,2.320739,3.756216,-0.003145,0.604196,-1.35406,3.22041,1.252581,5.189746,1.138575,0.804833,-0.991817,9.296036,-0.590445,2.256178,1.237062,4.054149,3.228264,19.890878,-0.833867,1.315201,3.668942,7.622151,1.536567,1.068037,-1.599017,20.936801,3.098573,5.570529,8.035646,19.112801,-3.309123,12.7651,2.977716,4.812453,-15.972722,11.836547,1.455936,0.984835
8,1578606107802708166,1,1.0,0.0,0.0,-0.853549,0.675151,-1.790587,1.432095,-2.599071,1.81336,-3.489244,2.434728,-4.256335,2.72501,-4.993893,3.520512,-6.188589,4.727529,-6.876923,5.768535,-8.105277,4.525105,-7.831268,4.755929,-9.677116,6.864786,-9.888159,6.59588,-11.722921,9.12609,-12.093123,8.109848,-12.649303,6.271169,-15.257378,11.862856,-15.285931,8.961327,-16.541152,9.533418,-15.411371,9.295754,-17.152314,11.792563,-21.593095,13.815706,-18.172736,12.440085,-22.830175,14.549207,-22.163327,15.248718,-22.766415,13.122268,-20.055905,18.834407,-23.872769,15.485266,-24.477742,22.363357,-24.568554,18.669571,-27.397857,15.38819,-30.112174,16.831095,-23.442412,19.808789,-28.844826,18.388544,-32.78608,18.526314,-31.027713,19.025498,-31.454484,19.424061,-37.527439,23.020668,-38.564418,16.190201,-40.236289,16.85475,-38.153719,19.930415,-34.80553,23.346115,-34.692321,22.081154,-43.519452,25.93052,-38.163908,19.079319,-36.268322,25.321607,-47.31953,18.621353,-43.665544,22.153034,-42.392502,29.107474,-40.673129,27.056603,-51.987565,25.759943
9,1578606132802907546,1,1.0,0.0,0.0,-0.899686,0.62449,-1.867622,1.637785,-2.697706,1.656493,-3.489244,1.772372,-4.256335,2.879409,-5.542234,3.697323,-6.083922,4.538214,-7.001304,4.951432,-8.383515,4.710537,-8.685777,5.97431,-9.652912,5.339196,-10.372789,6.039274,-10.930331,4.602924,-12.35168,7.431601,-13.227059,8.339168,-14.029262,8.463667,-15.500352,9.43051,-16.091877,9.929401,-16.324591,9.154336,-16.815621,11.209941,-17.161797,9.964109,-18.462997,12.392828,-17.931497,12.06151,-23.004169,9.664185,-22.362102,12.622881,-22.043252,13.589195,-23.444147,14.457031,-23.30388,15.123038,-27.214568,18.054175,-26.207277,15.767657,-27.232948,17.337881,-27.198229,17.470994,-28.828123,16.479928,-28.67156,14.736647,-31.87853,17.661693,-29.554606,19.326357,-36.605668,19.825692,-33.581335,20.25911,-29.897321,17.662457,-29.906734,24.562749,-33.696336,25.41728,-31.089343,25.246164,-29.094098,15.747251,-35.630292,23.433195,-37.659245,21.750105,-38.284651,24.109161,-36.078008,25.332685,-37.502778,24.400676,-38.633091,20.085384,-35.945891,22.139162


In [27]:
sub.to_csv("submission.csv", index=False)