In [None]:
import sys
sys.path.append(".")
from lib.sdwpf import SDWPFDataset
from lib.windff import WindFFModelManager, WindFFModelConfig

import torch
import logging
logging.basicConfig(level=logging.INFO)

import easydict as edict
args = {
    'input_win_sz': 6 * 24,
    'output_win_sz': 6 * 2,
    'hidden_win_sz': (6 * 24 + 6 * 2) // 2,
    'hidden_dim': 16,
    'adj_weight_threshold': 0.8,
    'epochs': 50,
    'early_stop': True,
    'patience': 5,
    'batch_size': 32,
    'lr': 1e-3
}
args = edict.EasyDict(args)

dataset = SDWPFDataset()

model_manager = WindFFModelManager(WindFFModelConfig(
    feat_dim=dataset.get_feat_dim(),
    target_dim=dataset.get_target_dim(),
    hidden_dim=args.hidden_dim,
    input_win_sz=args.input_win_sz,
    output_win_sz=args.output_win_sz,
    hidden_win_sz=args.hidden_win_sz,
    dtype=torch.float64
))

model_manager.train(
    [dataset[0]],
    model_manager.TrainConfig(
        adj_weight_threshold=args.adj_weight_threshold,

        epochs=args.epochs,
        early_stop=args.early_stop,
        patience=args.patience,
        lr=args.lr,
        batch_sz=args.batch_size,
        val_ratio=0.2
    )
)

loss_1 = model_manager.evaluate(dataset[1])
logging.info(f"Validation loss 1: {loss_1}")

# loss_2 = model_manager.evaluate(dataset[2])
# logging.info(f"Validation loss 2: {loss_2}")