# Baseline

## Load and prepare data

In [9]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
import tqdm

In [10]:
train_file = np.load('data/train.npz')

train_data = train_file['data']
print("train_data's shape", train_data.shape)
test_file = np.load('data/test_input.npz')

test_data = test_file['data']
print("test_data's shape", test_data.shape)

train_data's shape (10000, 50, 110, 6)
test_data's shape (2100, 50, 50, 6)


# Data Loaders

In [11]:
class TrajectoryDatasetTrain(Dataset):
    def __init__(self, data, scale=10.0, augment=True):
        """
        data: Shape (N, 50, 110, 6) Training data
        scale: Scale for normalization (suggested to use 10.0 for Argoverse 2 data)
        augment: Whether to apply data augmentation (only for training)
        """
        self.data = data
        self.scale = scale
        self.augment = augment

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx]
        # Getting 50 historical timestamps and 60 future timestamps
        hist = scene[:, :50, :].copy()    # (agents=50, time_seq=50, 6)
        future = torch.tensor(scene[0, 50:, :2].copy(), dtype=torch.float32)  # (60, 2)
        
        # Data augmentation(only for training)
        if self.augment:
            if np.random.rand() < 0.5:
                theta = np.random.uniform(-np.pi, np.pi)
                R = np.array([[np.cos(theta), -np.sin(theta)],
                              [np.sin(theta),  np.cos(theta)]], dtype=np.float32)
                # Rotate the historical trajectory and future trajectory
                hist[..., :2] = hist[..., :2] @ R
                hist[..., 2:4] = hist[..., 2:4] @ R
                future = future @ R
            if np.random.rand() < 0.5:
                hist[..., 0] *= -1
                hist[..., 2] *= -1
                future[:, 0] *= -1

        # Use the last timeframe of the historical trajectory as the origin
        origin = hist[0, 49, :2].copy()  # (2,)
        hist[..., :2] = hist[..., :2] - origin
        future = future - origin

        # Normalize the historical trajectory and future trajectory
        hist[..., :4] = hist[..., :4] / self.scale
        future = future / self.scale

        data_item = Data(
            x=torch.tensor(hist, dtype=torch.float32),
            y=future.type(torch.float32),
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),
            scale=torch.tensor(self.scale, dtype=torch.float32),
        )

        return data_item
    

class TrajectoryDatasetTest(Dataset):
    def __init__(self, data, scale=10.0):
        """
        data: Shape (N, 50, 110, 6) Testing data
        scale: Scale for normalization (suggested to use 10.0 for Argoverse 2 data)
        """
        self.data = data
        self.scale = scale

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Testing data only contains historical trajectory
        scene = self.data[idx]  # (50, 50, 6)
        hist = scene.copy()
        
        origin = hist[0, 49, :2].copy()
        hist[..., :2] = hist[..., :2] - origin
        hist[..., :4] = hist[..., :4] / self.scale

        data_item = Data(
            x=torch.tensor(hist, dtype=torch.float32),
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),
            scale=torch.tensor(self.scale, dtype=torch.float32),
        )
        return data_item

In [12]:
torch.manual_seed(251)
np.random.seed(42)

scale = 7.0

N = len(train_data)
val_size = int(0.1 * N)
train_size = N - val_size

train_dataset = TrajectoryDatasetTrain(train_data[:train_size], scale=scale, augment=True)
val_dataset = TrajectoryDatasetTrain(train_data[train_size:], scale=scale, augment=False)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=lambda x: Batch.from_data_list(x))
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=lambda x: Batch.from_data_list(x))

# Set device for training speedup
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using Apple Silicon GPU")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA GPU")
else:
    device = torch.device('cpu')

Using CUDA GPU


# LSTM

In [13]:
# Example of basic model that should work
class SimpleLSTM(nn.Module):
    def __init__(self,  hidden_dim=512, output_dim=60*2):
        super(SimpleLSTM, self).__init__()
        self.lstm1 = nn.LSTM(6, hidden_dim, batch_first=True,num_layers=6)
        self.lstm2 = nn.LSTM(64, int(hidden_dim/4), batch_first=True,num_layers=8)
        
        
        # Add multi-layer prediction head for better results
        self.compress1=nn.Linear(300, 512)
        self.compress2=nn.Linear(512, 512)
        self.compress3=nn.Linear(512, 512)
        self.compress4=nn.Linear(512, 256)
        self.compress5=nn.Linear(256, 128)
        self.compress6=nn.Linear(128, 64)


        
        self.fc1 = nn.Linear(hidden_dim+int(hidden_dim/4), hidden_dim*2)
        self.fc2 = nn.Linear(hidden_dim*2, hidden_dim*4)
        self.fc3 = nn.Linear(hidden_dim*4, hidden_dim*4)    
        self.fc4 = nn.Linear(hidden_dim*4, hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, hidden_dim)
        
        self.dropout = nn.Dropout(0.1)  # Add dropout for regularization
        self.relu = nn.ReLU()
        self.fc6 = nn.Linear(hidden_dim, output_dim)
        
        # Initialize weights properly
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.xavier_normal_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0.0)
        
    def forward(self, data):
        x = data.x
        x = x.reshape(-1, 50, 50, 6)  # (batch_size, num_agents, seq_len, input_dim)

        x1=x[:,0,:,:]
      
        x2 = x.permute(0, 2, 1, 3) 
        shp=x2.shape
        x2 = x2.reshape(shp[0], shp[1], 300) 


        
        # Process through LSTM
        lstm_out1, _ = self.lstm1(x1)
        x2=self.compress1(x2)
        x2=self.relu(x2)
        x2=self.dropout(x2)    
        x2=self.compress2(x2)
        x2=self.relu(x2)
        x2=self.dropout(x2)
        x2=self.compress3(x2)
        x2=self.relu(x2)
        x2=self.dropout(x2)
        x2=self.compress4(x2)
        x2=self.relu(x2)
        x2=self.dropout(x2)
        x2=self.compress5(x2)
        x2=self.relu(x2)
        x2=self.dropout(x2)
        x2=self.compress6(x2)
        lstm_out2, _ = self.lstm2(x2)
        
        # Extract final hidden state
        features1 = lstm_out1[:, -1, :]
        features2 = lstm_out2[:, -1, :]
      
        features= torch.cat((features1, features2), dim=1)
        
        
        # Process through prediction head
        features = self.relu(self.fc1(features))
        features = self.dropout(features)
        features = self.relu(self.fc2(features))
        features = self.dropout(features)
        features = self.relu(self.fc3(features))
        features = self.dropout(features)
        features = self.dropout(features)
        features = self.relu(self.fc4(features))
        features = self.dropout(features)
        features = self.dropout(features)
        features = self.relu(self.fc5(features))
        features = self.dropout(features)
        out = self.fc6(features)
        
        # Reshape to (batch_size, 60, 2)
        return out.view(-1, 60, 2)

# Train

