# All Model saves here

## import

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader, random_split
import DeepMIMOv3
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt
import time


plt . rcParams [ 'figure.figsize' ]  =  [ 12 ,  8 ]  # 기본 플롯 크기 설정

## GPU Settings

In [2]:
# GPU 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
import torch
print(torch.version.cuda)                   
print(torch.backends.cudnn.version())       
print("CUDA available:", torch.cuda.is_available())  # True

12.6
90501
CUDA available: True


## DeepMIMOv3 dataset

In [4]:
parameters = DeepMIMOv3.default_params()

In [5]:
## Change parameters for the setup
# Scenario O1_60 extracted at the dataset_folder
#LWM dynamic senario
# parameters['dataset_folder'] = r'/content/drive/MyDrive/Colab Notebooks/LWM'
scene = 15 # scene 15
# change my linux route
parameters['dataset_folder'] = '/home/dlghdbs200/LWM'

# scnario = 02_dyn_3p5 <- download file
parameters['scenario'] = 'O2_dyn_3p5'
parameters['dynamic_scenario_scenes'] = np.arange(scene) #scene 0~9

# Up to 10 multipath paths per user-to-base station channel
parameters['num_paths'] = 10

# User rows 1-100
parameters['user_rows'] = np.arange(100)
# User subsampling
parameters['user_subsampling'] = 0.01

# Activate only the first basestation
parameters['active_BS'] = np.array([1])

parameters['activate_OFDM'] = 1

parameters['OFDM']['bandwidth'] = 0.05 # 50 MHz
parameters['OFDM']['subcarriers'] = 512 # OFDM with 512 subcarriers
parameters['OFDM']['selected_subcarriers'] = np.arange(0, 64, 1)
#parameters['OFDM']['subcarriers_limit'] = 64 # Keep only first 64 subcarriers

parameters['ue_antenna']['shape'] = np.array([1, 1]) # Single antenna
parameters['bs_antenna']['shape'] = np.array([1, 32]) # ULA of 32 elements
#parameters['bs_antenna']['rotation'] = np.array([0, 30, 90]) # ULA of 32 elements
#parameters['ue_antenna']['rotation'] = np.array([[0, 30], [30, 60], [60, 90]]) # ULA of 32 elements
#parameters['ue_antenna']['radiation_pattern'] = 'isotropic'
#parameters['bs_antenna']['radiation_pattern'] = 'halfwave-dipole'

In [6]:
## dataset setting (chunked on‑the‑fly generation)
import time, gc
from tqdm import tqdm

# 0~999 scene index , process 50 at that time
scene_indices = np.arange(scene)
chunk_size   = 5
all_data     = []

# Call generate_data for each scene chunk
for i in tqdm(range(0, len(scene_indices), chunk_size)):
    chunk = scene_indices[i : i+chunk_size].tolist()
    parameters['dynamic_scenario_scenes'] = chunk

    start = time.time()
    data_chunk = DeepMIMOv3.generate_data(parameters)
    print(f"Scenes {chunk[0]}–{chunk[-1]} generation time: {time.time() - start:.2f}s")

    # combine all_data or save in the Disk
    all_data.extend(data_chunk)

    # free memory 
    del data_chunk
    gc.collect()

# comvine Dataset
dataset = all_data


print(parameters['user_rows'])

  0%|                                                                                             | 0/3 [00:00<?, ?it/s]

The following parameters seem unnecessary:
{'activate_OFDM'}

