## import

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader, random_split
#LWM을 하기위한 라이브러리 가져오기
import DeepMIMOv3
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt
import time


plt . rcParams [ 'figure.figsize' ]  =  [ 12 ,  8 ]  # 기본 플롯 크기 설정

## GPU설정

In [2]:
# GPU 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# DeepMIMOv3 다운

In [3]:
# pip install DeepMIMOv3 umap-learn

## 파라미터 수정

In [4]:
## Load and print the default parameters
# bandwith: 0.05GHz(50MHz 대역폭 사용)
parameters = DeepMIMOv3.default_params()
pprint(parameters)

{'OFDM': {'RX_filter': 0,
          'bandwidth': 0.05,
          'selected_subcarriers': array([0]),
          'subcarriers': 512},
 'OFDM_channels': 1,
 'active_BS': array([1]),
 'bs_antenna': {'FoV': array([360, 180]),
                'radiation_pattern': 'isotropic',
                'rotation': array([0, 0, 0]),
                'shape': array([8, 4]),
                'spacing': 0.5},
 'dataset_folder': './Raytracing_scenarios',
 'dynamic_scenario_scenes': array([1]),
 'enable_BS2BS': 1,
 'enable_doppler': 0,
 'enable_dual_polar': 0,
 'num_paths': 5,
 'scenario': 'O1_60',
 'ue_antenna': {'FoV': array([360, 180]),
                'radiation_pattern': 'isotropic',
                'rotation': array([0, 0, 0]),
                'shape': array([4, 2]),
                'spacing': 0.5},
 'user_rows': array([1]),
 'user_subsampling': 1}


In [11]:
## Change parameters for the setup
# Scenario O1_60 extracted at the dataset_folder
#LWM 동적 시나리오 불러오기
#자신의 LWM 파일 위치 경로 작성
# parameters['dataset_folder'] = r'/content/drive/MyDrive/Colab Notebooks/LWM'
scene = 10 # 장면 수
parameters['dataset_folder'] = r'C:\Users\dlghd\졸업프로젝트\LWM'

# scnario = 02_dyn_3p5 <- 다운받은 파일(동적시나리오)
parameters['scenario'] = 'O2_dyn_3p5'
parameters['dynamic_scenario_scenes'] = np.arange(scene) #scene 0~9

# 각 사용자-기지국 채널에 대해 최대 10개 멀티패스 경로 사용
parameters['num_paths'] = 10

# User rows 1-100
parameters['user_rows'] = np.arange(100)

# Activate only the first basestation
parameters['active_BS'] = np.array([1])

parameters['activate_OFDM'] = 1

parameters['OFDM']['bandwidth'] = 0.05 # 50 MHz
parameters['OFDM']['subcarriers'] = 512 # OFDM with 512 subcarriers
parameters['OFDM']['selected_subcarriers'] = np.arange(0, 64, 1)
#parameters['OFDM']['subcarriers_limit'] = 64 # Keep only first 64 subcarriers

parameters['ue_antenna']['shape'] = np.array([1, 1]) # Single antenna
parameters['bs_antenna']['shape'] = np.array([1, 32]) # ULA of 32 elements
#parameters['bs_antenna']['rotation'] = np.array([0, 30, 90]) # ULA of 32 elements
#parameters['ue_antenna']['rotation'] = np.array([[0, 30], [30, 60], [60, 90]]) # ULA of 32 elements
#parameters['ue_antenna']['radiation_pattern'] = 'isotropic'
#parameters['bs_antenna']['radiation_pattern'] = 'halfwave-dipole'

In [12]:
print(parameters)

{'dataset_folder': 'C:\\Users\\dlghd\\졸업프로젝트\\LWM', 'scenario': 'O2_dyn_3p5', 'dynamic_scenario_scenes': array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 'num_paths': 10, 'active_BS': array([1]), 'user_rows': array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
       68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
       85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]), 'user_subsampling': 1, 'bs_antenna': {'shape': array([ 1, 32]), 'spacing': 0.5, 'rotation': array([0, 0, 0]), 'FoV': array([360, 180]), 'radiation_pattern': 'isotropic'}, 'ue_antenna': {'shape': array([1, 1]), 'spacing': 0.5, 'rotation': array([0, 0, 0]), 'FoV': array([360, 180]), 'radiation_pattern': 'isotropic'}, 'enable_doppler': 0, 'enable_dual_polar'

## dataset 구축

In [13]:
## dataset 구축 (chunked on‑the‑fly generation)
import time, gc
from tqdm import tqdm

# 0~999 씬 인덱스, 한 번에 50개씩 처리
scene_indices = np.arange(scene)
chunk_size   = 5
all_data     = []

# 씬 묶음(chunk)마다 generate_data 호출
for i in tqdm(range(0, len(scene_indices), chunk_size)):
    chunk = scene_indices[i : i+chunk_size].tolist()
    parameters['dynamic_scenario_scenes'] = chunk

    start = time.time()
    data_chunk = DeepMIMOv3.generate_data(parameters)
    print(f"Scenes {chunk[0]}–{chunk[-1]} generation time: {time.time() - start:.2f}s")

    # 바로 all_data에 합치거나, 디스크에 저장해도 OK
    all_data.extend(data_chunk)

    # 메모리 해제
    del data_chunk
    gc.collect()

# 마지막에 하나의 리스트로 합친 데이터셋
dataset = all_data


  0%|                                                                                            | 0/2 [00:00<?, ?it/s]

The following parameters seem unnecessary:
{'activate_OFDM'}