In [14]:
def train_improved_model(model, train_dataloader, val_dataloader, 
                         device, criterion=nn.MSELoss(), 
                         lr=0.001, epochs=200, patience=15):
    """
    Improved training function with better debugging and early stopping
    """
    # Initialize optimizer with smaller learning rate
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # Exponential decay scheduler
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    
    early_stopping_patience = patience
    best_val_loss = float('inf')
    no_improvement = 0
    
    # Save initial state for comparison
    initial_state_dict = {k: v.clone() for k, v in model.state_dict().items()}
    
    for epoch in tqdm.tqdm(range(epochs), desc="Epoch", unit="epoch"):
        # ---- Training ----
        model.train()
        train_loss = 0
        num_train_batches = 0
        
        for batch in train_dataloader:
            batch = batch.to(device)
            pred = model(batch)
            y = batch.y.view(batch.num_graphs, 60, 2)
            
            # Check for NaN predictions
            if torch.isnan(pred).any():
                print(f"WARNING: NaN detected in predictions during training")
                continue
                
            loss = criterion(pred, y)
            
            # Check if loss is valid
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"WARNING: Invalid loss value: {loss.item()}")
                continue
                
            optimizer.zero_grad()
            loss.backward()
            
            # More conservative gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            train_loss += loss.item()
            num_train_batches += 1
        
        # Skip epoch if no valid batches
        if num_train_batches == 0:
            print("WARNING: No valid training batches in this epoch")
            continue
            
        train_loss /= num_train_batches
        
        # ---- Validation ----
        model.eval()
        val_loss = 0
        val_mae = 0
        val_mse = 0
        num_val_batches = 0
        
        # Sample predictions for debugging
        sample_input = None
        sample_pred = None
        sample_target = None
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_dataloader):
                batch = batch.to(device)
                pred = model(batch)
                y = batch.y.view(batch.num_graphs, 60, 2)
                
                # Store sample for debugging
                if batch_idx == 0 and sample_input is None:
                    sample_input = batch.x[0].cpu().numpy()
                    sample_pred = pred[0].cpu().numpy()
                    sample_target = y[0].cpu().numpy()
                
                # Skip invalid predictions
                if torch.isnan(pred).any():
                    print(f"WARNING: NaN detected in predictions during validation")
                    continue
                    
                batch_loss = criterion(pred, y).item()
                val_loss += batch_loss
                
                # Unnormalize for real-world metrics
                pred_unnorm = pred * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
                y_unnorm = y * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
                
                val_mae += nn.L1Loss()(pred_unnorm, y_unnorm).item()
                val_mse += nn.MSELoss()(pred_unnorm, y_unnorm).item()
                
                num_val_batches += 1
        
        # Skip epoch if no valid validation batches
        if num_val_batches == 0:
            print("WARNING: No valid validation batches in this epoch")
            continue
            
        val_loss /= num_val_batches
        val_mae /= num_val_batches
        val_mse /= num_val_batches
        
        # Update learning rate
        scheduler.step()
        
        # Print with more details
        tqdm.tqdm.write(
            f"Epoch {epoch:03d} | LR {optimizer.param_groups[0]['lr']:.6f} | "
            f"Train MSE {train_loss:.4f} | Val MSE {val_loss:.4f} | "
            f"Val MAE {val_mae:.4f} | Val MSE {val_mse:.4f}"
        )
        
        # Debug output - first 3 predictions vs targets
        if epoch % 5 == 0:
            tqdm.tqdm.write(f"Sample pred first 3 steps: {sample_pred[:3]}")
            tqdm.tqdm.write(f"Sample target first 3 steps: {sample_target[:3]}")
            
            # Check if model weights are changing
            if epoch > 0:
                weight_change = False
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        initial_param = initial_state_dict[name]
                        if not torch.allclose(param, initial_param, rtol=1e-4):
                            weight_change = True
                            break
                if not weight_change:
                    tqdm.tqdm.write("WARNING: Model weights barely changing!")
        
        # Relaxed improvement criterion - consider any improvement
        if val_loss < best_val_loss:
            tqdm.tqdm.write(f"Validation improved: {best_val_loss:.6f} -> {val_loss:.6f}")
            best_val_loss = val_loss
            no_improvement = 0
            torch.save(model.state_dict(), "best_model.pt")
        else:
            no_improvement += 1
            if no_improvement >= early_stopping_patience:
                print(f"Early stopping after {epoch+1} epochs without improvement")
                break
    
    # Load best model before returning
    model.load_state_dict(torch.load("best_model.pt"))
    return model

In [15]:
# Example usage
def train_and_evaluate_model():
    # Create model
    model = SimpleLSTM( hidden_dim=512)
    model = model.to(device)
    
    # Train with improved function
    train_improved_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        device=device,
        lr=0.001,  # Lower learning rate
        patience=30  # More patience
    )
    
    # Evaluate
    model.eval()
    test_mse = 0
    with torch.no_grad():
        for batch in val_dataloader:
            batch = batch.to(device)
            pred = model(batch)
            y = batch.y.view(batch.num_graphs, 60, 2)
            
            # Unnormalize
            pred = pred * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            y = y * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            
            test_mse += nn.MSELoss()(pred, y).item()
    
    test_mse /= len(val_dataloader)
    print(f"Val MSE: {test_mse:.4f}")
    
    return model

train_and_evaluate_model()

Epoch:   0%|          | 0/200 [00:16<?, ?epoch/s]

Epoch 000 | LR 0.000950 | Train MSE 5.4944 | Val MSE 5.0518 | Val MAE 10.3328 | Val MSE 247.5394
Sample pred first 3 steps: [[-0.09316703  0.11389622]
 [-0.25320682  0.02323009]
 [-0.29374775  0.01706527]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]
Validation improved: inf -> 5.051823


Epoch:   0%|          | 1/200 [00:31<55:29, 16.73s/epoch]

Epoch 001 | LR 0.000902 | Train MSE 4.9914 | Val MSE 4.8433 | Val MAE 9.6846 | Val MSE 237.3222
Validation improved: 5.051823 -> 4.843310


Epoch:   1%|          | 2/200 [00:47<52:51, 16.02s/epoch]

Epoch 002 | LR 0.000857 | Train MSE 4.3867 | Val MSE 3.6157 | Val MAE 8.0022 | Val MSE 177.1691
Validation improved: 4.843310 -> 3.615697


Epoch:   2%|▏         | 3/200 [01:04<52:17, 15.93s/epoch]

Epoch 003 | LR 0.000815 | Train MSE 3.7812 | Val MSE 3.2488 | Val MAE 7.8894 | Val MSE 159.1936
Validation improved: 3.615697 -> 3.248848


Epoch:   2%|▏         | 4/200 [01:20<52:40, 16.12s/epoch]

Epoch 004 | LR 0.000774 | Train MSE 3.0607 | Val MSE 2.5874 | Val MAE 7.1435 | Val MSE 126.7814
Validation improved: 3.248848 -> 2.587376


Epoch:   2%|▎         | 5/200 [01:36<52:14, 16.07s/epoch]

Epoch 005 | LR 0.000735 | Train MSE 2.7375 | Val MSE 2.5132 | Val MAE 6.9775 | Val MSE 123.1488
Sample pred first 3 steps: [[ 0.0009801  -0.0168315 ]
 [-0.00233477 -0.01848646]
 [-0.00042948 -0.02562985]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]
Validation improved: 2.587376 -> 2.513241


Epoch:   4%|▎         | 7/200 [01:51<50:59, 15.85s/epoch]

Epoch 006 | LR 0.000698 | Train MSE 2.7724 | Val MSE 3.0608 | Val MAE 7.5387 | Val MSE 149.9775


Epoch:   4%|▍         | 8/200 [02:07<50:34, 15.81s/epoch]

Epoch 007 | LR 0.000663 | Train MSE 2.7622 | Val MSE 2.7304 | Val MAE 6.8471 | Val MSE 133.7905


Epoch:   4%|▍         | 8/200 [02:23<50:34, 15.81s/epoch]

Epoch 008 | LR 0.000630 | Train MSE 2.1051 | Val MSE 1.5062 | Val MAE 5.0381 | Val MSE 73.8019
Validation improved: 2.513241 -> 1.506161


Epoch:   5%|▌         | 10/200 [02:39<49:58, 15.78s/epoch]

Epoch 009 | LR 0.000599 | Train MSE 1.7358 | Val MSE 1.6168 | Val MAE 5.3073 | Val MSE 79.2242


Epoch:   6%|▌         | 11/200 [02:54<49:33, 15.73s/epoch]

