# All Model saves here

## import

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader, random_split
import DeepMIMOv3
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt
import time


plt . rcParams [ 'figure.figsize' ]  =  [ 12 ,  8 ]  # 기본 플롯 크기 설정

## GPU Settings

In [2]:
# GPU 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
import torch
print(torch.version.cuda)                   
print(torch.backends.cudnn.version())       
print("CUDA available:", torch.cuda.is_available())  # True

12.6
90501
CUDA available: True


## DeepMIMOv3 dataset

In [4]:
parameters = DeepMIMOv3.default_params()

In [5]:
## Change parameters for the setup
# Scenario O1_60 extracted at the dataset_folder
#LWM dynamic senario
# parameters['dataset_folder'] = r'/content/drive/MyDrive/Colab Notebooks/LWM'
scene = 15 # scene 15
# change my linux route
parameters['dataset_folder'] = '/home/dlghdbs200/LWM'

# scnario = 02_dyn_3p5 <- download file
parameters['scenario'] = 'O2_dyn_3p5'
parameters['dynamic_scenario_scenes'] = np.arange(scene) #scene 0~9

# Up to 10 multipath paths per user-to-base station channel
parameters['num_paths'] = 10

# User rows 1-100
parameters['user_rows'] = np.arange(100)
# User subsampling
parameters['user_subsampling'] = 0.01

# Activate only the first basestation
parameters['active_BS'] = np.array([1])

parameters['activate_OFDM'] = 1

parameters['OFDM']['bandwidth'] = 0.05 # 50 MHz
parameters['OFDM']['subcarriers'] = 512 # OFDM with 512 subcarriers
parameters['OFDM']['selected_subcarriers'] = np.arange(0, 64, 1)
#parameters['OFDM']['subcarriers_limit'] = 64 # Keep only first 64 subcarriers

parameters['ue_antenna']['shape'] = np.array([1, 1]) # Single antenna
parameters['bs_antenna']['shape'] = np.array([1, 32]) # ULA of 32 elements
#parameters['bs_antenna']['rotation'] = np.array([0, 30, 90]) # ULA of 32 elements
#parameters['ue_antenna']['rotation'] = np.array([[0, 30], [30, 60], [60, 90]]) # ULA of 32 elements
#parameters['ue_antenna']['radiation_pattern'] = 'isotropic'
#parameters['bs_antenna']['radiation_pattern'] = 'halfwave-dipole'

In [6]:
## dataset setting (chunked on‑the‑fly generation)
import time, gc
from tqdm import tqdm

# 0~999 scene index , process 50 at that time
scene_indices = np.arange(scene)
chunk_size   = 5
all_data     = []

# Call generate_data for each scene chunk
for i in tqdm(range(0, len(scene_indices), chunk_size)):
    chunk = scene_indices[i : i+chunk_size].tolist()
    parameters['dynamic_scenario_scenes'] = chunk

    start = time.time()
    data_chunk = DeepMIMOv3.generate_data(parameters)
    print(f"Scenes {chunk[0]}–{chunk[-1]} generation time: {time.time() - start:.2f}s")

    # combine all_data or save in the Disk
    all_data.extend(data_chunk)

    # free memory 
    del data_chunk
    gc.collect()

# comvine Dataset
dataset = all_data


print(parameters['user_rows'])

  0%|                                                                                             | 0/3 [00:00<?, ?it/s]

The following parameters seem unnecessary:
{'activate_OFDM'}