Scene 1/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  39%|████████████████████▍                               | 27179/69006 [00:00<00:00, 271750.74it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 272716.01it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 6164.60it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 5203.85it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 592.75it/s][A



Scene 2/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  40%|████████████████████▊                               | 27596/69006 [00:00<00:00, 275943.82it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 277390.08it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 6172.08it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4650.00it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 213.37it/s][A



Scene 3/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  40%|████████████████████▊                               | 27620/69006 [00:00<00:00, 276181.17it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 276871.32it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 6335.37it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2400.86it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 557.16it/s][A



Scene 4/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  38%|███████████████████▋                                | 26091/69006 [00:00<00:00, 260879.77it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 269747.83it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 6210.00it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1912.59it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 538.98it/s][A



Scene 5/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  39%|████████████████████▌                               | 27229/69006 [00:00<00:00, 272267.54it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 273997.84it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 6019.40it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4686.37it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 264.78it/s][A
 33%|████████████████████████████▎                                                        | 1/3 [00:07<00:15,  7.77s/it]

Scenes 0–4 generation time: 7.56s
The following parameters seem unnecessary:
{'activate_OFDM'}

Scene 1/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  39%|████████████████████                                | 26610/69006 [00:00<00:00, 266069.80it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 260655.75it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 5934.86it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 5562.74it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 288.41it/s][A



Scene 2/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  38%|███████████████████▉                                | 26417/69006 [00:00<00:00, 264145.69it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 263237.15it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 5835.36it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2139.95it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 516.16it/s][A



Scene 3/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  40%|████████████████████▊                               | 27586/69006 [00:00<00:00, 275841.85it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 270010.80it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 5793.83it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4457.28it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 358.89it/s][A



Scene 4/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:   8%|████▍                                                  | 5495/69006 [00:02<00:33, 1902.54it/s][A
Reading ray-tracing:  45%|████████████████████████                             | 31265/69006 [00:02<00:02, 13958.32it/s][A
Reading ray-tracing: 100%|█████████████████████████████████████████████████████| 69006/69006 [00:03<00:00, 22048.38it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 5914.88it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4733.98it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 646.27it/s][A



Scene 5/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  39%|████████████████████▍                               | 27176/69006 [00:00<00:00, 271736.29it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 268112.88it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 5487.77it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2176.60it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 613.38it/s][A
 67%|████████████████████████████████████████████████████████▋                            | 2/3 [00:18<00:09,  9.27s/it]

Scenes 5–9 generation time: 10.12s
The following parameters seem unnecessary:
{'activate_OFDM'}

Scene 1/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  39%|████████████████████                                | 26654/69006 [00:00<00:00, 266523.10it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 261302.18it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 5809.61it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 3039.35it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 686.92it/s][A



Scene 2/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  39%|████████████████████▏                               | 26844/69006 [00:00<00:00, 268421.06it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 268846.56it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 5714.71it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4185.93it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 779.90it/s][A



Scene 3/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  38%|███████████████████▉                                | 26428/69006 [00:00<00:00, 264262.61it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 263775.01it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 5713.83it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 3137.10it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 688.72it/s][A



Scene 4/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  39%|████████████████████▏                               | 26828/69006 [00:00<00:00, 268250.20it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 271252.53it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 5658.98it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 5809.29it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 786.63it/s][A



Scene 5/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                    | 0/69006 [00:00<?, ?it/s][A
Reading ray-tracing:  40%|████████████████████▌                               | 27347/69006 [00:00<00:00, 273447.44it/s][A
Reading ray-tracing: 100%|████████████████████████████████████████████████████| 69006/69006 [00:00<00:00, 271154.69it/s][A

Generating channels:   0%|                                                                      | 0/727 [00:00<?, ?it/s][A
Generating channels: 100%|██████████████████████████████████████████████████████████| 727/727 [00:00<00:00, 5763.44it/s][A



BS-BS Channels



Reading ray-tracing: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 6105.25it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 686.69it/s][A
100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:25<00:00,  8.51s/it]

Scenes 10–14 generation time: 7.24s
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
 96 97 98 99]





## About Information
User : 737
UE antenna : 1
BS antenna : 32  Shape(a+bj)
subcarrier : 64

In [7]:
# Unmasked Data Model(gru
# separate maksed data and unmasked data

## Data Preprocessing

In [8]:
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import IterableDataset, DataLoader
import numpy as np
import torch

class UnMaskedChannelSeqDataset(IterableDataset):
    """
    IterableDataset for masked channel sequence data.
    - Predicts the next-step channel vector from a sequence of past vectors.
    - Applies power normalization and MinMax scaling to both inputs and targets.
    """
    def __init__(self, scenes, seq_len=5, eps=1e-9):
        super().__init__()
        self.scenes = scenes
        self.seq_len = seq_len
        self.eps = eps

        # Determine dimensions: users (U), antennas (A), subcarriers (S), and vector length
        ch0 = scenes[0][0]['user']['channel']  # Example shape: (U, 1, A, S), complex values
        self.U = ch0.shape[0]                  # Number of users
        self.A = ch0.shape[2]                  # Number of antennas
        self.S = ch0.shape[3]                  # Number of subcarriers
        self.vec_len = 2 * self.A              # Real+imag length after concatenation

        # ----------------------------------------------------------------------
        # Precompute MinMax scaler on entire dataset
        # ----------------------------------------------------------------------
        X_list, y_list = [], []
        T = len(scenes)
        # Slide over time index to collect sequences and targets
        for t in range(self.seq_len, T):
            past = scenes[t - self.seq_len : t]
            target = scenes[t]
            for u in range(self.U):
                for s in range(self.S):
                    # Build numpy sequence of shape (seq_len, vec_len)
                    seq_np = np.stack([
                        np.concatenate([
                            ps[0]['user']['channel'][u, 0, :, s].real,
                            ps[0]['user']['channel'][u, 0, :, s].imag
                        ])
                        for ps in past
                    ], axis=0).astype(np.float32)

                    # Build numpy target of shape (vec_len,)
                    target_np = np.concatenate([
                        target[0]['user']['channel'][u, 0, :, s].real,
                        target[0]['user']['channel'][u, 0, :, s].imag
                    ]).astype(np.float32)

                    # Skip if all zeros (invalid data)
                    if not np.any(seq_np) or not np.any(target_np):
                        continue

                    # Flatten sequence for fitting scaler
                    X_list.append(seq_np.reshape(-1, self.vec_len))
                    y_list.append(target_np)

        # Stack all data for fitting the MinMax scaler
        X_all = np.vstack(X_list)  # Shape: (num_samples*seq_len, vec_len)
        y_all = np.stack(y_list)   # Shape: (num_samples, vec_len)

        # Fit MinMax scalers for inputs and targets
        self.scaler_x = MinMaxScaler().fit(X_all)
        self.scaler_y = MinMaxScaler().fit(y_all)

    def __iter__(self):
        """
        Yield power-normalized and MinMax-scaled sequences, mask positions, and targets.
        Each item: (seq_tensor, masked_pos_tensor, target_tensor)
        Shapes: seq_tensor (seq_len, vec_len), masked_pos_tensor (1,), target_tensor (vec_len,)
        """
        T = len(self.scenes)
        for t in range(self.seq_len, T):
            past = self.scenes[t - self.seq_len : t]
            target = self.scenes[t]
            for u in range(self.U):
                for s in range(self.S):
                    # Compute power-normalized numpy arrays
                    seq_np = np.stack([
                        self._power_norm(ps[0]['user']['channel'][u, 0, :, s])
                        for ps in past
                    ], axis=0)
                    target_np = self._power_norm(target[0]['user']['channel'][u, 0, :, s])

                    # Skip sequences or targets that are all zero
                    if not np.any(seq_np) or not np.any(target_np):
                        continue

                    # Apply MinMax scaling: reshape, transform, and reshape back
                    N, D = seq_np.shape
                    seq_np = self.scaler_x.transform(seq_np.reshape(-1, D)).reshape(N, D)
                    target_np = self.scaler_y.transform(target_np.reshape(1, -1)).reshape(-1,)

                    # Convert to torch tensors and yield with masked position
                    seq = torch.from_numpy(seq_np)
                    target = torch.from_numpy(target_np)
                    yield seq, target

    def _power_norm(self, h: np.ndarray) -> np.ndarray:
        """
        Convert complex-valued vector to concatenated real-imag vector and normalize power to 1.
        """
        v = np.concatenate([h.real, h.imag]).astype(np.float32)
        power = np.mean(v * v) + self.eps
        return v / np.sqrt(power)

    def __len__(self):
        """
        Total number of valid (sequence, target) pairs in the dataset.
        """
        return (len(self.scenes) - self.seq_len) * self.U * self.S


In [9]:
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import IterableDataset, DataLoader
import numpy as np
import torch
import random

class MaskedChannelSeqDataset(IterableDataset):
    """
    IterableDataset for masked channel sequence data.
    - Predicts the next-step channel vector from a sequence of past vectors.
    - Applies power normalization and MinMax scaling to both inputs and targets.
    - MCM is 15% about the all data
    - MCM 
      :80% probability: replace the selected patch entirely with a fixed mask vector m (e.g., a vector of zeros)
      :10% probability: replace it with a random noise vector sampled from a normal distribution (e.g., N(0, σ²))
      :10% probability: leave the original patch unchanged
    """
    def __init__(self, scenes, seq_len=5, eps=1e-9, noise_std = 1.0):
        super().__init__()
        self.scenes = scenes
        self.seq_len = seq_len
        self.eps = eps

        # Determine dimensions: users (U), antennas (A), subcarriers (S), and vector length
        ch0 = scenes[0][0]['user']['channel']  # Example shape: (U, 1, A, S), complex values
        self.U = ch0.shape[0]                  # Number of users
        self.A = ch0.shape[2]                  # Number of antennas
        self.S = ch0.shape[3]                  # Number of subcarriers
        self.vec_len = 2 * self.A              # Real+imag length after concatenation

        # masked parameter
        self.mask_value = torch.zeros(self.vec_len, dtype=torch.float32)  
        self.noise_std = noise_std

        # ----------------------------------------------------------------------
        # Precompute MinMax scaler on entire dataset
        # ----------------------------------------------------------------------
        X_list, y_list = [], []
        T = len(scenes)
        # Slide over time index to collect sequences and targets
        for t in range(self.seq_len, T):
            past = scenes[t - self.seq_len : t]
            target = scenes[t]
            mpos = random.randrange(self.seq_len)
            for u in range(self.U):
                for s in range(self.S):
                    # Build numpy sequence of shape (seq_len, vec_len)
                    seq_np = np.stack([
                        np.concatenate([
                            ps[0]['user']['channel'][u, 0, :, s].real,
                            ps[0]['user']['channel'][u, 0, :, s].imag
                        ])
                        for ps in past
                    ], axis=0).astype(np.float32)

                    # Build numpy target of shape (vec_len,)
                    target_np = np.concatenate([
                        target[0]['user']['channel'][u, 0, :, s].real,
                        target[0]['user']['channel'][u, 0, :, s].imag
                    ]).astype(np.float32)

                    # Skip if all zeros (invalid data)
                    if not np.any(seq_np) or not np.any(target_np):
                        continue

                    # Flatten sequence for fitting scaler
                    X_list.append(seq_np.reshape(-1, self.vec_len))
                    y_list.append(target_np)

        # Stack all data for fitting the MinMax scaler
        X_all = np.vstack(X_list)  # Shape: (num_samples*seq_len, vec_len)
        y_all = np.stack(y_list)   # Shape: (num_samples, vec_len)

        # Fit MinMax scalers for inputs and targets
        self.scaler_x = MinMaxScaler().fit(X_all)
        self.scaler_y = MinMaxScaler().fit(y_all)

    def __iter__(self):
        """
        Yield power-normalized and MinMax-scaled sequences, mask positions, and targets.
        Each item: (seq_tensor, masked_pos_tensor, target_tensor)
        Shapes: seq_tensor (seq_len, vec_len), masked_pos_tensor (1,), target_tensor (vec_len,)
        """
        T = len(self.scenes)
        for t in range(self.seq_len, T):
            past = self.scenes[t - self.seq_len : t]
            target = self.scenes[t]
            for u in range(self.U):
                for s in range(self.S):
                    # Compute power-normalized numpy arrays
                    seq_np = np.stack([
                        self._power_norm(ps[0]['user']['channel'][u, 0, :, s])
                        for ps in past
                    ], axis=0)
                    target_np = self._power_norm(target[0]['user']['channel'][u, 0, :, s])

                    # Skip sequences or targets that are all zero
                    if not np.any(seq_np) or not np.any(target_np):
                        continue

                    # Apply MinMax scaling: reshape, transform, and reshape back
                    N, D = seq_np.shape
                    seq_np = self.scaler_x.transform(seq_np.reshape(-1, D)).reshape(N, D)
                    target_np = self.scaler_y.transform(target_np.reshape(1, -1)).reshape(-1,)

                    # Convert to torch tensors and yield with masked position
                    seq = torch.from_numpy(seq_np)
                    target = torch.from_numpy(target_np)

                    # select mask position
                    mpos = random.randrange(self.seq_len)

                    # 80/10/10 rules
                    if random.random() < 0.15:
                        # select mpos position
                        mpos = random.randrange(self.seq_len)

                        # 80/10/10
                        r = random.random()
                        seq_masked = seq.clone()
                        
                        if r < 0.8:
                            # 80% full masked
                            seq_masked[mpos] = self.mask_value
                        elif r < 0.9:
                            # 10% random noise -> std
                            seq_masked[mpos] = torch.randn(self.vec_len) * self.noise_std
                        
                        yield seq_masked, torch.tensor([mpos], dtype=torch.long), target

    def _power_norm(self, h: np.ndarray) -> np.ndarray:
        """
        Convert complex-valued vector to concatenated real-imag vector and normalize power to 1.
        """
        v = np.concatenate([h.real, h.imag]).astype(np.float32)
        power = np.mean(v * v) + self.eps
        return v / np.sqrt(power)

    def __len__(self):
        """
        Total number of valid (sequence, target) pairs in the dataset.
        """
        return (len(self.scenes) - self.seq_len) * self.U * self.S


## Split Train/Val

In [10]:
# ❷ Train/Validation DataLoader split train : val = 6 : 4
seq_len      = 5
split_ratio  = 0.6
split_idx    = int(len(dataset) * split_ratio)

In [11]:
unmasked_train_ds = UnMaskedChannelSeqDataset(dataset[:split_idx], seq_len=seq_len)
unmasked_val_ds   = UnMaskedChannelSeqDataset(dataset[split_idx:], seq_len=seq_len)

# iterate over train_ds to compute min and max of features/targets

batch_size   = 32
unmasked_train_loader = DataLoader(unmasked_train_ds, batch_size=batch_size, shuffle=False)
unmasked_val_loader   = DataLoader(unmasked_val_ds,   batch_size=batch_size, shuffle=False)
# ─────────────────────────────────────────────


In [12]:
# ❷ Train/Validation DataLoader split train : val = 6 : 4

masked_train_ds = MaskedChannelSeqDataset(dataset[:split_idx], seq_len=seq_len)
masked_val_ds   = MaskedChannelSeqDataset(dataset[split_idx:], seq_len=seq_len)

# iterate over train_ds to compute min and max of features/targets

batch_size   = 32
masked_train_loader = DataLoader(masked_train_ds, batch_size=batch_size, shuffle=False)
masked_val_loader   = DataLoader(masked_val_ds,   batch_size=batch_size, shuffle=False)
# ─────────────────────────────────────────────


## Define Model

LWMWithHead: A wrapper class that uses a pre-trained LWM (Transformer encoder) as the backbone,
             and attaches a new fully-connected (FC) head for downstream tasks
             (regression, classification, etc.).

Changes:
- input_dim: Dimension of the actual input data (e.g., 64)
- patch_length: Patch length expected by the backbone (e.g., 16)
- Replaces the original element_length parameter with these two distinct parameters
- Applies a projection layer (self.input_proj) in forward()


In [13]:
import torch
import torch.nn as nn
from lwm_model import lwm

class LWMWithHead(nn.Module):
    """
    LWMWithHead: A wrapper class that uses a pre-trained LWM (Transformer encoder) as the backbone,
                 and attaches a new fully-connected (FC) head for downstream tasks
                 (regression, classification, etc.).

    Changes:
    - input_dim: Dimension of the actual input data (e.g., 64)
    - patch_length: Patch length expected by the backbone (e.g., 16)
    - Replaces the original element_length parameter with these two distinct parameters
    - Applies a projection layer (self.input_proj) in forward()
    """
    def __init__(
        self,
        input_dim: int,                 # Dimension of the actual input data (e.g., 64)
        patch_length: int,              # Patch length expected by the backbone (e.g., 16)
        d_model: int = 64,              # LWM hidden size
        max_len: int = 129,             # Positional encoding max length
        n_layers: int = 12,             # Number of Transformer encoder layers
        hidden_dim: int = 256,          # FC head hidden dimension
        out_dim: int = 64,              # FC head output dimension
        freeze_backbone: bool = True,   # Whether to freeze the backbone
        checkpoint_path: str | None = "./model_weights.pth",
        device: str = "cuda"
    ):
        super().__init__()

        # apply a projection layer to match backbone's expected patch_length
        self.input_proj = nn.Linear(input_dim, patch_length)

        # initialize backbone
        if checkpoint_path is None:
            # randomly initialized backbone
            self.backbone = lwm(
                element_length=patch_length,
                d_model=d_model,
                max_len=max_len,
                n_layers=n_layers
            ).to(device)
        else:
            # load pre-trained weights
            self.backbone = lwm.from_pretrained(
                ckpt_name=checkpoint_path,
                device=device
            )

        # freeze backbone parameters if required
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # attach a new fully-connected head for downstream tasks
        self.head = nn.Sequential(
            nn.Linear(d_model, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, input_ids: torch.Tensor, masked_pos: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input_ids: Tensor of shape (B, L, input_dim)
            masked_pos: Tensor of shape (B, num_mask)
        Returns:
            out: Tensor of shape (B, out_dim)
        """
        # project inputs to patch_length dimension
        x = self.input_proj(input_ids)

        # backbone forward: returns (logits_lm, enc_output)
        _, enc_output = self.backbone(x, masked_pos)

        # extract CLS token feature (first token)
        feat = enc_output[:, 0, :]

        # pass through FC head to get final output
        out = self.head(feat)
        return out


In [14]:
import torch
import torch.nn as nn

class GRUWithHead(nn.Module):
    """
    GRUWithHead: A wrapper class that uses a GRU backbone and attaches a fully-connected (FC) head
                 for downstream tasks (regression, classification, etc.).
    """
    def __init__(
        self,
        feat_dim: int = 16,           # Dimension of input features (patch_length / element_length)
        d_model: int = 64,            # GRU hidden size
        n_layers: int = 12,           # Number of GRU layers to stack
        bidirectional: bool = True,   # Whether to use a bidirectional GRU
        dropout: float = 0.1,         # Dropout probability between GRU layers
        hidden_dim: int = 256,        # FC head hidden dimension
        out_dim: int = 64,            # FC head output dimension
        freeze_backbone: bool = False # Whether to freeze GRU backbone weights
    ):
        super().__init__()

        # 1) GRU backbone
        self.backbone = nn.GRU(
            input_size   = feat_dim,
            hidden_size  = d_model,
            num_layers   = n_layers,
            batch_first  = True,
            bidirectional= bidirectional,
            dropout      = dropout if n_layers > 1 else 0.0
        )

        # 2) Optionally freeze backbone parameters
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # Build FC head 
        gru_out_dim = d_model * (2 if bidirectional else 1)
        self.head = nn.Sequential(
            nn.Linear(gru_out_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, feat_dim)
        Returns:
            out: Tensor of shape (batch_size, out_dim)
        """
        # 1) Pass through GRU backbone
        out, _ = self.backbone(x)  # out shape: (B, seq_len, num_directions * d_model)

        # 2) Take the last time-step output as sequence representation
        feat = out[:, -1, :]       # shape: (B, gru_out_dim)

        # 3) Pass through FC head to get final output
        return self.head(feat)     # shape: (B, out_dim)


In [15]:
import math
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        # Create positional encoding matrix of shape (1, max_len, d_model)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, d_model)
        Returns:
            Tensor: x plus positional encodings
        """
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

class InputEmbedding(nn.Module):
    def __init__(self, feat_dim: int, d_model: int, max_len: int = 5000):
        super().__init__()
        # Optional linear projection from feat_dim to d_model
        self.proj = nn.Linear(feat_dim, d_model) if feat_dim != d_model else None
        self.pos_enc = PositionalEncoding(d_model, max_len)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch, seq_len, feat_dim)
        Returns:
            Tensor of shape (batch, seq_len, d_model)
        """
        if self.proj is not None:
            x = self.proj(x)
        return self.pos_enc(x)

class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dim_ff: int, dropout: float = 0.1):
        super().__init__()
        # Multi-Head Self-Attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        # Position-wise Feed-Forward Network
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(dim_ff, d_model)
        )
        # Layer Normalization and Dropout for residual connections
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        src_mask: torch.Tensor = None,
        src_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (seq_len, batch, d_model)
            src_mask: Optional Tensor of shape (seq_len, seq_len)
            src_key_padding_mask: Optional Tensor of shape (batch, seq_len)
        Returns:
            Tensor of shape (seq_len, batch, d_model)
        """
        # Self-attention sublayer
        attn_out, _ = self.self_attn(x, x, x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
        x = x + self.dropout1(attn_out)
        x = self.norm1(x)
        # Feed-forward sublayer
        ff_out = self.ff(x)
        x = x + self.dropout2(ff_out)
        x = self.norm2(x)
        return x

class TransformerEncoderCustom(nn.Module):
    def __init__(
        self,
        feat_dim: int,
        d_model: int,
        n_heads: int,
        dim_ff: int,
        n_layers: int,
        dropout: float = 0.1,
        max_len: int = 5000
    ):
        super().__init__()
        # Input embedding: feature projection + positional encoding
        self.input_embedding = InputEmbedding(feat_dim, d_model, max_len)
        # Stack of N encoder layers
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, dim_ff, dropout)
            for _ in range(n_layers)
        ])

    def forward(
        self,
        x: torch.Tensor,
        src_mask: torch.Tensor = None,
        src_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch, seq_len, feat_dim)
        Returns:
            Tensor of shape (seq_len, batch, d_model)
        """
        x = self.input_embedding(x)       # (batch, seq_len, d_model)
        x = x.transpose(0, 1)             # (seq_len, batch, d_model)
        for layer in self.layers:
            x = layer(x, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dim_ff: int, dropout: float = 0.1):
        super().__init__()
        # Masked Self-Attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        # Encoder-Decoder Attention
        self.multihead_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        # Position-wise Feed-Forward Network
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(dim_ff, d_model)
        )
        # Layer Normalizations and Dropouts
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: torch.Tensor = None,
        memory_mask: torch.Tensor = None,
        tgt_key_padding_mask: torch.Tensor = None,
        memory_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            tgt: Tensor of shape (tgt_len, batch, d_model)
            memory: Tensor of shape (src_len, batch, d_model)
        Returns:
            Tensor of shape (tgt_len, batch, d_model)
        """
        # Masked self-attention sublayer
        attn1, _ = self.self_attn(
            tgt, tgt, tgt,
            attn_mask=tgt_mask,
            key_padding_mask=tgt_key_padding_mask
        )
        tgt = tgt + self.dropout1(attn1)
        tgt = self.norm1(tgt)
        # Encoder-decoder attention sublayer
        attn2, _ = self.multihead_attn(
            tgt, memory, memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask
        )
        tgt = tgt + self.dropout2(attn2)
        tgt = self.norm2(tgt)
        # Feed-forward sublayer
        ff_out = self.ff(tgt)
        tgt = tgt + self.dropout3(ff_out)
        tgt = self.norm3(tgt)
        return tgt

class TransformerDecoderCustom(nn.Module):
    def __init__(
        self,
        feat_dim: int,
        d_model: int,
        n_heads: int,
        dim_ff: int,
        n_layers: int,
        dropout: float = 0.1,
        max_len: int = 5000
    ):
        super().__init__()
        # Input embedding for target sequence
        self.input_embedding = InputEmbedding(feat_dim, d_model, max_len)
        # Stack of N decoder layers
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, dim_ff, dropout)
            for _ in range(n_layers)
        ])
        # Final projection back to feature dimension
        self.output_linear = nn.Linear(d_model, feat_dim)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: torch.Tensor = None,
        memory_mask: torch.Tensor = None,
        tgt_key_padding_mask: torch.Tensor = None,
        memory_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            tgt: Tensor of shape (batch, tgt_len, feat_dim)
            memory: Tensor of shape (src_len, batch, d_model)
        Returns:
            Tensor of shape (batch, tgt_len, feat_dim)
        """
        x = self.input_embedding(tgt)       # (batch, tgt_len, d_model)
        x = x.transpose(0, 1)               # (tgt_len, batch, d_model)
        for layer in self.layers:
            x = layer(
                x,
                memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask
            )
        x = x.transpose(0, 1)               # (batch, tgt_len, d_model)
        return self.output_linear(x)        # project back to feat_dim

class TransformerWithHead(nn.Module):
    def __init__(
        self,
        feat_dim: int,
        d_model: int = 256,
        n_heads: int = 8,
        dim_ff: int = 512,
        n_layers: int = 12,
        dropout: float = 0.1,
        hidden_dim: int = 256,
        out_dim: int = 64,
        max_len: int = 5000
    ):
        super().__init__()
        # Encoder-only backbone
        self.encoder = TransformerEncoderCustom(
            feat_dim, d_model, n_heads, dim_ff, n_layers, dropout, max_len
        )
        # Classification head
        self.head = nn.Sequential(
            nn.Linear(d_model, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(
        self,
        x: torch.Tensor,
        src_mask: torch.Tensor = None,
        src_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch, seq_len, feat_dim)
        Returns:
            Tensor of shape (batch, out_dim)
        """
        # Encode input sequence
        enc_output = self.encoder(x, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        # Take the last token's representation
        last_token = enc_output[-1]         # (batch, d_model)
        return self.head(last_token)


In [16]:
class RNNWithHead(nn.Module):
    def __init__(
        self,
        feat_dim: int,
        hidden_size: int      = 128,
        num_layers: int       = 12,
        bidirectional: bool   = True,
        hidden_dim: int       = 256,
        out_dim: int          = 64,
        freeze_backbone: bool = False,
    ):
        super().__init__()
        # RNN backbone (no dropout)
        self.backbone = nn.RNN(
            input_size   = feat_dim,
            hidden_size  = hidden_size,
            num_layers   = num_layers,
            batch_first  = True,
            bidirectional= bidirectional,
            dropout      = 0.1
        )
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # Compute RNN output dimension based on bi-directionality
        rnn_out_dim = hidden_size * (2 if bidirectional else 1)

        # Fully-connected head (dropout layer removed)
        self.head = nn.Sequential(
            nn.Linear(rnn_out_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        # x: (batch_size, seq_len, feat_dim)
        out, _ = self.backbone(x)       # out: (batch_size, seq_len, rnn_out_dim)
        feat   = out[:, -1, :]          # last time step (batch_size, rnn_out_dim)
        return self.head(feat)          # (batch_size, out_dim)


In [17]:
class LSTMWithHead(nn.Module):
    def __init__(
        self,
        feat_dim: int,
        hidden_size: int = 128,
        num_layers: int = 1,
        bidirectional: bool = False,
        dropout: float = 0.0,
        hidden_dim: int = 256,
        out_dim: int = None,
        freeze_backbone: bool = False
    ):
        super().__init__()
        # Determine output dimension (default to input feature size if not specified)
        out_dim = feat_dim if out_dim is None else out_dim

        # LSTM backbone for sequence modeling
        # - input_size: number of features per time step
        # - hidden_size: dimensionality of LSTM hidden state
        # - num_layers: number of stacked LSTM layers
        # - bidirectional: whether to use a bi-directional LSTM
        # - dropout: dropout probability between LSTM layers (if num_layers > 1)
        self.backbone = nn.LSTM(
            input_size=feat_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=dropout if num_layers > 1 else 0.0
        )
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # Compute the LSTM output dimension accounting for bidirectionality
        lstm_out_dim = hidden_size * (2 if bidirectional else 1)

        # Fully-connected head to map the final LSTM state to desired output
        self.head = nn.Sequential(
            nn.Linear(lstm_out_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, feat_dim)
        Returns:
            Tensor of shape (batch_size, out_dim)
        """
        # Pass the input through the LSTM backbone
        # out shape: (batch_size, seq_len, lstm_out_dim)
        out, _ = self.backbone(x)

        # Extract features from the last time step
        # feat shape: (batch_size, lstm_out_dim)
        feat = out[:, -1, :]

        # Compute final output via the head
        return self.head(feat)


## fine-tuning

In [18]:
from torch.optim import Adam

INPUT_DIM = 64 # real data feature dimension
PATCH_LENGTH = 16
D_MODEL = 64
N_LAYERS = 12 
HIDDEN_DIM = 256 # Head Hidden dim
OUT_DIM = 64
DROPOUT = 0.1 # the same LWM dropout
BIDIRECTIONAL = True #  
DEVICE = "cuda"

# model catalog
MODEL_CATALOG = {
    "LWM_freeze_backbone" : LWMWithHead, # freeze backbone
    "LWM_pretrained_Fine_tune" : LWMWithHead, # not freeze backbone
    "LWM_Fine_tune" : LWMWithHead, # not pretrained not backbone
    "gru" : GRUWithHead, # gru Model
    "RNN" : RNNWithHead, # RNN Model
    "LSTM" : LSTMWithHead, # LSTM Model
    "Transformer" : TransformerWithHead # Transformer Model
}

MODEL_PARAMS = {
    "LWM_freeze_backbone" : {
        input_dim = INPUT_DIM,
        patch_length = PATCH_LENGTH,
        d_model = D_MODEL,
        hidden_dim = HIDDEN_DIM,
        out_dim = OUT_DIM,
        freeze_backbone = True,
        checkpoint_path = "./model_weights.pth",
        device = DEVICE,
    }
}


NameError: name 'GRU' is not defined