# All Model saves here

## import

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader, random_split
import DeepMIMOv3
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt
import time


plt . rcParams [ 'figure.figsize' ]  =  [ 12 ,  8 ]  # 기본 플롯 크기 설정

## GPU Settings

In [2]:
# GPU 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
import torch
print(torch.version.cuda)                   
print(torch.backends.cudnn.version())       
print("CUDA available:", torch.cuda.is_available())  # True

12.6
90501
CUDA available: True


## DeepMIMOv3 dataset

In [4]:
parameters = DeepMIMOv3.default_params()

In [5]:
## Change parameters for the setup
# Scenario O1_60 extracted at the dataset_folder
#LWM dynamic senario
# parameters['dataset_folder'] = r'/content/drive/MyDrive/Colab Notebooks/LWM'
scene = 15 # scene 15
# change my linux route
parameters['dataset_folder'] = '/home/dlghdbs200/LWM'

# scnario = 02_dyn_3p5 <- download file
parameters['scenario'] = 'O2_dyn_3p5'
parameters['dynamic_scenario_scenes'] = np.arange(scene) #scene 0~9

# Up to 10 multipath paths per user-to-base station channel
parameters['num_paths'] = 10

# User rows 1-100
parameters['user_rows'] = np.arange(100)
# User subsampling
parameters['user_subsampling'] = 0.01

# Activate only the first basestation
parameters['active_BS'] = np.array([1])

parameters['activate_OFDM'] = 1

parameters['OFDM']['bandwidth'] = 0.05 # 50 MHz
parameters['OFDM']['subcarriers'] = 512 # OFDM with 512 subcarriers
parameters['OFDM']['selected_subcarriers'] = np.arange(0, 64, 1)
#parameters['OFDM']['subcarriers_limit'] = 64 # Keep only first 64 subcarriers

parameters['ue_antenna']['shape'] = np.array([1, 1]) # Single antenna
parameters['bs_antenna']['shape'] = np.array([1, 32]) # ULA of 32 elements
#parameters['bs_antenna']['rotation'] = np.array([0, 30, 90]) # ULA of 32 elements
#parameters['ue_antenna']['rotation'] = np.array([[0, 30], [30, 60], [60, 90]]) # ULA of 32 elements
#parameters['ue_antenna']['radiation_pattern'] = 'isotropic'
#parameters['bs_antenna']['radiation_pattern'] = 'halfwave-dipole'

In [6]:
## dataset setting (chunked on‑the‑fly generation)
import time, gc
from tqdm import tqdm

# 0~999 scene index , process 50 at that time
scene_indices = np.arange(scene)
chunk_size   = 5
all_data     = []

# Call generate_data for each scene chunk
for i in tqdm(range(0, len(scene_indices), chunk_size)):
    chunk = scene_indices[i : i+chunk_size].tolist()
    parameters['dynamic_scenario_scenes'] = chunk

    start = time.time()
    data_chunk = DeepMIMOv3.generate_data(parameters)
    print(f"Scenes {chunk[0]}–{chunk[-1]} generation time: {time.time() - start:.2f}s")

    # combine all_data or save in the Disk
    all_data.extend(data_chunk)

    # free memory 
    del data_chunk
    gc.collect()

# comvine Dataset
dataset = all_data


print(parameters['user_rows'])

  0%|                                                                                             | 0/3 [00:00<?, ?it/s]

The following parameters seem unnecessary:
{'activate_OFDM'}

Scene 1/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 355908.35it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 8400.00it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 6078.70it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 748.98it/s][A



Scene 2/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 387601.01it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 9145.07it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 5203.85it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 469.84it/s][A



Scene 3/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 385698.67it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 8647.26it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4660.34it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 341.67it/s][A



Scene 4/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 398299.29it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 9321.90it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7973.96it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1152.91it/s][A



Scene 5/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 381939.89it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 9416.99it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8886.24it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 971.58it/s][A
 33%|████████████████████████████▎                                                        | 1/3 [00:06<00:13,  6.51s/it]

Scenes 0–4 generation time: 6.40s
The following parameters seem unnecessary:
{'activate_OFDM'}

Scene 1/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 383031.17it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 9183.91it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7898.88it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1096.55it/s][A



Scene 2/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 399919.50it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 9386.44it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8388.61it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1198.03it/s][A



Scene 3/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 389237.76it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 9479.88it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 6278.90it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1199.06it/s][A



Scene 4/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  47%|████████████████████████▎                           | 32244/69006 [00:00<00:00, 322420.32it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 329902.62it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 8783.14it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8577.31it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 721.04it/s][A



Scene 5/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 397078.00it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 9078.93it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8439.24it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1079.06it/s][A
 67%|████████████████████████████████████████████████████████▋                            | 2/3 [00:12<00:06,  6.41s/it]

Scenes 5–9 generation time: 6.22s
The following parameters seem unnecessary:
{'activate_OFDM'}

Scene 1/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 385129.51it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 8852.65it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8019.70it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1024.50it/s][A



Scene 2/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 370065.30it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 8836.20it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 6875.91it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1151.02it/s][A



Scene 3/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 395740.78it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 8949.93it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7244.05it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1299.75it/s][A



Scene 4/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 404278.31it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 9071.94it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8612.53it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1264.87it/s][A



Scene 5/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 395406.66it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 8676.87it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8594.89it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1229.28it/s][A
100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:19<00:00,  6.40s/it]

Scenes 10–14 generation time: 6.24s
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
 96 97 98 99]





## About Information
User : 737
UE antenna : 1
BS antenna : 32  Shape(a+bj)
subcarrier : 64

In [7]:
# Unmasked Data Model(gru
# separate maksed data and unmasked data

## Data Preprocessing

In [8]:
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import IterableDataset
import numpy as np
import torch

