In [1]:
import os
import numpy as np
import pandas as pd
import xarray as xr
import torch 
from tqdm import tqdm
from darts import TimeSeries
from darts.models import TFTModel

  __import__("pkg_resources").declare_namespace(__name__)  # type: ignore
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def check_data_folder(folder):
    return os.path.exists(folder) and os.path.isdir(folder)

def generate_date_range(start_date, end_date):
    """
    Generate a list of dates from start_date to end_date.
    """
    return pd.date_range(start=start_date, end=end_date, freq='D').strftime('%Y%m%d').tolist()

def load_data(file_path):
    """
    Load data from a NetCDF file.
    """
    if os.path.exists(file_path):
        return xr.open_dataset(file_path)
    else:
        raise FileNotFoundError(f'File not found: {file_path}')


# main program
## check data folder
data_folder = f'nc4'
if check_data_folder(data_folder):
    print(f'Data folder found: {data_folder}')
else:
    raise FileNotFoundError(f'Data folder not found: {data_folder}')

## load data
start_date = '2024-01-01'
end_date = '2024-01-14'
date_list = generate_date_range(start_date, end_date)

## get location and shape
path = os.path.join(data_folder, f'M2T1NXFLX.5.12.4%3AMERRA2_400.tavg1_2d_flx_Nx.{date_list[0]}.nc4.dap.nc4')
nc4_data = load_data(path)
lat = nc4_data['lat'].values
lon = nc4_data['lon'].values
shape = nc4_data['TLML'].shape
total_locations = shape[1] * shape[2]

## combine data
# 預先讀取第一個檔案以取得 shape 資訊
sample_path = os.path.join(data_folder, f'M2T1NXFLX.5.12.4%3AMERRA2_400.tavg1_2d_flx_Nx.{date_list[0]}.nc4.dap.nc4')
sample_data = load_data(sample_path)
shape_per_file = sample_data['TLML'].shape   # e.g. (24, 361, 576)
time_per_file = len(sample_data['time'])

# 預先配置 array（假設每天 24 筆資料）
total_samples = len(date_list)
combined = np.empty((total_samples * shape_per_file[0], *shape_per_file[1:]), dtype=np.float32)
time_list = np.empty(total_samples * time_per_file, dtype=sample_data['time'].dtype)

# 批次載入並填入預分配的 array
for i, date in enumerate(tqdm(date_list, desc="Combining")):
    path = os.path.join(data_folder, f'M2T1NXFLX.5.12.4%3AMERRA2_400.tavg1_2d_flx_Nx.{date}.nc4.dap.nc4')
    nc4_data = load_data(path)

    start = i * shape_per_file[0]
    end = (i + 1) * shape_per_file[0]

    combined[start:end] = nc4_data['TLML'].values
    time_list[start:end] = nc4_data['time'].values

print(f'Combined data shape: {combined.shape}')

Data folder found: nc4


Combining: 100%|██████████| 14/14 [00:02<00:00,  6.27it/s]

Combined data shape: (336, 361, 576)





In [None]:
SEED = 123
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# to Rcode y1 cell × time
ntot, nlat, nlon = combined.shape
ncell = nlat * nlon
y1 = combined.reshape(ntot, ncell).T

# to Rcode gg cell × (lon, lat)
lon_grid, lat_grid = np.meshgrid(lon, lat)
gg = np.vstack([lon_grid.ravel(), lat_grid.ravel()]).T

# 抽樣
ncell, ntot = y1.shape
T = 200
m = 1000
pickm = np.random.choice(ncell, size=m, replace=False)
y_sub   = y1[pickm, :]      # (m, ntot)
gg_sub  = gg[pickm, :]      # (m, 2)

# DatetimeIndex
time_index = pd.to_datetime(time_list)

df = pd.DataFrame(
    y_sub.T,
    index=time_index,
    columns=[f"cell_{i}" for i in range(m)]
)

# 經緯度 DataFrame
static_df = pd.DataFrame(
    gg_sub,
    index=df.columns,
    columns=["lon","lat"]
)

# to TimeSeries
series = TimeSeries.from_dataframe(df, static_covariates=static_df)
train  = series[:T]   # T 1-200
true_201 = df.iloc[T].values  # 201期真值，shape=(m,)

In [None]:
from pytorch_lightning.callbacks import TQDMProgressBar, EarlyStopping, ModelCheckpoint, Callback

class EpochPrinter(Callback):
    def __init__(self, every=5, also_last=True):
        super().__init__()
        self.every = every
        self.also_last = also_last
    def on_train_epoch_end(self, trainer, pl_module):
        e = trainer.current_epoch
        total = trainer.max_epochs
        if ((e + 1) % self.every != 0) and not (self.also_last and e + 1 == total):
            return
        m  = trainer.callback_metrics
        tr = float(m.get("train_loss_epoch", m.get("train_loss", float("nan"))))
        va = float(m.get("val_loss", float("nan")))
        print(f"[Epoch {e+1}/{total}] train_loss={tr:.4g} | val_loss={va:.4g}")

progress_callback = TQDMProgressBar(refresh_rate=1)  
epoch_printer     = EpochPrinter(every=5)

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    save_top_k=1,
    mode="min",
    filename="tft-{epoch:02d}-{val_loss:.2f}",
    dirpath="checkpoint"
)

