# Read data

In [1]:
import os
import typing
from tqdm import tqdm
from collections import Counter
import matplotlib.pyplot as plt
from copy import deepcopy
from scipy.interpolate import interp1d
import numpy as np
import json
import pandas as pd

In [2]:
def read_testcase_ids(dataset_path: str):
    ids = sorted([int(case_id) for case_id in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, case_id))])
    return ids

def read_metadata(metadata_path: str):
    with open(metadata_path, 'r') as f:
        data = json.load(f)
    data['tires_front'] = data['tires']['front']
    data['tires_rear'] = data['tires']['rear']
    data.pop('tires')
    return data

def create_meta_mapping(metas):
    mapping = {
    'vehicle_id': [],
    'vehicle_model': [],
    'vehicle_model_modification': [],
    'location_reference_point_id': [],
    'tires_front': [],
    'tires_rear': [],
    }
    
    for meta in metas:
        for k in mapping:
            mapping[k].append(meta[k])
    
    for k, lst in mapping.items():
        unk = Counter(lst).most_common()[0][0]
    
        v = sorted(set(lst))
        map = {str(label): idx for idx, label in enumerate(v)}
        mapping[k] = {'map': map, 'unk': str(unk)}
    return mapping

In [3]:
ROOT_DATA_FOLDER = "./dataset"

TRAIN_DATASET_PATH = os.path.join(ROOT_DATA_FOLDER, "YaCupTrain")
TEST_DATASET_PATH = os.path.join(ROOT_DATA_FOLDER, "YaCupTest")

In [4]:
train_ids = read_testcase_ids(TRAIN_DATASET_PATH)
all_train_ids = deepcopy(train_ids)
len(train_ids)

42000

In [5]:
test_ids = read_testcase_ids(TEST_DATASET_PATH)
len(test_ids)

8000

In [6]:
def read_sample(dataset_path, sample_id, is_test=False):
    sample_id = str(sample_id)
    sample = {}
    sample['localization'] = pd.read_csv(os.path.join(dataset_path, sample_id, 'localization.csv'))
    sample['control'] = pd.read_csv(os.path.join(dataset_path, sample_id, 'control.csv'))
    sample['metadata'] = reaad_metadata(os.path.join(dataset_path, sample_id, 'metadata.json'))
    if is_test:
        sample['requested_stamps'] = pd.read_csv(os.path.join(dataset_path, sample_id, 'requested_stamps.csv'))
    return sample

In [9]:
testcase_ids = train_ids

In [10]:
# metas = []
# for testcase_id in train_ids:
#     meta = read_metadata(os.path.join(TRAIN_DATASET_PATH, str(testcase_id), 'metadata.json'))
#     metas.append(meta)
# mapping = create_meta_mapping(metas)
# with open('./dataset/mapping.json', 'w') as f:
#     json.dump(mapping, f)
mapping = json.load(open('./dataset/mapping.json', 'r'))
for k, v in mapping.items():
    mapping[k]['map'] = {int(k):int(v) for k,v in v['map'].items()}
    mapping[k]['unk'] = int(v['unk'])
    assert v['unk'] in mapping[k]['map']
mapping['vehicle_model']['acc_denum_map'] = {
    0: 3000,
    1: 300,
}

In [11]:
import torch
from torch.utils.data import Dataset
from copy import deepcopy

In [12]:
NS_TO_SEC = 1e-9
SEC_TO_NS = 1e9

In [13]:
import itertools
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