class UnMaskedChannelSeqDataset(IterableDataset):
    """
    IterableDataset for unmasked channel sequence data.

    - Predicts the next-step channel vector from a sequence of past vectors.
    - Applies power normalization and MinMax scaling to both inputs and targets.
    """
    def __init__(self, scenes, seq_len=5, eps=1e-9):
        super().__init__()
        self.scenes  = scenes
        self.seq_len = seq_len
        self.eps     = eps

        # determine dimensions: U users, A antennas, S subcarriers
        ch0 = scenes[0][0]['user']['channel']  # shape: (U, 1, A, S)
        self.U       = ch0.shape[0]
        self.A       = ch0.shape[2]
        self.S       = ch0.shape[3]
        self.vec_len = 2 * self.A              # real+imag concatenated

        # precompute MinMax scalers over all valid (seq, target) pairs
        X_list, y_list = [], []
        T = len(scenes)
        for t in range(self.seq_len, T):
            past         = scenes[t - self.seq_len : t]
            scene_target = scenes[t]
            for u in range(self.U):
                for s in range(self.S):
                    # build sequence array of shape (seq_len, vec_len)
                    seq_np = np.stack([
                        np.concatenate([
                            ps[0]['user']['channel'][u,0,:,s].real,
                            ps[0]['user']['channel'][u,0,:,s].imag
                        ])
                        for ps in past
                    ], axis=0).astype(np.float32)

                    # build target array of shape (vec_len,)
                    tgt_np = np.concatenate([
                        scene_target[0]['user']['channel'][u,0,:,s].real,
                        scene_target[0]['user']['channel'][u,0,:,s].imag
                    ]).astype(np.float32)

                    # skip if invalid (all zeros)
                    if not np.any(seq_np) or not np.any(tgt_np):
                        continue

                    X_list.append(seq_np.reshape(-1, self.vec_len))
                    y_list.append(tgt_np)

        X_all = np.vstack(X_list)  # (num_samples*seq_len, vec_len)
        y_all = np.stack(y_list)   # (num_samples, vec_len)
        self.scaler_x = MinMaxScaler().fit(X_all)
        self.scaler_y = MinMaxScaler().fit(y_all)

    def __iter__(self):
        """
        Yields:
            seq_tensor   (seq_len, vec_len): power-normalized & MinMax-scaled input sequence
            target_tensor(vec_len,)        : power-normalized & MinMax-scaled target
        """
        T = len(self.scenes)
        for t in range(self.seq_len, T):
            past         = self.scenes[t - self.seq_len : t]
            scene_target = self.scenes[t]

            for u in range(self.U):
                for s in range(self.S):
                    # power-normalize using helper
                    seq_np = np.stack([
                        self._power_norm(ps[0]['user']['channel'][u,0,:,s])
                        for ps in past
                    ], axis=0)
                    tgt_np = self._power_norm(scene_target[0]['user']['channel'][u,0,:,s])

                    if not np.any(seq_np) or not np.any(tgt_np):
                        continue

                    # MinMax scale and reshape back
                    N, D = seq_np.shape
                    seq_np = self.scaler_x.transform(seq_np.reshape(-1, D)).reshape(N, D)
                    tgt_np = self.scaler_y.transform(tgt_np.reshape(1, -1)).reshape(-1,)

                    # convert to torch tensors
                    seq_tensor   = torch.from_numpy(seq_np)
                    target_tensor= torch.from_numpy(tgt_np)

                    yield seq_tensor, target_tensor

    def _power_norm(self, h: np.ndarray) -> np.ndarray:
        """
        Convert complex vector to real+imag concat and normalize power to unity.
        """
        v = np.concatenate([h.real, h.imag]).astype(np.float32)
        power = np.mean(v * v) + self.eps
        return v / np.sqrt(power)

    def __len__(self):
        """
        Total number of valid (sequence, target) pairs.
        """
        return (len(self.scenes) - self.seq_len) * self.U * self.S


In [9]:
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import IterableDataset
import numpy as np
import torch
import random