Epoch 010 | LR 0.000569 | Train MSE 1.6926 | Val MSE 1.5280 | Val MAE 5.0594 | Val MSE 74.8721
Sample pred first 3 steps: [[-0.00182919  0.00108631]
 [ 0.00037371 -0.00355579]
 [-0.00366303  0.0040546 ]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:   6%|▌         | 11/200 [03:07<49:33, 15.73s/epoch]

Epoch 011 | LR 0.000540 | Train MSE 1.6168 | Val MSE 1.3820 | Val MAE 4.9133 | Val MSE 67.7196
Validation improved: 1.506161 -> 1.382033


Epoch:   6%|▌         | 12/200 [03:20<47:02, 15.01s/epoch]

Epoch 012 | LR 0.000513 | Train MSE 1.5390 | Val MSE 1.1974 | Val MAE 4.5968 | Val MSE 58.6747
Validation improved: 1.382033 -> 1.197443


Epoch:   6%|▋         | 13/200 [03:36<44:56, 14.42s/epoch]

Epoch 013 | LR 0.000488 | Train MSE 1.2475 | Val MSE 1.0760 | Val MAE 4.1580 | Val MSE 52.7228
Validation improved: 1.197443 -> 1.075974


Epoch:   7%|▋         | 14/200 [03:52<45:42, 14.74s/epoch]

Epoch 014 | LR 0.000463 | Train MSE 1.1295 | Val MSE 0.9635 | Val MAE 3.9970 | Val MSE 47.2096
Validation improved: 1.075974 -> 0.963460


Epoch:   8%|▊         | 15/200 [04:08<46:25, 15.06s/epoch]

Epoch 015 | LR 0.000440 | Train MSE 1.0788 | Val MSE 0.9053 | Val MAE 3.7233 | Val MSE 44.3586
Sample pred first 3 steps: [[ 4.5065195e-03  2.5466410e-04]
 [ 7.6058977e-03 -6.2059262e-06]
 [ 5.7911947e-03  2.8394950e-03]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]
Validation improved: 0.963460 -> 0.905278


Epoch:   8%|▊         | 17/200 [04:23<46:57, 15.40s/epoch]

Epoch 016 | LR 0.000418 | Train MSE 1.0450 | Val MSE 0.9353 | Val MAE 3.8355 | Val MSE 45.8288


Epoch:   8%|▊         | 17/200 [04:39<46:57, 15.40s/epoch]

Epoch 017 | LR 0.000397 | Train MSE 0.9938 | Val MSE 0.8902 | Val MAE 3.7037 | Val MSE 43.6200
Validation improved: 0.905278 -> 0.890203


Epoch:   9%|▉         | 18/200 [04:54<46:37, 15.37s/epoch]

Epoch 018 | LR 0.000377 | Train MSE 0.9985 | Val MSE 0.8747 | Val MAE 3.5998 | Val MSE 42.8612
Validation improved: 0.890203 -> 0.874719


Epoch:  10%|▉         | 19/200 [05:09<46:15, 15.34s/epoch]

Epoch 019 | LR 0.000358 | Train MSE 0.9689 | Val MSE 0.7897 | Val MAE 3.4409 | Val MSE 38.6938
Validation improved: 0.874719 -> 0.789669


Epoch:  10%|█         | 21/200 [05:24<45:27, 15.24s/epoch]

Epoch 020 | LR 0.000341 | Train MSE 0.9688 | Val MSE 0.8121 | Val MAE 3.6639 | Val MSE 39.7934
Sample pred first 3 steps: [[ 0.01275276 -0.0016545 ]
 [ 0.01622185 -0.00195312]
 [ 0.01918722 -0.00222778]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  11%|█         | 22/200 [05:39<44:58, 15.16s/epoch]

Epoch 021 | LR 0.000324 | Train MSE 0.9279 | Val MSE 0.8563 | Val MAE 3.7105 | Val MSE 41.9563


Epoch:  12%|█▏        | 23/200 [05:54<44:37, 15.13s/epoch]

Epoch 022 | LR 0.000307 | Train MSE 0.9135 | Val MSE 0.8064 | Val MAE 3.4967 | Val MSE 39.5143


Epoch:  12%|█▏        | 23/200 [06:09<44:37, 15.13s/epoch]

Epoch 023 | LR 0.000292 | Train MSE 0.8659 | Val MSE 0.7508 | Val MAE 3.4511 | Val MSE 36.7888
Validation improved: 0.789669 -> 0.750793


Epoch:  12%|█▏        | 24/200 [06:25<44:26, 15.15s/epoch]

Epoch 024 | LR 0.000277 | Train MSE 0.8614 | Val MSE 0.7433 | Val MAE 3.3510 | Val MSE 36.4234
Validation improved: 0.750793 -> 0.743335


Epoch:  12%|█▎        | 25/200 [06:40<44:13, 15.16s/epoch]

Epoch 025 | LR 0.000264 | Train MSE 0.8390 | Val MSE 0.7343 | Val MAE 3.2993 | Val MSE 35.9827
Sample pred first 3 steps: [[-0.00041553  0.00421946]
 [-0.00400456  0.00487614]
 [-0.00513709  0.0056532 ]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]
Validation improved: 0.743335 -> 0.734341


Epoch:  13%|█▎        | 26/200 [06:56<44:27, 15.33s/epoch]

Epoch 026 | LR 0.000250 | Train MSE 0.8458 | Val MSE 0.7129 | Val MAE 3.2776 | Val MSE 34.9296
Validation improved: 0.734341 -> 0.712850


Epoch:  14%|█▍        | 28/200 [07:12<44:19, 15.46s/epoch]

Epoch 027 | LR 0.000238 | Train MSE 0.9466 | Val MSE 1.2321 | Val MAE 4.4193 | Val MSE 60.3709


Epoch:  14%|█▍        | 28/200 [07:27<44:19, 15.46s/epoch]

Epoch 028 | LR 0.000226 | Train MSE 0.8702 | Val MSE 0.6144 | Val MAE 2.8631 | Val MSE 30.1067
Validation improved: 0.712850 -> 0.614423


Epoch:  15%|█▌        | 30/200 [07:43<43:37, 15.40s/epoch]

Epoch 029 | LR 0.000215 | Train MSE 0.7799 | Val MSE 0.6639 | Val MAE 3.1143 | Val MSE 32.5299


Epoch:  16%|█▌        | 31/200 [07:57<42:59, 15.26s/epoch]

Epoch 030 | LR 0.000204 | Train MSE 0.7514 | Val MSE 0.6254 | Val MAE 2.9196 | Val MSE 30.6467
Sample pred first 3 steps: [[ 0.0047074  -0.00137124]
 [ 0.00378647  0.00172291]
 [ 0.00290888  0.00135214]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  16%|█▌        | 31/200 [08:13<42:59, 15.26s/epoch]

Epoch 031 | LR 0.000194 | Train MSE 0.7218 | Val MSE 0.5858 | Val MAE 2.8983 | Val MSE 28.7054
Validation improved: 0.614423 -> 0.585824


Epoch:  16%|█▋        | 33/200 [08:28<42:21, 15.22s/epoch]

Epoch 032 | LR 0.000184 | Train MSE 0.7052 | Val MSE 0.6123 | Val MAE 3.0292 | Val MSE 30.0006


Epoch:  16%|█▋        | 33/200 [08:43<42:21, 15.22s/epoch]

Epoch 033 | LR 0.000175 | Train MSE 0.6783 | Val MSE 0.5446 | Val MAE 2.6412 | Val MSE 26.6837
Validation improved: 0.585824 -> 0.544565


Epoch:  18%|█▊        | 35/200 [08:58<41:44, 15.18s/epoch]

Epoch 034 | LR 0.000166 | Train MSE 0.6671 | Val MSE 0.5804 | Val MAE 2.8038 | Val MSE 28.4396


Epoch:  18%|█▊        | 36/200 [09:13<41:25, 15.15s/epoch]

Epoch 035 | LR 0.000158 | Train MSE 0.6403 | Val MSE 0.5950 | Val MAE 2.8264 | Val MSE 29.1549
Sample pred first 3 steps: [[ 0.00016016 -0.00101989]
 [-0.00459554 -0.00208651]
 [-0.0044444  -0.00336589]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  18%|█▊        | 37/200 [09:28<41:09, 15.15s/epoch]

Epoch 036 | LR 0.000150 | Train MSE 0.6434 | Val MSE 0.5522 | Val MAE 2.8075 | Val MSE 27.0565


Epoch:  19%|█▉        | 38/200 [09:44<41:05, 15.22s/epoch]

Epoch 037 | LR 0.000142 | Train MSE 0.6169 | Val MSE 0.5611 | Val MAE 2.8936 | Val MSE 27.4957


Epoch:  19%|█▉        | 38/200 [09:59<41:05, 15.22s/epoch]

Epoch 038 | LR 0.000135 | Train MSE 0.6060 | Val MSE 0.4800 | Val MAE 2.6246 | Val MSE 23.5211
Validation improved: 0.544565 -> 0.480024


Epoch:  20%|█▉        | 39/200 [10:14<40:50, 15.22s/epoch]

Epoch 039 | LR 0.000129 | Train MSE 0.5800 | Val MSE 0.4725 | Val MAE 2.5616 | Val MSE 23.1549
Validation improved: 0.480024 -> 0.472548


Epoch:  20%|██        | 40/200 [10:29<40:31, 15.20s/epoch]

Epoch 040 | LR 0.000122 | Train MSE 0.5693 | Val MSE 0.4570 | Val MAE 2.5096 | Val MSE 22.3925
Sample pred first 3 steps: [[-0.00090663 -0.00169602]
 [-0.00108433  0.00143298]
 [-0.0010651   0.00262947]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]
Validation improved: 0.472548 -> 0.456991


Epoch:  21%|██        | 42/200 [10:44<39:48, 15.12s/epoch]

Epoch 041 | LR 0.000116 | Train MSE 0.5398 | Val MSE 0.5020 | Val MAE 2.7297 | Val MSE 24.6003


Epoch:  22%|██▏       | 43/200 [10:59<39:25, 15.07s/epoch]

Epoch 042 | LR 0.000110 | Train MSE 0.5509 | Val MSE 0.4789 | Val MAE 2.6073 | Val MSE 23.4678


Epoch:  22%|██▏       | 44/200 [11:14<39:05, 15.04s/epoch]

Epoch 043 | LR 0.000105 | Train MSE 0.5398 | Val MSE 0.4793 | Val MAE 2.6301 | Val MSE 23.4854


Epoch:  22%|██▎       | 45/200 [11:29<38:46, 15.01s/epoch]

Epoch 044 | LR 0.000099 | Train MSE 0.5351 | Val MSE 0.4643 | Val MAE 2.6082 | Val MSE 22.7528


Epoch:  22%|██▎       | 45/200 [11:44<38:46, 15.01s/epoch]

Epoch 045 | LR 0.000094 | Train MSE 0.5283 | Val MSE 0.4323 | Val MAE 2.4749 | Val MSE 21.1826
Sample pred first 3 steps: [[0.00140218 0.00146125]
 [0.00562599 0.0027997 ]
 [0.00673854 0.00401624]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]
Validation improved: 0.456991 -> 0.432298


Epoch:  24%|██▎       | 47/200 [12:00<38:34, 15.13s/epoch]

Epoch 046 | LR 0.000090 | Train MSE 0.5255 | Val MSE 0.4444 | Val MAE 2.4218 | Val MSE 21.7779


Epoch:  24%|██▎       | 47/200 [12:15<38:34, 15.13s/epoch]

Epoch 047 | LR 0.000085 | Train MSE 0.5110 | Val MSE 0.4118 | Val MAE 2.3211 | Val MSE 20.1782
Validation improved: 0.432298 -> 0.411800


Epoch:  24%|██▍       | 49/200 [12:30<38:10, 15.17s/epoch]

Epoch 048 | LR 0.000081 | Train MSE 0.4988 | Val MSE 0.4134 | Val MAE 2.3961 | Val MSE 20.2571


Epoch:  25%|██▌       | 50/200 [12:45<37:51, 15.15s/epoch]

Epoch 049 | LR 0.000077 | Train MSE 0.5004 | Val MSE 0.4238 | Val MAE 2.3886 | Val MSE 20.7643


Epoch:  26%|██▌       | 51/200 [13:00<37:29, 15.09s/epoch]

Epoch 050 | LR 0.000073 | Train MSE 0.4992 | Val MSE 0.4367 | Val MAE 2.4783 | Val MSE 21.3976
Sample pred first 3 steps: [[-0.00497061 -0.00055871]
 [-0.00665316 -0.0028048 ]
 [-0.00879699 -0.00330591]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  26%|██▌       | 52/200 [13:15<37:08, 15.06s/epoch]

Epoch 051 | LR 0.000069 | Train MSE 0.4779 | Val MSE 0.4390 | Val MAE 2.4295 | Val MSE 21.5091


Epoch:  26%|██▋       | 53/200 [13:30<36:50, 15.03s/epoch]

Epoch 052 | LR 0.000066 | Train MSE 0.4771 | Val MSE 0.4414 | Val MAE 2.4862 | Val MSE 21.6267


Epoch:  27%|██▋       | 54/200 [13:45<36:33, 15.02s/epoch]

Epoch 053 | LR 0.000063 | Train MSE 0.4655 | Val MSE 0.4202 | Val MAE 2.3901 | Val MSE 20.5898


Epoch:  28%|██▊       | 55/200 [14:00<36:18, 15.02s/epoch]

Epoch 054 | LR 0.000060 | Train MSE 0.4691 | Val MSE 0.4396 | Val MAE 2.4771 | Val MSE 21.5418


Epoch:  28%|██▊       | 55/200 [14:15<36:18, 15.02s/epoch]

Epoch 055 | LR 0.000057 | Train MSE 0.4585 | Val MSE 0.3893 | Val MAE 2.2428 | Val MSE 19.0758
Sample pred first 3 steps: [[-0.00178396  0.00196259]
 [-0.0042082   0.00260555]
 [-0.00459071  0.00253312]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]
Validation improved: 0.411800 -> 0.389303


Epoch:  28%|██▊       | 57/200 [14:30<35:52, 15.05s/epoch]

Epoch 056 | LR 0.000054 | Train MSE 0.4434 | Val MSE 0.4077 | Val MAE 2.3797 | Val MSE 19.9768


Epoch:  28%|██▊       | 57/200 [14:46<35:52, 15.05s/epoch]

Epoch 057 | LR 0.000051 | Train MSE 0.4392 | Val MSE 0.3848 | Val MAE 2.2520 | Val MSE 18.8562
Validation improved: 0.389303 -> 0.384821


Epoch:  29%|██▉       | 58/200 [15:01<35:52, 15.16s/epoch]

Epoch 058 | LR 0.000048 | Train MSE 0.4361 | Val MSE 0.3710 | Val MAE 2.2127 | Val MSE 18.1813
Validation improved: 0.384821 -> 0.371047


Epoch:  30%|███       | 60/200 [15:16<35:31, 15.23s/epoch]

Epoch 059 | LR 0.000046 | Train MSE 0.4334 | Val MSE 0.3777 | Val MAE 2.2568 | Val MSE 18.5080


Epoch:  30%|███       | 61/200 [15:31<35:03, 15.14s/epoch]

Epoch 060 | LR 0.000044 | Train MSE 0.4253 | Val MSE 0.3949 | Val MAE 2.3163 | Val MSE 19.3507
Sample pred first 3 steps: [[-0.00303947  0.0025835 ]
 [-0.00535966  0.00212732]
 [-0.00590168  0.00283048]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  30%|███       | 61/200 [15:46<35:03, 15.14s/epoch]

Epoch 061 | LR 0.000042 | Train MSE 0.4261 | Val MSE 0.3701 | Val MAE 2.1973 | Val MSE 18.1372
Validation improved: 0.371047 -> 0.370148


Epoch:  32%|███▏      | 63/200 [16:02<34:33, 15.14s/epoch]

Epoch 062 | LR 0.000039 | Train MSE 0.4162 | Val MSE 0.3748 | Val MAE 2.2373 | Val MSE 18.3632


Epoch:  32%|███▏      | 64/200 [16:18<34:49, 15.36s/epoch]

Epoch 063 | LR 0.000038 | Train MSE 0.4119 | Val MSE 0.3847 | Val MAE 2.2310 | Val MSE 18.8493


Epoch:  32%|███▏      | 64/200 [16:34<34:49, 15.36s/epoch]

Epoch 064 | LR 0.000036 | Train MSE 0.4189 | Val MSE 0.3580 | Val MAE 2.1583 | Val MSE 17.5408
Validation improved: 0.370148 -> 0.357975


Epoch:  33%|███▎      | 66/200 [16:50<35:07, 15.73s/epoch]

Epoch 065 | LR 0.000034 | Train MSE 0.4043 | Val MSE 0.3586 | Val MAE 2.1294 | Val MSE 17.5719
Sample pred first 3 steps: [[-5.8460166e-04  1.4988417e-03]
 [-1.3539868e-03 -3.8844533e-05]
 [-2.4385760e-03 -2.6863907e-04]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  34%|███▎      | 67/200 [17:05<34:34, 15.60s/epoch]

Epoch 066 | LR 0.000032 | Train MSE 0.4072 | Val MSE 0.3669 | Val MAE 2.2545 | Val MSE 17.9766


Epoch:  34%|███▍      | 68/200 [17:20<33:57, 15.43s/epoch]

Epoch 067 | LR 0.000031 | Train MSE 0.3985 | Val MSE 0.3733 | Val MAE 2.2579 | Val MSE 18.2918


Epoch:  34%|███▍      | 68/200 [17:35<33:57, 15.43s/epoch]

Epoch 068 | LR 0.000029 | Train MSE 0.4021 | Val MSE 0.3559 | Val MAE 2.2000 | Val MSE 17.4387
Validation improved: 0.357975 -> 0.355892


Epoch:  34%|███▍      | 69/200 [17:50<33:36, 15.39s/epoch]

Epoch 069 | LR 0.000028 | Train MSE 0.3870 | Val MSE 0.3535 | Val MAE 2.1269 | Val MSE 17.3201
Validation improved: 0.355892 -> 0.353472


Epoch:  35%|███▌      | 70/200 [18:06<33:17, 15.36s/epoch]

Epoch 070 | LR 0.000026 | Train MSE 0.3904 | Val MSE 0.3518 | Val MAE 2.1390 | Val MSE 17.2394
Sample pred first 3 steps: [[ 0.0002446   0.00086039]
 [ 0.00094784 -0.00051179]
 [ 0.00088959 -0.00086601]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]
Validation improved: 0.353472 -> 0.351824


Epoch:  36%|███▌      | 72/200 [18:21<32:33, 15.26s/epoch]

Epoch 071 | LR 0.000025 | Train MSE 0.3862 | Val MSE 0.3611 | Val MAE 2.2236 | Val MSE 17.6935


Epoch:  36%|███▌      | 72/200 [18:36<32:33, 15.26s/epoch]

Epoch 072 | LR 0.000024 | Train MSE 0.3825 | Val MSE 0.3450 | Val MAE 2.1838 | Val MSE 16.9044
Validation improved: 0.351824 -> 0.344987


Epoch:  37%|███▋      | 74/200 [18:52<32:01, 15.25s/epoch]

Epoch 073 | LR 0.000022 | Train MSE 0.3850 | Val MSE 0.3551 | Val MAE 2.1761 | Val MSE 17.4017


Epoch:  38%|███▊      | 75/200 [19:07<31:37, 15.18s/epoch]

Epoch 074 | LR 0.000021 | Train MSE 0.3783 | Val MSE 0.3612 | Val MAE 2.1855 | Val MSE 17.6977


Epoch:  38%|███▊      | 75/200 [19:22<31:37, 15.18s/epoch]

Epoch 075 | LR 0.000020 | Train MSE 0.3783 | Val MSE 0.3389 | Val MAE 2.1595 | Val MSE 16.6049
Sample pred first 3 steps: [[0.00030915 0.00111041]
 [0.00182978 0.00085519]
 [0.00279294 0.00178951]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]
Validation improved: 0.344987 -> 0.338876


Epoch:  38%|███▊      | 77/200 [19:37<31:05, 15.16s/epoch]

Epoch 076 | LR 0.000019 | Train MSE 0.3822 | Val MSE 0.3402 | Val MAE 2.1237 | Val MSE 16.6686


Epoch:  39%|███▉      | 78/200 [19:52<30:41, 15.09s/epoch]

Epoch 077 | LR 0.000018 | Train MSE 0.3663 | Val MSE 0.3574 | Val MAE 2.1935 | Val MSE 17.5123


Epoch:  39%|███▉      | 78/200 [20:07<30:41, 15.09s/epoch]

Epoch 078 | LR 0.000017 | Train MSE 0.3718 | Val MSE 0.3251 | Val MAE 2.0344 | Val MSE 15.9281
Validation improved: 0.338876 -> 0.325064


Epoch:  40%|████      | 80/200 [20:22<30:11, 15.10s/epoch]

Epoch 079 | LR 0.000017 | Train MSE 0.3749 | Val MSE 0.3259 | Val MAE 2.0556 | Val MSE 15.9676


Epoch:  40%|████      | 81/200 [20:37<29:53, 15.07s/epoch]

Epoch 080 | LR 0.000016 | Train MSE 0.3639 | Val MSE 0.3387 | Val MAE 2.1222 | Val MSE 16.5959
Sample pred first 3 steps: [[-9.0635614e-04 -4.1782390e-05]
 [-1.6203634e-03 -2.8885435e-05]
 [-2.0147082e-03  1.0576285e-03]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  41%|████      | 82/200 [20:52<29:36, 15.06s/epoch]

Epoch 081 | LR 0.000015 | Train MSE 0.3632 | Val MSE 0.3275 | Val MAE 2.0933 | Val MSE 16.0479


Epoch:  42%|████▏     | 83/200 [21:07<29:21, 15.05s/epoch]

Epoch 082 | LR 0.000014 | Train MSE 0.3633 | Val MSE 0.3262 | Val MAE 2.0695 | Val MSE 15.9862


Epoch:  42%|████▏     | 84/200 [21:22<29:05, 15.05s/epoch]

Epoch 083 | LR 0.000013 | Train MSE 0.3644 | Val MSE 0.3255 | Val MAE 2.0402 | Val MSE 15.9518


Epoch:  42%|████▏     | 84/200 [21:37<29:05, 15.05s/epoch]

Epoch 084 | LR 0.000013 | Train MSE 0.3525 | Val MSE 0.3232 | Val MAE 2.0095 | Val MSE 15.8363
Validation improved: 0.325064 -> 0.323190


Epoch:  42%|████▎     | 85/200 [21:52<28:54, 15.08s/epoch]

Epoch 085 | LR 0.000012 | Train MSE 0.3601 | Val MSE 0.3218 | Val MAE 2.0522 | Val MSE 15.7684
Sample pred first 3 steps: [[ 0.00086512 -0.00019423]
 [ 0.00109792 -0.00037012]
 [ 0.00110378  0.0006399 ]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]
Validation improved: 0.323190 -> 0.321805


Epoch:  43%|████▎     | 86/200 [22:08<28:48, 15.16s/epoch]

Epoch 086 | LR 0.000012 | Train MSE 0.3532 | Val MSE 0.3186 | Val MAE 2.0021 | Val MSE 15.6119
Validation improved: 0.321805 -> 0.318611


Epoch:  44%|████▍     | 88/200 [22:23<28:22, 15.20s/epoch]

Epoch 087 | LR 0.000011 | Train MSE 0.3567 | Val MSE 0.3195 | Val MAE 2.0559 | Val MSE 15.6556


Epoch:  44%|████▍     | 88/200 [22:38<28:22, 15.20s/epoch]

Epoch 088 | LR 0.000010 | Train MSE 0.3550 | Val MSE 0.3154 | Val MAE 2.0415 | Val MSE 15.4526
Validation improved: 0.318611 -> 0.315358


Epoch:  45%|████▌     | 90/200 [22:53<27:46, 15.15s/epoch]

Epoch 089 | LR 0.000010 | Train MSE 0.3565 | Val MSE 0.3330 | Val MAE 2.1301 | Val MSE 16.3176


Epoch:  46%|████▌     | 91/200 [23:09<27:27, 15.12s/epoch]

Epoch 090 | LR 0.000009 | Train MSE 0.3506 | Val MSE 0.3284 | Val MAE 2.1203 | Val MSE 16.0930
Sample pred first 3 steps: [[-0.00051707  0.00014582]
 [-0.00182811 -0.0009081 ]
 [-0.00298635 -0.00039294]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  46%|████▌     | 91/200 [23:24<27:27, 15.12s/epoch]

Epoch 091 | LR 0.000009 | Train MSE 0.3528 | Val MSE 0.3143 | Val MAE 2.0154 | Val MSE 15.4017
Validation improved: 0.315358 -> 0.314320


Epoch:  46%|████▋     | 93/200 [23:39<27:05, 15.19s/epoch]

Epoch 092 | LR 0.000008 | Train MSE 0.3561 | Val MSE 0.3152 | Val MAE 2.0471 | Val MSE 15.4434


Epoch:  47%|████▋     | 94/200 [23:54<26:44, 15.14s/epoch]

Epoch 093 | LR 0.000008 | Train MSE 0.3476 | Val MSE 0.3198 | Val MAE 2.0513 | Val MSE 15.6720


Epoch:  47%|████▋     | 94/200 [24:09<26:44, 15.14s/epoch]

Epoch 094 | LR 0.000008 | Train MSE 0.3525 | Val MSE 0.3079 | Val MAE 2.0043 | Val MSE 15.0891
Validation improved: 0.314320 -> 0.307940


Epoch:  48%|████▊     | 96/200 [24:24<26:14, 15.14s/epoch]

Epoch 095 | LR 0.000007 | Train MSE 0.3471 | Val MSE 0.3191 | Val MAE 2.0622 | Val MSE 15.6344
Sample pred first 3 steps: [[ 0.00097728 -0.00051585]
 [ 0.00116705 -0.00163357]
 [ 0.00129827 -0.00197452]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  48%|████▊     | 97/200 [24:39<25:55, 15.10s/epoch]

Epoch 096 | LR 0.000007 | Train MSE 0.3455 | Val MSE 0.3088 | Val MAE 2.0080 | Val MSE 15.1307


Epoch:  49%|████▉     | 98/200 [24:55<25:41, 15.11s/epoch]

Epoch 097 | LR 0.000007 | Train MSE 0.3498 | Val MSE 0.3150 | Val MAE 2.0295 | Val MSE 15.4336


Epoch:  50%|████▉     | 99/200 [25:10<25:28, 15.14s/epoch]

Epoch 098 | LR 0.000006 | Train MSE 0.3474 | Val MSE 0.3153 | Val MAE 2.0317 | Val MSE 15.4477


Epoch:  50%|█████     | 100/200 [25:25<25:12, 15.13s/epoch]

Epoch 099 | LR 0.000006 | Train MSE 0.3529 | Val MSE 0.3211 | Val MAE 2.0747 | Val MSE 15.7343


Epoch:  50%|█████     | 101/200 [25:40<24:55, 15.11s/epoch]

Epoch 100 | LR 0.000006 | Train MSE 0.3411 | Val MSE 0.3134 | Val MAE 2.0268 | Val MSE 15.3572
Sample pred first 3 steps: [[ 1.5621632e-04 -2.4371105e-04]
 [ 8.0689322e-05 -8.8694412e-04]
 [ 1.2816489e-04 -8.3865877e-04]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  51%|█████     | 102/200 [25:55<24:38, 15.09s/epoch]

Epoch 101 | LR 0.000005 | Train MSE 0.3428 | Val MSE 0.3162 | Val MAE 2.0431 | Val MSE 15.4922


Epoch:  52%|█████▏    | 103/200 [26:10<24:20, 15.06s/epoch]

Epoch 102 | LR 0.000005 | Train MSE 0.3426 | Val MSE 0.3176 | Val MAE 2.0398 | Val MSE 15.5607


Epoch:  52%|█████▏    | 104/200 [26:25<24:04, 15.04s/epoch]

Epoch 103 | LR 0.000005 | Train MSE 0.3492 | Val MSE 0.3186 | Val MAE 2.0584 | Val MSE 15.6115


Epoch:  52%|█████▎    | 105/200 [26:40<23:45, 15.00s/epoch]

Epoch 104 | LR 0.000005 | Train MSE 0.3478 | Val MSE 0.3123 | Val MAE 2.0236 | Val MSE 15.3019


Epoch:  53%|█████▎    | 106/200 [26:55<23:30, 15.01s/epoch]

Epoch 105 | LR 0.000004 | Train MSE 0.3438 | Val MSE 0.3111 | Val MAE 2.0120 | Val MSE 15.2415
Sample pred first 3 steps: [[ 9.1344118e-05  9.7067095e-06]
 [-6.3031237e-04 -1.1428203e-03]
 [-9.8265521e-04 -1.1828663e-03]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  54%|█████▎    | 107/200 [27:10<23:13, 14.99s/epoch]

Epoch 106 | LR 0.000004 | Train MSE 0.3418 | Val MSE 0.3153 | Val MAE 2.0462 | Val MSE 15.4480


Epoch:  54%|█████▍    | 108/200 [27:25<22:58, 14.98s/epoch]

Epoch 107 | LR 0.000004 | Train MSE 0.3415 | Val MSE 0.3113 | Val MAE 2.0303 | Val MSE 15.2555


Epoch:  55%|█████▍    | 109/200 [27:40<22:41, 14.96s/epoch]

Epoch 108 | LR 0.000004 | Train MSE 0.3448 | Val MSE 0.3120 | Val MAE 2.0209 | Val MSE 15.2868


Epoch:  55%|█████▍    | 109/200 [27:55<22:41, 14.96s/epoch]

Epoch 109 | LR 0.000004 | Train MSE 0.3399 | Val MSE 0.3034 | Val MAE 1.9701 | Val MSE 14.8656
Validation improved: 0.307940 -> 0.303379


Epoch:  56%|█████▌    | 111/200 [28:10<22:16, 15.01s/epoch]

Epoch 110 | LR 0.000003 | Train MSE 0.3413 | Val MSE 0.3070 | Val MAE 1.9874 | Val MSE 15.0416
Sample pred first 3 steps: [[-3.2181456e-04 -6.3979300e-05]
 [-7.1780174e-04 -7.6398812e-04]
 [-9.6686790e-04 -5.7006627e-04]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  56%|█████▌    | 112/200 [28:25<21:59, 14.99s/epoch]

Epoch 111 | LR 0.000003 | Train MSE 0.3409 | Val MSE 0.3104 | Val MAE 2.0294 | Val MSE 15.2087


Epoch:  56%|█████▋    | 113/200 [28:40<21:43, 14.98s/epoch]

Epoch 112 | LR 0.000003 | Train MSE 0.3416 | Val MSE 0.3163 | Val MAE 2.0214 | Val MSE 15.5003


Epoch:  57%|█████▋    | 114/200 [28:55<21:32, 15.02s/epoch]

Epoch 113 | LR 0.000003 | Train MSE 0.3385 | Val MSE 0.3048 | Val MAE 2.0017 | Val MSE 14.9345


Epoch:  57%|█████▊    | 115/200 [29:10<21:17, 15.03s/epoch]

Epoch 114 | LR 0.000003 | Train MSE 0.3380 | Val MSE 0.3088 | Val MAE 2.0145 | Val MSE 15.1310


Epoch:  58%|█████▊    | 116/200 [29:25<21:04, 15.05s/epoch]

Epoch 115 | LR 0.000003 | Train MSE 0.3407 | Val MSE 0.3072 | Val MAE 2.0270 | Val MSE 15.0506
Sample pred first 3 steps: [[ 6.6948123e-05  2.4849735e-04]
 [-7.2163763e-05 -2.4272129e-05]
 [-2.4799444e-04  3.8562622e-04]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  58%|█████▊    | 117/200 [29:40<20:47, 15.03s/epoch]

Epoch 116 | LR 0.000002 | Train MSE 0.3367 | Val MSE 0.3053 | Val MAE 1.9909 | Val MSE 14.9605


Epoch:  59%|█████▉    | 118/200 [29:55<20:36, 15.07s/epoch]

Epoch 117 | LR 0.000002 | Train MSE 0.3339 | Val MSE 0.3062 | Val MAE 1.9925 | Val MSE 15.0050


Epoch:  60%|█████▉    | 119/200 [30:10<20:17, 15.03s/epoch]

Epoch 118 | LR 0.000002 | Train MSE 0.3379 | Val MSE 0.3071 | Val MAE 1.9968 | Val MSE 15.0470


Epoch:  60%|██████    | 120/200 [30:25<20:01, 15.02s/epoch]

Epoch 119 | LR 0.000002 | Train MSE 0.3315 | Val MSE 0.3047 | Val MAE 1.9950 | Val MSE 14.9319


Epoch:  60%|██████    | 120/200 [30:40<20:01, 15.02s/epoch]

Epoch 120 | LR 0.000002 | Train MSE 0.3388 | Val MSE 0.3032 | Val MAE 1.9932 | Val MSE 14.8572
Sample pred first 3 steps: [[-2.1946500e-04  2.0264229e-04]
 [-6.6447491e-04 -3.6621839e-04]
 [-9.0735406e-04  1.3708137e-05]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]
Validation improved: 0.303379 -> 0.303207


Epoch:  61%|██████    | 122/200 [30:55<19:30, 15.00s/epoch]

Epoch 121 | LR 0.000002 | Train MSE 0.3378 | Val MSE 0.3080 | Val MAE 2.0141 | Val MSE 15.0942


Epoch:  62%|██████▏   | 123/200 [31:10<19:14, 14.99s/epoch]

Epoch 122 | LR 0.000002 | Train MSE 0.3345 | Val MSE 0.3079 | Val MAE 2.0140 | Val MSE 15.0885


Epoch:  62%|██████▏   | 123/200 [31:25<19:14, 14.99s/epoch]

Epoch 123 | LR 0.000002 | Train MSE 0.3414 | Val MSE 0.3026 | Val MAE 1.9811 | Val MSE 14.8272
Validation improved: 0.303207 -> 0.302595


Epoch:  62%|██████▎   | 125/200 [31:40<18:43, 14.98s/epoch]

Epoch 124 | LR 0.000002 | Train MSE 0.3357 | Val MSE 0.3052 | Val MAE 1.9999 | Val MSE 14.9547


Epoch:  63%|██████▎   | 126/200 [31:55<18:26, 14.95s/epoch]

Epoch 125 | LR 0.000002 | Train MSE 0.3315 | Val MSE 0.3060 | Val MAE 2.0091 | Val MSE 14.9956
Sample pred first 3 steps: [[ 9.4552292e-05  6.2930863e-05]
 [-2.8391019e-04 -7.0445566e-04]
 [-4.0623406e-04 -4.1570328e-04]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  64%|██████▎   | 127/200 [32:10<18:09, 14.92s/epoch]

Epoch 126 | LR 0.000001 | Train MSE 0.3350 | Val MSE 0.3051 | Val MAE 2.0026 | Val MSE 14.9488


Epoch:  64%|██████▍   | 128/200 [32:25<17:53, 14.91s/epoch]

Epoch 127 | LR 0.000001 | Train MSE 0.3404 | Val MSE 0.3075 | Val MAE 2.0016 | Val MSE 15.0660


Epoch:  64%|██████▍   | 129/200 [32:40<17:39, 14.92s/epoch]

Epoch 128 | LR 0.000001 | Train MSE 0.3345 | Val MSE 0.3063 | Val MAE 1.9960 | Val MSE 15.0076


Epoch:  64%|██████▍   | 129/200 [32:55<17:39, 14.92s/epoch]

Epoch 129 | LR 0.000001 | Train MSE 0.3318 | Val MSE 0.3021 | Val MAE 1.9819 | Val MSE 14.8020
Validation improved: 0.302595 -> 0.302082


Epoch:  66%|██████▌   | 131/200 [33:10<17:16, 15.02s/epoch]

Epoch 130 | LR 0.000001 | Train MSE 0.3380 | Val MSE 0.3034 | Val MAE 1.9765 | Val MSE 14.8648
Sample pred first 3 steps: [[-1.2518745e-04  9.2418399e-05]
 [-5.3275959e-04 -6.0000131e-04]
 [-6.9707911e-04 -2.0813290e-04]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  66%|██████▌   | 132/200 [33:25<16:59, 15.00s/epoch]

Epoch 131 | LR 0.000001 | Train MSE 0.3307 | Val MSE 0.3092 | Val MAE 2.0121 | Val MSE 15.1503


Epoch:  66%|██████▋   | 133/200 [33:40<16:41, 14.96s/epoch]

Epoch 132 | LR 0.000001 | Train MSE 0.3341 | Val MSE 0.3065 | Val MAE 2.0066 | Val MSE 15.0176


Epoch:  66%|██████▋   | 133/200 [33:55<16:41, 14.96s/epoch]

Epoch 133 | LR 0.000001 | Train MSE 0.3390 | Val MSE 0.3010 | Val MAE 1.9784 | Val MSE 14.7501
Validation improved: 0.302082 -> 0.301023


Epoch:  68%|██████▊   | 135/200 [34:10<16:15, 15.00s/epoch]

Epoch 134 | LR 0.000001 | Train MSE 0.3333 | Val MSE 0.3016 | Val MAE 1.9747 | Val MSE 14.7763


Epoch:  68%|██████▊   | 136/200 [34:25<15:58, 14.98s/epoch]

Epoch 135 | LR 0.000001 | Train MSE 0.3318 | Val MSE 0.3038 | Val MAE 1.9861 | Val MSE 14.8885
Sample pred first 3 steps: [[-0.00031199 -0.00030073]
 [-0.00077044 -0.00108006]
 [-0.00093085 -0.00074896]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  68%|██████▊   | 136/200 [34:40<15:58, 14.98s/epoch]

Epoch 136 | LR 0.000001 | Train MSE 0.3306 | Val MSE 0.3007 | Val MAE 1.9795 | Val MSE 14.7330
Validation improved: 0.301023 -> 0.300674


Epoch:  69%|██████▉   | 138/200 [34:55<15:31, 15.02s/epoch]

Epoch 137 | LR 0.000001 | Train MSE 0.3346 | Val MSE 0.3049 | Val MAE 1.9928 | Val MSE 14.9418


Epoch:  70%|██████▉   | 139/200 [35:10<15:15, 15.01s/epoch]

Epoch 138 | LR 0.000001 | Train MSE 0.3252 | Val MSE 0.3043 | Val MAE 1.9871 | Val MSE 14.9087


Epoch:  70%|██████▉   | 139/200 [35:25<15:15, 15.01s/epoch]

Epoch 139 | LR 0.000001 | Train MSE 0.3308 | Val MSE 0.3004 | Val MAE 1.9673 | Val MSE 14.7209
Validation improved: 0.300674 -> 0.300427


Epoch:  70%|███████   | 141/200 [35:41<14:53, 15.14s/epoch]

Epoch 140 | LR 0.000001 | Train MSE 0.3377 | Val MSE 0.3052 | Val MAE 1.9932 | Val MSE 14.9532
Sample pred first 3 steps: [[-2.4493784e-05 -3.5407022e-05]
 [-4.0842476e-04 -8.6999591e-04]
 [-5.4350309e-04 -6.6339690e-04]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  71%|███████   | 142/200 [35:56<14:46, 15.28s/epoch]

Epoch 141 | LR 0.000001 | Train MSE 0.3349 | Val MSE 0.3048 | Val MAE 1.9974 | Val MSE 14.9334


Epoch:  72%|███████▏  | 143/200 [36:12<14:34, 15.34s/epoch]

Epoch 142 | LR 0.000001 | Train MSE 0.3382 | Val MSE 0.3064 | Val MAE 1.9988 | Val MSE 15.0130


Epoch:  72%|███████▏  | 144/200 [36:27<14:23, 15.42s/epoch]

Epoch 143 | LR 0.000001 | Train MSE 0.3294 | Val MSE 0.3025 | Val MAE 1.9828 | Val MSE 14.8229


Epoch:  72%|███████▎  | 145/200 [36:43<14:09, 15.44s/epoch]

Epoch 144 | LR 0.000001 | Train MSE 0.3346 | Val MSE 0.3029 | Val MAE 1.9921 | Val MSE 14.8420


Epoch:  73%|███████▎  | 146/200 [36:58<13:54, 15.46s/epoch]

Epoch 145 | LR 0.000001 | Train MSE 0.3382 | Val MSE 0.3052 | Val MAE 2.0058 | Val MSE 14.9528
Sample pred first 3 steps: [[-0.0001607   0.00016004]
 [-0.00053812 -0.00048666]
 [-0.00065087 -0.00019469]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  74%|███████▎  | 147/200 [37:14<13:40, 15.48s/epoch]

Epoch 146 | LR 0.000001 | Train MSE 0.3323 | Val MSE 0.3043 | Val MAE 1.9989 | Val MSE 14.9111


Epoch:  74%|███████▍  | 148/200 [37:29<13:25, 15.50s/epoch]

Epoch 147 | LR 0.000001 | Train MSE 0.3311 | Val MSE 0.3030 | Val MAE 1.9924 | Val MSE 14.8446


Epoch:  74%|███████▍  | 149/200 [37:45<13:12, 15.54s/epoch]

Epoch 148 | LR 0.000000 | Train MSE 0.3321 | Val MSE 0.3036 | Val MAE 1.9877 | Val MSE 14.8775


Epoch:  75%|███████▌  | 150/200 [38:01<12:58, 15.58s/epoch]

Epoch 149 | LR 0.000000 | Train MSE 0.3294 | Val MSE 0.3032 | Val MAE 1.9902 | Val MSE 14.8577


Epoch:  76%|███████▌  | 151/200 [38:16<12:35, 15.42s/epoch]

Epoch 150 | LR 0.000000 | Train MSE 0.3323 | Val MSE 0.3047 | Val MAE 1.9935 | Val MSE 14.9313
Sample pred first 3 steps: [[-0.00015335  0.00010495]
 [-0.00077667 -0.00063694]
 [-0.00093852 -0.00039839]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  76%|███████▌  | 152/200 [38:30<12:12, 15.26s/epoch]

Epoch 151 | LR 0.000000 | Train MSE 0.3309 | Val MSE 0.3034 | Val MAE 1.9921 | Val MSE 14.8677


Epoch:  76%|███████▋  | 153/200 [38:45<11:53, 15.17s/epoch]

Epoch 152 | LR 0.000000 | Train MSE 0.3341 | Val MSE 0.3012 | Val MAE 1.9812 | Val MSE 14.7589


Epoch:  77%|███████▋  | 154/200 [39:01<11:36, 15.14s/epoch]

Epoch 153 | LR 0.000000 | Train MSE 0.3354 | Val MSE 0.3025 | Val MAE 1.9834 | Val MSE 14.8238


Epoch:  78%|███████▊  | 155/200 [39:16<11:18, 15.09s/epoch]

Epoch 154 | LR 0.000000 | Train MSE 0.3375 | Val MSE 0.3034 | Val MAE 1.9897 | Val MSE 14.8652


Epoch:  78%|███████▊  | 156/200 [39:30<11:02, 15.06s/epoch]

Epoch 155 | LR 0.000000 | Train MSE 0.3335 | Val MSE 0.3034 | Val MAE 1.9887 | Val MSE 14.8653
Sample pred first 3 steps: [[-1.8348824e-04  9.5112482e-05]
 [-6.1050593e-04 -6.1955862e-04]
 [-6.7408988e-04 -3.9240252e-04]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  78%|███████▊  | 157/200 [39:46<10:47, 15.06s/epoch]

Epoch 156 | LR 0.000000 | Train MSE 0.3307 | Val MSE 0.3036 | Val MAE 1.9888 | Val MSE 14.8745


Epoch:  79%|███████▉  | 158/200 [40:01<10:32, 15.06s/epoch]

Epoch 157 | LR 0.000000 | Train MSE 0.3246 | Val MSE 0.3033 | Val MAE 1.9911 | Val MSE 14.8618


Epoch:  80%|███████▉  | 159/200 [40:16<10:16, 15.03s/epoch]

Epoch 158 | LR 0.000000 | Train MSE 0.3364 | Val MSE 0.3020 | Val MAE 1.9813 | Val MSE 14.7986


Epoch:  80%|████████  | 160/200 [40:30<09:58, 14.97s/epoch]

Epoch 159 | LR 0.000000 | Train MSE 0.3377 | Val MSE 0.3015 | Val MAE 1.9789 | Val MSE 14.7743


Epoch:  80%|████████  | 160/200 [40:45<09:58, 14.97s/epoch]

Epoch 160 | LR 0.000000 | Train MSE 0.3277 | Val MSE 0.3001 | Val MAE 1.9743 | Val MSE 14.7040
Sample pred first 3 steps: [[-6.8911584e-05  3.1688483e-05]
 [-4.4203503e-04 -6.4203236e-04]
 [-4.9608992e-04 -4.2981282e-04]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]
Validation improved: 0.300427 -> 0.300082


Epoch:  81%|████████  | 162/200 [41:01<09:30, 15.01s/epoch]

Epoch 161 | LR 0.000000 | Train MSE 0.3286 | Val MSE 0.3031 | Val MAE 1.9876 | Val MSE 14.8535


Epoch:  82%|████████▏ | 163/200 [41:15<09:14, 14.98s/epoch]

Epoch 162 | LR 0.000000 | Train MSE 0.3349 | Val MSE 0.3038 | Val MAE 1.9951 | Val MSE 14.8858


Epoch:  82%|████████▏ | 164/200 [41:30<08:59, 14.98s/epoch]

Epoch 163 | LR 0.000000 | Train MSE 0.3336 | Val MSE 0.3020 | Val MAE 1.9846 | Val MSE 14.7993


Epoch:  82%|████████▎ | 165/200 [41:45<08:43, 14.96s/epoch]

Epoch 164 | LR 0.000000 | Train MSE 0.3331 | Val MSE 0.3029 | Val MAE 1.9844 | Val MSE 14.8402


Epoch:  83%|████████▎ | 166/200 [42:01<08:31, 15.05s/epoch]

Epoch 165 | LR 0.000000 | Train MSE 0.3339 | Val MSE 0.3026 | Val MAE 1.9866 | Val MSE 14.8290
Sample pred first 3 steps: [[-3.3544376e-05  1.3011019e-04]
 [-3.2788748e-04 -5.1451148e-04]
 [-3.5643065e-04 -3.0552037e-04]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  84%|████████▎ | 167/200 [42:16<08:15, 15.02s/epoch]

Epoch 166 | LR 0.000000 | Train MSE 0.3320 | Val MSE 0.3033 | Val MAE 1.9850 | Val MSE 14.8594


Epoch:  84%|████████▍ | 168/200 [42:31<08:05, 15.18s/epoch]

Epoch 167 | LR 0.000000 | Train MSE 0.3273 | Val MSE 0.3028 | Val MAE 1.9836 | Val MSE 14.8350


Epoch:  84%|████████▍ | 169/200 [42:47<07:54, 15.31s/epoch]

Epoch 168 | LR 0.000000 | Train MSE 0.3345 | Val MSE 0.3022 | Val MAE 1.9828 | Val MSE 14.8066


Epoch:  85%|████████▌ | 170/200 [43:02<07:41, 15.38s/epoch]

Epoch 169 | LR 0.000000 | Train MSE 0.3226 | Val MSE 0.3014 | Val MAE 1.9775 | Val MSE 14.7710


Epoch:  86%|████████▌ | 171/200 [43:18<07:27, 15.44s/epoch]

Epoch 170 | LR 0.000000 | Train MSE 0.3323 | Val MSE 0.3017 | Val MAE 1.9780 | Val MSE 14.7809
Sample pred first 3 steps: [[-0.00016326  0.00010219]
 [-0.00055592 -0.00057562]
 [-0.00062888 -0.0003458 ]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  86%|████████▌ | 172/200 [43:33<07:14, 15.50s/epoch]

Epoch 171 | LR 0.000000 | Train MSE 0.3329 | Val MSE 0.3023 | Val MAE 1.9828 | Val MSE 14.8150


Epoch:  86%|████████▋ | 173/200 [43:49<06:59, 15.54s/epoch]

Epoch 172 | LR 0.000000 | Train MSE 0.3305 | Val MSE 0.3022 | Val MAE 1.9819 | Val MSE 14.8063


Epoch:  87%|████████▋ | 174/200 [44:05<06:44, 15.56s/epoch]

Epoch 173 | LR 0.000000 | Train MSE 0.3321 | Val MSE 0.3037 | Val MAE 1.9898 | Val MSE 14.8799


Epoch:  88%|████████▊ | 175/200 [44:20<06:29, 15.56s/epoch]

Epoch 174 | LR 0.000000 | Train MSE 0.3341 | Val MSE 0.3034 | Val MAE 1.9880 | Val MSE 14.8674


Epoch:  88%|████████▊ | 176/200 [44:36<06:14, 15.60s/epoch]

Epoch 175 | LR 0.000000 | Train MSE 0.3270 | Val MSE 0.3029 | Val MAE 1.9858 | Val MSE 14.8420
Sample pred first 3 steps: [[-0.00024054  0.00011597]
 [-0.00068239 -0.00060324]
 [-0.00078892 -0.00036438]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  88%|████████▊ | 177/200 [44:52<05:59, 15.62s/epoch]

Epoch 176 | LR 0.000000 | Train MSE 0.3253 | Val MSE 0.3016 | Val MAE 1.9817 | Val MSE 14.7766


Epoch:  89%|████████▉ | 178/200 [45:07<05:43, 15.59s/epoch]

Epoch 177 | LR 0.000000 | Train MSE 0.3303 | Val MSE 0.3012 | Val MAE 1.9780 | Val MSE 14.7569


Epoch:  90%|████████▉ | 179/200 [45:23<05:27, 15.59s/epoch]

Epoch 178 | LR 0.000000 | Train MSE 0.3306 | Val MSE 0.3021 | Val MAE 1.9810 | Val MSE 14.8045


Epoch:  90%|█████████ | 180/200 [45:38<05:09, 15.47s/epoch]

Epoch 179 | LR 0.000000 | Train MSE 0.3263 | Val MSE 0.3024 | Val MAE 1.9826 | Val MSE 14.8178


Epoch:  90%|█████████ | 181/200 [45:53<04:51, 15.32s/epoch]

Epoch 180 | LR 0.000000 | Train MSE 0.3324 | Val MSE 0.3016 | Val MAE 1.9810 | Val MSE 14.7803
Sample pred first 3 steps: [[-0.00020939  0.00013338]
 [-0.00064006 -0.00054861]
 [-0.00074368 -0.00030851]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  91%|█████████ | 182/200 [46:08<04:33, 15.21s/epoch]

Epoch 181 | LR 0.000000 | Train MSE 0.3293 | Val MSE 0.3012 | Val MAE 1.9784 | Val MSE 14.7568


Epoch:  92%|█████████▏| 183/200 [46:23<04:17, 15.15s/epoch]

Epoch 182 | LR 0.000000 | Train MSE 0.3337 | Val MSE 0.3013 | Val MAE 1.9796 | Val MSE 14.7659


Epoch:  92%|█████████▏| 184/200 [46:38<04:01, 15.07s/epoch]

Epoch 183 | LR 0.000000 | Train MSE 0.3378 | Val MSE 0.3018 | Val MAE 1.9825 | Val MSE 14.7876


Epoch:  92%|█████████▎| 185/200 [46:53<03:45, 15.02s/epoch]

Epoch 184 | LR 0.000000 | Train MSE 0.3337 | Val MSE 0.3019 | Val MAE 1.9823 | Val MSE 14.7908


Epoch:  93%|█████████▎| 186/200 [47:08<03:29, 14.99s/epoch]

Epoch 185 | LR 0.000000 | Train MSE 0.3319 | Val MSE 0.3018 | Val MAE 1.9810 | Val MSE 14.7868
Sample pred first 3 steps: [[-0.00026315  0.00013175]
 [-0.00074306 -0.00059069]
 [-0.00086887 -0.00033654]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]


Epoch:  94%|█████████▎| 187/200 [47:23<03:14, 14.98s/epoch]

Epoch 186 | LR 0.000000 | Train MSE 0.3300 | Val MSE 0.3009 | Val MAE 1.9772 | Val MSE 14.7443


Epoch:  94%|█████████▍| 188/200 [47:37<02:59, 14.97s/epoch]

Epoch 187 | LR 0.000000 | Train MSE 0.3302 | Val MSE 0.3011 | Val MAE 1.9784 | Val MSE 14.7552


Epoch:  94%|█████████▍| 189/200 [47:52<02:44, 14.97s/epoch]

Epoch 188 | LR 0.000000 | Train MSE 0.3288 | Val MSE 0.3018 | Val MAE 1.9817 | Val MSE 14.7897


Epoch:  95%|█████████▌| 190/200 [48:07<02:29, 14.96s/epoch]

Epoch 189 | LR 0.000000 | Train MSE 0.3318 | Val MSE 0.3021 | Val MAE 1.9821 | Val MSE 14.8042


Epoch:  95%|█████████▌| 190/200 [48:22<02:32, 15.28s/epoch]


Epoch 190 | LR 0.000000 | Train MSE 0.3352 | Val MSE 0.3014 | Val MAE 1.9794 | Val MSE 14.7690
Sample pred first 3 steps: [[-0.00020279  0.00013119]
 [-0.00065784 -0.00057035]
 [-0.00076703 -0.00032313]]
Sample target first 3 steps: [[3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]
 [3.2984033e-06 6.9473140e-06]]
Early stopping after 191 epochs without improvement
Val MSE: 14.7040


SimpleLSTM(
  (lstm1): LSTM(6, 512, num_layers=6, batch_first=True)
  (lstm2): LSTM(64, 128, num_layers=8, batch_first=True)
  (compress1): Linear(in_features=300, out_features=512, bias=True)
  (compress2): Linear(in_features=512, out_features=512, bias=True)
  (compress3): Linear(in_features=512, out_features=512, bias=True)
  (compress4): Linear(in_features=512, out_features=256, bias=True)
  (compress5): Linear(in_features=256, out_features=128, bias=True)
  (compress6): Linear(in_features=128, out_features=64, bias=True)
  (fc1): Linear(in_features=640, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=2048, bias=True)
  (fc3): Linear(in_features=2048, out_features=2048, bias=True)
  (fc4): Linear(in_features=2048, out_features=512, bias=True)
  (fc5): Linear(in_features=512, out_features=512, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (relu): ReLU()
  (fc6): Linear(in_features=512, out_features=120, bias=True)
)

# Final Pred

In [16]:
test_dataset = TrajectoryDatasetTest(test_data, scale=scale)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,
                         collate_fn=lambda xs: Batch.from_data_list(xs))

best_model = torch.load("best_model.pt")
model = SimpleLSTM().to(device)
# optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.25) # You can try different schedulers
# criterion = nn.MSELoss()

model.load_state_dict(best_model)
model.eval()

pred_list = []
with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        pred_norm = model(batch)
        
        # Reshape the prediction to (N, 60, 2)
        pred = pred_norm * batch.scale.view(-1,1,1) + batch.origin.unsqueeze(1)
        pred_list.append(pred.cpu().numpy())
pred_list = np.concatenate(pred_list, axis=0)  # (N,60,2)
pred_output = pred_list.reshape(-1, 2)  # (N*60, 2)
output_df = pd.DataFrame(pred_output, columns=['x', 'y'])
output_df.index.name = 'index'
output_df.to_csv('submission_lstm_simple_auto2.csv', index=True)