## LstmSimpleAttention

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LstmEncoderDecoderWithAttention(nn.Module):
    def __init__(self, vehicle_feature_sizes, embedding_dim, localization_input_size, control_input_size, hidden_size, num_layers):
        super(LstmEncoderDecoderWithAttention, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = False 

        # Vehicle feature embeddings
        self.vehicle_id_embedding = nn.Embedding(num_embeddings=vehicle_feature_sizes['vehicle_id'], embedding_dim=embedding_dim)
        self.vehicle_model_embedding = nn.Embedding(num_embeddings=vehicle_feature_sizes['vehicle_model'], embedding_dim=embedding_dim)
        self.vehicle_model_modification_embedding = nn.Embedding(num_embeddings=vehicle_feature_sizes['vehicle_model_modification'], embedding_dim=embedding_dim)
        self.location_reference_point_id_embedding = nn.Embedding(num_embeddings=vehicle_feature_sizes['location_reference_point_id'], embedding_dim=embedding_dim)
        self.tires_front_embedding = nn.Embedding(num_embeddings=vehicle_feature_sizes['tires_front'], embedding_dim=embedding_dim)
        self.tires_rear_embedding = nn.Embedding(num_embeddings=vehicle_feature_sizes['tires_rear'], embedding_dim=embedding_dim)

        # Fully connected layer to combine vehicle features
        self.vehicle_fc = nn.Linear(embedding_dim * 6, hidden_size)

        # Attention layers
        self.attention = nn.Linear(hidden_size * 2, hidden_size) 
        self.attention_combine = nn.Linear(hidden_size * 3, hidden_size)

        # Encoder LSTM for input_localization_seq (однонаправленный)
        self.localization_encoder = nn.LSTM(
            input_size=localization_input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=self.bidirectional
        )

        # Encoder LSTM for input_control_seq (однонаправленный)
        self.control_encoder = nn.LSTM(
            input_size=control_input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=self.bidirectional
        )

        # Decoder LSTM (одинарный)
        self.decoder = nn.LSTM(
            input_size=control_input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )

        # Output layer
        self.fc_out = nn.Linear(hidden_size, localization_input_size)

    def forward(self, vehicle_features, input_localization, input_control_sequence, output_control_sequence):
        batch_size = vehicle_features.size(0)
        # print(f"Batch size: {batch_size}")

        # Embed vehicle features
        vehicle_id = self.vehicle_id_embedding(vehicle_features[:, 0])
        vehicle_model = self.vehicle_model_embedding(vehicle_features[:, 1])
        vehicle_model_modification = self.vehicle_model_modification_embedding(vehicle_features[:, 2])
        location_reference_point_id = self.location_reference_point_id_embedding(vehicle_features[:, 3])
        tires_front = self.tires_front_embedding(vehicle_features[:, 4])
        tires_rear = self.tires_rear_embedding(vehicle_features[:, 5])

        vehicle_embedded = torch.cat([
            vehicle_id,
            vehicle_model,
            vehicle_model_modification,
            location_reference_point_id,
            tires_front,
            tires_rear
        ], dim=1)  # Shape: [batch_size, embedding_dim * 6]

        vehicle_features_encoded = self.vehicle_fc(vehicle_embedded)  # Shape: [batch_size, hidden_size]

        localization_output, (hidden_loc, cell_loc) = self.localization_encoder(input_localization)
        control_output, (hidden_ctrl, cell_ctrl) = self.control_encoder(input_control_sequence)
        hidden_enc = torch.cat((hidden_loc[:self.num_layers], hidden_ctrl[:self.num_layers]), dim=2)

        # Combine cell states by averaging
        cell_enc = (cell_ctrl[:self.num_layers] + cell_loc[:self.num_layers]) / 2

        # Attention mechanism
        attention_weights = F.softmax(self.attention(hidden_enc), dim=2)  # [num_layers, batch_size, hidden_size]

        combined_hidden = torch.cat((hidden_enc, attention_weights), dim=2)
        combined_hidden = self.attention_combine(combined_hidden)
        
        # Incorporate vehicle_features_encoded into hidden state (first layer)
        combined_hidden = combined_hidden.clone()
        combined_hidden[0] = combined_hidden[0] + vehicle_features_encoded.unsqueeze(0)

        # Decoder
        decoder_output, (hidden_dec, cell_dec) = self.decoder(output_control_sequence, (combined_hidden, cell_enc))

        # Output layer
        output_localization = self.fc_out(decoder_output)  # [batch_size, seq_len, localization_input_size]
        return output_localization


## Metric

In [15]:
import numpy as np

SEGMENT_LENGTH = 1.0

def calculate_metric_on_batch(output_np, target_np, segment_length=1.0):
    """
    output_np: numpy array of shape [batch_size, seq_len, 4], predicted x, y, yaw
    target_np: numpy array of same shape, ground truth x, y, yaw

    Returns:
        metric: float, the average metric over the batch
    """
    x_pred, y_pred, yaw_pred = output_np[..., 0], output_np[..., 1], output_np[..., 2]
    x_gt, y_gt, yaw_gt = target_np[..., 0], target_np[..., 1], target_np[..., 2]

    # Compute c1 and c2 for predicted
    c1_pred = np.stack([x_pred, y_pred], axis=-1)
    c2_pred = c1_pred + segment_length * np.stack([np.cos(yaw_pred), np.sin(yaw_pred)], axis=-1)

    # Compute c1 and c2 for ground truth
    c1_gt = np.stack([x_gt, y_gt], axis=-1)
    c2_gt = c1_gt + segment_length * np.stack([np.cos(yaw_gt), np.sin(yaw_gt)], axis=-1)

    # Compute distances between corresponding points
    dist_c1 = np.linalg.norm(c1_pred - c1_gt, axis=-1)
    dist_c2 = np.linalg.norm(c2_pred - c2_gt, axis=-1)

    # Compute pose metric
    pose_metric = np.sqrt((dist_c1 ** 2 + dist_c2 ** 2) / 2.0)
    metric = np.mean(pose_metric)

    return metric

## Init model

In [16]:
import gc
try:
    del model
except:
    pass
torch.cuda.empty_cache()
gc.collect()

20

## Init

In [17]:
vehicle_feature_sizes = {k: len(v['map']) for k, v in mapping.items()}
embedding_dim = 16
localization_input_size = 6  # For example, x, y, z, roll, pitch, yaw
control_input_size = 2  # acceleration_level, steering
hidden_size = 256
num_layers = 2

In [18]:
device = 'cuda'

In [19]:
model = LstmEncoderDecoderWithAttention(
    vehicle_feature_sizes=vehicle_feature_sizes,
    embedding_dim=embedding_dim,
    localization_input_size=localization_input_size,
    control_input_size=control_input_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    # bidirectional=bidir_encoder,
).to(device)

## Lightning

In [20]:
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

  warn(f"Failed to load image Python extension: {e}")


In [21]:
class WeightedLoss(nn.Module):
    def __init__(self, criterion, weights, device='cuda'):
        super(WeightedLoss, self).__init__()
        self.weights = torch.tensor(weights).float().unsqueeze(0).to(device)
        self.criterion = criterion
        
    def forward(self, input, target):
        loss = self.criterion(input, target) 
        loss = loss * self.weights
        return loss.mean()

In [96]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LstmEncoderDecoderWithDualAttention(nn.Module):
    def __init__(self, vehicle_feature_sizes, embedding_dim, localization_input_size, control_input_size, hidden_size, num_layers):
        super(LstmEncoderDecoderWithDualAttention, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Vehicle feature embeddings
        self.vehicle_id_embedding = nn.Embedding(num_embeddings=vehicle_feature_sizes['vehicle_id'], embedding_dim=embedding_dim)
        self.vehicle_model_embedding = nn.Embedding(num_embeddings=vehicle_feature_sizes['vehicle_model'], embedding_dim=embedding_dim)
        self.vehicle_model_modification_embedding = nn.Embedding(num_embeddings=vehicle_feature_sizes['vehicle_model_modification'], embedding_dim=embedding_dim)
        self.location_reference_point_id_embedding = nn.Embedding(num_embeddings=vehicle_feature_sizes['location_reference_point_id'], embedding_dim=embedding_dim)
        self.tires_front_embedding = nn.Embedding(num_embeddings=vehicle_feature_sizes['tires_front'], embedding_dim=embedding_dim)
        self.tires_rear_embedding = nn.Embedding(num_embeddings=vehicle_feature_sizes['tires_rear'], embedding_dim=embedding_dim)

        # Fully connected layer to combine vehicle features
        self.vehicle_fc = nn.Linear(embedding_dim * 6, hidden_size * 2)  # Выходной размер: hidden_size * 2

        # Attention layers (Custom Attention)
        self.custom_attention = nn.Linear(hidden_size * 2, hidden_size) 
        self.custom_attention_combine = nn.Linear(hidden_size * 3, hidden_size * 2)  # Выходной размер: hidden_size * 2

        # Encoder LSTM for input_localization_seq
        self.localization_encoder = nn.LSTM(
            input_size=localization_input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=False,
        )

        # Encoder LSTM for input_control_seq 
        self.control_encoder = nn.LSTM(
            input_size=control_input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=False,
        )

        # Decoder LSTM
        self.decoder = nn.LSTM(
            input_size=control_input_size,
            hidden_size=hidden_size * 2,  # Выходной размер: hidden_size * 2
            num_layers=num_layers,
            batch_first=True
        )

        # Output layer
        self.fc_out = nn.Linear(hidden_size * 2, localization_input_size)  # Выходной размер: localization_input_size

    def forward(self, vehicle_features, input_localization, input_control_sequence, output_control_sequence):
        batch_size = vehicle_features.size(0)
        
        # print(f"Batch size: {batch_size}")

        # Embed vehicle features
        vehicle_id = self.vehicle_id_embedding(vehicle_features[:, 0])
        
        # print(f"vehicle_id shape: {vehicle_id.shape}")  # [batch_size, embedding_dim]

        vehicle_model = self.vehicle_model_embedding(vehicle_features[:, 1])
        # print(f"vehicle_model shape: {vehicle_model.shape}")  # [batch_size, embedding_dim]

        vehicle_model_modification = self.vehicle_model_modification_embedding(vehicle_features[:, 2])
        # print(f"vehicle_model_modification shape: {vehicle_model_modification.shape}")  # [batch_size, embedding_dim]

        location_reference_point_id = self.location_reference_point_id_embedding(vehicle_features[:, 3])
        
        # print(f"location_reference_point_id shape: {location_reference_point_id.shape}")  # [batch_size, embedding_dim]

        tires_front = self.tires_front_embedding(vehicle_features[:, 4])
        
        # print(f"tires_front shape: {tires_front.shape}")  # [batch_size, embedding_dim]

        tires_rear = self.tires_rear_embedding(vehicle_features[:, 5])
        
        # print(f"tires_rear shape: {tires_rear.shape}")  # [batch_size, embedding_dim]

        # Concatenate vehicle features
        vehicle_embedded = torch.cat([
            vehicle_id,
            vehicle_model,
            vehicle_model_modification,
            location_reference_point_id,
            tires_front,
            tires_rear
        ], dim=1)  # Shape: [batch_size, embedding_dim * 6]
        
        # print(f"vehicle_embedded shape: {vehicle_embedded.shape}")  # [batch_size, embedding_dim * 6]

        # Encode vehicle features
        vehicle_features_encoded = self.vehicle_fc(vehicle_embedded)  # Shape: [batch_size, hidden_size * 2]
        
        # print(f"vehicle_features_encoded shape: {vehicle_features_encoded.shape}")  # [batch_size, hidden_size * 2]

        # Encoder for localization
        localization_output, (hidden_loc, cell_loc) = self.localization_encoder(input_localization)
        
        # print(f"localization_output shape: {localization_output.shape}")  # [batch_size, seq_len_enc, hidden_size]
        # print(f"hidden_loc shape: {hidden_loc.shape}")  # [num_layers, batch_size, hidden_size]
        # print(f"cell_loc shape: {cell_loc.shape}")      # [num_layers, batch_size, hidden_size]

        # Encoder for control sequence
        control_output, (hidden_ctrl, cell_ctrl) = self.control_encoder(input_control_sequence)
        
        # print(f"control_output shape: {control_output.shape}")  # [batch_size, seq_len_enc, hidden_size]
        # print(f"hidden_ctrl shape: {hidden_ctrl.shape}")        # [num_layers, batch_size, hidden_size]
        # print(f"cell_ctrl shape: {cell_ctrl.shape}")            # [num_layers, batch_size, hidden_size]

        # Combine encoder outputs
        encoder_outputs = torch.cat((localization_output, control_output), dim=1)  # [batch_size, seq_len_enc*2, hidden_size]
        
        # print(f"encoder_outputs shape: {encoder_outputs.shape}")  # [batch_size, seq_len_enc*2, hidden_size]

        # Initial decoder hidden and cell states (concatenate)
        hidden_enc = torch.cat((hidden_loc, hidden_ctrl), dim=2)  # [num_layers, batch_size, hidden_size * 2]
        # print(f"hidden_enc shape: {hidden_enc.shape}")  # [num_layers, batch_size, hidden_size * 2]

        cell_enc = torch.cat((cell_loc, cell_ctrl), dim=2)        # [num_layers, batch_size, hidden_size * 2]
        # print(f"cell_enc shape: {cell_enc.shape}")      # [num_layers, batch_size, hidden_size * 2]

        # Custom Attention mechanism
        attention_weights = F.softmax(self.custom_attention(hidden_enc), dim=2)  # [num_layers, batch_size, hidden_size]
        # print(f"custom_attention_weights shape: {attention_weights.shape}")  # [num_layers, batch_size, hidden_size]

        combined_hidden = torch.cat((hidden_enc, attention_weights), dim=2)  # [num_layers, batch_size, hidden_size * 3]
        # print(f"combined_hidden before custom_attention_combine shape: {combined_hidden.shape}")  # [num_layers, batch_size, hidden_size * 3]

        combined_hidden = self.custom_attention_combine(combined_hidden)  # [num_layers, batch_size, hidden_size * 2]
        # print(f"combined_hidden after custom_attention_combine shape: {combined_hidden.shape}")  # [num_layers, batch_size, hidden_size * 2]

        # Incorporate vehicle_features_encoded into hidden state (first layer)
        hidden_dec = hidden_enc.clone()
        hidden_dec[0] = hidden_dec[0] + vehicle_features_encoded  # [num_layers, batch_size, hidden_size * 2]
        # print(f"hidden_dec shape after adding vehicle_features_encoded: {hidden_dec.shape}")  # [num_layers, batch_size, hidden_size * 2]

        cell_dec = cell_enc
        # print(f"cell_dec shape: {cell_dec.shape}")  # [num_layers, batch_size, hidden_size * 2]

        # Decoder
        decoder_output, (hidden_dec, cell_dec) = self.decoder(output_control_sequence, (hidden_dec, cell_dec))
        # print(f"decoder_output shape: {decoder_output.shape}")  # [batch_size, seq_len_dec, hidden_size * 2]
        # print(f"hidden_dec (decoder) shape: {hidden_dec.shape}")  # [num_layers, batch_size, hidden_size * 2]
        # print(f"cell_dec (decoder) shape: {cell_dec.shape}")      # [num_layers, batch_size, hidden_size * 2]

        # Luong Attention mechanism
        # Project decoder outputs to hidden_size if needed
        decoder_output_projected = decoder_output[:, :, :self.hidden_size]  # [batch_size, seq_len_dec, hidden_size]
        # print(f"decoder_output_projected shape: {decoder_output_projected.shape}")  # [batch_size, seq_len_dec, hidden_size]

        # Compute attention scores (dot product)
        attn_scores = torch.bmm(decoder_output_projected, encoder_outputs.transpose(1, 2))  # [batch_size, seq_len_dec, seq_len_enc*2]
        # print(f"attn_scores shape: {attn_scores.shape}")  # [batch_size, seq_len_dec, seq_len_enc*2]

        # Apply softmax to get attention weights
        attn_weights = F.softmax(attn_scores, dim=2)  # [batch_size, seq_len_dec, seq_len_enc*2]
        # print(f"attn_weights shape: {attn_weights.shape}")  # [batch_size, seq_len_dec, seq_len_enc*2]

        # Compute context vectors
        context = torch.bmm(attn_weights, encoder_outputs)  # [batch_size, seq_len_dec, hidden_size]
        # print(f"context shape: {context.shape}")  # [batch_size, seq_len_dec, hidden_size]

        # Concatenate decoder output and context
        combined = torch.cat((decoder_output_projected, context), dim=2)  # [batch_size, seq_len_dec, hidden_size * 2]
        # print(f"combined shape: {combined.shape}")  # [batch_size, seq_len_dec, hidden_size * 2]

        # Output layer
        output_localization = self.fc_out(combined)  # [batch_size, seq_len_dec, localization_input_size]
        # print(f"output_localization shape: {output_localization.shape}")  # [batch_size, seq_len_dec, localization_input_size]

        return output_localization


In [32]:
class TrajectoryLightningModule(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-3, weight_decay=5e-6):
        super(TrajectoryLightningModule, self).__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        # self.criterion = nn.MSELoss(reduction='mean')  # We'll handle reduction manually due to variable lengths
        base_criterion =  nn.SmoothL1Loss(reduction='none', beta=1.0)
        weights = np.array([3, 3, 1, 1, 1, 3], dtype='float')
        weights /= np.linalg.norm(weights)
        self.criterion = WeightedLoss(base_criterion, weights)
        

    def forward(self, batch):
        vehicle_features = batch['vehicle_features']
        input_localization = batch['input_localization']
        input_control = batch['input_control']
        output_control = batch['output_control']
        
        predicted_output_localization = self.model(
            vehicle_features,
            input_localization,
            input_control,
            output_control,
        )
        
        
        return predicted_output_localization

    def training_step(self, batch, batch_idx):
        output_localization = batch['output_localization']
        predicted_output_localization = self.forward(batch)
        
        loss = self.criterion(predicted_output_localization, output_localization)

        predicted_x_y_yaw = predicted_output_localization[..., [0, 1, -1]].detach().cpu().numpy()
        gt_x_y_yaw = output_localization[..., [0, 1, -1]].detach().cpu().numpy()
        batch_metric = calculate_metric_on_batch(predicted_x_y_yaw, gt_x_y_yaw)

        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_metric', batch_metric, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        output_localization = batch['output_localization']
        predicted_output_localization = self.forward(batch)
        
        loss = self.criterion(predicted_output_localization, output_localization)

        predicted_x_y_yaw = predicted_output_localization[..., [0, 1, -1]].detach().cpu().numpy()
        gt_x_y_yaw = output_localization[..., [0, 1, -1]].detach().cpu().numpy()
        batch_metric = calculate_metric_on_batch(predicted_x_y_yaw, gt_x_y_yaw)

        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_metric', batch_metric, on_step=False, on_epoch=True, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=3, verbose=True)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'}


In [80]:
# vehicle_feature_sizes = {k: len(v['map']) for k, v in mapping.items()}
# embedding_dim = 16
# localization_input_size = 6  # For example, x, y, z, roll, pitch, yaw
# control_input_size = 2  # acceleration_level, steering
# hidden_size = 256
# num_layers = 3

# model = LstmEncoderDecoderWithAttention(
#     vehicle_feature_sizes=vehicle_feature_sizes,
#     embedding_dim=embedding_dim,
#     localization_input_size=localization_input_size,
#     control_input_size=control_input_size,
#     hidden_size=hidden_size,
#     num_layers=num_layers,
#     # bidirectional=bidir_encoder,
# ).to(device)

In [97]:
vehicle_feature_sizes = {k: len(v['map']) for k, v in mapping.items()}
embedding_dim = 16
localization_input_size = 6  # For example, x, y, z, roll, pitch, yaw
control_input_size = 2  # acceleration_level, steering
hidden_size = 196
num_layers = 3

# model = LstmEncoderDecoderWithAttention(
# model = GruEncoderDecoderWithAttention(
# model = LstmEncoderDecoderWithLuongAttention(
model = LstmEncoderDecoderWithDualAttention(
    vehicle_feature_sizes=vehicle_feature_sizes,
    embedding_dim=embedding_dim,
    localization_input_size=localization_input_size,
    control_input_size=control_input_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    # bidirectional=bidir_encoder,
).to(device)

In [98]:
# lightning_module = TrajectoryLightningModule.load_from_checkpoint('lightning_logs/lstm_att_lagged_ds/weighted_loss/checkpoints/best_model.ckpt', model=model)
lightning_module = TrajectoryLightningModule.load_from_checkpoint('lightning_logs/dual_attn_lstm_3_196_random_ds/0/checkpoints/best_model.ckpt', model=model)


In [91]:
# torch.save(lightning_module.state_dict(), 'lstm_attn_256_3.ckpt')

In [99]:
def resample_data(original_timestamps, original_values, target_timestamps):
    interpolated = np.zeros((len(target_timestamps), original_values.shape[1]))

    for i in range(original_values.shape[1]):
        interpolated[:, i] = np.interp(
            target_timestamps,
            original_timestamps[:],
            original_values[:, i]
        )

        # Fill values outside the original timestamp range with nearest values
        interpolated[target_timestamps < original_timestamps[0], i] = original_values[:, i][0]
        interpolated[target_timestamps > original_timestamps[-1], i] = original_values[:, i][-1]

    return interpolated
class TestDataset(Dataset):
    def __init__(self, dataset_path: str, mapping: dict, testcase_ids=None, sliding_step: int = 125, training=True):
        self.data = []
        NS_TO_SEC = 1e-9
        self.mapping = mapping

        sampling_interval_ns = 4e7  # 0.04 seconds in nanoseconds
        initial_state_length = int(5 / 0.04)  # Steps for initial 5 seconds (125 steps)
        time_steps_localization = np.arange(0, 5 * 1e9, sampling_interval_ns)
        time_steps_control = np.arange(0, 20 * 1e9, sampling_interval_ns)
        
        if testcase_ids is None:
            testcase_ids = sorted([name for name in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, name))])

        
        for testcase_id in tqdm(testcase_ids):
            testcase_id = str(testcase_id)
            testcase_path = os.path.join(dataset_path, testcase_id)
            requested_stamps = pd.read_csv(os.path.join(testcase_path, 'requested_stamps.csv'))['stamp_ns'].values

            metadata = read_metadata(os.path.join(testcase_path, 'metadata.json'))
            vehicle_features = self.encode_vehicle_features(metadata)
            control = pd.read_csv(os.path.join(testcase_path, 'control.csv'))
            localization = pd.read_csv(os.path.join(testcase_path, 'localization.csv'))
            
            localization_resampled = resample_data(
                localization['stamp_ns'].values,
                localization.drop(columns=['stamp_ns']).values,
                time_steps_localization
            )

            control_resampled = resample_data(
                control['stamp_ns'].values,
                control.drop(columns=['stamp_ns']).values,
                time_steps_control
            )

            input_localization = localization_resampled.copy()
                
            start_position = input_localization[0, :3].copy()
            input_localization[:, :3] -= start_position

            input_control_sequence = control_resampled[:initial_state_length].copy() 
            output_control_sequence = control_resampled[initial_state_length:].copy()
                
        
            self.data.append({
                'testcase_id': int(testcase_id),
                'requested_stamps': requested_stamps,
                'start_position': start_position,
                
                'vehicle_features': vehicle_features,
                'input_localization': input_localization,
                'input_control': input_control_sequence,
                'output_control': output_control_sequence,
            })


    def encode_vehicle_features(self, metadata):
        feats = []
        for k, map_n_unk in self.mapping.items():
            feats.append(int(map_n_unk['map'].get(metadata[k], map_n_unk['unk'])))
        return feats

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        tensor_dict = {}
        for k, v in sample.items():
            if k.startswith('vehicle'):
                tensor_dict[k] = torch.tensor(v, dtype=torch.long)
            elif k in ['testcase_id', 'requested_stamps', 'start_position']:
                tensor_dict[k] = v
            else:
                tensor_dict[k] = torch.tensor(v, dtype=torch.float32)
            
        return tensor_dict