class MaskedChannelSeqDataset(IterableDataset):
    """
    IterableDataset for masked channel sequence data.

    - Predicts the next-step channel vector from a sequence of past vectors.
    - Applies power normalization and MinMax scaling to both inputs and targets.
    - Masks 15% of the patches according to:
        * 80% chance: replace selected patch with zeros
        * 10% chance: replace selected patch with Gaussian noise
        * 10% chance: leave the selected patch unchanged
      The other 85% of samples are returned unmasked.
    """
    def __init__(self, scenes, seq_len=5, eps=1e-9, noise_std=1.0):
        super().__init__()
        self.scenes    = scenes
        self.seq_len   = seq_len
        self.eps       = eps
        self.noise_std = noise_std

        # Determine dimensions: U=users, A=antennas, S=subcarriers
        ch0 = scenes[0][0]['user']['channel']   # shape: (U, 1, A, S)
        self.U       = ch0.shape[0]
        self.A       = ch0.shape[2]
        self.S       = ch0.shape[3]
        self.vec_len = 2 * self.A               # real+imag concatenated

        # ----------------------------------------------------------------------
        # 1) PRECOMPUTE power-norm → fit MinMax scalers on exactly the same data
        # ----------------------------------------------------------------------
        X_list, y_list = [], []
        T = len(scenes)
        for t in range(self.seq_len, T):
            past         = scenes[t - self.seq_len : t]
            target_scene = scenes[t]
            for u in range(self.U):
                for s in range(self.S):
                    # power-normalize each time-slice in the sequence
                    seq_np = np.stack([
                        self._power_norm(ps[0]['user']['channel'][u,0,:,s])
                        for ps in past
                    ], axis=0).astype(np.float32)  # (seq_len, vec_len)

                    # power-normalize target
                    tgt_np = self._power_norm(
                        target_scene[0]['user']['channel'][u,0,:,s]
                    ).astype(np.float32)            # (vec_len,)

                    # skip if invalid
                    if not np.any(seq_np) or not np.any(tgt_np):
                        continue

                    # flatten sequence for scaler
                    X_list.append(seq_np.reshape(-1, self.vec_len))
                    y_list.append(tgt_np)

        # fit scalers on the exact same power-normalized data
        X_all = np.vstack(X_list)  # (num_samples*seq_len, vec_len)
        y_all = np.stack(y_list)   # (num_samples, vec_len)
        self.scaler_x = MinMaxScaler().fit(X_all)
        self.scaler_y = MinMaxScaler().fit(y_all)

        # prepare a zero-mask vector
        self.mask_value = torch.zeros(self.vec_len, dtype=torch.float32)

    def __iter__(self):
        """
        Yields:
            seq_tensor      (seq_len, vec_len) : normalized & scaled sequence
            masked_pos      (1,) tensor long   : index of masked time step
            target_tensor   (vec_len,)          : normalized & scaled target
        """
        T = len(self.scenes)
        mask_prob  = 0.15
        zero_prob  = mask_prob * 0.8
        noise_prob = mask_prob * 0.1

        for t in range(self.seq_len, T):
            past       = self.scenes[t - self.seq_len : t]
            scene_dict = self.scenes[t]

            for u in range(self.U):
                for s in range(self.S):
                    # power-normalize each slice
                    seq_np = np.stack([
                        self._power_norm(ps[0]['user']['channel'][u,0,:,s])
                        for ps in past
                    ], axis=0)
                    tgt_np = self._power_norm(
                        scene_dict[0]['user']['channel'][u,0,:,s]
                    )

                    if not np.any(seq_np) or not np.any(tgt_np):
                        continue

                    # MinMax transform (using the same scalers)
                    N, D = seq_np.shape
                    seq_np = self.scaler_x.transform(seq_np.reshape(-1, D)).reshape(N, D)
                    tgt_np = self.scaler_y.transform(tgt_np.reshape(1, -1)).reshape(-1,)

                    seq_tensor   = torch.from_numpy(seq_np)
                    target_tensor= torch.from_numpy(tgt_np)

                    # select mask position
                    mpos = random.randrange(self.seq_len)

                    r = random.random()
                    if r < zero_prob:
                        masked_seq = seq_tensor.clone()
                        masked_seq[mpos] = self.mask_value
                        yield masked_seq, torch.tensor([mpos]), target_tensor

                    elif r < zero_prob + noise_prob:
                        masked_seq = seq_tensor.clone()
                        masked_seq[mpos] = torch.randn(self.vec_len) * self.noise_std
                        yield masked_seq, torch.tensor([mpos]), target_tensor

                    elif r < mask_prob:
                        # masked-but-unchanged
                        yield seq_tensor, torch.tensor([mpos]), target_tensor

                    else:
                        # unmasked
                        yield seq_tensor, torch.tensor([mpos]), target_tensor

    def _power_norm(self, h: np.ndarray) -> np.ndarray:
        """
        Convert complex vector to real+imag concat and normalize power to 1.
        """
        v = np.concatenate([h.real, h.imag]).astype(np.float32)
        power = np.mean(v * v) + self.eps
        return v / np.sqrt(power)

    def __len__(self):
        return (len(self.scenes) - self.seq_len) * self.U * self.S


## Split Train/Val

In [10]:
# ❷ Train/Validation DataLoader split train : val = 6 : 4
seq_len      = 5
split_ratio  = 0.6
split_idx    = int(len(dataset) * split_ratio)

In [11]:
unmasked_train_ds = UnMaskedChannelSeqDataset(dataset[:split_idx], seq_len=seq_len)
unmasked_val_ds   = UnMaskedChannelSeqDataset(dataset[split_idx:], seq_len=seq_len)

# iterate over train_ds to compute min and max of features/targets

batch_size   = 32
unmasked_train_loader = DataLoader(unmasked_train_ds, batch_size=batch_size, shuffle=False)
unmasked_val_loader   = DataLoader(unmasked_val_ds,   batch_size=batch_size, shuffle=False)
# ─────────────────────────────────────────────


In [12]:
# ❷ Train/Validation DataLoader split train : val = 6 : 4

masked_train_ds = MaskedChannelSeqDataset(dataset[:split_idx], seq_len=seq_len)
masked_val_ds   = MaskedChannelSeqDataset(dataset[split_idx:], seq_len=seq_len)

# iterate over train_ds to compute min and max of features/targets

batch_size   = 32
masked_train_loader = DataLoader(masked_train_ds, batch_size=batch_size, shuffle=False)
masked_val_loader   = DataLoader(masked_val_ds,   batch_size=batch_size, shuffle=False)
# ─────────────────────────────────────────────


In [13]:
batch = next(iter(masked_train_loader))
seq, mpos, tgt = batch
print("seq:", seq.min().item(), "~", seq.max().item())
print("tgt:", tgt.min().item(), "~", tgt.max().item())


seq: 0.0 ~ 0.6701632738113403
tgt: 0.329348623752594 ~ 0.666236400604248


## Define Model

LWMWithHead: A wrapper class that uses a pre-trained LWM (Transformer encoder) as the backbone,
             and attaches a new fully-connected (FC) head for downstream tasks
             (regression, classification, etc.).

Changes:
- input_dim: Dimension of the actual input data (e.g., 64)
- patch_length: Patch length expected by the backbone (e.g., 16)
- Replaces the original element_length parameter with these two distinct parameters
- Applies a projection layer (self.input_proj) in forward()


In [14]:
import torch
import torch.nn as nn
from lwm_model import lwm