Scene 1/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                   | 0/69040 [00:00<?, ?it/s][A
Reading ray-tracing:   4%|██                                                   | 2688/69040 [00:00<00:02, 26597.08it/s][A
Reading ray-tracing:   8%|████▍                                                | 5798/69040 [00:00<00:02, 28622.78it/s][A
Reading ray-tracing:  13%|██████▋                                              | 8659/69040 [00:00<00:02, 27791.54it/s][A
Reading ray-tracing:  17%|████████▋                                           | 11565/69040 [00:00<00:02, 27920.26it/s][A
Reading ray-tracing:  21%|██████████▊                                         | 14358/69040 [00:00<00:02, 26840.46it/s][A
Reading ray-tracing:  25%|████████████▉                                       | 17114/69040 [00:00<00:01, 26819.21it/s][A
Reading ray-tracing:  29%|██████████████▉                                     | 19800/69040 [00:00<00:01, 26378.16it/s][A
Reading ray-tra

Generating channels:  34%|██████████████████▏                                  | 23742/69040 [00:05<00:10, 4443.33it/s][A
Generating channels:  35%|██████████████████▌                                  | 24236/69040 [00:05<00:09, 4545.98it/s][A
Generating channels:  36%|███████████████████▏                                 | 24938/69040 [00:05<00:08, 5179.50it/s][A
Generating channels:  37%|███████████████████▌                                 | 25475/69040 [00:05<00:09, 4576.90it/s][A
Generating channels:  38%|███████████████████▉                                 | 25957/69040 [00:05<00:09, 4442.65it/s][A
Generating channels:  39%|████████████████████▌                                | 26726/69040 [00:05<00:08, 5256.96it/s][A
Generating channels:  40%|████████████████████▉                                | 27275/69040 [00:05<00:08, 4771.11it/s][A
Generating channels:  40%|█████████████████████▎                               | 27775/69040 [00:06<00:09, 4555.98it/s][A
Generating chann


BS-BS Channels



Reading ray-tracing: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 247.74it/s][A



Scene 2/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                   | 0/69040 [00:00<?, ?it/s][A
Reading ray-tracing:   4%|██▎                                                  | 3069/69040 [00:00<00:02, 29550.10it/s][A
Reading ray-tracing:   9%|████▊                                                | 6270/69040 [00:00<00:02, 30745.29it/s][A
Reading ray-tracing:  14%|███████▏                                             | 9374/69040 [00:00<00:01, 30638.15it/s][A
Reading ray-tracing:  18%|█████████▎                                          | 12439/69040 [00:00<00:01, 30140.15it/s][A
Reading ray-tracing:  22%|███████████▋                                        | 15534/69040 [00:00<00:01, 30110.03it/s][A
Reading ray-tracing:  27%|█████████████▉                                      | 18546/69040 [00:00<00:01, 29633.91it/s][A
Reading ray-tracing:  31%|████████████████▎                                   | 21617/69040 [00:00<00:01, 29828.42it/s][A
Reading ray-tra

Generating channels:  39%|████████████████████▌                                | 26781/69040 [00:05<00:07, 5383.06it/s][A
Generating channels:  40%|████████████████████▉                                | 27344/69040 [00:05<00:08, 4754.55it/s][A
Generating channels:  40%|█████████████████████▍                               | 27847/69040 [00:05<00:09, 4560.18it/s][A
Generating channels:  41%|█████████████████████▉                               | 28627/69040 [00:05<00:07, 5326.28it/s][A
Generating channels:  42%|██████████████████████▍                              | 29185/69040 [00:06<00:08, 4781.30it/s][A
Generating channels:  43%|██████████████████████▊                              | 29689/69040 [00:06<00:08, 4579.34it/s][A
Generating channels:  44%|███████████████████████▍                             | 30460/69040 [00:06<00:07, 5342.57it/s][A
Generating channels:  45%|███████████████████████▊                             | 31020/69040 [00:06<00:07, 4848.80it/s][A
Generating chann


BS-BS Channels



Reading ray-tracing: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 246.06it/s][A



Scene 3/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                   | 0/69040 [00:00<?, ?it/s][A
Reading ray-tracing:   4%|██▍                                                  | 3094/69040 [00:00<00:02, 30903.34it/s][A
Reading ray-tracing:   9%|████▋                                                | 6185/69040 [00:00<00:02, 30874.42it/s][A
Reading ray-tracing:  13%|███████                                              | 9273/69040 [00:00<00:01, 30246.12it/s][A
Reading ray-tracing:  18%|█████████▎                                          | 12300/69040 [00:00<00:01, 29121.27it/s][A
Reading ray-tracing:  22%|███████████▍                                        | 15218/69040 [00:00<00:01, 28607.85it/s][A
Reading ray-tracing:  26%|█████████████▌                                      | 18083/69040 [00:00<00:01, 27990.41it/s][A
Reading ray-tracing:  31%|███████████████▉                                    | 21189/69040 [00:00<00:01, 28926.12it/s][A
Reading ray-tra

Generating channels:  39%|████████████████████▌                                | 26742/69040 [00:05<00:07, 5377.75it/s][A
Generating channels:  40%|████████████████████▉                                | 27301/69040 [00:05<00:08, 4865.25it/s][A
Generating channels:  40%|█████████████████████▎                               | 27810/69040 [00:05<00:08, 4595.00it/s][A
Generating channels:  41%|█████████████████████▉                               | 28554/69040 [00:05<00:07, 5308.27it/s][A
Generating channels:  42%|██████████████████████▎                              | 29108/69040 [00:05<00:08, 4890.04it/s][A
Generating channels:  43%|██████████████████████▋                              | 29618/69040 [00:06<00:08, 4533.36it/s][A
Generating channels:  44%|███████████████████████▏                             | 30248/69040 [00:06<00:07, 4976.08it/s][A
Generating channels:  45%|███████████████████████▋                             | 30795/69040 [00:06<00:07, 5056.89it/s][A
Generating chann


BS-BS Channels



Reading ray-tracing: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 236.15it/s][A



Scene 4/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                   | 0/69040 [00:00<?, ?it/s][A
Reading ray-tracing:   4%|██▎                                                  | 2981/69040 [00:00<00:02, 29371.86it/s][A
Reading ray-tracing:   9%|████▊                                                | 6197/69040 [00:00<00:02, 30831.12it/s][A
Reading ray-tracing:  13%|███████▏                                             | 9307/69040 [00:00<00:01, 30443.10it/s][A
Reading ray-tracing:  18%|█████████▍                                          | 12483/69040 [00:00<00:01, 30533.69it/s][A
Reading ray-tracing:  23%|███████████▊                                        | 15645/69040 [00:00<00:01, 30500.43it/s][A
Reading ray-tracing:  27%|██████████████                                      | 18696/69040 [00:00<00:01, 30077.13it/s][A
Reading ray-tracing:  31%|████████████████▎                                   | 21705/69040 [00:00<00:01, 29877.80it/s][A
Reading ray-tra

Generating channels:  39%|████████████████████▍                                | 26656/69040 [00:05<00:08, 5287.11it/s][A
Generating channels:  39%|████████████████████▉                                | 27208/69040 [00:05<00:08, 4906.13it/s][A
Generating channels:  40%|█████████████████████▎                               | 27718/69040 [00:05<00:09, 4481.45it/s][A
Generating channels:  41%|█████████████████████▊                               | 28370/69040 [00:05<00:08, 4964.04it/s][A
Generating channels:  42%|██████████████████████▏                              | 28888/69040 [00:06<00:08, 5005.64it/s][A
Generating channels:  43%|██████████████████████▌                              | 29404/69040 [00:06<00:08, 4516.59it/s][A
Generating channels:  43%|██████████████████████▉                              | 29887/69040 [00:06<00:08, 4571.16it/s][A
Generating channels:  44%|███████████████████████▍                             | 30580/69040 [00:06<00:07, 5165.33it/s][A
Generating chann


BS-BS Channels



Reading ray-tracing: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 235.71it/s][A



Scene 5/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                   | 0/69040 [00:00<?, ?it/s][A
Reading ray-tracing:   5%|██▍                                                  | 3148/69040 [00:00<00:02, 30257.69it/s][A
Reading ray-tracing:   9%|████▉                                                | 6378/69040 [00:00<00:02, 30732.62it/s][A
Reading ray-tracing:  14%|███████▍                                             | 9636/69040 [00:00<00:01, 31455.91it/s][A
Reading ray-tracing:  19%|█████████▋                                          | 12783/69040 [00:00<00:01, 30739.13it/s][A
Reading ray-tracing:  23%|███████████▉                                        | 15879/69040 [00:00<00:01, 30519.25it/s][A
Reading ray-tracing:  27%|██████████████▎                                     | 18933/69040 [00:00<00:01, 30080.64it/s][A
Reading ray-tracing:  32%|████████████████▌                                   | 21943/69040 [00:00<00:01, 29243.92it/s][A
Reading ray-tra

Generating channels:  38%|████████████████████▎                                | 26444/69040 [00:05<00:08, 4905.74it/s][A
Generating channels:  39%|████████████████████▋                                | 27002/69040 [00:05<00:08, 5061.52it/s][A
Generating channels:  40%|█████████████████████▏                               | 27523/69040 [00:05<00:09, 4532.30it/s][A
Generating channels:  41%|█████████████████████▌                               | 28011/69040 [00:05<00:08, 4597.98it/s][A
Generating channels:  42%|██████████████████████                               | 28683/69040 [00:05<00:07, 5129.74it/s][A
Generating channels:  42%|██████████████████████▍                              | 29211/69040 [00:06<00:08, 4632.17it/s][A
Generating channels:  43%|██████████████████████▊                              | 29692/69040 [00:06<00:08, 4400.10it/s][A
Generating channels:  44%|███████████████████████▍                             | 30454/69040 [00:06<00:07, 5183.14it/s][A
Generating chann


BS-BS Channels



Reading ray-tracing: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s][A

Generating channels: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s][A
 50%|██████████████████████████████████████████                                          | 1/2 [01:19<01:19, 79.10s/it]

Scenes 0–4 generation time: 78.86s
The following parameters seem unnecessary:
{'activate_OFDM'}

Scene 1/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                   | 0/69040 [00:00<?, ?it/s][A
Reading ray-tracing:   4%|██                                                   | 2721/69040 [00:00<00:02, 26305.58it/s][A
Reading ray-tracing:   8%|████▎                                                | 5640/69040 [00:00<00:02, 27736.41it/s][A
Reading ray-tracing:  12%|██████▌                                              | 8533/69040 [00:00<00:02, 28254.93it/s][A
Reading ray-tracing:  17%|████████▋                                           | 11470/69040 [00:00<00:02, 28533.69it/s][A
Reading ray-tracing:  21%|██████████▊                                         | 14325/69040 [00:00<00:02, 27242.87it/s][A
Reading ray-tracing:  25%|████████████▊                                       | 17059/69040 [00:00<00:02, 25292.99it/s][A
Reading ray-tracing:  28%|██████████████▊                                     | 19614/69040 [00:00<00:02, 22878.10it/s][A
Reading ray-tra

Generating channels:  34%|█████████████████▊                                   | 23132/69040 [00:05<00:10, 4545.15it/s][A
Generating channels:  34%|██████████████████                                   | 23599/69040 [00:05<00:11, 4102.31it/s][A
Generating channels:  35%|██████████████████▍                                  | 24025/69040 [00:05<00:11, 4036.03it/s][A
Generating channels:  36%|██████████████████▉                                  | 24725/69040 [00:05<00:09, 4824.38it/s][A
Generating channels:  37%|███████████████████▎                                 | 25225/69040 [00:05<00:09, 4746.61it/s][A
Generating channels:  37%|███████████████████▋                                 | 25712/69040 [00:06<00:09, 4334.33it/s][A
Generating channels:  38%|████████████████████▏                                | 26289/69040 [00:06<00:09, 4666.54it/s][A
Generating channels:  39%|████████████████████▌                                | 26850/69040 [00:06<00:08, 4921.35it/s][A
Generating chann

Generating channels:  83%|███████████████████████████████████████████▊         | 57060/69040 [00:12<00:02, 4628.47it/s][A
Generating channels:  83%|████████████████████████████████████████████▏        | 57530/69040 [00:13<00:02, 4425.97it/s][A
Generating channels:  84%|████████████████████████████████████████████▌        | 57979/69040 [00:13<00:02, 4195.48it/s][A
Generating channels: 100%|█████████████████████████████████████████████████████| 69040/69040 [00:13<00:00, 5175.79it/s][A



BS-BS Channels



Reading ray-tracing: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 229.31it/s][A



Scene 2/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                   | 0/69040 [00:00<?, ?it/s][A
Reading ray-tracing:   5%|██▍                                                  | 3159/69040 [00:00<00:02, 31021.49it/s][A
Reading ray-tracing:   9%|████▉                                                | 6355/69040 [00:00<00:01, 31566.49it/s][A
Reading ray-tracing:  14%|███████▎                                             | 9513/69040 [00:00<00:01, 31294.47it/s][A
Reading ray-tracing:  18%|█████████▌                                          | 12643/69040 [00:00<00:01, 30355.45it/s][A
Reading ray-tracing:  23%|███████████▊                                        | 15683/69040 [00:00<00:01, 29932.23it/s][A
Reading ray-tracing:  27%|██████████████                                      | 18679/69040 [00:00<00:01, 29838.37it/s][A
Reading ray-tracing:  31%|████████████████▎                                   | 21675/69040 [00:00<00:01, 29524.22it/s][A
Reading ray-tra

Generating channels:  38%|███████████████████▉                                 | 25919/69040 [00:05<00:09, 4498.36it/s][A
Generating channels:  39%|████████████████████▍                                | 26649/69040 [00:05<00:08, 5227.60it/s][A
Generating channels:  39%|████████████████████▊                                | 27189/69040 [00:05<00:08, 4900.40it/s][A
Generating channels:  40%|█████████████████████▎                               | 27694/69040 [00:05<00:09, 4445.76it/s][A
Generating channels:  41%|█████████████████████▋                               | 28277/69040 [00:05<00:08, 4761.93it/s][A
Generating channels:  42%|██████████████████████                               | 28817/69040 [00:05<00:08, 4889.25it/s][A
Generating channels:  42%|██████████████████████▌                              | 29319/69040 [00:06<00:09, 4363.91it/s][A
Generating channels:  43%|██████████████████████▊                              | 29773/69040 [00:06<00:09, 4250.68it/s][A
Generating chann


BS-BS Channels



Reading ray-tracing: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 248.98it/s][A



Scene 3/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                   | 0/69040 [00:00<?, ?it/s][A
Reading ray-tracing:   4%|██▎                                                  | 3093/69040 [00:00<00:02, 30147.36it/s][A
Reading ray-tracing:   9%|████▋                                                | 6177/69040 [00:00<00:02, 30504.11it/s][A
Reading ray-tracing:  14%|███████▎                                             | 9450/69040 [00:00<00:01, 31424.19it/s][A
Reading ray-tracing:  18%|█████████▍                                          | 12593/69040 [00:00<00:01, 30914.31it/s][A
Reading ray-tracing:  23%|███████████▊                                        | 15686/69040 [00:00<00:01, 30867.46it/s][A
Reading ray-tracing:  27%|██████████████▏                                     | 18774/69040 [00:00<00:01, 30551.06it/s][A
Reading ray-tracing:  32%|████████████████▍                                   | 21830/69040 [00:00<00:01, 29461.19it/s][A
Reading ray-tra

Generating channels:  37%|███████████████████▋                                 | 25600/69040 [00:05<00:09, 4508.43it/s][A
Generating channels:  38%|████████████████████                                 | 26067/69040 [00:05<00:09, 4490.22it/s][A
Generating channels:  39%|████████████████████▌                                | 26800/69040 [00:05<00:08, 5206.87it/s][A
Generating channels:  40%|████████████████████▉                                | 27335/69040 [00:05<00:08, 4696.38it/s][A
Generating channels:  40%|█████████████████████▎                               | 27822/69040 [00:05<00:09, 4522.70it/s][A
Generating channels:  41%|█████████████████████▉                               | 28512/69040 [00:05<00:07, 5105.51it/s][A
Generating channels:  42%|██████████████████████▎                              | 29038/69040 [00:06<00:08, 4794.85it/s][A
Generating channels:  43%|██████████████████████▋                              | 29531/69040 [00:06<00:09, 4376.10it/s][A
Generating chann

Generating channels: 100%|█████████████████████████████████████████████████████| 69040/69040 [00:12<00:00, 5366.16it/s][A



BS-BS Channels



Reading ray-tracing: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 247.52it/s][A



Scene 4/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                   | 0/69040 [00:00<?, ?it/s][A
Reading ray-tracing:   4%|██▎                                                  | 2984/69040 [00:00<00:02, 29485.37it/s][A
Reading ray-tracing:   9%|████▊                                                | 6283/69040 [00:00<00:02, 31366.71it/s][A
Reading ray-tracing:  14%|███████▎                                             | 9588/69040 [00:00<00:01, 31651.87it/s][A
Reading ray-tracing:  18%|█████████▌                                          | 12753/69040 [00:00<00:01, 30872.20it/s][A
Reading ray-tracing:  23%|███████████▉                                        | 15843/69040 [00:00<00:01, 30485.02it/s][A
Reading ray-tracing:  27%|██████████████▏                                     | 18893/69040 [00:00<00:01, 30136.94it/s][A
Reading ray-tracing:  32%|████████████████▌                                   | 21908/69040 [00:00<00:01, 29573.93it/s][A
Reading ray-tra

Generating channels:  36%|███████████████████▏                                 | 25004/69040 [00:05<00:09, 4768.39it/s][A
Generating channels:  37%|███████████████████▌                                 | 25492/69040 [00:05<00:10, 4276.47it/s][A
Generating channels:  38%|███████████████████▉                                 | 25935/69040 [00:05<00:10, 4258.53it/s][A
Generating channels:  39%|████████████████████▍                                | 26592/69040 [00:05<00:08, 4874.99it/s][A
Generating channels:  39%|████████████████████▊                                | 27095/69040 [00:05<00:08, 4673.82it/s][A
Generating channels:  40%|█████████████████████▏                               | 27574/69040 [00:05<00:09, 4232.34it/s][A
Generating channels:  41%|█████████████████████▌                               | 28108/69040 [00:06<00:09, 4485.35it/s][A
Generating channels:  42%|██████████████████████                               | 28702/69040 [00:06<00:08, 4842.72it/s][A
Generating chann

Generating channels:  83%|███████████████████████████████████████████▊         | 57152/69040 [00:12<00:02, 4727.00it/s][A
Generating channels:  83%|████████████████████████████████████████████▏        | 57631/69040 [00:12<00:02, 4294.61it/s][A
Generating channels: 100%|█████████████████████████████████████████████████████| 69040/69040 [00:13<00:00, 5252.47it/s][A



BS-BS Channels



Reading ray-tracing: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 238.31it/s][A



Scene 5/5

Basestation 1

UE-BS Channels



Reading ray-tracing:   0%|                                                                   | 0/69040 [00:00<?, ?it/s][A
Reading ray-tracing:   4%|██▎                                                  | 2979/69040 [00:00<00:02, 29149.20it/s][A
Reading ray-tracing:   9%|████▋                                                | 6092/69040 [00:00<00:02, 29845.89it/s][A
Reading ray-tracing:  13%|██████▉                                              | 9077/69040 [00:00<00:02, 29766.05it/s][A
Reading ray-tracing:  18%|█████████                                           | 12111/69040 [00:00<00:01, 29708.55it/s][A
Reading ray-tracing:  22%|███████████▎                                        | 15082/69040 [00:00<00:01, 29285.94it/s][A
Reading ray-tracing:  26%|█████████████▌                                      | 18012/69040 [00:00<00:01, 29196.15it/s][A
Reading ray-tracing:  30%|███████████████▊                                    | 21024/69040 [00:00<00:01, 29263.10it/s][A
Reading ray-tra

Generating channels:  33%|█████████████████▋                                   | 23116/69040 [00:05<00:09, 4704.60it/s][A
Generating channels:  34%|██████████████████                                   | 23598/69040 [00:05<00:10, 4281.12it/s][A
Generating channels:  35%|██████████████████▍                                  | 24040/69040 [00:05<00:10, 4164.42it/s][A
Generating channels:  36%|██████████████████▉                                  | 24663/69040 [00:05<00:09, 4712.05it/s][A
Generating channels:  36%|███████████████████▎                                 | 25148/69040 [00:05<00:09, 4636.31it/s][A
Generating channels:  37%|███████████████████▋                                 | 25621/69040 [00:05<00:10, 4241.59it/s][A
Generating channels:  38%|████████████████████                                 | 26075/69040 [00:05<00:09, 4310.97it/s][A
Generating channels:  39%|████████████████████▌                                | 26727/69040 [00:05<00:08, 4904.07it/s][A
Generating chann

Generating channels:  79%|█████████████████████████████████████████▊           | 54472/69040 [00:12<00:03, 4173.05it/s][A
Generating channels:  80%|██████████████████████████████████████████▏          | 55002/69040 [00:12<00:03, 4461.16it/s][A
Generating channels:  80%|██████████████████████████████████████████▌          | 55454/69040 [00:12<00:03, 4381.51it/s][A
Generating channels:  81%|██████████████████████████████████████████▉          | 55896/69040 [00:12<00:03, 4059.23it/s][A
Generating channels:  82%|███████████████████████████████████████████▏         | 56337/69040 [00:12<00:03, 4132.49it/s][A
Generating channels:  82%|███████████████████████████████████████████▋         | 56873/69040 [00:13<00:02, 4453.39it/s][A
Generating channels:  83%|████████████████████████████████████████████         | 57343/69040 [00:13<00:02, 4487.46it/s][A
Generating channels:  84%|████████████████████████████████████████████▎        | 57796/69040 [00:13<00:02, 4108.84it/s][A
Generating chann


BS-BS Channels



Reading ray-tracing: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s][A

Generating channels: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 239.24it/s][A
100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [02:42<00:00, 81.07s/it]


Scenes 5–9 generation time: 82.88s


# 사용자 접근 데이터

In [17]:
user_data = dataset[0][0]['user']
print(user_data.keys())

dict_keys(['paths', 'LoS', 'location', 'distance', 'pathloss', 'channel'])


# 사용자 채널 정보 확인

In [18]:
# subcarries = 나눈 각각의 주파수 채널
# Channel = H <- 채널 벡터
# 채널 형태
# (user, UE antenna, Bs antenna, subcarrier)
channel = dataset[0][0]['user']['channel']
print(channel.shape)  

(69040, 1, 32, 64)


In [19]:
print(dataset[0][0]['user']['channel'][100])

[[[-4.9276509e-06+6.0179661e-07j -4.7681883e-06+1.3829425e-06j
   -4.4857170e-06+2.1285541e-06j ...  4.3711116e-06-2.4074948e-06j
    3.9297033e-06-3.0764131e-06j  3.3868639e-06-3.6660977e-06j]
  [-4.7825752e-06+1.6998159e-06j -4.4491662e-06+2.4436672e-06j
   -4.0009481e-06+3.1246111e-06j ...  3.8229477e-06-3.3767817e-06j
    3.2333687e-06-3.9455144e-06j  2.5602983e-06-4.4125736e-06j]
  [-4.3861578e-06+2.7614337e-06j -3.8878534e-06+3.4282084e-06j
   -3.2891933e-06+4.0066625e-06j ...  3.0543217e-06-4.2173128e-06j
    2.3400164e-06-4.6522073e-06j  1.5652473e-06-4.9671735e-06j]
  ...
  [-5.9705985e-06+1.8381670e-06j -5.5992200e-06+2.7702260e-06j
   -5.0834296e-06+3.6307945e-06j ...  4.8462462e-06-3.9304277e-06j
    4.1544481e-06-4.6554678e-06j  3.3555179e-06-5.2603964e-06j]
  [-5.3850690e-06+3.0986103e-06j -4.8194606e-06+3.9206407e-06j
   -4.1295657e-06+4.6415066e-06j ...  3.8328362e-06-4.8787911e-06j
    3.0023016e-06-5.4293314e-06j  2.0943648e-06-5.8398036e-06j]
  [-4.5357333e-06+4.1878

In [20]:
print(len(dataset[0][0]['user']['channel'][100]))

1


In [21]:
print(channel[10000][0][0])

[ 1.14151417e-05-4.36109121e-06j  1.13104134e-05-4.99916769e-06j
  1.11716581e-05-5.67118514e-06j  1.09873472e-05-6.37486755e-06j
  1.07464039e-05-7.10507175e-06j  1.04388646e-05-7.85394786e-06j
  1.00564839e-05-8.61123954e-06j  9.59326280e-06-9.36472679e-06j
  9.04586159e-06-1.01007672e-05j  8.41387282e-06-1.08049298e-05j
  7.69995040e-06-1.14626673e-05j  6.90977913e-06-1.20600052e-05j
  6.05189052e-06-1.25842089e-05j  5.13732675e-06-1.30243880e-05j
  4.17918318e-06-1.33720132e-05j  3.19203900e-06-1.36213148e-05j
  2.19131471e-06-1.37695433e-05j  1.19258925e-06-1.38170753e-05j
  2.10908027e-07-1.37673669e-05j -7.39877976e-07-1.36267463e-05j
 -1.64770893e-06-1.34040656e-05j -2.50281732e-06-1.31102279e-05j
 -3.29810814e-06-1.27576113e-05j -4.02939668e-06-1.23594264e-05j
 -4.69549514e-06-1.19290335e-05j -5.29813951e-06-1.14792647e-05j
 -5.84176632e-06-1.10217779e-05j -6.33314448e-06-1.05664840e-05j
 -6.78089191e-06-1.01210744e-05j -7.19489299e-06-9.69066787e-06j
 -7.58565420e-06-9.277610

In [22]:
print(channel[1][0][0])

[0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j
 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j
 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j
 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j
 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j
 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j
 0.+0.j 0.+0.j 0.+0.j 0.+0.j]


# 사용자 위치 정보

In [23]:
location = dataset[0][0]['user']['location']
print(location.shape)      # (사용자 수, 3)
print(location[0:4])         # 첫 번째 사용자의 (x, y, z)

(69040, 3)
[[-91.03330231 -15.57629967   1.        ]
 [-90.83329773 -15.57629967   1.        ]
 [-90.63330078 -15.57629967   1.        ]
 [-90.4332962  -15.57629967   1.        ]]


# 경로정보

In [24]:
paths = dataset[0][0]['user']['paths']
#사용자 수
print(len(paths))
# 첫 번째 사용자 경로 정보
print(paths[0])

69040
{'num_paths': 0, 'DoD_phi': [], 'DoD_theta': [], 'DoA_phi': [], 'DoA_theta': [], 'phase': [], 'ToA': [], 'power': [], 'LoS': []}


# 기지국 정보

In [25]:
bs_data = dataset[0][0]['basestation']
print(bs_data.keys())


dict_keys(['paths', 'LoS', 'location', 'distance', 'pathloss', 'channel'])


# Scene 및 사용자 수

In [26]:
for i, scene in enumerate(dataset[0]):
    user_locs = scene['user']['location']
    print(f"Scene {i}: {len(user_locs)} users")

Scene 0: 69040 users


# 채널 수

In [27]:
len(dataset[0][0]['user']['channel'])

69040

In [28]:
print(dataset[0][0]['user']['paths'][0])

{'num_paths': 0, 'DoD_phi': [], 'DoD_theta': [], 'DoA_phi': [], 'DoA_theta': [], 'phase': [], 'ToA': [], 'power': [], 'LoS': []}


In [29]:
scene = dataset[0][0] # scene 0
ue_idx = 0 # 첫 번째 사용자
channel = scene['user']['channel'][ue_idx]
print(channel.shape)

(1, 32, 64)


# channel CIR mat 정보 가져오기

In [30]:
import scipy.io as sio

file_path = r'C:\Users\dlghd\졸업프로젝트\LWM\O2_dyn_3p5\scene_0\O2_dyn_3p5.1.CIR.mat'
mat_data = sio.loadmat(file_path)

# 파일 안의 key 확인
print(mat_data.keys())




dict_keys(['__header__', '__version__', '__globals__', 'CIR_array_full'])


In [31]:
# 일반적으로 CIR key는 'CIR' 또는 'cir' 같은 이름일 가능성 높음
H_cir = mat_data['__header__']  
print(H_cir)

b'MATLAB 5.0 MAT-file, Platform: PCWIN64, Created on: Wed Jun 30 11:33:01 2021'


# Time-Prediction 시작
## Time Series 형태로 변환
### 단일사용자 채널 예측

In [32]:
# print(dataset[0][0]['user']['channel'][150][0][3])

count = 0
for h in dataset[0][0]['user']['channel'][100][0]:
#     h = h.squeeze(0)
    h_real = h.real
    h_imag = h.imag
    if np.sum(np.abs(h_real)) ==0:
        count+=1
    elif np.sum(np.abs(h_imag)) == 0:
        count+=1

print("0이 존재하는 채널 개수",count)

0이 존재하는 채널 개수 0


In [33]:
import numpy as np

# 1) (user, ue_port, bs_ant, subc) → (bs_ant, subc) 로 squeeze
H = dataset[0][0]['user']['channel'][100, 0]   # shape: (32, 64), complex

# 2) BS 안테나 인덱스 3의 서브캐리어 벡터 (64,)
print("Antenna #3 subcarriers:", H[3])

# 3) 전체 서브캐리어(32×64) 중 값이 정확히 0인 요소 개수
zero_elements = np.sum(H == 0)
print("0+0j인 서브캐리어 개수:", zero_elements)

# 4) 서브캐리어 전부가 0인 안테나 포트(행) 개수
zero_ports = np.sum(np.all(H == 0, axis=1))
print("완전 0+0j 안테나 포트 개수:", zero_ports)

# 5) 만약 “값이 하나도 0이 아닌” 서브캐리어 요소 개수를 보고 싶다면
nonzero_elements = np.sum(np.abs(H) > 0)
print("0이 아닌 서브캐리어 개수:", nonzero_elements)


Antenna #3 subcarriers: [-3.7481711e-06+3.7276918e-06j -3.1033380e-06+4.2799297e-06j
 -2.3783671e-06+4.7218546e-06j -1.5919456e-06+5.0420513e-06j
 -7.6434952e-07+5.2322434e-06j  8.3081005e-08+5.2875071e-06j
  9.2849081e-07+5.2063970e-06j  1.7500737e-06+4.9909859e-06j
  2.5266352e-06+4.6468099e-06j  3.2381383e-06+4.1827284e-06j
  3.8662224e-06+3.6106942e-06j  4.3946743e-06+2.9454461e-06j
  4.8098500e-06+2.2041299e-06j  5.1010238e-06+1.4058554e-06j
  5.2606679e-06+5.7120445e-07j  5.2846453e-06-2.7829972e-07j
  5.1723187e-06-1.1207479e-06j  4.9265650e-06-1.9344095e-06j
  4.5537045e-06-2.6982937e-06j  4.0633354e-06-3.3926899e-06j
  3.4680893e-06-3.9996776e-06j  2.7833046e-06-4.5035881e-06j
  2.0266314e-06-4.8914089e-06j  1.2175759e-06-5.1531201e-06j
  3.7699922e-07-5.2819537e-06j -4.7342218e-07-5.2745672e-06j
 -1.3117545e-06-5.1311317e-06j -2.1163728e-06-4.8553270e-06j
 -2.8665186e-06-4.4542485e-06j -3.5428352e-06-3.9382230e-06j
 -4.1278677e-06-3.3205442e-06j -4.6065129e-06-2.6171292e-06j


## 결측치 제거 및 dataload

In [34]:
from torch.utils.data import Dataset, DataLoader

class ChannelDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self,idx):
        scene = self.dataset[idx]
        H = scene[0]['user']['channel'].squeeze(1)
        
        mask_valid_sc = ~np.all(H == 0+0j, axis=(0,1))
        
        H = H[:,:,mask_valid_sc]
        
        H_real = H.real
        H_imag = H.imag
        H_concat  = np.concatenate([H_real, H_imag], 1)
        H_sc_first = H_concat.transpose(0,2,1)
        
        return torch.from_numpy(H_sc_first.astype(np.float32))   
            
            

In [35]:
# ─────────────────────────────────────────────
# ❶ IterableDataset: 모든 유저·서브캐리어를 스트리밍
import torch
from torch.utils.data import IterableDataset, DataLoader
import numpy as np

class ChannelSeqDataset(IterableDataset):
    """
    • seq_len 개의 과거 채널 벡터(real 64 + imag 64 → 128) → 다음 시점 벡터 예측
    • 벡터는 평균전력 1 로 power‑normalize 후 반환
    """
    def __init__(self, scenes, seq_len: int = 5, eps: float = 1e-9):
        super().__init__()
        self.scenes   = scenes
        self.seq_len  = seq_len
        self.eps      = eps                        # 0 division 방지용
        ch0           = scenes[0][0]['user']['channel']
        self.U        = ch0.shape[0]               # 사용자 수
        self.A        = ch0.shape[2]               # 안테나 32
        self.S        = ch0.shape[3]               # 서브캐리어 64
        self.vec_len  = 2 * self.A                 # 64

    def _vec(self, scene, u: int, sc: int) -> torch.Tensor:
        """(32,) complex → (64,) float32  +  power norm"""
        h = scene[0]['user']['channel'][u, 0, :, sc]          # (32,)
        v = np.concatenate([h.real, h.imag]).astype(np.float32)
        p = np.mean(v * v) + self.eps                         # 평균 전력
        v /= np.sqrt(p)                                       # 정규화
        return torch.from_numpy(v)                            # (64,)

    def __iter__(self):
        T = len(self.scenes)
        for t in range(self.seq_len, T):                      # 타깃 시점
            past_scenes = self.scenes[t - self.seq_len : t]
            tgt_scene   = self.scenes[t]
            for u in range(self.U):
                for s in range(self.S):
                    seq = torch.stack([self._vec(ps, u, s) for ps in past_scenes])
                    if not torch.any(seq):                    # 전부 0 이면 skip
                        continue
                    target     = self._vec(tgt_scene, u, s)
                    masked_pos = torch.tensor([self.seq_len - 2], dtype=torch.long)
                    yield seq, masked_pos, target             # shapes: (5,64) / (1,) / (64,)
# ─────────────────────────────────────────────
# ❷ 학습·검증 DataLoader
seq_len      = 5
split_ratio  = 0.8
split_idx    = int(len(dataset) * split_ratio)

train_ds = ChannelSeqDataset(dataset[:split_idx], seq_len=seq_len)
val_ds   = ChannelSeqDataset(dataset[split_idx:], seq_len=seq_len)

batch_size   = 32
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)
# ─────────────────────────────────────────────


# 아래 코드 구조
┌──────────────────────────────────────────────────────────────┐
│ input_ids  (B, seq_len, element_length)  ─┐                 │
│ masked_pos (B, num_mask)                  ├─>  LWM backbone │
│                                           │    (12-층 트랜스포머)  
└────────────────────────────────────────────┘         │
            logits_lm  (B, num_mask, element_length)  │   enc_output (B, seq_len, d_model)
                                                      ▼
                        ┌─[풀링]───────────────┐      ←── feat (B, d_model)
                        │ 첫 토큰(0번) 선택    │
                        │   or 평균/최대 풀링 │
                        └──────────────────────┘
                                      ▼
                       FC 헤드  (d_model → hidden_dim → out_dim)
                                      ▼
                                out (B, out_dim)

# 시각적비유

[패치 프로젝터]──▶[Transformer ×12]──▶[LayerNorm]──┐
                                                  ├─▶ 64-차 벡터 (CLS 또는 풀링) ─▶ MLP ─▶ out                                                
[Positional Embedding]─────────────────────────────┘


In [36]:
"""
LWMWithHead: 사전학습된 LWM(Transformer encoder)을 ‘백본(backbone)’으로 사용하고,
그 뒤에 새로운 완전연결(FC) 헤드(head)를 붙여 다운스트림 작업(회귀·분류 등)에
사용할 수 있도록 만든 래퍼(wrapper) 클래스입니다.
"""

import torch
import torch.nn as nn
from lwm_model import lwm   # 기존 LWM 모델 클래스 (import 경로는 프로젝트 구조에 맞게 조정)

class LWMWithHead(nn.Module):
    """
    Args
    ----
    element_length : int
        LWM 입력 패치의 길이. 예) 64*64 = 4096 (H_real + H_imag)
    d_model        : int
        Transformer 모델 차원(=LWM hidden size).
    max_len        : int
        포지셔널 임베딩 최대 길이(시퀀스 길이).
    n_layers       : int
        Transformer 인코더 층 수.
    hidden_dim     : int
        새 FC 헤드의 중간 차원.
    out_dim        : int
        최종 출력 차원. 1 → 회귀/이진분류, k → k-클래스 분류.
    freeze_backbone: bool
        True면 백본을 동결(freeze)하여 헤드만 학습.
    ckpt_path      : str | None
        사전학습 가중치(.pth) 경로. None이면 랜덤 초기화.
    device         : str
        'cuda' / 'cpu' 등 모델을 올릴 장치.
    """

    def __init__(
        self,
        element_length: int,
        d_model: int = 64,
        max_len: int = 129,
        n_layers: int = 12,
        hidden_dim: int = 256,
        out_dim: int = 64, 
        freeze_backbone: bool = False,
        ckpt_path: str | None = None,
        device: str = "cuda",
    ):
        super().__init__()

        # ────────────────────────────
        # 1) 백본(backbone) 초기화
        # ────────────────────────────
        if ckpt_path is None:
            # 가중치 없이 새로 생성
            self.backbone = lwm(
                element_length=element_length,
                d_model=d_model,
                max_len=max_len,
                n_layers=n_layers
            ).to(device)
        else:
            # 사전학습 가중치 로드
            self.backbone = lwm.from_pretrained(
                ckpt_name=ckpt_path,
                device=device
            )

        # 백본 동결(선택)
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # ────────────────────────────
        # 2) 헤드(head) 정의
        # ────────────────────────────
        self.head = nn.Sequential(
            nn.Linear(d_model, 64),  # 첫 FC
            nn.ReLU(),                       # 활성화
            nn.Linear(64, out_dim)   # 최종 FC
        )

    # ────────────────────────────
    # forward
    # ────────────────────────────
    def forward(self, input_ids, masked_pos):
        """
        Parameters
        ----------
        input_ids : Tensor  (B, seq_len, element_length)
            LWM 입력 시퀀스 (패치/토큰 단위 실수·복소수 채널값 등).
        masked_pos : Tensor  (B, num_mask)
            LWM의 마스크드 채널 모델링용 인덱스 (백본 규격 유지용).

        Returns
        -------
        out : Tensor  (B, out_dim)
            헤드에서 계산된 다운스트림 작업용 로짓/예측값.
        """

        # 기존 LWM forward:
        #   logits_lm : (B, num_mask, element_length)  ← 사용 안 함
        #   enc_output: (B, seq_len, d_model)
        _, enc_output = self.backbone(input_ids, masked_pos)

        # 특징 추출(feat)
        # ① 첫 토큰 벡터 사용 (CLS 개념) ─────────────
        feat = enc_output[:, 0, :]           # (B, d_model)

        # ② 평균 풀링 예시 (필요 시 주석 해제) ─────
        # feat = enc_output.mean(dim=1)       # (B, d_model)

        # ③ Max 풀링 예시 (필요 시 주석 해제) ─────
        # feat, _ = enc_output.max(dim=1)     # (B, d_model)

        # 헤드 통과 → 최종 출력
        out = self.head(feat)                # (B, out_dim)
        return out


In [37]:
print(train_ds.U)

69040


## input_size
input_size = (scene - seq_len) * U * S -> (10-5)+69040*64 = 22092800
batch_size = 32
배치 수 = input_size / batch_size = 690400배치

In [38]:
import torch
import torch.nn as nn
from torch.optim import Adam

# 디바이스 설정(GPU/CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# 모델 초기화
seq_len = 5
element_length=train_ds.A*2



model = LWMWithHead(
    element_length=element_length,  # 예: 64
    d_model=64,
    max_len=seq_len,                # 예: 5
    n_layers=12,
    hidden_dim=256,
    out_dim=element_length,         # 예: 64
    freeze_backbone=False,  
    ckpt_path=None,  
    device=device
).to(device)

# 손실함수
criterion = nn.MSELoss()

# 옵티마이저 설정
optimizer = Adam(model.parameters(), lr=1e-4)

Using device: cuda


In [39]:
# import time
# import sys

# #  ❶ quick sanity checks
# #    If train_loader is empty, this will print 0 or raise
# try:
#     n_batches = len(train_loader)
# except TypeError:
#     # IterableDataset → no __len__
#     n_batches = sum(1 for _ in train_loader)
# print(f"→ training batches: {n_batches}")

# # ❷ actual training loop with per-epoch print
# start = time.time()
# num_epochs = 10

# for epoch in range(num_epochs):
#     model.train()
#     running_loss = 0.0

#     # loop over batches
#     for i, (input_ids, masked_pos, target) in enumerate(train_loader, 1):
#         input_ids   = input_ids.to(device)
#         masked_pos  = masked_pos.to(device)
#         target      = target.to(device)

#         optimizer.zero_grad()
#         pred = model(input_ids, masked_pos).squeeze(-1)
#         loss = criterion(pred, target)
#         loss.backward()
#         optimizer.step()

#         running_loss += loss.item()

#     # ❸ now we know at least one epoch happened
#     avg_loss = running_loss / i if i>0 else float('nan')
#     print(f"Epoch {epoch+1}/{num_epochs},  avg loss: {avg_loss:.6f}", flush=True)

# end = time.time()
# print(f"Total training time: {end - start:.2f}s")


In [None]:
from tqdm import tqdm
import time

num_epochs = 10
start = time.time()
for epoch in range(1, num_epochs+1):
    model.train()
    running_loss = 0.0

    # tqdm 으로 배치 단위 진행바 감싸기
    for batch_idx, (input_ids, masked_pos, target) in enumerate(
            tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}"), 1):

        input_ids, masked_pos, target = (
            input_ids.to(device),
            masked_pos.to(device),
            target.to(device),
        )
        optimizer.zero_grad()
        pred = model(input_ids, masked_pos).squeeze(-1)
        loss = criterion(pred, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        # (선택) 100 배치마다 로그
        if batch_idx % 100 == 0:
            tqdm.write(f"  [Batch {batch_idx}]  loss: {running_loss/batch_idx:.6f}")

    avg_loss = running_loss / batch_idx
    print(f"→ Epoch {epoch}/{num_epochs} done, avg_loss: {avg_loss:.6f}")

print(f"Total training time: {time.time() - start:.2f}s")


Epoch 1/10: 102it [00:06, 20.57it/s]

  [Batch 100]  loss: 0.006478


Epoch 1/10: 203it [00:11, 19.76it/s]

  [Batch 200]  loss: 0.011726


Epoch 1/10: 296it [00:16, 17.30it/s]

In [None]:
print(input_ids.shape)
print(masked_pos.shape)
print(target.shape)
print(pred.shape)

In [None]:
# 모델 평가 방법
import torch.nn.functional as F

def rmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    RMSE = Root MSE
    RMSE = {1/n*sum((y^-y)**2)}**1/2
    """
    return torch.sqrt(F.mse_loss(pred, target, reduction="mean"))   # √MSE

def nmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Normalized MSE  =  E[‖ŷ − y‖²] / E[‖y‖²]
    returns: 스칼라 (배치 평균)
    """
    # (B, …) → (B,)  : 각 샘플별 제곱합
    mse_per_sample   = ((pred - target)**2).view(pred.size(0), -1).sum(dim=1)
    power_per_sample = (target**2).view(target.size(0), -1).sum(dim=1)
    return (mse_per_sample / power_per_sample).mean()


In [None]:
# 모델 평가 함수
import torch
import torch.nn.functional as F

# ─────────────────────────────────────────
# 1. 배치 단위 RMSE, NMSE 함수
# ─────────────────────────────────────────
def rmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Root-Mean-Squared Error
    returns: 스칼라 (배치 평균)
    """
    return torch.sqrt(F.mse_loss(pred, target, reduction="mean"))   # √MSE

def nmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Normalized MSE  =  E[‖ŷ − y‖²] / E[‖y‖²]
      · 채널 예측 분야에서 흔히 쓰는 지표
    returns: 스칼라 (배치 평균)
    """
    # (B, …) → (B,)  : 각 샘플별 제곱합
    mse_per_sample   = ((pred - target)**2).view(pred.size(0), -1).sum(dim=1)
    power_per_sample = (target**2).view(target.size(0), -1).sum(dim=1)
    return (mse_per_sample / power_per_sample).mean()


# ─────────────────────────────────────────
# 2. 검증 루프 예시
# ─────────────────────────────────────────
def evaluate(model, loader, device="cuda"):
    model.eval()
    total_rmse, total_nmse = 0.0, 0.0

    with torch.no_grad():
        for input_ids, masked_pos, target in loader:
            input_ids  = input_ids.to(device)
            masked_pos = masked_pos.to(device)
            target     = target.to(device)

            pred = model(input_ids, masked_pos)

            total_rmse += rmse(pred, target).item() * input_ids.size(0)
            total_nmse += nmse(pred, target).item() * input_ids.size(0)

    N = len(loader.dataset)
    return {
        "RMSE": total_rmse / N,
        "NMSE": total_nmse / N
    }


In [None]:
num_epochs = 30
patience_counter = 0

for epoch in range(1, num_epochs + 1):
    # ──────────── ① 학습 ────────────
    model.train()
    for input_ids, masked_pos, target in train_loader:
        input_ids  = input_ids.to(device)
        masked_pos = masked_pos.to(device)
        target     = target.to(device)

        pred = model(input_ids, masked_pos)
        loss = criterion(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # ──────────── ② 검증 ────────────
    metrics = evaluate(model, val_loader, device)
    rmse_val, nmse_val = metrics["RMSE"], metrics["NMSE"]

    # ──────────── ③ 로그 출력 ────────────
    print(f"[Epoch {epoch:02d}]  "
          f"val_RMSE = {rmse_val:.4f}   "
          f"val_NMSE = {nmse_val:.4e}   "
          f"val_NMSE(dB) = {10*torch.log10(torch.tensor(nmse_val)):.2f} dB")