In [75]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn_test(batch):
    """
    Custom collate function to handle variable-length sequences for seq2seq tasks.
    """
    # Extract vehicle_features and stack them
    vehicle_features = torch.stack([torch.tensor(sample['vehicle_features'], dtype=torch.long) for sample in batch], dim=0)  # [batch_size, num_vehicle_features]
    
    # Extract input_localization_seq sequences and lengths
    input_localization_seqs = [sample['input_localization_seq'] for sample in batch]
    input_localization_lengths = torch.tensor([len(seq) for seq in input_localization_seqs], dtype=torch.long)
    # Pad input_localization_seqs
    input_localization_seq_padded = pad_sequence(
        [torch.tensor(seq, dtype=torch.float32) for seq in input_localization_seqs],
        batch_first=True,
        padding_value=0.0
    )  # [batch_size, max_input_loc_length, localization_input_size]
    
    # Extract input_control_seq sequences and lengths
    input_control_seqs = [torch.tensor(sample['input_control_seq'], dtype=torch.float32) for sample in batch]
    input_control_lengths = torch.tensor([seq.size(0) for seq in input_control_seqs], dtype=torch.long)
    # Pad input_control_seqs
    input_control_seq_padded = pad_sequence(input_control_seqs, batch_first=True, padding_value=0.0)
    
    # Extract output_control_seq sequences and lengths
    output_control_seqs = [torch.tensor(sample['output_control_seq'], dtype=torch.float32) for sample in batch]
    output_control_lengths = torch.tensor([seq.size(0) for seq in output_control_seqs], dtype=torch.long)
    # Pad output_control_seqs
    output_control_seq_padded = pad_sequence(output_control_seqs, batch_first=True, padding_value=0.0)

    input_localization_seq_padded[..., :3] *= 1000
    # output_localization_seq_padded[..., :3] *= 1000

    batch_dict = {
        'testcase_id': torch.stack([torch.tensor(sample['testcase_id'], dtype=torch.long) for sample in batch], dim=0),
        'start_position': torch.stack([torch.tensor(sample['start_position'], dtype=torch.float32) for sample in batch], dim=0),
        'requested_stamps': torch.stack([torch.tensor(sample['requested_stamps'], dtype=torch.long) for sample in batch], dim=0),
        
        'vehicle_features': vehicle_features,  # [batch_size, num_vehicle_features]
        'input_localization_seq': input_localization_seq_padded,  # [batch_size, max_input_loc_length, localization_input_size]
        'input_localization_lengths': input_localization_lengths,  # [batch_size]
        'input_control_seq': input_control_seq_padded,  # [batch_size, max_input_control_length, control_input_size]
        'input_control_lengths': input_control_lengths,  # [batch_size]
        'output_control_seq': output_control_seq_padded,  # [batch_size, max_output_control_length, control_input_size]
        'output_control_lengths': output_control_lengths,  # [batch_size]
    }
    
    return batch_dict