class LWMWithHead(nn.Module):
    """
    LWMWithHead: A wrapper class that uses a pre-trained LWM (Transformer encoder) as the backbone,
                 and attaches a new fully-connected (FC) head for downstream tasks
                 (regression, classification, etc.).

    Changes:
    - input_dim: Dimension of the actual input data (e.g., 64)
    - patch_length: Patch length expected by the backbone (e.g., 16)
    - Replaces the original element_length parameter with these two distinct parameters
    - Applies a projection layer (self.input_proj) in forward()
    """
    def __init__(
        self,
        input_dim: int,                 # Dimension of the actual input data (e.g., 64)
        patch_length: int,              # Patch length expected by the backbone (e.g., 16)
        d_model: int = 64,              # LWM hidden size
        max_len: int = 129,             # Positional encoding max length
        n_layers: int = 12,             # Number of Transformer encoder layers
        hidden_dim: int = 256,          # FC head hidden dimension
        out_dim: int = 64,              # FC head output dimension
        freeze_backbone: bool = True,   # Whether to freeze the backbone
        checkpoint_path: str | None = "./model_weights.pth",
        device: str = "cuda"
    ):
        super().__init__()

        # apply a projection layer to match backbone's expected patch_length
        self.input_proj = nn.Linear(input_dim, patch_length)

        # initialize backbone
        if checkpoint_path is None:
            # randomly initialized backbone
            self.backbone = lwm(
                element_length=patch_length,
                d_model=d_model,
                max_len=max_len,
                n_layers=n_layers
            ).to(device)
        else:
            # load pre-trained weights
            self.backbone = lwm.from_pretrained(
                ckpt_name=checkpoint_path,
                device=device
            )

        # freeze backbone parameters if required
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # attach a new fully-connected head for downstream tasks
        self.head = nn.Sequential(
            nn.Linear(d_model, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, input_ids: torch.Tensor, masked_pos: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input_ids: Tensor of shape (B, L, input_dim)
            masked_pos: Tensor of shape (B, num_mask)
        Returns:
            out: Tensor of shape (B, out_dim)
        """
        # project inputs to patch_length dimension
        x = self.input_proj(input_ids)

        # backbone forward: returns (logits_lm, enc_output)
        _, enc_output = self.backbone(x, masked_pos)

        # extract CLS token feature (first token)
        feat = enc_output[:, 0, :]

        # pass through FC head to get final output
        out = self.head(feat)
        return out


In [15]:
import torch
import torch.nn as nn

class GRUWithHead(nn.Module):
    """
    GRUWithHead: A wrapper class that uses a GRU backbone and attaches a fully-connected (FC) head
                 for downstream tasks (regression, classification, etc.).
    """
    def __init__(
        self,
        feat_dim: int = 16,           # Dimension of input features (patch_length / element_length)
        d_model: int = 64,            # GRU hidden size
        n_layers: int = 12,           # Number of GRU layers to stack
        bidirectional: bool = True,   # Whether to use a bidirectional GRU
        dropout: float = 0.1,         # Dropout probability between GRU layers
        hidden_dim: int = 256,        # FC head hidden dimension
        out_dim: int = 64,            # FC head output dimension
        freeze_backbone: bool = False # Whether to freeze GRU backbone weights
    ):
        super().__init__()

        # 1) GRU backbone
        self.backbone = nn.GRU(
            input_size   = feat_dim,
            hidden_size  = d_model,
            num_layers   = n_layers,
            batch_first  = True,
            bidirectional= bidirectional,
            dropout      = dropout if n_layers > 1 else 0.0
        )

        # 2) Optionally freeze backbone parameters
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # Build FC head 
        gru_out_dim = d_model * (2 if bidirectional else 1)
        self.head = nn.Sequential(
            nn.Linear(gru_out_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, feat_dim)
        Returns:
            out: Tensor of shape (batch_size, out_dim)
        """
        # 1) Pass through GRU backbone
        out, _ = self.backbone(x)  # out shape: (B, seq_len, num_directions * d_model)

        # 2) Take the last time-step output as sequence representation
        feat = out[:, -1, :]       # shape: (B, gru_out_dim)

        # 3) Pass through FC head to get final output
        return self.head(feat)     # shape: (B, out_dim)


In [16]:
import math
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        # Create positional encoding matrix of shape (1, max_len, d_model)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, d_model)
        Returns:
            Tensor: x plus positional encodings
        """
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

class InputEmbedding(nn.Module):
    def __init__(self, feat_dim: int, d_model: int, max_len: int = 5000):
        super().__init__()
        # Optional linear projection from feat_dim to d_model
        self.proj = nn.Linear(feat_dim, d_model) if feat_dim != d_model else None
        self.pos_enc = PositionalEncoding(d_model, max_len)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch, seq_len, feat_dim)
        Returns:
            Tensor of shape (batch, seq_len, d_model)
        """
        if self.proj is not None:
            x = self.proj(x)
        return self.pos_enc(x)

class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dim_ff: int, dropout: float = 0.1):
        super().__init__()
        # Multi-Head Self-Attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        # Position-wise Feed-Forward Network
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(dim_ff, d_model)
        )
        # Layer Normalization and Dropout for residual connections
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        src_mask: torch.Tensor = None,
        src_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (seq_len, batch, d_model)
            src_mask: Optional Tensor of shape (seq_len, seq_len)
            src_key_padding_mask: Optional Tensor of shape (batch, seq_len)
        Returns:
            Tensor of shape (seq_len, batch, d_model)
        """
        # Self-attention sublayer
        attn_out, _ = self.self_attn(x, x, x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
        x = x + self.dropout1(attn_out)
        x = self.norm1(x)
        # Feed-forward sublayer
        ff_out = self.ff(x)
        x = x + self.dropout2(ff_out)
        x = self.norm2(x)
        return x

class TransformerEncoderCustom(nn.Module):
    def __init__(
        self,
        feat_dim: int,
        d_model: int,
        n_heads: int,
        dim_ff: int,
        n_layers: int,
        dropout: float = 0.1,
        max_len: int = 5000
    ):
        super().__init__()
        # Input embedding: feature projection + positional encoding
        self.input_embedding = InputEmbedding(feat_dim, d_model, max_len)
        # Stack of N encoder layers
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, dim_ff, dropout)
            for _ in range(n_layers)
        ])

    def forward(
        self,
        x: torch.Tensor,
        src_mask: torch.Tensor = None,
        src_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch, seq_len, feat_dim)
        Returns:
            Tensor of shape (seq_len, batch, d_model)
        """
        x = self.input_embedding(x)       # (batch, seq_len, d_model)
        x = x.transpose(0, 1)             # (seq_len, batch, d_model)
        for layer in self.layers:
            x = layer(x, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dim_ff: int, dropout: float = 0.1):
        super().__init__()
        # Masked Self-Attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        # Encoder-Decoder Attention
        self.multihead_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        # Position-wise Feed-Forward Network
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(dim_ff, d_model)
        )
        # Layer Normalizations and Dropouts
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: torch.Tensor = None,
        memory_mask: torch.Tensor = None,
        tgt_key_padding_mask: torch.Tensor = None,
        memory_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            tgt: Tensor of shape (tgt_len, batch, d_model)
            memory: Tensor of shape (src_len, batch, d_model)
        Returns:
            Tensor of shape (tgt_len, batch, d_model)
        """
        # Masked self-attention sublayer
        attn1, _ = self.self_attn(
            tgt, tgt, tgt,
            attn_mask=tgt_mask,
            key_padding_mask=tgt_key_padding_mask
        )
        tgt = tgt + self.dropout1(attn1)
        tgt = self.norm1(tgt)
        # Encoder-decoder attention sublayer
        attn2, _ = self.multihead_attn(
            tgt, memory, memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask
        )
        tgt = tgt + self.dropout2(attn2)
        tgt = self.norm2(tgt)
        # Feed-forward sublayer
        ff_out = self.ff(tgt)
        tgt = tgt + self.dropout3(ff_out)
        tgt = self.norm3(tgt)
        return tgt

class TransformerDecoderCustom(nn.Module):
    def __init__(
        self,
        feat_dim: int,
        d_model: int,
        n_heads: int,
        dim_ff: int,
        n_layers: int,
        dropout: float = 0.1,
        max_len: int = 5000
    ):
        super().__init__()
        # Input embedding for target sequence
        self.input_embedding = InputEmbedding(feat_dim, d_model, max_len)
        # Stack of N decoder layers
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, dim_ff, dropout)
            for _ in range(n_layers)
        ])
        # Final projection back to feature dimension
        self.output_linear = nn.Linear(d_model, feat_dim)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: torch.Tensor = None,
        memory_mask: torch.Tensor = None,
        tgt_key_padding_mask: torch.Tensor = None,
        memory_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            tgt: Tensor of shape (batch, tgt_len, feat_dim)
            memory: Tensor of shape (src_len, batch, d_model)
        Returns:
            Tensor of shape (batch, tgt_len, feat_dim)
        """
        x = self.input_embedding(tgt)       # (batch, tgt_len, d_model)
        x = x.transpose(0, 1)               # (tgt_len, batch, d_model)
        for layer in self.layers:
            x = layer(
                x,
                memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask
            )
        x = x.transpose(0, 1)               # (batch, tgt_len, d_model)
        return self.output_linear(x)        # project back to feat_dim