pl_trainer_kwargs = {
    # 裝置
    "accelerator":                     "cpu",        # "cpu" / "gpu"  
    "devices":                         1,          # CPU:1 / GPU:[0]       

    # 訓練長度
    "max_epochs":                      300,          # 最多跑幾個 epoch
    "min_epochs":                      10,           # 最少跑幾個 epoch

    # batch 限制（提升除錯速度）
    "limit_train_batches":             1.0,          # 比例 (0.0~1.0) 或整數
    "limit_val_batches":               1.0,          # 同上
    "val_check_interval":              1.0,          # 每多少 epoch (float) 或 batch (float<1)

    # 日誌與 sanity check
    "log_every_n_steps":               10,           # 每多少 batch 印一次 log
    "num_sanity_val_steps":            2,            # 開始前跑多少驗證 batch

    # 梯度與更新
    "gradient_clip_val":               0.1,          # 梯度裁剪值
    "gradient_clip_algorithm":         "norm",       # "norm" 或 "value"
    "accumulate_grad_batches":         1,            # 幾個 batch 累積一次更新

    # 檢查點、Callback、Logger
    "enable_checkpointing":            True,         # 啟用 ModelCheckpoint
    "callbacks":                       [              # Lightning Callback list
        progress_callback,
        epoch_printer,
        checkpoint_callback
    ],

    # 其他選項
    "reload_dataloaders_every_n_epochs":  0,         # 幾 epoch 後重載 DataLoader
    "fast_dev_run":                      False,      # 同時跑 train/val/test 各一 batch
    "benchmark":                         False,      # cuDNN 找最佳配置 (需和 deterministic 搭配)
    "profiler":                          "simple",   # "simple" / "advanced" / 自訂
    "enable_model_summary":              True        # 啟動時印模型摘要
}

# TFTModel
model = TFTModel(
    input_chunk_length  = 24,  # 模型每次觀察過去  期資料
    output_chunk_length = 1,   # 預測下 1 期
    hidden_size = 64,
    lstm_layers = 1,
    num_attention_heads = 4,
    dropout = 0.1,
    batch_size = 20,
    random_state = SEED,
    add_relative_index = True,
    add_encoders = {
      'cyclic': {'past': ['hour']},
      'position': {'past': ['relative'], 'future': ['relative']}
    },
    pl_trainer_kwargs = pl_trainer_kwargs
)


In [5]:
# train and predict
train = series[:160]     
val = series[160:200]   
model.fit( train, val_series = val)
forecast_201 = model.predict(n=1)      # shape=(1,5000)
pred_201 = forecast_201.values()[0]

# MSPE
mask = ~np.isnan(true_201)
mspe_py  = np.mean((pred_201[mask] - true_201[mask])**2)
rmspe_py = np.sqrt(mspe_py)
print(f"Python TFTModel Next-step MSPE = {mspe_py:.6f} | RMSPE = {rmspe_py:.6f}")

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
c:\Users\user\AppData\Local\Programs\Python\Python312\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:658: Checkpoint directory C:\Users\user\Desktop\TFTModel-use\test\checkpoint exists and is not empty.

   | Name                              | Type                             | Params | Mode 
------------------------------------------------------------------------------------------------
0  | train_metrics                     | MetricCollection                 | 0      | train
1  | val_metrics                       | MetricCollection                 | 0      | train
2  | 

Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]



Epoch 4: 100%|██████████| 7/7 [00:38<00:00,  0.18it/s, train_loss=2.33e+3, val_loss=2.32e+3][Epoch 5/300] train_loss=2328 | val_loss=2324
Epoch 9: 100%|██████████| 7/7 [00:38<00:00,  0.18it/s, train_loss=2.31e+3, val_loss=2.31e+3][Epoch 10/300] train_loss=2313 | val_loss=2309
Epoch 14: 100%|██████████| 7/7 [00:38<00:00,  0.18it/s, train_loss=2.3e+3, val_loss=2.29e+3] [Epoch 15/300] train_loss=2298 | val_loss=2294
Epoch 19: 100%|██████████| 7/7 [00:38<00:00,  0.18it/s, train_loss=2.28e+3, val_loss=2.28e+3][Epoch 20/300] train_loss=2283 | val_loss=2279
Epoch 24: 100%|██████████| 7/7 [00:39<00:00,  0.18it/s, train_loss=2.27e+3, val_loss=2.26e+3][Epoch 25/300] train_loss=2268 | val_loss=2264
Epoch 29: 100%|██████████| 7/7 [00:37<00:00,  0.18it/s, train_loss=2.25e+3, val_loss=2.25e+3][Epoch 30/300] train_loss=2251 | val_loss=2248
Epoch 34: 100%|██████████| 7/7 [00:36<00:00,  0.19it/s, train_loss=2.23e+3, val_loss=2.23e+3][Epoch 35/300] train_loss=2235 | val_loss=2231
Epoch 39: 100%|████████

`Trainer.fit` stopped: `max_epochs=300` reached.


Epoch 299: 100%|██████████| 7/7 [00:39<00:00,  0.18it/s, train_loss=7.590, val_loss=12.40]


FIT Profiler Report

-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                               	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                                	|  -     

Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  1.96it/s]


PREDICT Profiler Report

---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                            	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                             	|  -              	|  50   

Python TFTModel Next-step MSPE = 21.028786 | RMSPE = 4.585715


In [6]:
ctx200 = series[:200]
forecast_201 = model.predict(n=1, series=ctx200)
pred_201 = forecast_201.values()[0]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  2.02it/s]


PREDICT Profiler Report

---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                            	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                             	|  -              	|  50   