In [100]:
TEST_DATASET_PATH

'./dataset/YaCupTest'

In [101]:
test_dataset = TestDataset(TEST_DATASET_PATH, mapping)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=None)

100%|██████████| 8000/8000 [00:34<00:00, 229.01it/s]


In [103]:
predictions = []

with torch.inference_mode():
    for sample in tqdm(test_loader):
        for k, v in sample.items():
            sample[k] = v.to('cuda')
            
        start_position = sample['start_position'].detach().cpu().numpy()[0]
        requested_stamps = sample['requested_stamps'].detach().cpu().numpy()[0]
        testcase_id = sample['testcase_id'].detach().cpu().item()
        
        predicted_output_localization = lightning_module(sample)
        
        time_steps = np.arange(5 * 1e9, 20 * 1e9, 4e7)
        
        predicted_output_localization = predicted_output_localization.detach().cpu().numpy()[0]
        predicted_output_localization[:, :3] += start_position
        
        # Извлечение предсказанных координат и углов
        yaw_pred = predicted_output_localization[:, -1]
        x_pred = predicted_output_localization[:, 0]
        y_pred = predicted_output_localization[:, 1]
        
        # Интерполяция координат x и y
        x_interp = np.interp(
            requested_stamps, 
            time_steps, 
            x_pred, 
            left=x_pred[0], 
            right=x_pred[-1]
        )
        y_interp = np.interp(
            requested_stamps, 
            time_steps, 
            y_pred, 
            left=y_pred[0], 
            right=y_pred[-1]
        )

        yaw_interp = np.interp(
            requested_stamps, 
            time_steps, 
            yaw_pred, 
            left=yaw_pred[0], 
            right=yaw_pred[-1]
        )
        

        
        # Проверка соответствия длин массивов
        assert len(requested_stamps) == len(x_interp)
        assert len(requested_stamps) == len(y_interp)
        assert len(requested_stamps) == len(yaw_interp)
        
        # Сбор предсказаний
        for stamp_ns, x, y, yaw in zip(requested_stamps, x_interp, y_interp, yaw_interp):
            predictions.append({
                'testcase_id': int(testcase_id),
                'stamp_ns': int(stamp_ns),
                'x': x,
                'y': y,
                'yaw': yaw
            })