class TransformerWithHead(nn.Module):
    def __init__(
        self,
        feat_dim: int,
        d_model: int = 256,
        n_heads: int = 8,
        dim_ff: int = 512,
        n_layers: int = 12,
        dropout: float = 0.1,
        hidden_dim: int = 256,
        out_dim: int = 64,
        max_len: int = 5000
    ):
        super().__init__()
        # Encoder-only backbone
        self.encoder = TransformerEncoderCustom(
            feat_dim, d_model, n_heads, dim_ff, n_layers, dropout, max_len
        )
        # Classification head
        self.head = nn.Sequential(
            nn.Linear(d_model, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(
        self,
        x: torch.Tensor,
        src_mask: torch.Tensor = None,
        src_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch, seq_len, feat_dim)
        Returns:
            Tensor of shape (batch, out_dim)
        """
        # Encode input sequence
        enc_output = self.encoder(x, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        # Take the last token's representation
        last_token = enc_output[-1]         # (batch, d_model)
        return self.head(last_token)


In [17]:
class RNNWithHead(nn.Module):
    def __init__(
        self,
        feat_dim: int,
        hidden_size: int      = 128,
        num_layers: int       = 12,
        bidirectional: bool   = True,
        hidden_dim: int       = 256,
        out_dim: int          = 64,
        freeze_backbone: bool = False,
    ):
        super().__init__()
        # RNN backbone (no dropout)
        self.backbone = nn.RNN(
            input_size   = feat_dim,
            hidden_size  = hidden_size,
            num_layers   = num_layers,
            batch_first  = True,
            bidirectional= bidirectional,
            dropout      = 0.1
        )
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # Compute RNN output dimension based on bi-directionality
        rnn_out_dim = hidden_size * (2 if bidirectional else 1)

        # Fully-connected head (dropout layer removed)
        self.head = nn.Sequential(
            nn.Linear(rnn_out_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        # x: (batch_size, seq_len, feat_dim)
        out, _ = self.backbone(x)       # out: (batch_size, seq_len, rnn_out_dim)
        feat   = out[:, -1, :]          # last time step (batch_size, rnn_out_dim)
        return self.head(feat)          # (batch_size, out_dim)


In [18]:
class LSTMWithHead(nn.Module):
    def __init__(
        self,
        feat_dim: int,
        hidden_size: int = 128,
        num_layers: int = 1,
        bidirectional: bool = False,
        dropout: float = 0.0,
        hidden_dim: int = 256,
        out_dim: int = None,
        freeze_backbone: bool = False
    ):
        super().__init__()
        
        # Determine output dimension (default to input feature size if not specified)
        out_dim = feat_dim if out_dim is None else out_dim

        # LSTM backbone for sequence modeling
        # - input_size: number of features per time step
        # - hidden_size: dimensionality of LSTM hidden state
        # - num_layers: number of stacked LSTM layers
        # - bidirectional: whether to use a bi-directional LSTM
        # - dropout: dropout probability between LSTM layers (if num_layers > 1)
        self.backbone = nn.LSTM(
            input_size=feat_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=dropout if num_layers > 1 else 0.0
        )
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # Compute the LSTM output dimension accounting for bidirectionality
        lstm_out_dim = hidden_size * (2 if bidirectional else 1)

        # Fully-connected head to map the final LSTM state to desired output
        self.head = nn.Sequential(
            nn.Linear(lstm_out_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, feat_dim)
        Returns:
            Tensor of shape (batch_size, out_dim)
        """
        # Pass the input through the LSTM backbone
        # out shape: (batch_size, seq_len, lstm_out_dim)
        out, _ = self.backbone(x)

        # Extract features from the last time step
        # feat shape: (batch_size, lstm_out_dim)
        feat = out[:, -1, :]

        # Compute final output via the head
        return self.head(feat)


## fine-tuning

In [19]:
from torch.optim import Adam

# shared hyper-parameters
INPUT_DIM     = 64   # raw feature dimension
PATCH_LENGTH  = 16   # backbone input dimension
D_MODEL       = 64   # hidden size for all backbones
N_LAYERS      = 12   # number of layers
HIDDEN_DIM    = 256  # head hidden dimension
OUT_DIM       = 64   # head output dimension
DROPOUT       = 0.1  # dropout for Transformer / GRU / LSTM
BIDIRECTIONAL = True # whether to use bidirectional RNNs
DEVICE        = "cuda"

# model classes
MODEL_CATALOG = {
    "LWM_freeze_backbone"     : LWMWithHead,
    "LWM_pretrained_Fine_tune": LWMWithHead,
    "LWM_Fine_tune"           : LWMWithHead,
    "gru"                     : GRUWithHead,
    "RNN"                     : RNNWithHead,
    "LSTM"                    : LSTMWithHead,
    "Transformer"             : TransformerWithHead,
}

# model‐specific init args
MODEL_PARAMS = {
    # LWM variants
    "LWM_freeze_backbone": {
        "input_dim"       : INPUT_DIM,
        "patch_length"    : PATCH_LENGTH,
        "d_model"         : D_MODEL,
        "max_len"         : 129,
        "n_layers"        : N_LAYERS,
        "hidden_dim"      : HIDDEN_DIM,
        "out_dim"         : OUT_DIM,
        "freeze_backbone" : True,
        "checkpoint_path" : "./model_weights.pth",
        "device"          : DEVICE,
    },
    "LWM_pretrained_Fine_tune": {
        "input_dim"       : INPUT_DIM,
        "patch_length"    : PATCH_LENGTH,
        "d_model"         : D_MODEL,
        "max_len"         : 129,
        "n_layers"        : N_LAYERS,
        "hidden_dim"      : HIDDEN_DIM,
        "out_dim"         : OUT_DIM,
        "freeze_backbone" : False,
        "checkpoint_path" : "./model_weights.pth",
        "device"          : DEVICE,
    },
    "LWM_Fine_tune": {
        "input_dim"       : INPUT_DIM,
        "patch_length"    : PATCH_LENGTH,
        "d_model"         : D_MODEL,
        "max_len"         : 129,
        "n_layers"        : N_LAYERS,
        "hidden_dim"      : HIDDEN_DIM,
        "out_dim"         : OUT_DIM,
        "freeze_backbone" : False,
        "checkpoint_path" : None,
        "device"          : DEVICE,
    },

    # GRU
    "gru": {
        "feat_dim"        : PATCH_LENGTH,  # matches LWM backbone
        "d_model"         : D_MODEL,
        "n_layers"        : N_LAYERS,
        "bidirectional"   : BIDIRECTIONAL,
        "dropout"         : DROPOUT,
        "hidden_dim"      : HIDDEN_DIM,
        "out_dim"         : OUT_DIM,
        "freeze_backbone" : False,
    },

    # Vanilla RNN
    "RNN": {
        "feat_dim"        : PATCH_LENGTH,
        "hidden_size"     : D_MODEL,
        "num_layers"      : N_LAYERS,
        "bidirectional"   : BIDIRECTIONAL,
        "hidden_dim"      : HIDDEN_DIM,
        "out_dim"         : OUT_DIM,
        "freeze_backbone" : False,
    },

    # LSTM
    "LSTM": {
        "feat_dim"        : PATCH_LENGTH,
        "hidden_size"     : D_MODEL,
        "num_layers"      : N_LAYERS,
        "bidirectional"   : BIDIRECTIONAL,
        "dropout"         : DROPOUT,
        "hidden_dim"      : HIDDEN_DIM,
        "out_dim"         : OUT_DIM,
        "freeze_backbone" : False,
    },

    # Transformer
    "Transformer": {
        "feat_dim"        : PATCH_LENGTH,
        "d_model"         : D_MODEL,
        "n_heads"         : 4,
        "dim_ff"          : 256,
        "n_layers"        : N_LAYERS,
        "dropout"         : DROPOUT,
        "hidden_dim"      : HIDDEN_DIM,
        "out_dim"         : OUT_DIM,
        "freeze_backbone" : False,
    },
}


## model evaluate

In [20]:
import torch
import torch.nn.functional as F

def rmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Root-Mean-Squared Error
    """
    return torch.sqrt(F.mse_loss(pred, target, reduction="mean"))   # √MSE

def nmse(pred: torch.Tensor, target: torch.Tensor, eps : float = 1e-12) -> torch.Tensor:
    """
    Normalized MSE  =  E[‖ŷ − y‖²] / E[‖y‖²]
    """
    # (B, …) → (B,)  
    mse_per_sample   = ((pred - target)**2).view(pred.size(0), -1).sum(dim=1)
    power_per_sample = (target**2).view(target.size(0), -1).sum(dim=1) + eps
    return (mse_per_sample / power_per_sample).mean()



In [21]:
def masked_evaluate(model, loader, device="cuda"):
    """
    Validation loop for IterableDataset.
    Returns average RMSE and NMSE over all samples.
    """
    model.eval()
    total_rmse, total_nmse, total_samples = 0.0, 0.0, 0

    with torch.no_grad():
        for input_ids, masked_pos, target in loader:
            # Move to device
            input_ids, masked_pos, target = (
                input_ids.to(device),
                masked_pos.to(device),
                target.to(device),
            )
            # Batch size
            bs = input_ids.size(0)

            # Forward
            pred = model(input_ids, masked_pos)

            # Accumulate batch metrics
            total_rmse    += rmse(pred, target).item() * bs
            total_nmse    += nmse(pred, target).item() * bs
            total_samples += bs

    # Compute averages
    return {
        "RMSE": total_rmse / total_samples,
        "NMSE": total_nmse / total_samples
    }

In [22]:
def unmasked_evaluate(model, loader, device="cuda"):
    """
    Validation loop for IterableDataset.
    Returns average RMSE and NMSE over all samples.
    """
    model.eval()
    total_rmse, total_nmse, total_samples = 0.0, 0.0, 0

    with torch.no_grad():
        for input_ids, target in loader:
            # Move to device
            input_ids,target = (
                input_ids.to(device),
                target.to(device),
            )
            # Batch size
            bs = input_ids.size(0)

            # Forward
            pred = model(input_ids)

            # Accumulate batch metrics
            total_rmse    += rmse(pred, target).item() * bs
            total_nmse    += nmse(pred, target).item() * bs
            total_samples += bs

    # Compute averages
    return {
        "RMSE": total_rmse / total_samples,
        "NMSE": total_nmse / total_samples
    }

# Model Training

In [24]:
"""
Unified training / validation script
------------------------------------
* Handles multiple models listed in MODEL_CATALOG
* Uses separate masked / un-masked DataLoaders depending on model type
* Measures epoch-level train time and prints per-epoch + average timings
"""

from tqdm import tqdm
import time, math, torch
import torch.nn as nn

# ─────────────────────────────────────────────
# 0) Globals and hyper-parameters
# ─────────────────────────────────────────────
device            = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion         = nn.MSELoss().to(device)

NUM_EPOCHS        = 10
LR                = 1e-4                      # learning-rate
total_start_time  = time.time()               # wall-clock timer for all models
results           = {}                        # holds best-epoch NMSE_dB for each model

# ─────────────────────────────────────────────
# 1) Model training / validation loop
# ─────────────────────────────────────────────
for model_name, ModelCls in MODEL_CATALOG.items():

    print(f"\n=== Training {model_name} ===")

    # 1-A) instantiate model
    model_args = MODEL_PARAMS[model_name]
    model      = ModelCls(**model_args).to(device)

    # 1-B) collect trainable parameters
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    if len(trainable_params) == 0:
        print(f"⚠️  '{model_name}' has no trainable parameters — skipping training.")
        results[model_name] = float("nan")
        continue

    optimizer   = torch.optim.Adam(trainable_params, lr=LR)
    epoch_times = []                       # store per-epoch train times

    # choose loaders depending on model family
    uses_mask   = model_name.startswith("LWM_")
    tr_loader   = masked_train_loader   if uses_mask else unmasked_train_loader
    val_loader  = masked_val_loader    if uses_mask else unmasked_val_loader
    eval_fn     = masked_evaluate      if uses_mask else unmasked_evaluate

    # 1-C) epoch loop
    for epoch in range(1, NUM_EPOCHS + 1):

        # ── TRAIN ───────────────────────────────────────
        t0 = time.time()
        model.train()
        running_loss = 0.0

        pbar = tqdm(tr_loader,
                    desc=f"[{model_name} {epoch:02d}/{NUM_EPOCHS}] train",
                    leave=False)

        for b, batch in enumerate(pbar, 1):
            # unpack batch depending on masked / unmasked loader
            if uses_mask:
                inp, mpos, tgt = [x.to(device) for x in batch]
                pred = model(inp, mpos).squeeze(-1)
            else:
                inp, tgt      = [x.to(device) for x in batch]
                pred          = model(inp)

            loss = criterion(pred, tgt)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if b % 100 == 0:
                pbar.set_postfix(train_loss=running_loss / b)

        epoch_train_time = time.time() - t0
        epoch_times.append(epoch_train_time)
        avg_train_loss = running_loss / b

        # ── VALID ───────────────────────────────────────
        metrics      = eval_fn(model, val_loader, device)
        val_rmse     = metrics["RMSE"]
        val_nmse     = metrics["NMSE"]
        val_nmse_db  = 10 * torch.log10(torch.tensor(val_nmse)).item()

        print(
            f"[{epoch:02d}/{NUM_EPOCHS}] "
            f"TrainLoss: {avg_train_loss:.4f}  "
            f"Val RMSE: {val_rmse:.4f}  "
            f"Val NMSE: {val_nmse:.4e}  "
            f"Val NMSE_dB: {val_nmse_db:.1f} dB  "
            f"TrainTime: {epoch_train_time:.2f}s"
        )

    # 1-D) epoch-time statistics and final log
    avg_epoch_time = sum(epoch_times) / len(epoch_times)
    print(f"🕒 {model_name} - average train time / epoch: {avg_epoch_time:.2f}s")
    results[model_name] = val_nmse_db

# ─────────────────────────────────────────────
# 2) summary
# ─────────────────────────────────────────────
print("\n=== Summary of NMSE(dB) by model ===")
for name, nmse_db in results.items():
    print(f"{name:25s}: {nmse_db if not math.isnan(nmse_db) else 'skipped':>6}")

print(f"\nTotal training time for all models: {time.time() - total_start_time:.2f}s")



=== Training LWM_freeze_backbone ===
Model loaded successfully from ./model_weights.pth to cuda


                                                                                                                        

[01/10] TrainLoss: 0.0161  Val RMSE: 0.0938  Val NMSE: 4.6554e-02  Val NMSE_dB: -13.3 dB  TrainTime: 199.13s


                                                                                                                        

[02/10] TrainLoss: 0.0103  Val RMSE: 0.0921  Val NMSE: 4.2642e-02  Val NMSE_dB: -13.7 dB  TrainTime: 209.88s


                                                                                                                        

[03/10] TrainLoss: 0.0086  Val RMSE: 0.0927  Val NMSE: 3.9494e-02  Val NMSE_dB: -14.0 dB  TrainTime: 227.78s


                                                                                                                        

[04/10] TrainLoss: 0.0076  Val RMSE: 0.0925  Val NMSE: 3.8524e-02  Val NMSE_dB: -14.1 dB  TrainTime: 215.17s


                                                                                                                        

[05/10] TrainLoss: 0.0072  Val RMSE: 0.0929  Val NMSE: 3.8366e-02  Val NMSE_dB: -14.2 dB  TrainTime: 207.31s


                                                                                                                        

[06/10] TrainLoss: 0.0070  Val RMSE: 0.0918  Val NMSE: 3.7559e-02  Val NMSE_dB: -14.3 dB  TrainTime: 227.30s


                                                                                                                        

[07/10] TrainLoss: 0.0068  Val RMSE: 0.0916  Val NMSE: 3.7141e-02  Val NMSE_dB: -14.3 dB  TrainTime: 225.31s


                                                                                                                        

[08/10] TrainLoss: 0.0066  Val RMSE: 0.0910  Val NMSE: 3.6407e-02  Val NMSE_dB: -14.4 dB  TrainTime: 203.32s


                                                                                                                        

[09/10] TrainLoss: 0.0063  Val RMSE: 0.0908  Val NMSE: 3.6003e-02  Val NMSE_dB: -14.4 dB  TrainTime: 210.75s


                                                                                                                        

[10/10] TrainLoss: 0.0060  Val RMSE: 0.0900  Val NMSE: 3.5227e-02  Val NMSE_dB: -14.5 dB  TrainTime: 236.20s
🕒 LWM_freeze_backbone - average train time / epoch: 216.22s

=== Training LWM_pretrained_Fine_tune ===
Model loaded successfully from ./model_weights.pth to cuda


                                                                                                                        

[01/10] TrainLoss: 0.0124  Val RMSE: 0.0909  Val NMSE: 4.0770e-02  Val NMSE_dB: -13.9 dB  TrainTime: 268.43s


                                                                                                                        

[02/10] TrainLoss: 0.0068  Val RMSE: 0.0828  Val NMSE: 3.2098e-02  Val NMSE_dB: -14.9 dB  TrainTime: 264.16s


                                                                                                                        

[03/10] TrainLoss: 0.0051  Val RMSE: 0.0743  Val NMSE: 2.5223e-02  Val NMSE_dB: -16.0 dB  TrainTime: 268.02s


                                                                                                                        

[04/10] TrainLoss: 0.0039  Val RMSE: 0.0732  Val NMSE: 2.3721e-02  Val NMSE_dB: -16.2 dB  TrainTime: 263.31s


                                                                                                                        

[05/10] TrainLoss: 0.0034  Val RMSE: 0.0712  Val NMSE: 2.2017e-02  Val NMSE_dB: -16.6 dB  TrainTime: 254.25s


                                                                                                                        

[06/10] TrainLoss: 0.0030  Val RMSE: 0.0706  Val NMSE: 2.1667e-02  Val NMSE_dB: -16.6 dB  TrainTime: 251.07s


                                                                                                                        

[07/10] TrainLoss: 0.0029  Val RMSE: 0.0707  Val NMSE: 2.1740e-02  Val NMSE_dB: -16.6 dB  TrainTime: 271.41s


                                                                                                                        

[08/10] TrainLoss: 0.0027  Val RMSE: 0.0716  Val NMSE: 2.2120e-02  Val NMSE_dB: -16.6 dB  TrainTime: 273.08s


                                                                                                                        

[09/10] TrainLoss: 0.0026  Val RMSE: 0.0715  Val NMSE: 2.1988e-02  Val NMSE_dB: -16.6 dB  TrainTime: 295.92s


                                                                                                                        

[10/10] TrainLoss: 0.0026  Val RMSE: 0.0716  Val NMSE: 2.1980e-02  Val NMSE_dB: -16.6 dB  TrainTime: 264.81s
🕒 LWM_pretrained_Fine_tune - average train time / epoch: 267.45s

=== Training LWM_Fine_tune ===


                                                                                                                        

[01/10] TrainLoss: 0.0101  Val RMSE: 0.0870  Val NMSE: 3.4073e-02  Val NMSE_dB: -14.7 dB  TrainTime: 257.40s


                                                                                                                        

[02/10] TrainLoss: 0.0050  Val RMSE: 0.0742  Val NMSE: 2.4652e-02  Val NMSE_dB: -16.1 dB  TrainTime: 271.24s


                                                                                                                        

[03/10] TrainLoss: 0.0037  Val RMSE: 0.0750  Val NMSE: 2.4800e-02  Val NMSE_dB: -16.1 dB  TrainTime: 249.95s


                                                                                                                        

[04/10] TrainLoss: 0.0032  Val RMSE: 0.0742  Val NMSE: 2.4158e-02  Val NMSE_dB: -16.2 dB  TrainTime: 299.32s


                                                                                                                        

[05/10] TrainLoss: 0.0030  Val RMSE: 0.0736  Val NMSE: 2.3661e-02  Val NMSE_dB: -16.3 dB  TrainTime: 259.83s


                                                                                                                        

[06/10] TrainLoss: 0.0028  Val RMSE: 0.0737  Val NMSE: 2.3683e-02  Val NMSE_dB: -16.3 dB  TrainTime: 254.78s


                                                                                                                        

[07/10] TrainLoss: 0.0027  Val RMSE: 0.0730  Val NMSE: 2.3189e-02  Val NMSE_dB: -16.3 dB  TrainTime: 288.12s


                                                                                                                        

[08/10] TrainLoss: 0.0026  Val RMSE: 0.0723  Val NMSE: 2.2662e-02  Val NMSE_dB: -16.4 dB  TrainTime: 257.69s


                                                                                                                        

[09/10] TrainLoss: 0.0025  Val RMSE: 0.0709  Val NMSE: 2.1688e-02  Val NMSE_dB: -16.6 dB  TrainTime: 265.18s


                                                                                                                        

[10/10] TrainLoss: 0.0024  Val RMSE: 0.0714  Val NMSE: 2.1984e-02  Val NMSE_dB: -16.6 dB  TrainTime: 286.30s
🕒 LWM_Fine_tune - average train time / epoch: 268.98s

=== Training gru ===


                                                                                                                        

RuntimeError: input.size(-1) must be equal to input_size. Expected 16, got 64

In [None]:
# after you (re)create masked_train_loader
batch = next(iter(masked_train_loader))
seq, mpos, tgt = batch   # seq: (B, L, D), tgt: (B, D)
print(
    "seq: min", seq.min().item(), "max", seq.max().item(),
    "   tgt: min", tgt.min().item(), "max", tgt.max().item()
)
