In [12]:
best_model = 'spatial_temporal_lstm_large.pth'

In [13]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
import tqdm

In [14]:
train_file = np.load('../cse-251-b-2025/train.npz')

train_data = train_file['data']
# train_data = train_data[::2]
print("train_data's shape", train_data.shape)
test_file = np.load('../cse-251-b-2025/test_input.npz')

test_data = test_file['data']
print("test_data's shape", test_data.shape)

train_data's shape (10000, 50, 110, 6)
test_data's shape (2100, 50, 50, 6)


In [15]:
class TrajectoryDatasetTrain(Dataset):
    def __init__(self, data, scale=10.0, augment=True):
        """
        data: Shape (N, 50, 110, 6) Training data
        scale: Scale for normalization (suggested to use 10.0 for Argoverse 2 data)
        augment: Whether to apply data augmentation (only for training)
        """
        self.data = data
        self.scale = scale
        self.augment = augment

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx]
        # Getting 50 historical timestamps and 60 future timestamps
        hist = scene[:, :50, :].copy()    # (agents=50, time_seq=50, 6)
        future = torch.tensor(scene[0, 50:, :2].copy(), dtype=torch.float32)  # (50, 60, 2)
        #add the feature of the scene number for each sample
       
 
        # Data augmentation(only for training)
        if self.augment:
            if np.random.rand() < 0.5:
                theta = np.random.uniform(-np.pi, np.pi)
                R = np.array([[np.cos(theta), -np.sin(theta)],
                              [np.sin(theta),  np.cos(theta)]], dtype=np.float32)
                # Rotate the historical trajectory and future trajectory
                hist[..., :2] = hist[..., :2] @ R
                hist[..., 2:4] = hist[..., 2:4] @ R
                future = future @ R
                # future[..., 2:4] = future[..., 2:4] @ R
            if np.random.rand() < 0.5:
                hist[..., 0] *= -1
                hist[..., 2] *= -1
                future[:, 0] *= -1
                # future[:, 2] *= -1

        # Use the last timeframe of the historical trajectory as the origin
        origin = hist[0, 49, :2].copy()  # (2,)
        hist[..., :2] = hist[..., :2] - origin
        # future[..., :2] = future[..., :2] - origin
        future = future - origin

        # Normalize the historical trajectory and future trajectory
        hist[..., :4] = hist[..., :4] / self.scale
        future = future / self.scale
        # hist[..., :2] = hist[..., :2] / self.scale
        # future[..., :2] = future[..., :2] / self.scale

        
        # print("hist's shape", hist.shape)
        data_item = Data(
            x=torch.tensor(hist, dtype=torch.float32),
            y=future.type(torch.float32),
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0), # (1,2)
            scale=torch.tensor(self.scale, dtype=torch.float32), # scalar e.g. 7.0
        )

        return data_item
    

class TrajectoryDatasetTest(Dataset):
    def __init__(self, data, scale=10.0):
        """
        data: Shape (N, 50, 110, 6) Testing data
        scale: Scale for normalization (suggested to use 10.0 for Argoverse 2 data)
        """
        self.data = data
        self.scale = scale

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Testing data only contains historical trajectory
        scene = self.data[idx]  # (50, 50, 6)
        hist = scene.copy()
        # hist = hist[...,]
        
        origin = hist[0, 49, :2].copy()
        hist[..., :2] = hist[..., :2] - origin
        hist[..., :4] = hist[..., :4] / self.scale
        hist[..., :2] = hist[..., :2] / self.scale

        data_item = Data(
            x=torch.tensor(hist, dtype=torch.float32),
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),
            scale=torch.tensor(self.scale, dtype=torch.float32),
        )
        return data_item

In [16]:
torch.manual_seed(251)
np.random.seed(42)

scale = 10.0 #why not 10

N = len(train_data)
val_size = int(0.1 * N)
train_size = N - val_size

train_dataset = TrajectoryDatasetTrain(train_data[:train_size], scale=scale, augment=True)
val_dataset = TrajectoryDatasetTrain(train_data[train_size:], scale=scale, augment=False)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=lambda x: Batch.from_data_list(x))
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=lambda x: Batch.from_data_list(x))

# Set device for training speedup
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using Apple Silicon GPU")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA GPU")
else:
    device = torch.device('cpu')

Using Apple Silicon GPU


In [17]:
class AutoRegressiveLSTM(nn.Module):
    def __init__(self, A=50, T=50, input_dim=5, hidden_dim=512, hidden_dim_2 = 512, output_dim=2, num_layers=1, future_steps=60):
        super().__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.future_steps = future_steps
        self.A, self.T = A, T
        self.input_proj = nn.Linear(input_dim, hidden_dim_2)
        self.spatial_embed = nn.Embedding(A, hidden_dim_2)
        self.temporal_embed = nn.Embedding(T, hidden_dim_2)
        # Encoder: takes in past trajectory
        self.input_ln = nn.LayerNorm(hidden_dim_2)
        self.encoder = nn.LSTM(input_size=hidden_dim_2, hidden_size=hidden_dim, num_layers=num_layers, dropout= 0.2, batch_first=True)

        # Decoder: predicts future positions one step at a time
        self.decoder = nn.LSTM(input_size=output_dim, hidden_size=hidden_dim, num_layers=num_layers, dropout= 0.2,  batch_first=True)
        self.decoder_ln = nn.LayerNorm(hidden_dim)

        self.out = nn.Linear(hidden_dim, output_dim)

    def forward(self, data, forcing_ratio = 0.5):

        x = data.x[..., :self.input_dim] 
        x = x.reshape(-1, 50, 50, self.input_dim) # (B*A, 50, 5)
        # x = x.reshape(-1, 50, self.input_dim)
        batch_size = x.size(0)

        B, A, T, F = x.shape
        x = x.view(B, A * T, F)  # [B, A*T, F]
        x = self.input_proj(x)  # [B, A*T, d_model]
        x = self.input_ln(x)  # [B, A*T, d_model]

        # Positional encodings
        agent_ids = torch.arange(A, device=x.device).repeat_interleave(T)  # [A*T]
        time_ids = torch.arange(T, device=x.device).repeat(A)              # [A*T]

        pe = self.spatial_embed(agent_ids) + self.temporal_embed(time_ids)  # [A*T, d_model]
        x = x + pe.unsqueeze(0)
        # print('attn done')
        if self.training:
            future = data.y.view(batch_size, 60, self.output_dim) # (batch, 60, 2)

        device = x.device

        # Encode past
        _, (hidden, cell) = self.encoder(x)
        # print('encoder done')

        # Initialize decoder input with last observed position
        # decoder_input = x.view(B,self.A,self.T,-1)[:, 0, -1, :self.output_dim]  # (batch, 1, 2)
        decoder_input = data.x.view(-1, 50, 50, self.input_dim)[:, 0, -1, :self.output_dim].unsqueeze(1)
        
        # print("decoder_input.shape - initial", decoder_input.shape)  # should be (batch, 1, 2)

        outputs = []

        for t in range(self.future_steps):
            output, (hidden, cell) = self.decoder(decoder_input, (hidden, cell))
            output = self.decoder_ln(output)  # Apply layer normalization to the decoder output
            pred = self.out(output)  # (batch, 1, 2)
            outputs.append(pred)

            # TODO: remove forcing ratio?
            if self.training and random.random() < forcing_ratio:
            # if self.training:
                decoder_input = future[:, t].unsqueeze(1)  # ground truth
                # print("decoder_input.shape - teacher forcing", decoder_input.shape)  # should be (batch, 1, 2)
            else:
                decoder_input = pred.detach()  # predicted output as next input
                # print("decoder_input.shape - autoreg", decoder_input.shape)  # should be (batch, 1, 2)

        outputs = torch.cat(outputs, dim=1)  # (batch, 60, 2)
        return outputs