predictions = pd.DataFrame(predictions)
predictions['testcase_id'] = predictions['testcase_id'].apply(int)
predictions = predictions.sort_values(by=['testcase_id', 'stamp_ns'])
predictions.to_csv('submissions/196_3_dual_attn_interp.csv.gz', index=False, compression='gzip')

100%|██████████| 8000/8000 [06:04<00:00, 21.93it/s]


In [105]:
import random

In [111]:
a = [1, 2, 3]

In [112]:
random.shuffle(a)

In [113]:
a

[2, 1, 3]

In [40]:
model.eval()
predictions = []

with torch.no_grad():
    for sample in tqdm(test_loader):
        for k, v in sample.items():
            sample[k] = v.to('cuda')
            
        start_position = sample['start_position'].detach().cpu().numpy()[0]
        requested_stamps = sample['requested_stamps'].detach().cpu().numpy()[0]
        testcase_id = sample['testcase_id'].detach().cpu().item()
        
        predicted_output_localization = lightning_module(sample)
        
        time_steps = np.arange(5 * 1e9, 20 * 1e9, 4e7)
        
        predicted_output_localization = predicted_output_localization.detach().cpu().numpy()[0]
        predicted_output_localization[:, :3] += start_position
        
        # Get x, y positions, yaw
        yaw_pred = predicted_output_localization[:, -1]
        x_pred = predicted_output_localization[:, 0]
        y_pred = predicted_output_localization[:, 1]

        indices = np.searchsorted(time_steps, requested_stamps)
        indices = np.clip(indices, 0, len(time_steps) - 1)
        x_pred = x_pred[indices]
        y_pred = y_pred[indices]
        yaw_pred = yaw_pred[indices]

        assert len(requested_stamps) == len(x_pred)

        # Collect predictions
        for stamp_ns, x, y, yaw in zip(requested_stamps, x_pred, y_pred, yaw_pred):
            predictions.append({
                'testcase_id': int(testcase_id),
                'stamp_ns': int(stamp_ns),
                'x': x,
                'y': y,
                'yaw': yaw
            })
        

100%|██████████| 8000/8000 [03:58<00:00, 33.60it/s]


In [41]:
predictions = pd.DataFrame(predictions)
predictions['testcase_id'] = predictions['testcase_id'].apply(int)
predictions = predictions.sort_values(by=['testcase_id', 'stamp_ns'])

In [42]:
predictions.isna().any()

testcase_id    False
stamp_ns       False
x              False
y              False
yaw            False
dtype: bool

In [43]:
# good_df = pd.read_csv('submissions/worked.csv')
# good_df = good_df.sort_values(by=['testcase_id', 'stamp_ns'])

In [44]:
predictions.shape
assert predictions.shape == (2998763, 5)

In [45]:
predictions.to_csv('submissions/256_2_lin_attn_old_alg.csv.gz', index=False, compression='gzip')

In [61]:
predictions = []

with torch.inference_mode():
    for sample in tqdm(test_loader):
        for k, v in sample.items():
            sample[k] = v.to('cuda')
            
        start_position = sample['start_position'].detach().cpu().numpy()[0]
        requested_stamps = sample['requested_stamps'].detach().cpu().numpy()[0]
        testcase_id = sample['testcase_id'].detach().cpu().item()
        
        predicted_output_localization = lightning_module(sample)
        
        time_steps = np.arange(5 * 1e9, 20 * 1e9, 4e7)
        
        predicted_output_localization = predicted_output_localization.detach().cpu().numpy()[0]
        predicted_output_localization[:, :3] += start_position
        
        # Извлечение предсказанных координат и углов
        yaw_pred = predicted_output_localization[:, -1]
        x_pred = predicted_output_localization[:, 0]
        y_pred = predicted_output_localization[:, 1]
        
        # Интерполяция координат x и y
        x_interp = np.interp(
            requested_stamps, 
            time_steps, 
            x_pred, 
            left=x_pred[0], 
            right=x_pred[-1]
        )
        y_interp = np.interp(
            requested_stamps, 
            time_steps, 
            y_pred, 
            left=y_pred[0], 
            right=y_pred[-1]
        )

        yaw_interp = np.interp(
            requested_stamps, 
            time_steps, 
            yaw_pred, 
            left=yaw_pred[0], 
            right=yaw_pred[-1]
        )
        

        
        # Проверка соответствия длин массивов
        assert len(requested_stamps) == len(x_interp)
        assert len(requested_stamps) == len(y_interp)
        assert len(requested_stamps) == len(yaw_interp)
        
        # Сбор предсказаний
        for stamp_ns, x, y, yaw in zip(requested_stamps, x_interp, y_interp, yaw_interp):
            predictions.append({
                'testcase_id': int(testcase_id),
                'stamp_ns': int(stamp_ns),
                'x': x,
                'y': y,
                'yaw': yaw
            })

100%|██████████| 8000/8000 [03:53<00:00, 34.19it/s]


In [62]:
predictions = pd.DataFrame(predictions)
predictions['testcase_id'] = predictions['testcase_id'].apply(int)
predictions = predictions.sort_values(by=['testcase_id', 'stamp_ns'])
predictions.to_csv('submissions/256_2_lin_attn_new_alg_165_epoch.csv.gz', index=False, compression='gzip')

