### Predict next track using RNN from Steve

In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
from trackml.dataset import load_event
from trackml.randomize import shuffle_hits
from trackml.score import score_event

In [3]:
import os
import numpy as np
import pandas as pd
import glob
import math


import time
from utils import timeSince
from tqdm import tqdm

In [70]:
class HitGausPredictor(nn.Module):
    """
    A PyTorch module for particle track state estimation and hit prediction.

    This module is an RNN which takes a sequence of hits and produces a
    Gaussian shaped prediction for the location of the next hit.
    """

    def __init__(self, hidden_dim=5, batch_size=64, device=None):
        super(HitGausPredictor, self).__init__()
        input_dim = 3
        output_dim = 2
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        out_size = int(output_dim * (output_dim + 3) / 2)
        self.fc = nn.Linear(hidden_dim, out_size)
        self.device = device
        self.batch_size = batch_size
        

    def forward(self, x):
        """Might want to accept also the radius of the target layer."""
        input_size = x.size()

        # Initialize the LSTM hidden state
        h = (torch.zeros(self.lstm.num_layers, self.batch_size, self.lstm.hidden_size, device=self.device),
             torch.zeros(self.lstm.num_layers, self.batch_size, self.lstm.hidden_size, device=self.device))
        # Apply the LSTM module
        x, h = self.lstm(x, h)
        # Squash layer axis into batch axis
        x = x.contiguous().view(-1, x.size(-1))
#         x = x.contiguous().view(self.batch_size, -1, x.size(-1))
        # Apply linear layer
        output = self.fc(x)

        # Extract and transform the gaussian parameters
        means = output[:, :2]
        variances = output[:, 2:4].exp()
        correlations = output[:, 4].tanh()

        # Construct the covariance matrix
        covs = torch.bmm(variances[:, :, None], variances[:, None, :]).sqrt()
#         covs = torch.bmm(variances, variances.view(self.batch_size, 2, -1)).sqrt()
#         print(covs.size())
#         print(correlations.size())
        covs[:, 0, 1] = covs[:, 0, 1].clone() * correlations
        covs[:, 1, 0] = covs[:, 1, 0].clone() * correlations

        # Expand the layer axis again, just for consistency/interpretability
        means = means.contiguous().view(self.batch_size, -1, 2)
        covs = covs.contiguous().view(self.batch_size, -1, 2, 2)
        return means, covs

In [76]:
def gaus_llh_loss(outputs, targets):
    """Custom gaussian log-likelihood loss function"""
    means, covs = outputs
    means = means[:, 1:, :]
    covs = covs[:, 1:, :, :]
    print(means.size())
    print(covs.size())
    # Flatten layer axis into batch axis to use batch matrix operations
    means = means.contiguous().view(means.size(0)*means.size(1), means.size(2))
    covs = covs.contiguous().view(covs.size(0)*covs.size(1),
                                  covs.size(2), covs.size(3))
    targets = targets.contiguous().view(targets.size(0)*targets.size(1),
                                        targets.size(2))
    # Calculate the inverses of the covariance matrices
    inv_covs = torch.stack([cov.inverse() for cov in covs])
    # Calculate the residual error
    res = targets - means
    # Calculate the residual error term
    res_right = torch.bmm(inv_covs, res.unsqueeze(-1)).squeeze(-1)
    res_term = torch.bmm(res[:,None,:], res_right[:,:,None]).squeeze()
    # For the determinant term, we first have to compute the cholesky roots.
    # Testing out new differentiable functionality in pytorch 0.3
    diag_chols = torch.stack([torch.potrf(cov).diag() for cov in covs])
    #diag_chols = torch.stack([Cholesky.apply(cov).diag() for cov in covs])
    log_det = diag_chols.log().sum(1) * 2
    gllh_loss = (res_term + log_det).sum()
    return gllh_loss

In [6]:
event_prefix = 'event000001000'
hits, cells, particles, truth = load_event(os.path.join('input/train_1', event_prefix))

In [7]:
pIDs = particles[particles['nhits'] == 10]['particle_id']

In [8]:
len(pIDs)

1037

In [9]:
hits_truth = pd.merge(hits, truth, on=['hit_id'])

In [10]:
from utils import get_features
hits_truth = get_features(hits_truth)

In [10]:
hits_truth = hits_truth[ (hits_truth['eta'] < 1) & (hits_truth['eta'] > -1) ]

In [11]:
hits_truth.shape

(36655, 24)

In [15]:
ten_hits = pd.merge(truth, particles[particles['nhits'] == 10], on=['particle_id'])

In [16]:
hits_truth = pd.merge(hits, truth, on=['hit_id'])

In [20]:
hits_truth = get_features(hits_truth)

In [16]:
particles[particles['particle_id'] == 4504699138998272]

Unnamed: 0,particle_id,vx,vy,vz,px,py,pz,q,nhits
12,4504699138998272,-0.009288,0.009861,-0.077879,0.079537,0.451322,-5.25235,1,10


In [20]:
hits_truth[hits_truth['particle_id'] == 4504699138998272].values.shape[0]

0

In [17]:
pIDs = np.unique(hits_truth['particle_id'])

In [11]:
track_list = []
for pID in pIDs:
    if pID == 0:
        continue
    this_track = hits_truth[hits_truth['particle_id'] == pID][['r', 'phi', 'z']].values
    track_list.append(this_track)


In [50]:
np.array(track_list[0:64])[:, 1:, :2].shape

(64, 9, 2)

In [25]:
input_t = torch.from_numpy(np.array(track_list[0:64]))

In [51]:
target_t = torch.from_numpy(np.array(track_list[0:64])[:, 1:, :2])

In [26]:
input_t.size()

torch.Size([64, 10, 3])

In [80]:
target_t.size()

torch.Size([64, 9, 2])

In [32]:
loss_func = gaus_llh_loss

In [72]:
model = HitGausPredictor()
output = model(input_t)
print(output[0].size())
print(output[1].size())

torch.Size([64, 10, 2])
torch.Size([64, 10, 2, 2])


In [78]:
gaus_llh_loss(output, target_t)

torch.Size([64, 9, 2])
torch.Size([64, 9, 2, 2])


tensor(4.9554e+08)

In [None]:
particles[particles['nhits'] == 10]