Scene 1/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  44%|███████████████████████▏                            | 30697/69006 [00:00<00:00, 306945.41it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 313376.96it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 8015.21it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4634.59it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 739.48it/s][A



Scene 2/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  47%|████████████████████████▎                           | 32287/69006 [00:00<00:00, 322846.45it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 326412.38it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 7942.68it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 5899.16it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 460.91it/s][A



Scene 3/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 348279.96it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 7695.45it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7796.10it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 375.16it/s][A



Scene 4/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  50%|█████████████████████████▉                          | 34413/69006 [00:00<00:00, 344117.20it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 331206.18it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 7342.02it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 5468.45it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1018.03it/s][A



Scene 5/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  49%|█████████████████████████▎                          | 33563/69006 [00:00<00:00, 335601.52it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 326403.54it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 7649.90it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 3266.59it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 578.60it/s][A
 33%|████████████████████████████▎                                                        | 1/3 [00:07<00:14,  7.10s/it]

Scenes 0–4 generation time: 6.93s
The following parameters seem unnecessary:
{'activate_OFDM'}

Scene 1/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  50%|█████████████████████████▊                          | 34186/69006 [00:00<00:00, 341839.14it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 337904.94it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 8456.82it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7294.44it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 953.68it/s][A



Scene 2/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 349790.61it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 8430.45it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 5468.45it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 735.46it/s][A



Scene 3/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 352216.49it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 8299.43it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 6584.46it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1091.98it/s][A



Scene 4/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 361264.09it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 7574.05it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4755.45it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1045.18it/s][A



Scene 5/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  47%|████████████████████████▍                           | 32506/69006 [00:00<00:00, 325040.94it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 311281.62it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 7637.24it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7182.03it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 736.36it/s][A
 67%|████████████████████████████████████████████████████████▋                            | 2/3 [00:13<00:06,  6.90s/it]

Scenes 5–9 generation time: 6.61s
The following parameters seem unnecessary:
{'activate_OFDM'}

Scene 1/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 351877.78it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 8166.33it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7810.62it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 979.29it/s][A



Scene 2/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 353596.90it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 8034.77it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8081.51it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 751.67it/s][A



Scene 3/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 358623.53it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 7496.79it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 6043.67it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1055.97it/s][A



Scene 4/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 353583.07it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 7069.89it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7145.32it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 690.65it/s][A



Scene 5/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 341961.54it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 7645.76it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 6204.59it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 977.92it/s][A
100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:20<00:00,  6.88s/it]

Scenes 10–14 generation time: 6.66s
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
 96 97 98 99]





## About Information
User : 737
UE antenna : 1
BS antenna : 32  Shape(a+bj)
subcarrier : 64

In [7]:
# Unmasked Data Model(gru
# separate maksed data and unmasked data

## Data Preprocessing

In [8]:
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import IterableDataset
import numpy as np
import torch

class UnMaskedChannelSeqDataset(IterableDataset):
    """
    IterableDataset (un-masked version)

    * Task : predict the next-step channel vector from the past `seq_len` steps.
    * Pipeline
        1. **Power-normalise** each complex channel vector → real + imag concat.
        2. **Min–Max scale** inputs and targets with *one* shared scaler
           (fitted on the same power-normalised data).
        3. Yield `(sequence, target)` pairs as `torch.FloatTensor`.
    * Optionally reuse externally provided scalers (train/val split consistency).
    """
    def __init__(
        self,
        scenes,
        seq_len: int = 5,
        eps: float   = 1e-9,
        scalers: tuple[MinMaxScaler, MinMaxScaler] | None = None,
    ):
        super().__init__()
        self.scenes  = scenes
        self.seq_len = seq_len
        self.eps     = eps

        # --- channel tensor dimensions --------------------------------------
        ch0          = scenes[0][0]['user']['channel']        # (U, 1, A, S)
        self.U       = ch0.shape[0]                           # users
        self.A       = ch0.shape[2]                           # BS antennas
        self.S       = ch0.shape[3]                           # sub-carriers
        self.vec_len = 2 * self.A                            # real + imag

        # --------------------------------------------------------------------
        # 1) Fit / reuse MinMax scalers on power-normalised data
        # --------------------------------------------------------------------
        if scalers is None:
            X_list, y_list = [], []
            T = len(scenes)
            for t in range(self.seq_len, T):
                past  = scenes[t - self.seq_len : t]
                s_tgt = scenes[t]

                for u in range(self.U):
                    for s in range(self.S):
                        # power-normalised sequence (seq_len, vec_len)
                        seq_np = np.stack([
                            self._power_norm(p[0]['user']['channel'][u, 0, :, s])
                            for p in past
                        ], axis=0).astype(np.float32)

                        # power-normalised target   (vec_len,)
                        tgt_np = self._power_norm(
                            s_tgt[0]['user']['channel'][u, 0, :, s]
                        ).astype(np.float32)

                        # skip if empty (all zeros)
                        if not np.any(seq_np) or not np.any(tgt_np):
                            continue

                        X_list.append(seq_np.reshape(-1, self.vec_len))
                        y_list.append(tgt_np)

            X_all = np.vstack(X_list)          # (N*seq_len, vec_len)
            y_all = np.stack(y_list)           # (N, vec_len)
            self.scaler_x = MinMaxScaler().fit(X_all)
            self.scaler_y = MinMaxScaler().fit(y_all)
        else:
            # use pre-computed scalers (train/val share the same)
            self.scaler_x, self.scaler_y = scalers

    # ------------------------------------------------------------------------
    # iterator
    # ------------------------------------------------------------------------
    def __iter__(self):
        """
        Yields:
            seq_tensor   : FloatTensor (seq_len, vec_len)
            target_tensor: FloatTensor (vec_len,)
        Both tensors are power-normalised **and** Min–Max scaled.
        """
        T = len(self.scenes)
        for t in range(self.seq_len, T):
            past  = self.scenes[t - self.seq_len : t]
            s_tgt = self.scenes[t]

            for u in range(self.U):
                for s in range(self.S):
                    seq_np = np.stack([
                        self._power_norm(p[0]['user']['channel'][u, 0, :, s])
                        for p in past
                    ], axis=0)
                    tgt_np = self._power_norm(
                        s_tgt[0]['user']['channel'][u, 0, :, s]
                    )

                    if not np.any(seq_np) or not np.any(tgt_np):
                        continue

                    # identical scalers for train / val
                    N, D = seq_np.shape
                    seq_np = self.scaler_x.transform(seq_np.reshape(-1, D)).reshape(N, D)
                    tgt_np = self.scaler_y.transform(tgt_np.reshape(1, -1)).reshape(-1,)

                    yield torch.from_numpy(seq_np), torch.from_numpy(tgt_np)

    # ------------------------------------------------------------------------
    # helpers
    # ------------------------------------------------------------------------
    def _power_norm(self, h: np.ndarray) -> np.ndarray:
        """
        Convert complex vector → real|imag concatenation
        and force average power to 1.
        """
        v      = np.concatenate([h.real, h.imag]).astype(np.float32)
        power  = np.mean(v * v) + self.eps
        return v / np.sqrt(power)

    def __len__(self):
        """Rough size estimate (not used by IterableDataset)."""
        return (len(self.scenes) - self.seq_len) * self.U * self.S


In [9]:
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import IterableDataset
import numpy as np
import torch
import random

class MaskedChannelSeqDataset(IterableDataset):
    """
    IterableDataset for masked channel sequence data.

    - Predicts the next-step channel vector from a sequence of past vectors.
    - Applies power normalization and MinMax scaling to both inputs and targets.
    - Masks 15% of the patches according to:
        * 80% chance: replace selected patch with zeros
        * 10% chance: replace selected patch with Gaussian noise
        * 10% chance: leave the selected patch unchanged
      The other 85% of samples are returned unmasked.
    """
    def __init__(self, scenes, seq_len=5, eps=1e-9, noise_std=1.0):
        super().__init__()
        self.scenes    = scenes
        self.seq_len   = seq_len
        self.eps       = eps
        self.noise_std = noise_std

        # Determine dimensions: U=users, A=antennas, S=subcarriers
        ch0 = scenes[0][0]['user']['channel']   # shape: (U, 1, A, S)
        self.U       = ch0.shape[0]
        self.A       = ch0.shape[2]
        self.S       = ch0.shape[3]
        self.vec_len = 2 * self.A               # real+imag concatenated

        # ----------------------------------------------------------------------
        # 1) PRECOMPUTE power-norm → fit MinMax scalers on exactly the same data
        # ----------------------------------------------------------------------
        X_list, y_list = [], []
        T = len(scenes)
        for t in range(self.seq_len, T):
            past         = scenes[t - self.seq_len : t]
            target_scene = scenes[t]
            for u in range(self.U):
                for s in range(self.S):
                    # power-normalize each time-slice in the sequence
                    seq_np = np.stack([
                        self._power_norm(ps[0]['user']['channel'][u,0,:,s])
                        for ps in past
                    ], axis=0).astype(np.float32)  # (seq_len, vec_len)

                    # power-normalize target
                    tgt_np = self._power_norm(
                        target_scene[0]['user']['channel'][u,0,:,s]
                    ).astype(np.float32)            # (vec_len,)

                    # skip if invalid
                    if not np.any(seq_np) or not np.any(tgt_np):
                        continue

                    # flatten sequence for scaler
                    X_list.append(seq_np.reshape(-1, self.vec_len))
                    y_list.append(tgt_np)

        # fit scalers on the exact same power-normalized data
        X_all = np.vstack(X_list)  # (num_samples*seq_len, vec_len)
        y_all = np.stack(y_list)   # (num_samples, vec_len)
        self.scaler_x = MinMaxScaler().fit(X_all)
        self.scaler_y = MinMaxScaler().fit(y_all)

        # prepare a zero-mask vector
        self.mask_value = torch.zeros(self.vec_len, dtype=torch.float32)

    def __iter__(self):
        """
        Yields:
            seq_tensor      (seq_len, vec_len) : normalized & scaled sequence
            masked_pos      (1,) tensor long   : index of masked time step
            target_tensor   (vec_len,)          : normalized & scaled target
        """
        T = len(self.scenes)
        mask_prob  = 0.15
        zero_prob  = mask_prob * 0.8
        noise_prob = mask_prob * 0.1

        for t in range(self.seq_len, T):
            past       = self.scenes[t - self.seq_len : t]
            scene_dict = self.scenes[t]

            for u in range(self.U):
                for s in range(self.S):
                    # power-normalize each slice
                    seq_np = np.stack([
                        self._power_norm(ps[0]['user']['channel'][u,0,:,s])
                        for ps in past
                    ], axis=0)
                    tgt_np = self._power_norm(
                        scene_dict[0]['user']['channel'][u,0,:,s]
                    )

                    if not np.any(seq_np) or not np.any(tgt_np):
                        continue

                    # MinMax transform (using the same scalers)
                    N, D = seq_np.shape
                    seq_np = self.scaler_x.transform(seq_np.reshape(-1, D)).reshape(N, D)
                    tgt_np = self.scaler_y.transform(tgt_np.reshape(1, -1)).reshape(-1,)

                    seq_tensor   = torch.from_numpy(seq_np)
                    target_tensor= torch.from_numpy(tgt_np)

                    # select mask position
                    mpos = random.randrange(self.seq_len)

                    r = random.random()
                    if r < zero_prob:
                        masked_seq = seq_tensor.clone()
                        masked_seq[mpos] = self.mask_value
                        yield masked_seq, torch.tensor([mpos]), target_tensor

                    elif r < zero_prob + noise_prob:
                        masked_seq = seq_tensor.clone()
                        masked_seq[mpos] = torch.randn(self.vec_len) * self.noise_std
                        yield masked_seq, torch.tensor([mpos]), target_tensor

                    elif r < mask_prob:
                        # masked-but-unchanged
                        yield seq_tensor, torch.tensor([mpos]), target_tensor

                    else:
                        # unmasked
                        yield seq_tensor, torch.tensor([mpos]), target_tensor

    def _power_norm(self, h: np.ndarray) -> np.ndarray:
        """
        Convert complex vector to real+imag concat and normalize power to 1.
        """
        v = np.concatenate([h.real, h.imag]).astype(np.float32)
        power = np.mean(v * v) + self.eps
        return v / np.sqrt(power)

    def __len__(self):
        return (len(self.scenes) - self.seq_len) * self.U * self.S


## Split Train/Val

In [10]:
# ❷ Train/Validation DataLoader split train : val = 6 : 4
seq_len      = 5
split_ratio  = 0.6
split_idx    = int(len(dataset) * split_ratio)

In [11]:
unmasked_train_ds = UnMaskedChannelSeqDataset(dataset[:split_idx], seq_len=seq_len)
unmasked_val_ds   = UnMaskedChannelSeqDataset(dataset[split_idx:], seq_len=seq_len)

# iterate over train_ds to compute min and max of features/targets

batch_size   = 32
unmasked_train_loader = DataLoader(unmasked_train_ds, batch_size=batch_size, shuffle=False)
unmasked_val_loader   = DataLoader(unmasked_val_ds,   batch_size=batch_size, shuffle=False)
# ─────────────────────────────────────────────


In [12]:
# after you (re)create masked_train_loader
batch = next(iter(unmasked_train_loader))
seq,  tgt = batch   # seq: (B, L, D), tgt: (B, D)
print(
    "seq: min", seq.min().item(), "max", seq.max().item(),
    "   tgt: min", tgt.min().item(), "max", tgt.max().item()
)


seq: min 0.33174002170562744 max 0.6701632738113403    tgt: min 0.329348623752594 max 0.666236400604248


In [13]:
# ❷ Train/Validation DataLoader split train : val = 6 : 4

masked_train_ds = MaskedChannelSeqDataset(dataset[:split_idx], seq_len=seq_len)
masked_val_ds   = MaskedChannelSeqDataset(dataset[split_idx:], seq_len=seq_len)

# iterate over train_ds to compute min and max of features/targets

batch_size   = 32
masked_train_loader = DataLoader(masked_train_ds, batch_size=batch_size, shuffle=False)
masked_val_loader   = DataLoader(masked_val_ds,   batch_size=batch_size, shuffle=False)
# ─────────────────────────────────────────────


In [14]:
batch = next(iter(masked_train_loader))
seq, mpos, tgt = batch
print("seq:", seq.min().item(), "~", seq.max().item())
print("tgt:", tgt.min().item(), "~", tgt.max().item())


seq: 0.0 ~ 0.6701632738113403
tgt: 0.329348623752594 ~ 0.666236400604248


## Define Model

LWMWithHead: A wrapper class that uses a pre-trained LWM (Transformer encoder) as the backbone,
             and attaches a new fully-connected (FC) head for downstream tasks
             (regression, classification, etc.).

Changes:
- input_dim: Dimension of the actual input data (e.g., 64)
- patch_length: Patch length expected by the backbone (e.g., 16)
- Replaces the original element_length parameter with these two distinct parameters
- Applies a projection layer (self.input_proj) in forward()


In [15]:
import torch
import torch.nn as nn
from lwm_model import lwm

class LWMWithHead(nn.Module):
    """
    LWMWithHead: A wrapper class that uses a pre-trained LWM (Transformer encoder) as the backbone,
                 and attaches a new fully-connected (FC) head for downstream tasks
                 (regression, classification, etc.).

    Changes:
    - input_dim: Dimension of the actual input data (e.g., 64)
    - patch_length: Patch length expected by the backbone (e.g., 16)
    - Replaces the original element_length parameter with these two distinct parameters
    - Applies a projection layer (self.input_proj) in forward()
    """
    def __init__(
        self,
        input_dim: int,                 # Dimension of the actual input data (e.g., 64)
        patch_length: int,              # Patch length expected by the backbone (e.g., 16)
        d_model: int = 64,              # LWM hidden size
        max_len: int = 129,             # Positional encoding max length
        n_layers: int = 12,             # Number of Transformer encoder layers
        hidden_dim: int = 256,          # FC head hidden dimension
        out_dim: int = 64,              # FC head output dimension
        freeze_backbone: bool = True,   # Whether to freeze the backbone
        checkpoint_path: str | None = "./model_weights.pth",
        device: str = "cuda"
    ):
        super().__init__()

        # apply a projection layer to match backbone's expected patch_length
        self.input_proj = nn.Linear(input_dim, patch_length)

        # initialize backbone
        if checkpoint_path is None:
            # randomly initialized backbone
            self.backbone = lwm(
                element_length=patch_length,
                d_model=d_model,
                max_len=max_len,
                n_layers=n_layers
            ).to(device)
        else:
            # load pre-trained weights
            self.backbone = lwm.from_pretrained(
                ckpt_name=checkpoint_path,
                device=device
            )

        # freeze backbone parameters if required
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # attach a new fully-connected head for downstream tasks
        self.head = nn.Sequential(
            nn.Linear(d_model, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, input_ids: torch.Tensor, masked_pos: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input_ids: Tensor of shape (B, L, input_dim)
            masked_pos: Tensor of shape (B, num_mask)
        Returns:
            out: Tensor of shape (B, out_dim)
        """
        # project inputs to patch_length dimension
        x = self.input_proj(input_ids)

        # backbone forward: returns (logits_lm, enc_output)
        _, enc_output = self.backbone(x, masked_pos)

        # extract CLS token feature (first token)
        feat = enc_output[:, 0, :]

        # pass through FC head to get final output
        out = self.head(feat)
        return out


In [16]:
import torch
import torch.nn as nn

class GRUWithHead(nn.Module):
    """
    GRUWithHead (projected):
      • Projects the raw feature dimension (input_dim) to a smaller patch_length
        so every backbone receives the same patch-sized input (like LWM).
      • Stacks N GRU layers, then an FC head for downstream tasks.
    """
    def __init__(
        self,
        input_dim: int    = 64,   # raw feature dimension coming from the DataLoader
        patch_length: int = 16,   # target dimension fed to the GRU backbone
        d_model: int      = 64,   # GRU hidden size
        n_layers: int     = 12,   # number of stacked GRU layers
        bidirectional: bool = True,
        dropout: float      = 0.1,
        hidden_dim: int     = 256, # FC-head hidden size
        out_dim: int        = 64,  # FC-head output size
        freeze_backbone: bool = False
    ):
        super().__init__()

        # 0) Project raw_dim → patch_length (64 → 16)
        self.input_proj = nn.Linear(input_dim, patch_length)

        # 1) GRU backbone that expects 'patch_length' features per time step
        self.backbone = nn.GRU(
            input_size     = patch_length,
            hidden_size    = d_model,
            num_layers     = n_layers,
            batch_first    = True,
            bidirectional  = bidirectional,
            dropout        = dropout if n_layers > 1 else 0.0
        )

        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # 2) Fully-connected head
        gru_out_dim = d_model * (2 if bidirectional else 1)
        self.head = nn.Sequential(
            nn.Linear(gru_out_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x : Tensor of shape (batch, seq_len, input_dim) – raw features
        Returns:
            Tensor of shape (batch, out_dim)
        """
        # project raw features to patch_length
        x_proj = self.input_proj(x)                 # (B, seq_len, patch_length)

        # sequence modelling with GRU
        out, _ = self.backbone(x_proj)              # (B, seq_len, num_dirs*d_model)

        # use the last time-step representation
        feat = out[:, -1, :]                        # (B, gru_out_dim)

        # downstream head
        return self.head(feat)                      # (B, out_dim)


In [17]:
import math
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        # Create positional encoding matrix of shape (1, max_len, d_model)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, d_model)
        Returns:
            Tensor: x plus positional encodings
        """
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

class InputEmbedding(nn.Module):
    def __init__(self, feat_dim: int, d_model: int, max_len: int = 5000):
        super().__init__()
        # Optional linear projection from feat_dim to d_model
        self.proj = nn.Linear(feat_dim, d_model) if feat_dim != d_model else None
        self.pos_enc = PositionalEncoding(d_model, max_len)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch, seq_len, feat_dim)
        Returns:
            Tensor of shape (batch, seq_len, d_model)
        """
        if self.proj is not None:
            x = self.proj(x)
        return self.pos_enc(x)

class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dim_ff: int, dropout: float = 0.1):
        super().__init__()
        # Multi-Head Self-Attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        # Position-wise Feed-Forward Network
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(dim_ff, d_model)
        )
        # Layer Normalization and Dropout for residual connections
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        src_mask: torch.Tensor = None,
        src_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (seq_len, batch, d_model)
            src_mask: Optional Tensor of shape (seq_len, seq_len)
            src_key_padding_mask: Optional Tensor of shape (batch, seq_len)
        Returns:
            Tensor of shape (seq_len, batch, d_model)
        """
        # Self-attention sublayer
        attn_out, _ = self.self_attn(x, x, x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
        x = x + self.dropout1(attn_out)
        x = self.norm1(x)
        # Feed-forward sublayer
        ff_out = self.ff(x)
        x = x + self.dropout2(ff_out)
        x = self.norm2(x)
        return x

class TransformerEncoderCustom(nn.Module):
    def __init__(
        self,
        feat_dim: int,
        d_model: int,
        n_heads: int,
        dim_ff: int,
        n_layers: int,
        dropout: float = 0.1,
        max_len: int = 5000
    ):
        super().__init__()
        # Input embedding: feature projection + positional encoding
        self.input_embedding = InputEmbedding(feat_dim, d_model, max_len)
        # Stack of N encoder layers
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, dim_ff, dropout)
            for _ in range(n_layers)
        ])

    def forward(
        self,
        x: torch.Tensor,
        src_mask: torch.Tensor = None,
        src_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch, seq_len, feat_dim)
        Returns:
            Tensor of shape (seq_len, batch, d_model)
        """
        x = self.input_embedding(x)       # (batch, seq_len, d_model)
        x = x.transpose(0, 1)             # (seq_len, batch, d_model)
        for layer in self.layers:
            x = layer(x, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dim_ff: int, dropout: float = 0.1):
        super().__init__()
        # Masked Self-Attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        # Encoder-Decoder Attention
        self.multihead_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        # Position-wise Feed-Forward Network
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(dim_ff, d_model)
        )
        # Layer Normalizations and Dropouts
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: torch.Tensor = None,
        memory_mask: torch.Tensor = None,
        tgt_key_padding_mask: torch.Tensor = None,
        memory_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            tgt: Tensor of shape (tgt_len, batch, d_model)
            memory: Tensor of shape (src_len, batch, d_model)
        Returns:
            Tensor of shape (tgt_len, batch, d_model)
        """
        # Masked self-attention sublayer
        attn1, _ = self.self_attn(
            tgt, tgt, tgt,
            attn_mask=tgt_mask,
            key_padding_mask=tgt_key_padding_mask
        )
        tgt = tgt + self.dropout1(attn1)
        tgt = self.norm1(tgt)
        # Encoder-decoder attention sublayer
        attn2, _ = self.multihead_attn(
            tgt, memory, memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask
        )
        tgt = tgt + self.dropout2(attn2)
        tgt = self.norm2(tgt)
        # Feed-forward sublayer
        ff_out = self.ff(tgt)
        tgt = tgt + self.dropout3(ff_out)
        tgt = self.norm3(tgt)
        return tgt

class TransformerDecoderCustom(nn.Module):
    def __init__(
        self,
        feat_dim: int,
        d_model: int,
        n_heads: int,
        dim_ff: int,
        n_layers: int,
        dropout: float = 0.1,
        max_len: int = 5000
    ):
        super().__init__()
        # Input embedding for target sequence
        self.input_embedding = InputEmbedding(feat_dim, d_model, max_len)
        # Stack of N decoder layers
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, dim_ff, dropout)
            for _ in range(n_layers)
        ])
        # Final projection back to feature dimension
        self.output_linear = nn.Linear(d_model, feat_dim)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: torch.Tensor = None,
        memory_mask: torch.Tensor = None,
        tgt_key_padding_mask: torch.Tensor = None,
        memory_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            tgt: Tensor of shape (batch, tgt_len, feat_dim)
            memory: Tensor of shape (src_len, batch, d_model)
        Returns:
            Tensor of shape (batch, tgt_len, feat_dim)
        """
        x = self.input_embedding(tgt)       # (batch, tgt_len, d_model)
        x = x.transpose(0, 1)               # (tgt_len, batch, d_model)
        for layer in self.layers:
            x = layer(
                x,
                memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask
            )
        x = x.transpose(0, 1)               # (batch, tgt_len, d_model)
        return self.output_linear(x)        # project back to feat_dim

        
class TransformerWithHead(nn.Module):
    def __init__(
        self,
        input_dim: int    = 64,   # raw feature dimension
        patch_length: int = 16,   # what the encoder consumes
        d_model: int      = 64,   # hidden size inside transformer
        n_heads: int      = 4,
        dim_ff: int       = 256,
        n_layers: int     = 12,
        dropout: float    = 0.1,
        hidden_dim: int   = 256,
        out_dim: int      = 64,
        max_len: int      = 5000,
        freeze_backbone: bool = False,
    ):
        super().__init__()

        # 0) project raw 64-dim → 16-dim patch
        self.input_proj = nn.Linear(input_dim, patch_length)

        # 1) encoder expects patch_length features
        self.encoder = TransformerEncoderCustom(
            feat_dim = patch_length,
            d_model  = d_model,
            n_heads  = n_heads,
            dim_ff   = dim_ff,
            n_layers = n_layers,
            dropout  = dropout,
            max_len  = max_len,
        )
        if freeze_backbone:
            for p in self.encoder.parameters():
                p.requires_grad = False

        # 2) task head
        self.head = nn.Sequential(
            nn.Linear(d_model, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(
        self,
        x: torch.Tensor,                 # (batch, seq_len, 64)
        src_mask: torch.Tensor = None,
        src_key_padding_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        # raw → patch_length
        x = self.input_proj(x)           # (batch, seq_len, 16)

        # encode (InputEmbedding inside encoder adds 16→64 + positional encoding)
        enc_out = self.encoder(
            x, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask
        )                                # (seq_len, batch, d_model)

        last_token = enc_out[-1]         # (batch, d_model)
        return self.head(last_token)     # (batch, out_dim)



In [18]:
import torch
import torch.nn as nn

class RNNWithHead(nn.Module):
    """
    RNNWithHead (projected):
      • Projects raw feature vectors from `input_dim` to `patch_length`
      • Feeds the projected sequence to an RNN backbone
      • Maps the last hidden state through an FC head
    """
    def __init__(
        self,
        input_dim: int    = 64,   # raw feature dimension coming from DataLoader
        patch_length: int = 16,   # dimension consumed by the RNN backbone
        hidden_size: int  = 64,   # RNN hidden size
        num_layers: int   = 12,   # number of stacked RNN layers
        bidirectional: bool = True,
        dropout: float      = 0.1,
        hidden_dim: int     = 256, # FC-head hidden size
        out_dim: int        = 64,  # FC-head output size
        freeze_backbone: bool = False,
    ):
        super().__init__()

        # 0) project raw 64-dim → 16-dim
        self.input_proj = nn.Linear(input_dim, patch_length)

        # 1) RNN backbone
        self.backbone = nn.RNN(
            input_size     = patch_length,
            hidden_size    = hidden_size,
            num_layers     = num_layers,
            batch_first    = True,
            bidirectional  = bidirectional,
            dropout        = dropout if num_layers > 1 else 0.0,
        )

        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # 2) FC head
        rnn_out_dim = hidden_size * (2 if bidirectional else 1)
        self.head = nn.Sequential(
            nn.Linear(rnn_out_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (batch, seq_len, input_dim=64)
        returns: (batch, out_dim)
        """
        x_proj = self.input_proj(x)           # (batch, seq_len, 16)
        out, _ = self.backbone(x_proj)        # (batch, seq_len, rnn_out_dim)
        feat   = out[:, -1, :]                # take last time step
        return self.head(feat)                # (batch, out_dim)


In [19]:
import torch
import torch.nn as nn

class LSTMWithHead(nn.Module):
    """
    LSTMWithHead (projected):
      • Projects raw feature vectors from `input_dim` to a compact `patch_length`
      • Feeds the projected sequence to an LSTM backbone
      • Uses the last hidden state to drive an FC head for the downstream task
    """
    def __init__(
        self,
        input_dim: int    = 64,   # raw feature dimension (e.g., 64)
        patch_length: int = 16,   # dimension consumed by the LSTM backbone
        hidden_size: int  = 64,   # LSTM hidden size
        num_layers: int   = 12,   # number of stacked LSTM layers
        bidirectional: bool = True,
        dropout: float      = 0.1,
        hidden_dim: int     = 256, # FC-head hidden size
        out_dim: int        = 64,  # FC-head output size
        freeze_backbone: bool = False,
    ):
        super().__init__()

        # 0) Raw 64-dim → 16-dim patch projection
        self.input_proj = nn.Linear(input_dim, patch_length)

        # 1) LSTM backbone that expects `patch_length` features
        self.backbone = nn.LSTM(
            input_size     = patch_length,
            hidden_size    = hidden_size,
            num_layers     = num_layers,
            batch_first    = True,
            bidirectional  = bidirectional,
            dropout        = dropout if num_layers > 1 else 0.0,
        )
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # 2) FC head
        lstm_out_dim = hidden_size * (2 if bidirectional else 1)
        self.head = nn.Sequential(
            nn.Linear(lstm_out_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (batch, seq_len, input_dim=64)
        returns: (batch, out_dim)
        """
        # project raw features to patch_length
        x_proj = self.input_proj(x)             # (B, seq_len, 16)

        # sequence modeling with LSTM
        out, _ = self.backbone(x_proj)          # (B, seq_len, lstm_out_dim)

        # take the last time-step representation
        feat = out[:, -1, :]                    # (B, lstm_out_dim)

        # downstream head
        return self.head(feat)                  # (B, out_dim)


## fine-tuning

In [20]:
from torch.optim import Adam

# ──────────────────────────
# Shared hyper-parameters
# ──────────────────────────
INPUT_DIM     = 64     # raw feature dimension
PATCH_LENGTH  = 16     # dimension fed to every backbone
D_MODEL       = 64     # internal hidden size (GRU/LSTM/Transformer)
N_LAYERS      = 12     # stacked layers
HIDDEN_DIM    = 256    # head hidden dimension
OUT_DIM       = 64     # head output dimension
DROPOUT       = 0.1    # dropout for recurrent / transformer blocks
BIDIRECTIONAL = True   # use bidirectional RNNs
DEVICE        = "cuda"

# ──────────────────────────
# Model class catalog
# ──────────────────────────
MODEL_CATALOG = {
    "LWM_freeze_backbone"     : LWMWithHead,
    "LWM_pretrained_Fine_tune": LWMWithHead,
    "LWM_Fine_tune"           : LWMWithHead,
    "gru"                     : GRUWithHead,
    "RNN"                     : RNNWithHead,
    "LSTM"                    : LSTMWithHead,
    "Transformer"             : TransformerWithHead,
}

# ──────────────────────────
# Per-model constructor kwargs
# ──────────────────────────
MODEL_PARAMS = {
    # ── LWM variants ─────────────────────────────
    "LWM_freeze_backbone": {
        "input_dim"       : INPUT_DIM,
        "patch_length"    : PATCH_LENGTH,
        "d_model"         : D_MODEL,
        "max_len"         : PATCH_LENGTH + 1,
        "n_layers"        : N_LAYERS,
        "hidden_dim"      : HIDDEN_DIM,
        "out_dim"         : OUT_DIM,
        "freeze_backbone" : True,
        "checkpoint_path" : "./model_weights.pth",
        "device"          : DEVICE,
    },
    "LWM_pretrained_Fine_tune": {
        "input_dim"       : INPUT_DIM,
        "patch_length"    : PATCH_LENGTH,
        "d_model"         : D_MODEL,
        "max_len"         : PATCH_LENGTH + 1,
        "n_layers"        : N_LAYERS,
        "hidden_dim"      : HIDDEN_DIM,
        "out_dim"         : OUT_DIM,
        "freeze_backbone" : False,
        "checkpoint_path" : "./model_weights.pth",
        "device"          : DEVICE,
    },
    "LWM_Fine_tune": {
        "input_dim"       : INPUT_DIM,
        "patch_length"    : PATCH_LENGTH,
        "d_model"         : D_MODEL,
        "max_len"         : PATCH_LENGTH + 1,
        "n_layers"        : N_LAYERS,
        "hidden_dim"      : HIDDEN_DIM,
        "out_dim"         : OUT_DIM,
        "freeze_backbone" : False,
        "checkpoint_path" : None,
        "device"          : DEVICE,
    },

    # ── GRU (projected) ──────────────────────────
    "gru": {
        "input_dim"       : INPUT_DIM,     # 64 → project → 16
        "patch_length"    : PATCH_LENGTH,
        "d_model"         : D_MODEL,
        "n_layers"        : N_LAYERS,
        "bidirectional"   : BIDIRECTIONAL,
        "dropout"         : DROPOUT,
        "hidden_dim"      : HIDDEN_DIM,
        "out_dim"         : OUT_DIM,
        "freeze_backbone" : False,
    },

    # ── Vanilla RNN (projected) ──────────────────
    "RNN": {
        "input_dim"       : INPUT_DIM,
        "patch_length"    : PATCH_LENGTH,
        "hidden_size"     : D_MODEL,
        "num_layers"      : N_LAYERS,
        "bidirectional"   : BIDIRECTIONAL,
        "dropout"         : 0.0,
        "hidden_dim"      : HIDDEN_DIM,
        "out_dim"         : OUT_DIM,
        "freeze_backbone" : False,
    },

    # ── LSTM (projected) ─────────────────────────
    "LSTM": {
        "input_dim"       : INPUT_DIM,
        "patch_length"    : PATCH_LENGTH,
        "hidden_size"     : D_MODEL,
        "num_layers"      : N_LAYERS,
        "bidirectional"   : BIDIRECTIONAL,
        "dropout"         : DROPOUT,
        "hidden_dim"      : HIDDEN_DIM,
        "out_dim"         : OUT_DIM,
        "freeze_backbone" : False,
    },

    # ── Transformer (projected) ──────────────────
    "Transformer": {
        "input_dim"       : INPUT_DIM,
        "patch_length"    : PATCH_LENGTH,
        "d_model"         : D_MODEL,
        "n_heads"         : 4,
        "dim_ff"          : 256,
        "n_layers"        : N_LAYERS,
        "dropout"         : DROPOUT,
        "hidden_dim"      : HIDDEN_DIM,
        "out_dim"         : OUT_DIM,
        "max_len"         : PATCH_LENGTH + 1,
        "freeze_backbone" : False,
    },
}


## model evaluate

In [21]:
import torch
import torch.nn.functional as F

def rmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Root-Mean-Squared Error
    """
    return torch.sqrt(F.mse_loss(pred, target, reduction="mean"))   # √MSE

def nmse(pred: torch.Tensor, target: torch.Tensor, eps : float = 1e-12) -> torch.Tensor:
    """
    Normalized MSE  =  E[‖ŷ − y‖²] / E[‖y‖²]
    """
    # (B, …) → (B,)  
    mse_per_sample   = ((pred - target)**2).view(pred.size(0), -1).sum(dim=1)
    power_per_sample = (target**2).view(target.size(0), -1).sum(dim=1) + eps
    return (mse_per_sample / power_per_sample).mean()



In [22]:
def masked_evaluate(model, loader, device="cuda"):
    """
    Validation loop for IterableDataset.
    Returns average RMSE and NMSE over all samples.
    """
    model.eval()
    total_rmse, total_nmse, total_samples = 0.0, 0.0, 0

    with torch.no_grad():
        for input_ids, masked_pos, target in loader:
            # Move to device
            input_ids, masked_pos, target = (
                input_ids.to(device),
                masked_pos.to(device),
                target.to(device),
            )
            # Batch size
            bs = input_ids.size(0)

            # Forward
            pred = model(input_ids, masked_pos)

            # Accumulate batch metrics
            total_rmse    += rmse(pred, target).item() * bs
            total_nmse    += nmse(pred, target).item() * bs
            total_samples += bs

    # Compute averages
    return {
        "RMSE": total_rmse / total_samples,
        "NMSE": total_nmse / total_samples
    }

In [23]:
def unmasked_evaluate(model, loader, device="cuda"):
    """
    Validation loop for IterableDataset.
    Returns average RMSE and NMSE over all samples.
    """
    model.eval()
    total_rmse, total_nmse, total_samples = 0.0, 0.0, 0

    with torch.no_grad():
        for input_ids, target in loader:
            # Move to device
            input_ids,target = (
                input_ids.to(device),
                target.to(device),
            )
            # Batch size
            bs = input_ids.size(0)

            # Forward
            pred = model(input_ids)

            # Accumulate batch metrics
            total_rmse    += rmse(pred, target).item() * bs
            total_nmse    += nmse(pred, target).item() * bs
            total_samples += bs

    # Compute averages
    return {
        "RMSE": total_rmse / total_samples,
        "NMSE": total_nmse / total_samples
    }

# Model Training

In [24]:
"""
Unified training / validation script
------------------------------------
* Handles multiple models listed in MODEL_CATALOG
* Uses separate masked / un-masked DataLoaders depending on model type
* Measures epoch-level train time and prints per-epoch + average timings
"""

from tqdm import tqdm
import time, math, torch
import torch.nn as nn

# ─────────────────────────────────────────────
# 0) Globals and hyper-parameters
# ─────────────────────────────────────────────
device            = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion         = nn.MSELoss().to(device)

NUM_EPOCHS        = 10
LR                = 1e-4                      # learning-rate
total_start_time  = time.time()               # wall-clock timer for all models
results           = {}                        # holds best-epoch NMSE_dB for each model

# ─────────────────────────────────────────────
# 1) Model training / validation loop
# ─────────────────────────────────────────────
for model_name, ModelCls in MODEL_CATALOG.items():

    print(f"\n=== Training {model_name} ===")

    # 1-A) instantiate model
    model_args = MODEL_PARAMS[model_name]
    model      = ModelCls(**model_args).to(device)

    # 1-B) collect trainable parameters
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    if len(trainable_params) == 0:
        print(f"⚠️  '{model_name}' has no trainable parameters — skipping training.")
        results[model_name] = float("nan")
        continue

    optimizer   = torch.optim.Adam(trainable_params, lr=LR)
    epoch_times = []                       # store per-epoch train times

    # choose loaders depending on model family
    uses_mask   = model_name.startswith("LWM_")
    tr_loader   = masked_train_loader   if uses_mask else unmasked_train_loader
    val_loader  = masked_val_loader    if uses_mask else unmasked_val_loader
    eval_fn     = masked_evaluate      if uses_mask else unmasked_evaluate

    # 1-C) epoch loop
    for epoch in range(1, NUM_EPOCHS + 1):

        # ── TRAIN ───────────────────────────────────────
        t0 = time.time()
        model.train()
        running_loss = 0.0

        pbar = tqdm(tr_loader,
                    desc=f"[{model_name} {epoch:02d}/{NUM_EPOCHS}] train",
                    leave=False)

        for b, batch in enumerate(pbar, 1):
            # unpack batch depending on masked / unmasked loader
            if uses_mask:
                inp, mpos, tgt = [x.to(device) for x in batch]
                pred = model(inp, mpos).squeeze(-1)
            else:
                inp, tgt      = [x.to(device) for x in batch]
                pred          = model(inp)

            loss = criterion(pred, tgt)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if b % 100 == 0:
                pbar.set_postfix(train_loss=running_loss / b)

        epoch_train_time = time.time() - t0
        epoch_times.append(epoch_train_time)
        avg_train_loss = running_loss / b

        # ── VALID ───────────────────────────────────────
        metrics      = eval_fn(model, val_loader, device)
        val_rmse     = metrics["RMSE"]
        val_nmse     = metrics["NMSE"]
        val_nmse_db  = 10 * torch.log10(torch.tensor(val_nmse)).item()

        print(
            f"[{epoch:02d}/{NUM_EPOCHS}] "
            f"TrainLoss: {avg_train_loss:.4f}  "
            f"Val RMSE: {val_rmse:.4f}  "
            f"Val NMSE: {val_nmse:.4e}  "
            f"Val NMSE_dB: {val_nmse_db:.1f} dB  "
            f"TrainTime: {epoch_train_time:.2f}s"
        )

    # 1-D) epoch-time statistics and final log
    avg_epoch_time = sum(epoch_times) / len(epoch_times)
    print(f"🕒 {model_name} - average train time / epoch: {avg_epoch_time:.2f}s")
    results[model_name] = val_nmse_db

# ─────────────────────────────────────────────
# 2) summary
# ─────────────────────────────────────────────
print("\n=== Summary of NMSE(dB) by model ===")
for name, nmse_db in results.items():
    print(f"{name:25s}: {nmse_db if not math.isnan(nmse_db) else 'skipped':>6}")

print(f"\nTotal training time for all models: {time.time() - total_start_time:.2f}s")



=== Training LWM_freeze_backbone ===
Model loaded successfully from ./model_weights.pth to cuda


                                                                                                                        

[01/10] TrainLoss: 0.0157  Val RMSE: 0.0937  Val NMSE: 4.6457e-02  Val NMSE_dB: -13.3 dB  TrainTime: 206.49s


                                                                                                                        

[02/10] TrainLoss: 0.0100  Val RMSE: 0.0926  Val NMSE: 4.2636e-02  Val NMSE_dB: -13.7 dB  TrainTime: 178.40s


                                                                                                                        

[03/10] TrainLoss: 0.0086  Val RMSE: 0.0907  Val NMSE: 3.8790e-02  Val NMSE_dB: -14.1 dB  TrainTime: 179.94s


                                                                                                                        

[04/10] TrainLoss: 0.0076  Val RMSE: 0.0888  Val NMSE: 3.6792e-02  Val NMSE_dB: -14.3 dB  TrainTime: 202.28s


                                                                                                                        

[05/10] TrainLoss: 0.0072  Val RMSE: 0.0875  Val NMSE: 3.5754e-02  Val NMSE_dB: -14.5 dB  TrainTime: 203.23s


                                                                                                                        

[06/10] TrainLoss: 0.0070  Val RMSE: 0.0870  Val NMSE: 3.5194e-02  Val NMSE_dB: -14.5 dB  TrainTime: 208.71s


                                                                                                                        

[07/10] TrainLoss: 0.0068  Val RMSE: 0.0858  Val NMSE: 3.4261e-02  Val NMSE_dB: -14.7 dB  TrainTime: 211.54s


                                                                                                                        

[08/10] TrainLoss: 0.0066  Val RMSE: 0.0849  Val NMSE: 3.3518e-02  Val NMSE_dB: -14.7 dB  TrainTime: 205.50s


                                                                                                                        

[09/10] TrainLoss: 0.0064  Val RMSE: 0.0841  Val NMSE: 3.2760e-02  Val NMSE_dB: -14.8 dB  TrainTime: 199.86s


                                                                                                                        

[10/10] TrainLoss: 0.0062  Val RMSE: 0.0828  Val NMSE: 3.1742e-02  Val NMSE_dB: -15.0 dB  TrainTime: 219.54s
🕒 LWM_freeze_backbone - average train time / epoch: 201.55s

=== Training LWM_pretrained_Fine_tune ===
Model loaded successfully from ./model_weights.pth to cuda


                                                                                                                        

[01/10] TrainLoss: 0.0115  Val RMSE: 0.0855  Val NMSE: 3.6221e-02  Val NMSE_dB: -14.4 dB  TrainTime: 276.97s


                                                                                                                        

[02/10] TrainLoss: 0.0069  Val RMSE: 0.0834  Val NMSE: 3.2972e-02  Val NMSE_dB: -14.8 dB  TrainTime: 245.39s


                                                                                                                        

[03/10] TrainLoss: 0.0061  Val RMSE: 0.0825  Val NMSE: 3.1170e-02  Val NMSE_dB: -15.1 dB  TrainTime: 242.71s


                                                                                                                        

[04/10] TrainLoss: 0.0050  Val RMSE: 0.0863  Val NMSE: 3.1857e-02  Val NMSE_dB: -15.0 dB  TrainTime: 255.61s


                                                                                                                        

[05/10] TrainLoss: 0.0042  Val RMSE: 0.0819  Val NMSE: 2.9014e-02  Val NMSE_dB: -15.4 dB  TrainTime: 250.56s


                                                                                                                        

[06/10] TrainLoss: 0.0038  Val RMSE: 0.0789  Val NMSE: 2.6393e-02  Val NMSE_dB: -15.8 dB  TrainTime: 260.04s


                                                                                                                        

[07/10] TrainLoss: 0.0033  Val RMSE: 0.0757  Val NMSE: 2.4076e-02  Val NMSE_dB: -16.2 dB  TrainTime: 271.34s


                                                                                                                        

[08/10] TrainLoss: 0.0030  Val RMSE: 0.0755  Val NMSE: 2.3930e-02  Val NMSE_dB: -16.2 dB  TrainTime: 267.70s


                                                                                                                        

[09/10] TrainLoss: 0.0028  Val RMSE: 0.0768  Val NMSE: 2.4643e-02  Val NMSE_dB: -16.1 dB  TrainTime: 283.97s


                                                                                                                        

[10/10] TrainLoss: 0.0027  Val RMSE: 0.0765  Val NMSE: 2.4444e-02  Val NMSE_dB: -16.1 dB  TrainTime: 263.95s
🕒 LWM_pretrained_Fine_tune - average train time / epoch: 261.82s

=== Training LWM_Fine_tune ===


                                                                                                                        

[01/10] TrainLoss: 0.0103  Val RMSE: 0.0879  Val NMSE: 3.4356e-02  Val NMSE_dB: -14.6 dB  TrainTime: 256.01s


                                                                                                                        

[02/10] TrainLoss: 0.0052  Val RMSE: 0.0824  Val NMSE: 2.9784e-02  Val NMSE_dB: -15.3 dB  TrainTime: 278.70s


                                                                                                                        

[03/10] TrainLoss: 0.0041  Val RMSE: 0.0797  Val NMSE: 2.7388e-02  Val NMSE_dB: -15.6 dB  TrainTime: 269.49s


                                                                                                                        

[04/10] TrainLoss: 0.0035  Val RMSE: 0.0792  Val NMSE: 2.6363e-02  Val NMSE_dB: -15.8 dB  TrainTime: 263.49s


                                                                                                                        

[05/10] TrainLoss: 0.0032  Val RMSE: 0.0790  Val NMSE: 2.6550e-02  Val NMSE_dB: -15.8 dB  TrainTime: 264.37s


                                                                                                                        

[06/10] TrainLoss: 0.0030  Val RMSE: 0.0759  Val NMSE: 2.4467e-02  Val NMSE_dB: -16.1 dB  TrainTime: 290.80s


                                                                                                                        

[07/10] TrainLoss: 0.0028  Val RMSE: 0.0767  Val NMSE: 2.5073e-02  Val NMSE_dB: -16.0 dB  TrainTime: 283.41s


                                                                                                                        

[08/10] TrainLoss: 0.0027  Val RMSE: 0.0764  Val NMSE: 2.4780e-02  Val NMSE_dB: -16.1 dB  TrainTime: 252.23s


                                                                                                                        

[09/10] TrainLoss: 0.0026  Val RMSE: 0.0760  Val NMSE: 2.4948e-02  Val NMSE_dB: -16.0 dB  TrainTime: 258.61s


                                                                                                                        

[10/10] TrainLoss: 0.0025  Val RMSE: 0.0768  Val NMSE: 2.5728e-02  Val NMSE_dB: -15.9 dB  TrainTime: 282.60s
🕒 LWM_Fine_tune - average train time / epoch: 269.97s

=== Training gru ===


                                                                                                                        

[01/10] TrainLoss: 0.0132  Val RMSE: 0.0957  Val NMSE: 4.5185e-02  Val NMSE_dB: -13.5 dB  TrainTime: 104.55s


                                                                                                                        

[02/10] TrainLoss: 0.0095  Val RMSE: 0.0971  Val NMSE: 4.2676e-02  Val NMSE_dB: -13.7 dB  TrainTime: 108.50s


                                                                                                                        

[03/10] TrainLoss: 0.0078  Val RMSE: 0.0903  Val NMSE: 3.7457e-02  Val NMSE_dB: -14.3 dB  TrainTime: 106.34s


                                                                                                                        

[04/10] TrainLoss: 0.0070  Val RMSE: 0.0904  Val NMSE: 3.6410e-02  Val NMSE_dB: -14.4 dB  TrainTime: 98.97s


                                                                                                                        

[05/10] TrainLoss: 0.0060  Val RMSE: 0.0846  Val NMSE: 3.1739e-02  Val NMSE_dB: -15.0 dB  TrainTime: 98.00s


                                                                                                                        

[06/10] TrainLoss: 0.0047  Val RMSE: 0.0818  Val NMSE: 2.8369e-02  Val NMSE_dB: -15.5 dB  TrainTime: 98.02s


                                                                                                                        

[07/10] TrainLoss: 0.0044  Val RMSE: 0.0827  Val NMSE: 2.8442e-02  Val NMSE_dB: -15.5 dB  TrainTime: 102.88s


                                                                                                                        

[08/10] TrainLoss: 0.0042  Val RMSE: 0.0829  Val NMSE: 2.8391e-02  Val NMSE_dB: -15.5 dB  TrainTime: 98.36s


                                                                                                                        

[09/10] TrainLoss: 0.0041  Val RMSE: 0.0809  Val NMSE: 2.7217e-02  Val NMSE_dB: -15.7 dB  TrainTime: 97.93s


                                                                                                                        

[10/10] TrainLoss: 0.0040  Val RMSE: 0.0808  Val NMSE: 2.7079e-02  Val NMSE_dB: -15.7 dB  TrainTime: 98.96s
🕒 gru - average train time / epoch: 101.25s

=== Training RNN ===


                                                                                                                        

[01/10] TrainLoss: 0.0105  Val RMSE: 0.0879  Val NMSE: 3.6127e-02  Val NMSE_dB: -14.4 dB  TrainTime: 74.73s


                                                                                                                        

[02/10] TrainLoss: 0.0057  Val RMSE: 0.0817  Val NMSE: 2.9307e-02  Val NMSE_dB: -15.3 dB  TrainTime: 75.23s


                                                                                                                        

[03/10] TrainLoss: 0.0042  Val RMSE: 0.0745  Val NMSE: 2.4948e-02  Val NMSE_dB: -16.0 dB  TrainTime: 75.87s


                                                                                                                        

[04/10] TrainLoss: 0.0036  Val RMSE: 0.0707  Val NMSE: 2.2895e-02  Val NMSE_dB: -16.4 dB  TrainTime: 74.30s


                                                                                                                        

[05/10] TrainLoss: 0.0033  Val RMSE: 0.0686  Val NMSE: 2.1485e-02  Val NMSE_dB: -16.7 dB  TrainTime: 84.11s


                                                                                                                        

[06/10] TrainLoss: 0.0030  Val RMSE: 0.0681  Val NMSE: 2.1101e-02  Val NMSE_dB: -16.8 dB  TrainTime: 78.63s


                                                                                                                        

[07/10] TrainLoss: 0.0028  Val RMSE: 0.0673  Val NMSE: 2.0503e-02  Val NMSE_dB: -16.9 dB  TrainTime: 73.38s


                                                                                                                        

[08/10] TrainLoss: 0.0026  Val RMSE: 0.0671  Val NMSE: 2.0288e-02  Val NMSE_dB: -16.9 dB  TrainTime: 70.05s


                                                                                                                        

[09/10] TrainLoss: 0.0025  Val RMSE: 0.0675  Val NMSE: 2.0327e-02  Val NMSE_dB: -16.9 dB  TrainTime: 73.61s


                                                                                                                        

[10/10] TrainLoss: 0.0025  Val RMSE: 0.0676  Val NMSE: 2.0284e-02  Val NMSE_dB: -16.9 dB  TrainTime: 68.00s
🕒 RNN - average train time / epoch: 74.79s

=== Training LSTM ===


                                                                                                                        

[01/10] TrainLoss: 0.0141  Val RMSE: 0.0988  Val NMSE: 4.6593e-02  Val NMSE_dB: -13.3 dB  TrainTime: 95.89s


                                                                                                                        

[02/10] TrainLoss: 0.0098  Val RMSE: 0.1003  Val NMSE: 4.7084e-02  Val NMSE_dB: -13.3 dB  TrainTime: 101.55s


                                                                                                                        

[03/10] TrainLoss: 0.0094  Val RMSE: 0.0925  Val NMSE: 4.2195e-02  Val NMSE_dB: -13.7 dB  TrainTime: 94.13s


                                                                                                                        

[04/10] TrainLoss: 0.0090  Val RMSE: 0.0907  Val NMSE: 4.1144e-02  Val NMSE_dB: -13.9 dB  TrainTime: 97.70s


                                                                                                                        

[05/10] TrainLoss: 0.0088  Val RMSE: 0.0904  Val NMSE: 4.0404e-02  Val NMSE_dB: -13.9 dB  TrainTime: 97.71s


                                                                                                                        

[06/10] TrainLoss: 0.0086  Val RMSE: 0.0910  Val NMSE: 4.0542e-02  Val NMSE_dB: -13.9 dB  TrainTime: 101.97s


                                                                                                                        

[07/10] TrainLoss: 0.0085  Val RMSE: 0.0974  Val NMSE: 4.4175e-02  Val NMSE_dB: -13.5 dB  TrainTime: 93.94s


                                                                                                                        

[08/10] TrainLoss: 0.0080  Val RMSE: 0.0958  Val NMSE: 4.3562e-02  Val NMSE_dB: -13.6 dB  TrainTime: 104.29s


                                                                                                                        

[09/10] TrainLoss: 0.0078  Val RMSE: 0.0884  Val NMSE: 3.8428e-02  Val NMSE_dB: -14.2 dB  TrainTime: 94.67s


                                                                                                                        

[10/10] TrainLoss: 0.0076  Val RMSE: 0.0877  Val NMSE: 3.7055e-02  Val NMSE_dB: -14.3 dB  TrainTime: 94.77s
🕒 LSTM - average train time / epoch: 97.66s

=== Training Transformer ===


                                                                                                                        

[01/10] TrainLoss: 0.0102  Val RMSE: 0.0822  Val NMSE: 2.9728e-02  Val NMSE_dB: -15.3 dB  TrainTime: 190.83s


                                                                                                                        

[02/10] TrainLoss: 0.0044  Val RMSE: 0.0771  Val NMSE: 2.5726e-02  Val NMSE_dB: -15.9 dB  TrainTime: 172.56s


                                                                                                                        

[03/10] TrainLoss: 0.0035  Val RMSE: 0.0740  Val NMSE: 2.3903e-02  Val NMSE_dB: -16.2 dB  TrainTime: 177.54s


                                                                                                                        

[04/10] TrainLoss: 0.0031  Val RMSE: 0.0738  Val NMSE: 2.4129e-02  Val NMSE_dB: -16.2 dB  TrainTime: 184.68s


                                                                                                                        

[05/10] TrainLoss: 0.0029  Val RMSE: 0.0733  Val NMSE: 2.3701e-02  Val NMSE_dB: -16.3 dB  TrainTime: 189.74s


                                                                                                                        

[06/10] TrainLoss: 0.0028  Val RMSE: 0.0735  Val NMSE: 2.3474e-02  Val NMSE_dB: -16.3 dB  TrainTime: 178.84s


                                                                                                                        

[07/10] TrainLoss: 0.0026  Val RMSE: 0.0714  Val NMSE: 2.2195e-02  Val NMSE_dB: -16.5 dB  TrainTime: 187.54s


                                                                                                                        

[08/10] TrainLoss: 0.0025  Val RMSE: 0.0704  Val NMSE: 2.1493e-02  Val NMSE_dB: -16.7 dB  TrainTime: 175.68s


                                                                                                                        

[09/10] TrainLoss: 0.0024  Val RMSE: 0.0707  Val NMSE: 2.1681e-02  Val NMSE_dB: -16.6 dB  TrainTime: 171.46s


                                                                                                                        

[10/10] TrainLoss: 0.0024  Val RMSE: 0.0704  Val NMSE: 2.1441e-02  Val NMSE_dB: -16.7 dB  TrainTime: 180.34s
🕒 Transformer - average train time / epoch: 180.92s

=== Summary of NMSE(dB) by model ===
LWM_freeze_backbone      : -14.983714818954468
LWM_pretrained_Fine_tune : -16.11828923225403
LWM_Fine_tune            : -15.895888805389404
gru                      : -15.673699378967285
RNN                      : -16.92837953567505
LSTM                     : -14.311479330062866
Transformer              : -16.68756127357483

Total training time for all models: 13304.01s


In [24]:
# after you (re)create masked_train_loader
batch = next(iter(masked_train_loader))
seq, mpos, tgt = batch   # seq: (B, L, D), tgt: (B, D)
print(
    "seq: min", seq.min().item(), "max", seq.max().item(),
    "   tgt: min", tgt.min().item(), "max", tgt.max().item()
)


seq: min 0.0 max 0.6701632738113403    tgt: min 0.329348623752594 max 0.666236400604248