In [303]:
# predictions.to_csv('submissions/auged_data_75_step_lstm_2_256.csv', index=False)

## Ensemble

In [441]:
from sklearn.model_selection import KFold
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping


In [443]:
num_folds = 5
early_stopping_patience = 10 
kfold = KFold(n_splits=num_folds, shuffle=True, random_state=42)

In [459]:
fold_results = []
for fold, (train_idx, val_idx) in enumerate(kfold.split(train_ids)):
    try:
        del model
    except:
        pass
    torch.cuda.empty_cache()
    gc.collect()

    model = LstmEncoderDecoder(
        vehicle_feature_sizes=vehicle_feature_sizes,
        embedding_dim=embedding_dim,
        localization_input_size=localization_input_size,
        control_input_size=control_input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        # bidirectional=bidir_encoder,
    ).to(device)
    
    print(f"Training fold {fold + 1}/{num_folds}...")

    # Датасеты для текущего фолда
    train_identifiers = [train_ids[i] for i in train_idx]
    val_identifiers = [train_ids[i] for i in val_idx]

    train_dataset = TrajectoryDataset(TRAIN_DATASET_PATH, mapping, testcase_ids=train_identifiers, sliding_step=75, training=True)
    val_dataset = TrajectoryDataset(TRAIN_DATASET_PATH, mapping, testcase_ids=val_identifiers, sliding_step=125, training=False)

        # DataLoader
    collate_fn = None  # Определите, если нужно
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=collate_fn)
    logger = CSVLogger("lightning_logs", name=f"lstm_fold_{fold + 1}")
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        filename=f'best_model_fold_{fold + 1}',
        save_top_k=2,
        mode='min'
    )

    early_stopping_callback = EarlyStopping(
        monitor='val_loss',
        patience=early_stopping_patience,
        mode='min'
    )

    callbacks=[checkpoint_callback, early_stopping_callback, pl.callbacks.LearningRateMonitor(logging_interval='epoch')]

    lightning_module = TrajectoryLightningModule(model=model, learning_rate=1e-3, weight_decay=5e-6)
    
    trainer = pl.Trainer(
        logger=logger,
        max_epochs=250,
        val_check_interval=500,
        log_every_n_steps=100,
        callbacks=callbacks,
    
    )
    trainer.fit(lightning_module, train_loader, val_loader)
    fold_results.append({
        "fold": fold,
        "checkpoint_path": checkpoint_callback.best_model_path,
        "val_loss": checkpoint_callback.best_model_score
    })
    