In [18]:
def train_improved_model(model, train_dataloader, val_dataloader, 
                         device, criterion=nn.MSELoss(), 
                         lr=0.001, epochs=100, patience=15):
    """
    Improved training function with better debugging and early stopping
    """
    # Initialize optimizer with smaller learning rate
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # Exponential decay scheduler
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    
    early_stopping_patience = patience
    best_val_loss = float('inf')
    no_improvement = 0
    
    # Save initial state for comparison
    initial_state_dict = {k: v.clone() for k, v in model.state_dict().items()}
    
    for epoch in tqdm.tqdm(range(epochs), desc="Epoch", unit="epoch"):
        # ---- Training ----
        model.train()
        train_loss = 0
        num_train_batches = 0
        forcing_ratio = max(0.0, 1.0 - epoch / 50)
        
        for batch in train_dataloader:
            batch = batch.to(device)
            pred = model(batch, forcing_ratio=forcing_ratio)
            y = batch.y.view(batch.num_graphs, 60, 2)
            
            # Check for NaN predictions
            if torch.isnan(pred).any():
                print(f"WARNING: NaN detected in predictions during training")
                continue
                
            loss = criterion(pred, y)
            
            # Check if loss is valid
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"WARNING: Invalid loss value: {loss.item()}")
                continue
                
            optimizer.zero_grad()
            loss.backward()
            
            # More conservative gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            train_loss += loss.item()
            num_train_batches += 1
        
        # Skip epoch if no valid batches
        if num_train_batches == 0:
            print("WARNING: No valid training batches in this epoch")
            continue
            
        train_loss /= num_train_batches
        
        # ---- Validation ----
        model.eval()
        val_loss = 0
        val_mae = 0
        val_mse = 0
        num_val_batches = 0
        
        # Sample predictions for debugging
        sample_input = None
        sample_pred = None
        sample_target = None
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_dataloader):
                batch = batch.to(device)
                pred = model(batch)
                y = batch.y.view(batch.num_graphs, 60, 2)
                
                # Store sample for debugging
                if batch_idx == 0 and sample_input is None:
                    sample_input = batch.x[0].cpu().numpy()
                    sample_pred = pred[0].cpu().numpy()
                    sample_target = y[0].cpu().numpy()
                
                # Skip invalid predictions
                if torch.isnan(pred).any():
                    print(f"WARNING: NaN detected in predictions during validation")
                    continue
                    
                batch_loss = criterion(pred, y).item()
                val_loss += batch_loss
                
                # Unnormalize for real-world metrics
                # batch.scale turns scale from 7.0 or (1,) shape i.e. scalar to (B,) shape
                # batch.origin turns origin from (1,2) shape to (B,2)
                
                # then .view(-1, 1, 1) turns scale from (B,) to (B, 1, 1)
                # then .unsqueeze(1) turns origin from (B, 2) to (B, 1, 2)
                # because pred and y have shapes (B, 60, 2) so these transformations make them compatible for the calculation
                
                pred_unnorm = pred * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
                y_unnorm = y * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
                
                val_mae += nn.L1Loss()(pred_unnorm, y_unnorm).item()
                val_mse += nn.MSELoss()(pred_unnorm, y_unnorm).item()
                
                num_val_batches += 1
        
        # Skip epoch if no valid validation batches
        if num_val_batches == 0:
            print("WARNING: No valid validation batches in this epoch")
            continue
            
        val_loss /= num_val_batches
        val_mae /= num_val_batches
        val_mse /= num_val_batches
        
        # Update learning rate
        scheduler.step()
        
        # Print with more details
        tqdm.tqdm.write(
            f"Epoch {epoch:03d} | LR {optimizer.param_groups[0]['lr']:.6f} | "
            f"Train MSE {train_loss:.4f} | Val MSE {val_loss:.4f} | "
            f"Val MAE {val_mae:.4f} | Val MSE {val_mse:.4f}"
        )
        
        # Debug output - first 3 predictions vs targets
        if epoch % 5 == 0:
            tqdm.tqdm.write(f"Sample pred first 3 steps: {sample_pred[:3]}")
            tqdm.tqdm.write(f"Sample target first 3 steps: {sample_target[:3]}")
            
            # Check if model weights are changing
            if epoch > 0:
                weight_change = False
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        initial_param = initial_state_dict[name]
                        if not torch.allclose(param, initial_param, rtol=1e-4):
                            weight_change = True
                            break
                if not weight_change:
                    tqdm.tqdm.write("WARNING: Model weights barely changing!")
        
        # Relaxed improvement criterion - consider any improvement
        if val_loss < best_val_loss:
            tqdm.tqdm.write(f"Validation improved: {best_val_loss:.6f} -> {val_loss:.6f}")
            best_val_loss = val_loss
            no_improvement = 0
            torch.save(model.state_dict(), best_model)
        else:
            no_improvement += 1
            if no_improvement >= early_stopping_patience:
                print(f"Early stopping after {epoch+1} epochs without improvement")
                break
    
    # Load best model before returning
    model.load_state_dict(torch.load(best_model))
    return model

In [19]:
# Example usage
def train_and_evaluate_model():
    # Create model
    model = AutoRegressiveLSTM(input_dim=6, output_dim = 2)
    model = model.to(device)
    
    # Train with improved function
    train_improved_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        device=device,
        # lr = 0.007 => 8.946
        lr=0.007,  # Lower learning rate
        patience=20,  # More patience
        epochs=200
    )
    
    # Evaluate
    model.eval()
    test_mse = 0
    with torch.no_grad():
        for batch in val_dataloader:
            batch = batch.to(device)
            pred = model(batch)
            y = batch.y.view(batch.num_graphs, 60, 2)
            
            # Unnormalize
            pred = pred * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            y = y * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            
            test_mse += nn.MSELoss()(pred, y).item()
    
    test_mse /= len(val_dataloader)
    print(f"Val MSE: {test_mse:.4f}")
    
    return model

In [None]:
train_and_evaluate_model()

  future = future @ R
  future = future - origin
Epoch:   0%|          | 1/200 [02:44<9:04:02, 164.03s/epoch]

Epoch 000 | LR 0.006650 | Train MSE 0.2396 | Val MSE 18.8034 | Val MAE 34.4837 | Val MSE 1880.3400
Sample pred first 3 steps: [[0.05494022 0.02282789]
 [0.01379648 0.05335474]
 [0.08879995 0.12638728]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]
Validation improved: inf -> 18.803400


Epoch:   1%|          | 2/200 [05:26<8:58:59, 163.33s/epoch]

Epoch 001 | LR 0.006317 | Train MSE 0.0058 | Val MSE 8.8346 | Val MAE 22.3609 | Val MSE 883.4586
Validation improved: 18.803400 -> 8.834586


Epoch:   2%|▏         | 3/200 [08:13<9:00:53, 164.74s/epoch]

Epoch 002 | LR 0.006002 | Train MSE 0.0061 | Val MSE 13.1151 | Val MAE 26.1410 | Val MSE 1311.5093


Epoch:   2%|▏         | 4/200 [10:59<8:59:55, 165.28s/epoch]

Epoch 003 | LR 0.005702 | Train MSE 0.0030 | Val MSE 19.8297 | Val MAE 33.6215 | Val MSE 1982.9686


Epoch:   2%|▎         | 5/200 [13:42<8:54:19, 164.41s/epoch]