Training fold 1/5...



  0%|          | 0/33600 [00:00<?, ?it/s][A
  0%|          | 4/33600 [00:00<14:43, 38.01it/s][A
  0%|          | 8/33600 [00:00<14:39, 38.20it/s][A
  0%|          | 12/33600 [00:00<14:45, 37.92it/s][A
  0%|          | 16/33600 [00:00<14:42, 38.08it/s][A
  0%|          | 20/33600 [00:00<14:38, 38.21it/s][A
  0%|          | 24/33600 [00:00<14:38, 38.22it/s][A
  0%|          | 28/33600 [00:00<14:42, 38.05it/s][A
  0%|          | 32/33600 [00:00<14:35, 38.33it/s][A
  0%|          | 36/33600 [00:00<14:32, 38.45it/s][A
  0%|          | 40/33600 [00:01<14:31, 38.50it/s][A
  0%|          | 44/33600 [00:01<14:35, 38.34it/s][A
  0%|          | 48/33600 [00:01<14:32, 38.46it/s][A
  0%|          | 52/33600 [00:01<14:30, 38.54it/s][A
  0%|          | 56/33600 [00:01<14:27, 38.65it/s][A
  0%|          | 60/33600 [00:01<14:32, 38.46it/s][A
  0%|          | 64/33600 [00:01<14:31, 38.49it/s][A
  0%|          | 68/33600 [00:01<14:29, 38.55it/s][A
  0%|          | 72/33600 [00:01<14:29

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=49` in the `DataLoader` to improve performance.
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=49` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training fold 2/5...



  0%|          | 0/33600 [00:00<?, ?it/s][A
  0%|          | 14/33600 [00:00<04:04, 137.27it/s][A
  0%|          | 29/33600 [00:00<03:59, 140.11it/s][A
  0%|          | 44/33600 [00:00<03:56, 141.96it/s][A
  0%|          | 59/33600 [00:00<03:55, 142.13it/s][A
  0%|          | 74/33600 [00:00<03:56, 141.74it/s][A
  0%|          | 89/33600 [00:00<03:55, 142.29it/s][A
  0%|          | 104/33600 [00:00<03:55, 142.13it/s][A
  0%|          | 119/33600 [00:00<03:54, 142.54it/s][A
  0%|          | 134/33600 [00:00<03:54, 142.73it/s][A
  0%|          | 149/33600 [00:01<03:54, 142.47it/s][A
  0%|          | 164/33600 [00:01<03:54, 142.81it/s][A
  1%|          | 179/33600 [00:01<03:53, 143.20it/s][A
  1%|          | 194/33600 [00:01<03:54, 142.50it/s][A
  1%|          | 209/33600 [00:01<03:52, 143.49it/s][A
  1%|          | 224/33600 [00:01<03:52, 143.80it/s][A
  1%|          | 239/33600 [00:01<03:52, 143.36it/s][A
  1%|          | 254/33600 [00:01<03:52, 143.47it/s][A
  1%|   

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=49` in the `DataLoader` to improve performance.
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=49` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training fold 3/5...



  0%|          | 0/33600 [00:00<?, ?it/s][A
  0%|          | 4/33600 [00:00<16:54, 33.13it/s][A
  0%|          | 8/33600 [00:00<16:28, 34.00it/s][A
  0%|          | 12/33600 [00:00<16:26, 34.04it/s][A
  0%|          | 16/33600 [00:00<16:19, 34.29it/s][A
  0%|          | 20/33600 [00:00<16:20, 34.26it/s][A
  0%|          | 24/33600 [00:00<16:18, 34.32it/s][A
  0%|          | 28/33600 [00:00<16:18, 34.31it/s][A
  0%|          | 32/33600 [00:00<16:22, 34.18it/s][A
  0%|          | 36/33600 [00:01<16:18, 34.29it/s][A
  0%|          | 40/33600 [00:01<16:22, 34.16it/s][A
  0%|          | 44/33600 [00:01<16:16, 34.36it/s][A
  0%|          | 48/33600 [00:01<16:13, 34.45it/s][A
  0%|          | 52/33600 [00:01<16:16, 34.35it/s][A
  0%|          | 56/33600 [00:01<16:13, 34.46it/s][A
  0%|          | 60/33600 [00:01<16:17, 34.31it/s][A
  0%|          | 64/33600 [00:01<16:13, 34.45it/s][A
  0%|          | 68/33600 [00:01<16:08, 34.62it/s][A
  0%|          | 72/33600 [00:02<16:08

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=49` in the `DataLoader` to improve performance.
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=49` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training fold 4/5...



  0%|          | 0/33600 [00:00<?, ?it/s][A
  0%|          | 13/33600 [00:00<04:19, 129.31it/s][A
  0%|          | 26/33600 [00:00<04:19, 129.29it/s][A
  0%|          | 40/33600 [00:00<04:16, 131.07it/s][A
  0%|          | 54/33600 [00:00<04:15, 131.17it/s][A
  0%|          | 68/33600 [00:00<04:14, 132.00it/s][A
  0%|          | 82/33600 [00:00<04:12, 132.70it/s][A
  0%|          | 96/33600 [00:00<04:12, 132.69it/s][A
  0%|          | 110/33600 [00:00<04:13, 132.02it/s][A
  0%|          | 124/33600 [00:00<04:11, 133.17it/s][A
  0%|          | 138/33600 [00:01<04:11, 133.05it/s][A
  0%|          | 152/33600 [00:01<04:08, 134.35it/s][A
  0%|          | 166/33600 [00:01<04:07, 135.09it/s][A
  1%|          | 180/33600 [00:01<04:07, 134.89it/s][A
  1%|          | 194/33600 [00:01<04:06, 135.79it/s][A
  1%|          | 208/33600 [00:01<04:04, 136.59it/s][A
  1%|          | 222/33600 [00:01<04:05, 136.20it/s][A
  1%|          | 236/33600 [00:01<04:03, 137.05it/s][A
  1%|    

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=49` in the `DataLoader` to improve performance.
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=49` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training fold 5/5...



  0%|          | 0/33600 [00:00<?, ?it/s][A
  0%|          | 4/33600 [00:00<18:21, 30.50it/s][A
  0%|          | 8/33600 [00:00<18:02, 31.02it/s][A
  0%|          | 12/33600 [00:00<18:06, 30.91it/s][A
  0%|          | 16/33600 [00:00<17:57, 31.17it/s][A
  0%|          | 20/33600 [00:00<18:04, 30.97it/s][A
  0%|          | 24/33600 [00:00<17:56, 31.20it/s][A
  0%|          | 28/33600 [00:00<17:51, 31.32it/s][A
  0%|          | 32/33600 [00:01<17:54, 31.23it/s][A
  0%|          | 36/33600 [00:01<17:51, 31.33it/s][A
  0%|          | 40/33600 [00:01<18:01, 31.04it/s][A
  0%|          | 44/33600 [00:01<17:56, 31.17it/s][A
  0%|          | 48/33600 [00:01<17:57, 31.15it/s][A
  0%|          | 52/33600 [00:01<17:53, 31.25it/s][A
  0%|          | 56/33600 [00:01<17:52, 31.27it/s][A
  0%|          | 60/33600 [00:01<17:47, 31.41it/s][A
  0%|          | 64/33600 [00:02<17:43, 31.52it/s][A
  0%|          | 68/33600 [00:02<17:48, 31.40it/s][A
  0%|          | 72/33600 [00:02<17:45

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=49` in the `DataLoader` to improve performance.
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=49` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [474]:
checkpoint_callback.best_model_path

'lightning_logs/lstm_fold_5/version_0/checkpoints/best_model_fold_5.ckpt'

In [466]:

for fold in range(num_folds):
    del lightning_module
    torch.cuda.empty_cache()
    gc.collect()
    with torch.no_grad():
        torch.cuda.empty_cache()
    gc.collect() 
    
    dir = os.path.join("lightning_logs", f"lstm_fold_{fold + 1}", "version_0", "checkpoints")
    ckpt_path = f"{dir}/best_model_fold_{fold + 1}.ckpt"
    lightning_module = TrajectoryLightningModule.load_from_checkpoint(ckpt_path, model=model)
    lightning_module.eval()
    lightning_module.model.eval()
    lightning_module.model.to('cuda')

    predictions = []
    
    with torch.inference_mode():
        for sample in tqdm(test_loader):
            for k, v in sample.items():
                sample[k] = v.to('cuda')
                
            start_position = sample['start_position'].detach().cpu().numpy()[0]
            requested_stamps = sample['requested_stamps'].detach().cpu().numpy()[0]
            testcase_id = sample['testcase_id'].detach().cpu().item()
            
            predicted_output_localization = lightning_module(sample)
            
            time_steps = np.arange(5 * 1e9, 20 * 1e9, 4e7)
            
            predicted_output_localization = predicted_output_localization.detach().cpu().numpy()[0]
            predicted_output_localization[:, :3] += start_position
            
            # Get x, y positions, yaw
            yaw_pred = predicted_output_localization[:, -1]
            x_pred = predicted_output_localization[:, 0]
            y_pred = predicted_output_localization[:, 1]
    
            indices = np.searchsorted(time_steps, requested_stamps)
            indices = np.clip(indices, 0, len(time_steps) - 1)
            x_pred = x_pred[indices]
            y_pred = y_pred[indices]
            yaw_pred = yaw_pred[indices]
    
            assert len(requested_stamps) == len(x_pred)
    
            # Collect predictions
            for stamp_ns, x, y, yaw in zip(requested_stamps, x_pred, y_pred, yaw_pred):
                predictions.append({
                    'testcase_id': int(testcase_id),
                    'stamp_ns': int(stamp_ns),
                    'x': x,
                    'y': y,
                    'yaw': yaw
                })
        predictions = pd.DataFrame(predictions)
        predictions['testcase_id'] = predictions['testcase_id'].apply(int)
        predictions = predictions.sort_values(by=['testcase_id', 'stamp_ns'])
        predictions.to_csv(f'submissions/lstm_fold_{fold + 1}.csv.gz', index=False, compression='gzip')


  tensor_dict[k] = torch.tensor(v, dtype=torch.float32)

  0%|          | 6/8000 [00:00<02:34, 51.84it/s][A
  0%|          | 12/8000 [00:00<02:32, 52.26it/s][A
  0%|          | 18/8000 [00:00<02:33, 52.09it/s][A
  0%|          | 24/8000 [00:00<02:32, 52.34it/s][A
  0%|          | 30/8000 [00:00<02:32, 52.15it/s][A
  0%|          | 36/8000 [00:00<02:32, 52.32it/s][A
  1%|          | 42/8000 [00:00<02:32, 52.18it/s][A
  1%|          | 48/8000 [00:00<02:32, 52.25it/s][A
  1%|          | 54/8000 [00:01<02:31, 52.40it/s][A
  1%|          | 60/8000 [00:01<02:31, 52.34it/s][A
  1%|          | 66/8000 [00:01<02:31, 52.46it/s][A
  1%|          | 72/8000 [00:01<02:31, 52.29it/s][A
  1%|          | 78/8000 [00:01<02:31, 52.41it/s][A
  1%|          | 84/8000 [00:01<02:31, 52.19it/s][A
  1%|          | 90/8000 [00:01<02:31, 52.31it/s][A
  1%|          | 96/8000 [00:01<02:31, 52.27it/s][A
  1%|▏         | 102/8000 [00:01<02:31, 52.29it/s][A
  1%|▏         | 108/8000 [00:02<02:30, 5

In [480]:
sample = lightning_module._apply_batch_transfer_handler(sample)

In [478]:
sample = next(iter(test_loader))

  tensor_dict[k] = torch.tensor(v, dtype=torch.float32)


In [482]:
sample.keys()

dict_keys(['testcase_id', 'requested_stamps', 'start_position', 'vehicle_features', 'input_localization', 'output_localization', 'input_control', 'output_control'])

In [465]:
TrajectoryLightningModule.load_from_checkpoint(ckpt_path, model=model)

TrajectoryLightningModule(
  (model): LstmEncoderDecoder(
    (vehicle_id_embedding): Embedding(131, 16)
    (vehicle_model_embedding): Embedding(2, 16)
    (vehicle_model_modification_embedding): Embedding(6, 16)
    (location_reference_point_id_embedding): Embedding(3, 16)
    (tires_front_embedding): Embedding(14, 16)
    (tires_rear_embedding): Embedding(14, 16)
    (vehicle_fc): Linear(in_features=96, out_features=256, bias=True)
    (localization_encoder): LSTM(6, 256, num_layers=2, batch_first=True)
    (control_encoder): LSTM(2, 256, num_layers=2, batch_first=True)
    (decoder): LSTM(2, 256, num_layers=2, batch_first=True)
    (fc_out): Sequential(
      (0): Linear(in_features=256, out_features=128, bias=True)
      (1): GELU(approximate='none')
      (2): Dropout(p=0.3, inplace=False)
      (3): Linear(in_features=128, out_features=6, bias=True)
    )
  )
  (criterion): SmoothL1Loss()
)

In [468]:
predictions

Unnamed: 0,testcase_id,stamp_ns,x,y,yaw
0,0,5000888836,-1491.080444,-1311.885254,1.786414
1,0,5040043013,-1491.333984,-1310.805054,1.840383
2,0,5079989560,-1491.333984,-1310.805054,1.840383
3,0,5120797471,-1491.258667,-1311.200195,1.798719
4,0,5165218288,-1491.452637,-1310.674438,1.812457
...,...,...,...,...,...
2915536,7999,19800025528,-2822.793701,-686.651001,0.283858
2915537,7999,19840700774,-2822.794434,-686.651428,0.283877
2915538,7999,19880004766,-2822.794922,-686.651855,0.283897
2915539,7999,19920766768,-2822.795654,-686.652283,0.283916


In [469]:
fold_predictions = []
for fold in range(num_folds):
    predictions = pd.read_csv(f'submissions/lstm_fold_{fold + 1}.csv.gz', compression='gzip')
    fold_predictions.append(predictions)

In [471]:
merged_predictions = pd.concat(fold_predictions).groupby(['testcase_id', 'stamp_ns']).mean().reset_index()

In [472]:
merged_predictions.shape

(2998763, 5)

In [473]:
merged_predictions.to_csv(f'submissions/lstm_folds_merged.csv.gz', index=False, compression='gzip')

## modern inference

In [483]:
fold = 0 

In [484]:
dir = os.path.join("lightning_logs", f"lstm_fold_{fold + 1}", "version_0", "checkpoints")
ckpt_path = f"{dir}/best_model_fold_{fold + 1}.ckpt"
lightning_module = TrajectoryLightningModule.load_from_checkpoint(ckpt_path, model=model)
lightning_module.eval()
lightning_module.model.eval()
lightning_module.model.to('cuda')

LstmEncoderDecoder(
  (vehicle_id_embedding): Embedding(131, 16)
  (vehicle_model_embedding): Embedding(2, 16)
  (vehicle_model_modification_embedding): Embedding(6, 16)
  (location_reference_point_id_embedding): Embedding(3, 16)
  (tires_front_embedding): Embedding(14, 16)
  (tires_rear_embedding): Embedding(14, 16)
  (vehicle_fc): Linear(in_features=96, out_features=256, bias=True)
  (localization_encoder): LSTM(6, 256, num_layers=2, batch_first=True)
  (control_encoder): LSTM(2, 256, num_layers=2, batch_first=True)
  (decoder): LSTM(2, 256, num_layers=2, batch_first=True)
  (fc_out): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=128, out_features=6, bias=True)
  )
)

In [485]:
predictions = []

with torch.inference_mode():
    for sample in tqdm(test_loader):
        for k, v in sample.items():
            sample[k] = v.to('cuda')
            
        start_position = sample['start_position'].detach().cpu().numpy()[0]
        requested_stamps = sample['requested_stamps'].detach().cpu().numpy()[0]
        testcase_id = sample['testcase_id'].detach().cpu().item()
        
        predicted_output_localization = lightning_module(sample)
        
        time_steps = np.arange(5 * 1e9, 20 * 1e9, 4e7)
        
        predicted_output_localization = predicted_output_localization.detach().cpu().numpy()[0]
        predicted_output_localization[:, :3] += start_position
        
        # Извлечение предсказанных координат и углов
        yaw_pred = predicted_output_localization[:, -1]
        x_pred = predicted_output_localization[:, 0]
        y_pred = predicted_output_localization[:, 1]
        
        # Интерполяция координат x и y
        x_interp = np.interp(
            requested_stamps, 
            time_steps, 
            x_pred, 
            left=x_pred[0], 
            right=x_pred[-1]
        )
        y_interp = np.interp(
            requested_stamps, 
            time_steps, 
            y_pred, 
            left=y_pred[0], 
            right=y_pred[-1]
        )
        
        # Интерполяция угла yaw с учётом циклической природы
        sin_yaw = np.sin(yaw_pred)
        cos_yaw = np.cos(yaw_pred)
        sin_interp = np.interp(
            requested_stamps, 
            time_steps, 
            sin_yaw, 
            left=sin_yaw[0], 
            right=sin_yaw[-1]
        )
        cos_interp = np.interp(
            requested_stamps, 
            time_steps, 
            cos_yaw, 
            left=cos_yaw[0], 
            right=cos_yaw[-1]
        )
        yaw_interp = np.arctan2(sin_interp, cos_interp)
        
        # Проверка соответствия длин массивов
        assert len(requested_stamps) == len(x_interp)
        assert len(requested_stamps) == len(y_interp)
        assert len(requested_stamps) == len(yaw_interp)
        
        # Сбор предсказаний
        for stamp_ns, x, y, yaw in zip(requested_stamps, x_interp, y_interp, yaw_interp):
            predictions.append({
                'testcase_id': int(testcase_id),
                'stamp_ns': int(stamp_ns),
                'x': x,
                'y': y,
                'yaw': yaw
            })
    predictions = pd.DataFrame(predictions)
    predictions['testcase_id'] = predictions['testcase_id'].apply(int)
    predictions = predictions.sort_values(by=['testcase_id', 'stamp_ns'])
    predictions.to_csv(f'submissions/interpt_lstm_fold_{fold + 1}.csv.gz', index=False, compression='gzip')


  tensor_dict[k] = torch.tensor(v, dtype=torch.float32)

  0%|          | 6/8000 [00:00<02:35, 51.25it/s][A
  0%|          | 12/8000 [00:00<02:35, 51.28it/s][A
  0%|          | 18/8000 [00:00<02:34, 51.71it/s][A
  0%|          | 24/8000 [00:00<02:34, 51.79it/s][A
  0%|          | 30/8000 [00:00<02:33, 52.04it/s][A
  0%|          | 36/8000 [00:00<02:32, 52.17it/s][A
  1%|          | 42/8000 [00:00<02:32, 52.03it/s][A
  1%|          | 48/8000 [00:00<02:32, 52.15it/s][A
  1%|          | 54/8000 [00:01<02:32, 52.03it/s][A
  1%|          | 60/8000 [00:01<02:32, 52.18it/s][A
  1%|          | 66/8000 [00:01<02:33, 51.72it/s][A
  1%|          | 72/8000 [00:01<02:34, 51.22it/s][A
  1%|          | 78/8000 [00:01<02:34, 51.16it/s][A
  1%|          | 84/8000 [00:01<02:35, 50.93it/s][A
  1%|          | 90/8000 [00:01<02:33, 51.37it/s][A
  1%|          | 96/8000 [00:01<02:33, 51.57it/s][A
  1%|▏         | 102/8000 [00:01<02:32, 51.89it/s][A
  1%|▏         | 108/8000 [00:02<02:31, 5

KeyError: 'testcase_id'

In [497]:
predictions = pd.DataFrame(predictions)
predictions['testcase_id'] = predictions['testcase_id'].apply(int)
predictions = predictions.sort_values(by=['testcase_id', 'stamp_ns'])
predictions.to_csv(f'submissions/interpt_lstm_fold_{fold + 1}.csv.gz', index=False, compression='gzip')

In [498]:
predictions.shape

(2998763, 5)