Epoch 004 | LR 0.005416 | Train MSE 0.0012 | Val MSE 3.9946 | Val MAE 12.8873 | Val MSE 399.4601
Validation improved: 8.834586 -> 3.994601


Epoch:   3%|▎         | 6/200 [16:25<8:49:57, 163.90s/epoch]

Epoch 005 | LR 0.005146 | Train MSE 0.0032 | Val MSE 36.1791 | Val MAE 46.9553 | Val MSE 3617.9079
Sample pred first 3 steps: [[ 0.02806851  0.01300588]
 [ 0.12000111 -0.10407393]
 [ 0.11067808 -0.03585684]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:   4%|▎         | 7/200 [19:07<8:45:44, 163.44s/epoch]

Epoch 006 | LR 0.004888 | Train MSE 0.0037 | Val MSE 37.3550 | Val MAE 45.8764 | Val MSE 3735.4987


Epoch:   4%|▍         | 8/200 [21:51<8:43:18, 163.53s/epoch]

Epoch 007 | LR 0.004644 | Train MSE 0.0011 | Val MSE 6.9861 | Val MAE 19.0316 | Val MSE 698.6110


Epoch:   4%|▍         | 9/200 [24:35<8:40:46, 163.60s/epoch]

Epoch 008 | LR 0.004412 | Train MSE 0.0007 | Val MSE 8.6127 | Val MAE 20.6516 | Val MSE 861.2707


Epoch:   5%|▌         | 10/200 [27:18<8:37:31, 163.43s/epoch]

Epoch 009 | LR 0.004191 | Train MSE 0.0012 | Val MSE 22.3125 | Val MAE 35.4896 | Val MSE 2231.2521


Epoch:   6%|▌         | 11/200 [30:01<8:34:42, 163.40s/epoch]

Epoch 010 | LR 0.003982 | Train MSE 0.0032 | Val MSE 20.2503 | Val MAE 29.2011 | Val MSE 2025.0313
Sample pred first 3 steps: [[-0.01030119  0.02718644]
 [-0.03683445  0.01592194]
 [-0.07611015  0.02644658]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:   6%|▌         | 12/200 [32:45<8:32:52, 163.68s/epoch]

Epoch 011 | LR 0.003783 | Train MSE 0.0025 | Val MSE 21.0091 | Val MAE 33.2210 | Val MSE 2100.9137


Epoch:   6%|▋         | 13/200 [35:27<8:28:30, 163.16s/epoch]

Epoch 012 | LR 0.003593 | Train MSE 0.0024 | Val MSE 4.2677 | Val MAE 14.1071 | Val MSE 426.7670


Epoch:   7%|▋         | 14/200 [38:11<8:26:25, 163.36s/epoch]

Epoch 013 | LR 0.003414 | Train MSE 0.0008 | Val MSE 4.4498 | Val MAE 14.4275 | Val MSE 444.9769


Epoch:   8%|▊         | 15/200 [40:53<8:22:44, 163.05s/epoch]

Epoch 014 | LR 0.003243 | Train MSE 0.0011 | Val MSE 4.5018 | Val MAE 15.0848 | Val MSE 450.1789


Epoch:   8%|▊         | 16/200 [43:38<8:21:16, 163.46s/epoch]

Epoch 015 | LR 0.003081 | Train MSE 0.0011 | Val MSE 6.2685 | Val MAE 17.7532 | Val MSE 626.8479
Sample pred first 3 steps: [[ 0.04307085  0.02704006]
 [ 0.08728755 -0.00918118]
 [ 0.11479218 -0.00895115]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:   8%|▊         | 17/200 [46:23<8:20:02, 163.95s/epoch]

Epoch 016 | LR 0.002927 | Train MSE 0.0023 | Val MSE 24.9990 | Val MAE 36.8748 | Val MSE 2499.8967


Epoch:   9%|▉         | 18/200 [49:08<8:18:40, 164.40s/epoch]

Epoch 017 | LR 0.002781 | Train MSE 0.0016 | Val MSE 10.9923 | Val MAE 20.4470 | Val MSE 1099.2271


Epoch:  10%|▉         | 19/200 [51:52<8:15:15, 164.17s/epoch]

Epoch 018 | LR 0.002641 | Train MSE 0.0014 | Val MSE 5.2278 | Val MAE 16.1437 | Val MSE 522.7828


Epoch:  10%|█         | 20/200 [54:35<8:11:24, 163.80s/epoch]

Epoch 019 | LR 0.002509 | Train MSE 0.0020 | Val MSE 5.4450 | Val MAE 16.2159 | Val MSE 544.4980


Epoch:  10%|█         | 21/200 [57:18<8:08:08, 163.62s/epoch]

Epoch 020 | LR 0.002384 | Train MSE 0.0018 | Val MSE 2.8436 | Val MAE 10.0806 | Val MSE 284.3613
Sample pred first 3 steps: [[-0.02085766  0.01110913]
 [-0.03967725 -0.01223704]
 [-0.07659718 -0.00226747]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]
Validation improved: 3.994601 -> 2.843613


Epoch:  11%|█         | 22/200 [1:00:01<8:04:36, 163.35s/epoch]

Epoch 021 | LR 0.002265 | Train MSE 0.0010 | Val MSE 3.8503 | Val MAE 12.8145 | Val MSE 385.0322


Epoch:  12%|█▏        | 23/200 [1:02:42<8:00:10, 162.77s/epoch]

Epoch 022 | LR 0.002151 | Train MSE 0.0019 | Val MSE 3.7894 | Val MAE 12.7689 | Val MSE 378.9438


Epoch:  12%|█▏        | 24/200 [1:05:23<7:55:56, 162.25s/epoch]

Epoch 023 | LR 0.002044 | Train MSE 0.0012 | Val MSE 3.3622 | Val MAE 11.7717 | Val MSE 336.2155


Epoch:  12%|█▎        | 25/200 [1:23:18<21:11:37, 435.99s/epoch]

Epoch 024 | LR 0.001942 | Train MSE 0.0030 | Val MSE 4.1491 | Val MAE 13.6776 | Val MSE 414.9149


Epoch:  13%|█▎        | 26/200 [1:29:29<20:07:39, 416.43s/epoch]

Epoch 025 | LR 0.001845 | Train MSE 0.0018 | Val MSE 2.6780 | Val MAE 9.9195 | Val MSE 267.7980
Sample pred first 3 steps: [[0.00486567 0.00962381]
 [0.01451507 0.02260903]
 [0.00532266 0.03676297]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]
Validation improved: 2.843613 -> 2.677980


Epoch:  14%|█▎        | 27/200 [1:32:12<16:21:45, 340.50s/epoch]

Epoch 026 | LR 0.001752 | Train MSE 0.0018 | Val MSE 7.8537 | Val MAE 19.8888 | Val MSE 785.3737


Epoch:  14%|█▍        | 28/200 [1:34:54<13:42:44, 287.00s/epoch]

Epoch 027 | LR 0.001665 | Train MSE 0.0073 | Val MSE 1.9552 | Val MAE 10.0167 | Val MSE 195.5237
Validation improved: 2.677980 -> 1.955237


Epoch:  14%|█▍        | 29/200 [1:37:36<11:50:44, 249.38s/epoch]

Epoch 028 | LR 0.001582 | Train MSE 0.0005 | Val MSE 0.5890 | Val MAE 4.3819 | Val MSE 58.9003
Validation improved: 1.955237 -> 0.589003


Epoch:  15%|█▌        | 30/200 [1:40:18<10:32:47, 223.34s/epoch]

Epoch 029 | LR 0.001502 | Train MSE 0.0006 | Val MSE 2.6787 | Val MAE 10.6505 | Val MSE 267.8700


Epoch:  16%|█▌        | 31/200 [1:43:01<9:38:02, 205.22s/epoch] 

Epoch 030 | LR 0.001427 | Train MSE 0.0008 | Val MSE 0.5285 | Val MAE 4.3039 | Val MSE 52.8470
Sample pred first 3 steps: [[-0.00880465 -0.00278153]
 [-0.00347604  0.000955  ]
 [-0.00173968 -0.00324715]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]
Validation improved: 0.589003 -> 0.528470


Epoch:  16%|█▌        | 32/200 [1:45:45<8:59:42, 192.76s/epoch]

Epoch 031 | LR 0.001356 | Train MSE 0.0015 | Val MSE 1.1816 | Val MAE 6.9988 | Val MSE 118.1570


Epoch:  16%|█▋        | 33/200 [1:48:28<8:31:30, 183.78s/epoch]

Epoch 032 | LR 0.001288 | Train MSE 0.0068 | Val MSE 3.7237 | Val MAE 13.5524 | Val MSE 372.3740


Epoch:  17%|█▋        | 34/200 [1:51:11<8:11:29, 177.64s/epoch]

Epoch 033 | LR 0.001224 | Train MSE 0.0006 | Val MSE 0.4253 | Val MAE 4.0797 | Val MSE 42.5349
Validation improved: 0.528470 -> 0.425349


Epoch:  18%|█▊        | 35/200 [1:53:54<7:56:22, 173.23s/epoch]

Epoch 034 | LR 0.001163 | Train MSE 0.0028 | Val MSE 0.3867 | Val MAE 3.5514 | Val MSE 38.6710
Validation improved: 0.425349 -> 0.386710


Epoch:  18%|█▊        | 36/200 [1:56:37<7:45:04, 170.15s/epoch]

Epoch 035 | LR 0.001104 | Train MSE 0.0020 | Val MSE 1.3773 | Val MAE 7.6812 | Val MSE 137.7331
Sample pred first 3 steps: [[-0.00341757  0.00786195]
 [-0.00735325  0.00465561]
 [-0.00787984 -0.00308512]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  18%|█▊        | 37/200 [1:59:21<7:37:12, 168.30s/epoch]

Epoch 036 | LR 0.001049 | Train MSE 0.0058 | Val MSE 0.9450 | Val MAE 6.4781 | Val MSE 94.5020


Epoch:  19%|█▉        | 38/200 [2:02:05<7:30:48, 166.97s/epoch]

Epoch 037 | LR 0.000997 | Train MSE 0.0027 | Val MSE 1.8384 | Val MAE 9.3904 | Val MSE 183.8388


Epoch:  20%|█▉        | 39/200 [3:48:31<90:54:04, 2032.58s/epoch]

Epoch 038 | LR 0.000947 | Train MSE 0.0041 | Val MSE 0.3130 | Val MAE 3.2051 | Val MSE 31.2966
Validation improved: 0.386710 -> 0.312966


Epoch:  20%|██        | 40/200 [3:51:12<65:23:35, 1471.35s/epoch]

Epoch 039 | LR 0.000900 | Train MSE 0.0058 | Val MSE 0.6629 | Val MAE 5.2151 | Val MSE 66.2930


Epoch:  20%|██        | 41/200 [3:53:55<47:38:40, 1078.74s/epoch]

Epoch 040 | LR 0.000855 | Train MSE 0.0097 | Val MSE 0.3244 | Val MAE 3.1234 | Val MSE 32.4363
Sample pred first 3 steps: [[ 0.00612723 -0.00376687]
 [ 0.00891236  0.00328803]
 [ 0.00786905  0.00469331]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  21%|██        | 42/200 [3:56:40<35:18:43, 804.58s/epoch] 

Epoch 041 | LR 0.000812 | Train MSE 0.0079 | Val MSE 0.3460 | Val MAE 3.4378 | Val MSE 34.6040


Epoch:  22%|██▏       | 43/200 [3:59:28<26:45:45, 613.67s/epoch]

Epoch 042 | LR 0.000771 | Train MSE 0.0158 | Val MSE 1.1509 | Val MAE 7.3596 | Val MSE 115.0943


Epoch:  22%|██▏       | 44/200 [4:02:10<20:43:23, 478.23s/epoch]

Epoch 043 | LR 0.000733 | Train MSE 0.0129 | Val MSE 0.9171 | Val MAE 6.5409 | Val MSE 91.7057


Epoch:  22%|██▎       | 45/200 [4:04:53<16:30:31, 383.43s/epoch]

Epoch 044 | LR 0.000696 | Train MSE 0.0255 | Val MSE 0.4565 | Val MAE 4.1702 | Val MSE 45.6538


Epoch:  23%|██▎       | 46/200 [4:07:34<13:33:06, 316.79s/epoch]

Epoch 045 | LR 0.000661 | Train MSE 0.0294 | Val MSE 1.6788 | Val MAE 9.3548 | Val MSE 167.8801
Sample pred first 3 steps: [[ 0.00562045  0.01059959]
 [-0.01140896 -0.00632125]
 [-0.02517348 -0.03053093]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  24%|██▎       | 47/200 [4:10:14<11:27:37, 269.66s/epoch]

Epoch 046 | LR 0.000628 | Train MSE 0.0581 | Val MSE 0.3239 | Val MAE 3.5249 | Val MSE 32.3880


Epoch:  24%|██▍       | 48/200 [4:12:52<9:58:31, 236.26s/epoch] 

Epoch 047 | LR 0.000597 | Train MSE 0.0857 | Val MSE 0.6762 | Val MAE 5.1241 | Val MSE 67.6206


Epoch:  24%|██▍       | 49/200 [4:15:30<8:55:29, 212.78s/epoch]

Epoch 048 | LR 0.000567 | Train MSE 0.1301 | Val MSE 0.3202 | Val MAE 3.0207 | Val MSE 32.0150


Epoch:  25%|██▌       | 50/200 [4:18:10<8:12:15, 196.90s/epoch]

Epoch 049 | LR 0.000539 | Train MSE 0.2350 | Val MSE 0.4232 | Val MAE 4.0281 | Val MSE 42.3192


Epoch:  26%|██▌       | 51/200 [4:20:50<7:41:35, 185.87s/epoch]

Epoch 050 | LR 0.000512 | Train MSE 0.4584 | Val MSE 0.4403 | Val MAE 4.0310 | Val MSE 44.0269
Sample pred first 3 steps: [[0.01531954 0.023463  ]
 [0.01538081 0.05585781]
 [0.02415663 0.09425094]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  26%|██▌       | 52/200 [4:23:30<7:19:02, 177.99s/epoch]

Epoch 051 | LR 0.000486 | Train MSE 0.3410 | Val MSE 0.2810 | Val MAE 3.3698 | Val MSE 28.0955
Validation improved: 0.312966 -> 0.280955


Epoch:  26%|██▋       | 53/200 [4:26:12<7:05:01, 173.48s/epoch]

Epoch 052 | LR 0.000462 | Train MSE 0.2681 | Val MSE 0.2141 | Val MAE 2.6098 | Val MSE 21.4075
Validation improved: 0.280955 -> 0.214075


Epoch:  27%|██▋       | 54/200 [4:28:56<6:54:50, 170.48s/epoch]

Epoch 053 | LR 0.000439 | Train MSE 0.2486 | Val MSE 0.2580 | Val MAE 3.1892 | Val MSE 25.7963


Epoch:  28%|██▊       | 55/200 [4:31:39<6:46:24, 168.17s/epoch]

Epoch 054 | LR 0.000417 | Train MSE 0.2473 | Val MSE 0.2243 | Val MAE 2.7154 | Val MSE 22.4328


Epoch:  28%|██▊       | 56/200 [4:34:20<6:38:52, 166.20s/epoch]

Epoch 055 | LR 0.000396 | Train MSE 0.2234 | Val MSE 0.2484 | Val MAE 3.0478 | Val MSE 24.8369
Sample pred first 3 steps: [[-0.00364922 -0.00329226]
 [-0.00476248  0.00060519]
 [-0.00642664  0.00552462]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  28%|██▊       | 57/200 [4:37:02<6:32:41, 164.76s/epoch]

Epoch 056 | LR 0.000376 | Train MSE 0.2145 | Val MSE 0.1929 | Val MAE 2.3833 | Val MSE 19.2860
Validation improved: 0.214075 -> 0.192861


Epoch:  29%|██▉       | 58/200 [4:39:44<6:28:23, 164.11s/epoch]

Epoch 057 | LR 0.000357 | Train MSE 0.1995 | Val MSE 0.2255 | Val MAE 2.8485 | Val MSE 22.5455


Epoch:  30%|██▉       | 59/200 [4:42:25<6:22:58, 162.97s/epoch]

Epoch 058 | LR 0.000339 | Train MSE 0.1961 | Val MSE 0.1749 | Val MAE 2.3078 | Val MSE 17.4866
Validation improved: 0.192861 -> 0.174866


Epoch:  30%|███       | 60/200 [4:45:05<6:18:32, 162.23s/epoch]

Epoch 059 | LR 0.000322 | Train MSE 0.1909 | Val MSE 0.1988 | Val MAE 2.6225 | Val MSE 19.8839


Epoch:  30%|███       | 61/200 [4:47:44<6:13:39, 161.29s/epoch]

Epoch 060 | LR 0.000306 | Train MSE 0.1869 | Val MSE 0.1618 | Val MAE 2.1258 | Val MSE 16.1752
Sample pred first 3 steps: [[ 0.00292613 -0.02734635]
 [ 0.00912778 -0.00694495]
 [ 0.01469546  0.0050005 ]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]
Validation improved: 0.174866 -> 0.161752


Epoch:  31%|███       | 62/200 [4:50:22<6:08:26, 160.19s/epoch]

Epoch 061 | LR 0.000291 | Train MSE 0.1816 | Val MSE 0.1699 | Val MAE 2.2877 | Val MSE 16.9947


Epoch:  32%|███▏      | 63/200 [4:53:01<6:04:44, 159.74s/epoch]

Epoch 062 | LR 0.000276 | Train MSE 0.1795 | Val MSE 0.1620 | Val MAE 2.0857 | Val MSE 16.2048


Epoch:  32%|███▏      | 64/200 [5:03:36<11:25:34, 302.46s/epoch]

Epoch 063 | LR 0.000263 | Train MSE 0.1734 | Val MSE 0.1597 | Val MAE 2.1578 | Val MSE 15.9728
Validation improved: 0.161752 -> 0.159728


Epoch:  32%|███▎      | 65/200 [5:06:44<10:03:32, 268.24s/epoch]

Epoch 064 | LR 0.000250 | Train MSE 0.1707 | Val MSE 0.1708 | Val MAE 2.3356 | Val MSE 17.0841


Epoch:  33%|███▎      | 66/200 [5:26:37<20:18:37, 545.65s/epoch]

Epoch 065 | LR 0.000237 | Train MSE 0.1700 | Val MSE 0.1589 | Val MAE 2.1528 | Val MSE 15.8880
Sample pred first 3 steps: [[-1.8321358e-02 -3.7980992e-03]
 [-1.8433679e-02 -1.4505321e-02]
 [-1.0939536e-02  7.0361421e-05]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]
Validation improved: 0.159728 -> 0.158880


Epoch:  34%|███▎      | 67/200 [5:31:11<17:08:37, 464.04s/epoch]

Epoch 066 | LR 0.000225 | Train MSE 0.1661 | Val MSE 0.1665 | Val MAE 2.3098 | Val MSE 16.6482


Epoch:  34%|███▍      | 68/200 [6:00:15<31:05:40, 848.03s/epoch]

Epoch 067 | LR 0.000214 | Train MSE 0.1657 | Val MSE 0.1518 | Val MAE 2.0835 | Val MSE 15.1779
Validation improved: 0.158880 -> 0.151779


Epoch:  34%|███▍      | 69/200 [6:12:26<29:35:05, 813.02s/epoch]

Epoch 068 | LR 0.000203 | Train MSE 0.1622 | Val MSE 0.1480 | Val MAE 2.0385 | Val MSE 14.8034
Validation improved: 0.151779 -> 0.148034


Epoch:  35%|███▌      | 70/200 [6:30:20<32:11:08, 891.29s/epoch]

Epoch 069 | LR 0.000193 | Train MSE 0.1606 | Val MSE 0.1509 | Val MAE 2.0511 | Val MSE 15.0867


Epoch:  36%|███▌      | 71/200 [6:38:26<27:34:50, 769.69s/epoch]

Epoch 070 | LR 0.000183 | Train MSE 0.1614 | Val MSE 0.1452 | Val MAE 1.9645 | Val MSE 14.5189
Sample pred first 3 steps: [[-0.0114276  -0.03735918]
 [-0.01205    -0.01979084]
 [-0.01156558 -0.00728052]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]
Validation improved: 0.148034 -> 0.145189


Epoch:  36%|███▌      | 72/200 [6:45:47<23:51:25, 670.98s/epoch]

Epoch 071 | LR 0.000174 | Train MSE 0.1571 | Val MSE 0.1466 | Val MAE 2.0023 | Val MSE 14.6573


Epoch:  36%|███▋      | 73/200 [7:05:21<28:59:56, 822.02s/epoch]

Epoch 072 | LR 0.000166 | Train MSE 0.1569 | Val MSE 0.1529 | Val MAE 2.0509 | Val MSE 15.2853


Epoch:  37%|███▋      | 74/200 [7:08:37<22:11:48, 634.20s/epoch]

Epoch 073 | LR 0.000157 | Train MSE 0.1561 | Val MSE 0.1547 | Val MAE 2.0680 | Val MSE 15.4748


Epoch:  38%|███▊      | 75/200 [7:26:30<26:35:12, 765.70s/epoch]

Epoch 074 | LR 0.000149 | Train MSE 0.1528 | Val MSE 0.1440 | Val MAE 1.9688 | Val MSE 14.4036
Validation improved: 0.145189 -> 0.144036


Epoch:  38%|███▊      | 76/200 [7:49:06<32:28:33, 942.85s/epoch]

Epoch 075 | LR 0.000142 | Train MSE 0.1530 | Val MSE 0.1549 | Val MAE 2.2542 | Val MSE 15.4903
Sample pred first 3 steps: [[ 0.00134407 -0.02598702]
 [-0.00053961 -0.00824437]
 [ 0.01217157 -0.00350608]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  38%|███▊      | 77/200 [8:07:22<33:47:15, 988.91s/epoch]

Epoch 076 | LR 0.000135 | Train MSE 0.1513 | Val MSE 0.1464 | Val MAE 1.9946 | Val MSE 14.6417


Epoch:  39%|███▉      | 78/200 [8:16:16<28:52:55, 852.26s/epoch]

Epoch 077 | LR 0.000128 | Train MSE 0.1510 | Val MSE 0.1392 | Val MAE 1.9529 | Val MSE 13.9191
Validation improved: 0.144036 -> 0.139191


Epoch:  40%|███▉      | 79/200 [8:33:54<30:43:05, 913.93s/epoch]

Epoch 078 | LR 0.000122 | Train MSE 0.1512 | Val MSE 0.1461 | Val MAE 2.0121 | Val MSE 14.6064


Epoch:  40%|████      | 80/200 [8:36:55<23:08:30, 694.25s/epoch]

Epoch 079 | LR 0.000116 | Train MSE 0.1475 | Val MSE 0.1382 | Val MAE 1.8950 | Val MSE 13.8154
Validation improved: 0.139191 -> 0.138154


Epoch:  40%|████      | 81/200 [8:45:37<21:14:08, 642.43s/epoch]

Epoch 080 | LR 0.000110 | Train MSE 0.1462 | Val MSE 0.1384 | Val MAE 1.9276 | Val MSE 13.8428
Sample pred first 3 steps: [[-0.000191   -0.01640825]
 [-0.00887659 -0.00050079]
 [-0.00569648  0.00509479]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  41%|████      | 82/200 [8:48:19<16:20:15, 498.44s/epoch]

Epoch 081 | LR 0.000104 | Train MSE 0.1469 | Val MSE 0.1382 | Val MAE 1.9252 | Val MSE 13.8186


Epoch:  42%|████▏     | 83/200 [8:51:03<12:56:01, 397.96s/epoch]

Epoch 082 | LR 0.000099 | Train MSE 0.1444 | Val MSE 0.1464 | Val MAE 2.0501 | Val MSE 14.6386


Epoch:  42%|████▏     | 84/200 [8:53:44<10:32:00, 326.90s/epoch]

Epoch 083 | LR 0.000094 | Train MSE 0.1463 | Val MSE 0.1407 | Val MAE 2.0509 | Val MSE 14.0669


Epoch:  42%|████▎     | 85/200 [8:56:25<8:51:04, 277.08s/epoch] 

Epoch 084 | LR 0.000089 | Train MSE 0.1461 | Val MSE 0.1414 | Val MAE 1.9968 | Val MSE 14.1411


Epoch:  43%|████▎     | 86/200 [8:59:06<7:40:35, 242.41s/epoch]

Epoch 085 | LR 0.000085 | Train MSE 0.1436 | Val MSE 0.1363 | Val MAE 1.9107 | Val MSE 13.6334
Sample pred first 3 steps: [[-0.00551261 -0.0303377 ]
 [-0.00400737 -0.00249134]
 [-0.00171119 -0.03129618]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]
Validation improved: 0.138154 -> 0.136334


Epoch:  44%|████▎     | 87/200 [9:01:48<6:50:49, 218.14s/epoch]

Epoch 086 | LR 0.000081 | Train MSE 0.1419 | Val MSE 0.1389 | Val MAE 1.9137 | Val MSE 13.8884


Epoch:  44%|████▍     | 88/200 [9:04:34<6:18:04, 202.54s/epoch]

Epoch 087 | LR 0.000077 | Train MSE 0.1422 | Val MSE 0.1436 | Val MAE 1.9814 | Val MSE 14.3603


Epoch:  44%|████▍     | 89/200 [9:07:18<5:53:11, 190.91s/epoch]

Epoch 088 | LR 0.000073 | Train MSE 0.1416 | Val MSE 0.1363 | Val MAE 1.8927 | Val MSE 13.6286
Validation improved: 0.136334 -> 0.136286


Epoch:  45%|████▌     | 90/200 [9:10:02<5:35:23, 182.94s/epoch]

Epoch 089 | LR 0.000069 | Train MSE 0.1393 | Val MSE 0.1373 | Val MAE 1.9052 | Val MSE 13.7314


Epoch:  46%|████▌     | 91/200 [9:12:47<5:22:38, 177.60s/epoch]

Epoch 090 | LR 0.000066 | Train MSE 0.1390 | Val MSE 0.1322 | Val MAE 1.8503 | Val MSE 13.2155
Sample pred first 3 steps: [[ 0.00144635 -0.00949545]
 [ 0.00269883  0.00770124]
 [ 0.01841464  0.0162251 ]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]
Validation improved: 0.136286 -> 0.132155


Epoch:  46%|████▌     | 92/200 [9:15:32<5:12:46, 173.76s/epoch]

Epoch 091 | LR 0.000062 | Train MSE 0.1388 | Val MSE 0.1348 | Val MAE 1.8512 | Val MSE 13.4800


Epoch:  46%|████▋     | 93/200 [9:18:16<5:04:36, 170.81s/epoch]

Epoch 092 | LR 0.000059 | Train MSE 0.1382 | Val MSE 0.1368 | Val MAE 1.9237 | Val MSE 13.6771


Epoch:  47%|████▋     | 94/200 [9:20:59<4:57:43, 168.53s/epoch]

Epoch 093 | LR 0.000056 | Train MSE 0.1389 | Val MSE 0.1345 | Val MAE 1.8735 | Val MSE 13.4490


Epoch:  48%|████▊     | 95/200 [9:23:44<4:53:01, 167.44s/epoch]

Epoch 094 | LR 0.000054 | Train MSE 0.1370 | Val MSE 0.1336 | Val MAE 1.8468 | Val MSE 13.3647


Epoch:  48%|████▊     | 96/200 [9:26:28<4:48:20, 166.35s/epoch]

Epoch 095 | LR 0.000051 | Train MSE 0.1372 | Val MSE 0.1358 | Val MAE 1.8478 | Val MSE 13.5849
Sample pred first 3 steps: [[ 0.01159187 -0.01153686]
 [ 0.00754298  0.00157583]
 [ 0.01684639  0.01112962]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  48%|████▊     | 97/200 [9:29:09<4:43:07, 164.92s/epoch]

Epoch 096 | LR 0.000048 | Train MSE 0.1369 | Val MSE 0.1326 | Val MAE 1.8317 | Val MSE 13.2624


Epoch:  49%|████▉     | 98/200 [9:31:53<4:39:50, 164.61s/epoch]

Epoch 097 | LR 0.000046 | Train MSE 0.1355 | Val MSE 0.1328 | Val MAE 1.8408 | Val MSE 13.2848


Epoch:  50%|████▉     | 99/200 [9:34:37<4:36:46, 164.43s/epoch]

Epoch 098 | LR 0.000044 | Train MSE 0.1372 | Val MSE 0.1323 | Val MAE 1.8706 | Val MSE 13.2301


Epoch:  50%|█████     | 100/200 [9:37:21<4:33:55, 164.36s/epoch]

Epoch 099 | LR 0.000041 | Train MSE 0.1358 | Val MSE 0.1307 | Val MAE 1.8197 | Val MSE 13.0696
Validation improved: 0.132155 -> 0.130696


Epoch:  50%|█████     | 101/200 [9:40:05<4:31:00, 164.25s/epoch]

Epoch 100 | LR 0.000039 | Train MSE 0.1351 | Val MSE 0.1325 | Val MAE 1.8529 | Val MSE 13.2533
Sample pred first 3 steps: [[ 0.00271462 -0.01766692]
 [ 0.00208454 -0.00059029]
 [ 0.0124524   0.01050809]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  51%|█████     | 102/200 [9:42:48<4:27:17, 163.65s/epoch]

Epoch 101 | LR 0.000037 | Train MSE 0.1354 | Val MSE 0.1316 | Val MAE 1.8225 | Val MSE 13.1613


Epoch:  52%|█████▏    | 103/200 [9:45:29<4:23:19, 162.88s/epoch]

Epoch 102 | LR 0.000036 | Train MSE 0.1353 | Val MSE 0.1319 | Val MAE 1.8518 | Val MSE 13.1856


Epoch:  52%|█████▏    | 104/200 [9:48:13<4:21:11, 163.25s/epoch]

Epoch 103 | LR 0.000034 | Train MSE 0.1355 | Val MSE 0.1345 | Val MAE 1.8889 | Val MSE 13.4498


Epoch:  52%|█████▎    | 105/200 [9:50:56<4:18:30, 163.27s/epoch]

Epoch 104 | LR 0.000032 | Train MSE 0.1349 | Val MSE 0.1308 | Val MAE 1.8172 | Val MSE 13.0754


Epoch:  53%|█████▎    | 106/200 [9:53:40<4:16:03, 163.44s/epoch]

Epoch 105 | LR 0.000030 | Train MSE 0.1342 | Val MSE 0.1335 | Val MAE 1.8281 | Val MSE 13.3495
Sample pred first 3 steps: [[-3.2511968e-03 -1.9158486e-02]
 [ 2.4417546e-03 -5.3245574e-05]
 [ 1.1650266e-02  1.0915946e-02]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  54%|█████▎    | 107/200 [9:56:23<4:13:01, 163.24s/epoch]

Epoch 106 | LR 0.000029 | Train MSE 0.1335 | Val MSE 0.1314 | Val MAE 1.8068 | Val MSE 13.1392


Epoch:  54%|█████▍    | 108/200 [9:59:06<4:10:27, 163.34s/epoch]

Epoch 107 | LR 0.000027 | Train MSE 0.1329 | Val MSE 0.1314 | Val MAE 1.8749 | Val MSE 13.1361


Epoch:  55%|█████▍    | 109/200 [10:01:51<4:08:13, 163.66s/epoch]

Epoch 108 | LR 0.000026 | Train MSE 0.1345 | Val MSE 0.1341 | Val MAE 1.8927 | Val MSE 13.4073


Epoch:  55%|█████▌    | 110/200 [10:04:35<4:05:33, 163.71s/epoch]

Epoch 109 | LR 0.000025 | Train MSE 0.1339 | Val MSE 0.1319 | Val MAE 1.8433 | Val MSE 13.1910


Epoch:  56%|█████▌    | 111/200 [10:07:19<4:03:04, 163.88s/epoch]

Epoch 110 | LR 0.000024 | Train MSE 0.1340 | Val MSE 0.1322 | Val MAE 1.8269 | Val MSE 13.2205
Sample pred first 3 steps: [[-0.00420317 -0.02133872]
 [ 0.00238288  0.00031481]
 [ 0.01124811  0.0121459 ]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  56%|█████▌    | 112/200 [10:10:03<4:00:14, 163.80s/epoch]

Epoch 111 | LR 0.000022 | Train MSE 0.1319 | Val MSE 0.1311 | Val MAE 1.8094 | Val MSE 13.1114


Epoch:  56%|█████▋    | 113/200 [10:12:46<3:57:17, 163.65s/epoch]

Epoch 112 | LR 0.000021 | Train MSE 0.1345 | Val MSE 0.1312 | Val MAE 1.8765 | Val MSE 13.1189


Epoch:  57%|█████▋    | 114/200 [10:15:33<3:56:12, 164.79s/epoch]

Epoch 113 | LR 0.000020 | Train MSE 0.1324 | Val MSE 0.1292 | Val MAE 1.8051 | Val MSE 12.9223
Validation improved: 0.130696 -> 0.129223


Epoch:  57%|█████▊    | 115/200 [10:18:25<3:56:32, 166.97s/epoch]

Epoch 114 | LR 0.000019 | Train MSE 0.1314 | Val MSE 0.1302 | Val MAE 1.8017 | Val MSE 13.0238


Epoch:  58%|█████▊    | 116/200 [10:21:24<3:58:50, 170.61s/epoch]

Epoch 115 | LR 0.000018 | Train MSE 0.1314 | Val MSE 0.1295 | Val MAE 1.7852 | Val MSE 12.9475
Sample pred first 3 steps: [[-0.00428537 -0.02397801]
 [ 0.00138649 -0.00174612]
 [ 0.01128761  0.00723887]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  58%|█████▊    | 117/200 [10:44:35<12:22:27, 536.72s/epoch]

Epoch 116 | LR 0.000017 | Train MSE 0.1313 | Val MSE 0.1303 | Val MAE 1.8205 | Val MSE 13.0255


Epoch:  59%|█████▉    | 118/200 [11:04:14<16:36:46, 729.35s/epoch]

Epoch 117 | LR 0.000016 | Train MSE 0.1307 | Val MSE 0.1308 | Val MAE 1.8295 | Val MSE 13.0785


Epoch:  60%|█████▉    | 119/200 [11:22:44<18:58:37, 843.43s/epoch]

Epoch 118 | LR 0.000016 | Train MSE 0.1307 | Val MSE 0.1305 | Val MAE 1.8149 | Val MSE 13.0489


Epoch:  60%|██████    | 120/200 [11:41:23<20:35:03, 926.29s/epoch]

Epoch 119 | LR 0.000015 | Train MSE 0.1319 | Val MSE 0.1311 | Val MAE 1.8441 | Val MSE 13.1147


Epoch:  60%|██████    | 121/200 [11:59:15<21:16:57, 969.84s/epoch]

Epoch 120 | LR 0.000014 | Train MSE 0.1308 | Val MSE 0.1296 | Val MAE 1.7859 | Val MSE 12.9595
Sample pred first 3 steps: [[-8.2850456e-05 -1.9394137e-02]
 [ 3.8771313e-03  1.8518679e-03]
 [ 1.6447516e-02  1.1503212e-02]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  61%|██████    | 122/200 [12:25:59<25:08:05, 1160.08s/epoch]

Epoch 121 | LR 0.000013 | Train MSE 0.1316 | Val MSE 0.1300 | Val MAE 1.8298 | Val MSE 12.9952


Epoch:  62%|██████▏   | 123/200 [12:44:33<24:30:57, 1146.20s/epoch]

Epoch 122 | LR 0.000013 | Train MSE 0.1308 | Val MSE 0.1305 | Val MAE 1.8119 | Val MSE 13.0546


Epoch:  62%|██████▏   | 124/200 [12:47:12<17:56:55, 850.21s/epoch] 

Epoch 123 | LR 0.000012 | Train MSE 0.1315 | Val MSE 0.1298 | Val MAE 1.8182 | Val MSE 12.9822


Epoch:  62%|██████▎   | 125/200 [13:06:10<19:30:32, 936.43s/epoch]

Epoch 124 | LR 0.000011 | Train MSE 0.1313 | Val MSE 0.1306 | Val MAE 1.8092 | Val MSE 13.0565


Epoch:  63%|██████▎   | 126/200 [13:26:54<21:08:35, 1028.59s/epoch]

Epoch 125 | LR 0.000011 | Train MSE 0.1315 | Val MSE 0.1292 | Val MAE 1.7895 | Val MSE 12.9179
Sample pred first 3 steps: [[-0.00027575 -0.02447014]
 [ 0.00380212 -0.00183129]
 [ 0.01764185  0.00737524]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]
Validation improved: 0.129223 -> 0.129179


Epoch:  64%|██████▎   | 127/200 [13:44:50<21:08:57, 1042.98s/epoch]

Epoch 126 | LR 0.000010 | Train MSE 0.1300 | Val MSE 0.1290 | Val MAE 1.7915 | Val MSE 12.8954
Validation improved: 0.129179 -> 0.128954


Epoch:  64%|██████▍   | 128/200 [14:03:57<21:29:02, 1074.20s/epoch]

Epoch 127 | LR 0.000010 | Train MSE 0.1311 | Val MSE 0.1290 | Val MAE 1.7912 | Val MSE 12.8983


Epoch:  64%|██████▍   | 129/200 [14:23:12<21:39:49, 1098.45s/epoch]

Epoch 128 | LR 0.000009 | Train MSE 0.1315 | Val MSE 0.1297 | Val MAE 1.8007 | Val MSE 12.9679


Epoch:  65%|██████▌   | 130/200 [14:27:53<16:35:20, 853.15s/epoch] 

Epoch 129 | LR 0.000009 | Train MSE 0.1297 | Val MSE 0.1298 | Val MAE 1.7993 | Val MSE 12.9773


Epoch:  66%|██████▌   | 131/200 [14:47:48<18:19:08, 955.78s/epoch]

Epoch 130 | LR 0.000008 | Train MSE 0.1321 | Val MSE 0.1287 | Val MAE 1.7845 | Val MSE 12.8721
Sample pred first 3 steps: [[-0.00112596 -0.02294551]
 [ 0.00372867  0.00046162]
 [ 0.01808438  0.01116892]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]
Validation improved: 0.128954 -> 0.128721


Epoch:  66%|██████▌   | 132/200 [15:06:41<19:03:34, 1009.04s/epoch]

Epoch 131 | LR 0.000008 | Train MSE 0.1308 | Val MSE 0.1286 | Val MAE 1.7953 | Val MSE 12.8561
Validation improved: 0.128721 -> 0.128561


Epoch:  66%|██████▋   | 133/200 [15:25:33<19:27:44, 1045.73s/epoch]

Epoch 132 | LR 0.000008 | Train MSE 0.1288 | Val MSE 0.1301 | Val MAE 1.8248 | Val MSE 13.0148


Epoch:  67%|██████▋   | 134/200 [15:47:25<20:38:19, 1125.76s/epoch]

Epoch 133 | LR 0.000007 | Train MSE 0.1290 | Val MSE 0.1294 | Val MAE 1.8099 | Val MSE 12.9365


Epoch:  68%|██████▊   | 135/200 [16:06:23<20:23:26, 1129.33s/epoch]

Epoch 134 | LR 0.000007 | Train MSE 0.1288 | Val MSE 0.1284 | Val MAE 1.7967 | Val MSE 12.8366
Validation improved: 0.128561 -> 0.128366


Epoch:  68%|██████▊   | 136/200 [16:09:02<14:54:07, 838.24s/epoch] 

Epoch 135 | LR 0.000007 | Train MSE 0.1302 | Val MSE 0.1288 | Val MAE 1.7951 | Val MSE 12.8825
Sample pred first 3 steps: [[-0.00140648 -0.02129337]
 [ 0.00225315 -0.0012337 ]
 [ 0.01402512  0.00857172]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  68%|██████▊   | 137/200 [16:22:28<14:30:02, 828.62s/epoch]

Epoch 136 | LR 0.000006 | Train MSE 0.1288 | Val MSE 0.1290 | Val MAE 1.7985 | Val MSE 12.8984


Epoch:  69%|██████▉   | 138/200 [16:25:10<10:49:27, 628.51s/epoch]

Epoch 137 | LR 0.000006 | Train MSE 0.1298 | Val MSE 0.1293 | Val MAE 1.7982 | Val MSE 12.9284


Epoch:  70%|██████▉   | 139/200 [16:27:52<8:16:44, 488.59s/epoch] 

Epoch 138 | LR 0.000006 | Train MSE 0.1302 | Val MSE 0.1289 | Val MAE 1.7847 | Val MSE 12.8911


Epoch:  70%|███████   | 140/200 [16:30:33<6:30:20, 390.35s/epoch]

Epoch 139 | LR 0.000005 | Train MSE 0.1294 | Val MSE 0.1286 | Val MAE 1.7899 | Val MSE 12.8578


Epoch:  70%|███████   | 141/200 [16:34:55<5:46:04, 351.94s/epoch]

Epoch 140 | LR 0.000005 | Train MSE 0.1301 | Val MSE 0.1292 | Val MAE 1.7961 | Val MSE 12.9240
Sample pred first 3 steps: [[-0.00040821 -0.02154307]
 [ 0.00141034  0.0002645 ]
 [ 0.01413239  0.01143052]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  71%|███████   | 142/200 [16:37:45<4:47:21, 297.28s/epoch]

Epoch 141 | LR 0.000005 | Train MSE 0.1294 | Val MSE 0.1293 | Val MAE 1.7959 | Val MSE 12.9335


Epoch:  72%|███████▏  | 143/200 [16:40:27<4:03:46, 256.61s/epoch]

Epoch 142 | LR 0.000005 | Train MSE 0.1301 | Val MSE 0.1285 | Val MAE 1.7925 | Val MSE 12.8508


Epoch:  72%|███████▏  | 144/200 [16:43:09<3:32:59, 228.20s/epoch]

Epoch 143 | LR 0.000004 | Train MSE 0.1293 | Val MSE 0.1285 | Val MAE 1.7990 | Val MSE 12.8540


Epoch:  72%|███████▎  | 145/200 [16:45:50<3:10:53, 208.25s/epoch]

Epoch 144 | LR 0.000004 | Train MSE 0.1301 | Val MSE 0.1301 | Val MAE 1.8031 | Val MSE 13.0104


Epoch:  73%|███████▎  | 146/200 [16:48:32<2:54:56, 194.38s/epoch]

Epoch 145 | LR 0.000004 | Train MSE 0.1311 | Val MSE 0.1288 | Val MAE 1.7886 | Val MSE 12.8830
Sample pred first 3 steps: [[-0.00144453 -0.02129232]
 [ 0.00281385 -0.0009856 ]
 [ 0.01508722  0.01012544]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  74%|███████▎  | 147/200 [16:51:14<2:43:03, 184.60s/epoch]

Epoch 146 | LR 0.000004 | Train MSE 0.1302 | Val MSE 0.1289 | Val MAE 1.7888 | Val MSE 12.8922


Epoch:  74%|███████▍  | 148/200 [16:53:55<2:33:55, 177.60s/epoch]

Epoch 147 | LR 0.000004 | Train MSE 0.1291 | Val MSE 0.1284 | Val MAE 1.7780 | Val MSE 12.8422


Epoch:  74%|███████▍  | 149/200 [16:56:35<2:26:28, 172.31s/epoch]

Epoch 148 | LR 0.000003 | Train MSE 0.1286 | Val MSE 0.1286 | Val MAE 1.7770 | Val MSE 12.8604


Epoch:  75%|███████▌  | 150/200 [16:59:16<2:20:34, 168.70s/epoch]

Epoch 149 | LR 0.000003 | Train MSE 0.1295 | Val MSE 0.1289 | Val MAE 1.7836 | Val MSE 12.8930


Epoch:  76%|███████▌  | 151/200 [17:01:55<2:15:30, 165.92s/epoch]

Epoch 150 | LR 0.000003 | Train MSE 0.1294 | Val MSE 0.1285 | Val MAE 1.7781 | Val MSE 12.8455
Sample pred first 3 steps: [[-0.0004233  -0.0185684 ]
 [ 0.00309229  0.00094085]
 [ 0.01553952  0.01244266]]
Sample target first 3 steps: [[2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]
 [2.3088824e-06 4.8631200e-06]]


Epoch:  76%|███████▌  | 152/200 [17:17:59<5:24:08, 405.18s/epoch]

Epoch 151 | LR 0.000003 | Train MSE 0.1287 | Val MSE 0.1287 | Val MAE 1.7837 | Val MSE 12.8673
