# DR5 Prediction

## Dataset

In [None]:
from utils import make_train_val_datasets, BatteryRDRFullDataset, atteryRDRFullDataset_SF, format_time
import time
import numpy as np

car_type = "P"
folder_dir = rf"data\{car_type}\train"

start_data = time.time()
train_dataset, val_dataset, global_stats = make_train_val_datasets(
    folder_dir,
    val_ratio=0.2,
    window_size=30,
    step_size=10,
    normalize=True,
    dataset_class=BatteryRDRFullDataset,    # or BatteryRDRFullDataset_SF for selected features
    num_workers=4
)

data_time = time.time() - start_data
print(f"time (data): {format_time(data_time)}")

np.savez(
    f"global_stats.npz",
    lumped_mean=global_stats['lumped_mean'],
    lumped_std=global_stats['lumped_std'],
    point_mean=global_stats['point_mean'],
    point_std=global_stats['point_std']
)

print('train_dataset length:', len(train_dataset))
print('val_dataset length:', len(val_dataset))
print('Label shape:', train_dataset[0]['label_future'].shape)

lumped_size = train_dataset[0]['lumped'].shape[1]
point_size = train_dataset[0]['point'].shape[0]
print('lumped_size: ', lumped_size)
print('point_size: ', point_size)


üìÅ Train files: 7, Val files: 2
‚úÖ Loaded 71 samples from 7 files (parallel 4 workers)
‚úÖ Computed global mean/std for lumped & point
‚úÖ Loaded 53 samples from 2 files (parallel 4 workers)
time (data): 00:00:05.52
train_dataset length: 71
val_dataset length: 53
Label shape: ()
lumped_size:  17
point_size:  196


## Model Training

In [None]:
from utils import MidFusionMMR_Transformer, train_multimodal_regressor
import torch
import os

save_dir = 'model_MFT256'
os.makedirs(save_dir, exist_ok=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
start_time = time.time()
model = MidFusionMMR_Transformer(
    lumped_embedding_size=lumped_size,
    point_embedding_size=point_size,
    hidden_size=256,
)
start_infer = time.time()

train_losses, val_losses = train_multimodal_regressor(
    model, save_dir,
    train_dataset, val_dataset,
    batch_size=64, epochs=100, lr=1e-3, 
    device='cuda', patience=20, min_lr=1e-5,
    grad_clip=1.0, warmup_epochs=5
)
infer_time = time.time() - start_time
print(f"time (training): {format_time(infer_time)}")


## Test

In [None]:
import os
import time
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from utils import make_test_datasets_car_folder, BatteryRDRFullDataset, BatteryRDRFullDataset_SF, test_seg

save_dir = "test_seg_Fusion_SM"
car_type = "P"
file_path = rf"data\{car_type}\P\test"

stats = np.load(f'global_stats_SF.npz')
global_stats = {
    'lumped_mean': stats['lumped_mean'],
    'lumped_std': stats['lumped_std'],
    'point_mean': stats['point_mean'],
    'point_std': stats['point_std']
}

start_data = time.time()
test_datasets = make_test_datasets_car_folder(
    folder_dir, global_stats, car_type,
    window_size=30, step_size=10, normalize=True, 
    dataset_class=BatteryRDRFullDataset,  # or BatteryRDRFullDataset_SF for selected features 
    num_workers=4
)
test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)
data_time = time.time() - start_data
print(f"time (data): {data_time:.4f}s")


model_name = 'MFT256'
model = MidFusionMMR_Transformer(
    lumped_embedding_size=15,
    point_embedding_size=26,
    hidden_size=256
)
start_data = time.time()
test_seg(model_name, model, test_loader, save_dir)
end_time = time.time() - start_data
print(f"time (infer): {end_time:.4f}s")


## Tansfer Learning

### Dataset

In [None]:
from utils import make_train_val_datasets, BatteryRDRFullDataset_SF, format_time
import time
import numpy as np
import os
import torch

car_type = "Q"      # or "R"
folder_dir = rf"data\{car_type}\train"

start_data = time.time()
train_dataset, val_dataset, global_stats = make_train_val_datasets(
    folder_dir,
    car_type,
    val_ratio=0.2,
    window_size=30,
    step_size=10,
    normalize=True,
    dataset_class=BatteryRDRFullDataset_SF,
    num_workers=4
)

data_time = time.time() - start_data
print(f"time (data): {format_time(data_time)}")

np.savez(
    f"global_stats_SF_{car_type}.npz",
    lumped_mean=global_stats['lumped_mean'],
    lumped_std=global_stats['lumped_std'],
    point_mean=global_stats['point_mean'],
    point_std=global_stats['point_std']
)

print('train_dataset length:', len(train_dataset))
print('val_dataset length:', len(val_dataset))
print('Label shape:', train_dataset[0]['label_future'].shape)

lumped_size = train_dataset[0]['lumped'].shape[1]
point_size = train_dataset[0]['point'].shape[0]
print('lumped_size: ', lumped_size)
print('point_size: ', point_size)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

üìÅ Train files: 4192, Val files: 1049
‚úÖ Loaded 130329 samples from 4192 files (parallel 4 workers)
‚úÖ Computed global mean/std for lumped & point
‚úÖ Loaded 33395 samples from 1049 files (parallel 4 workers)
Êï∞ÊçÆËé∑ÂèñÊó∂Èó¥Ôºö00:01:46.20
Êï∞ÊçÆËé∑ÂèñÊó∂Èó¥Ôºö106.1976s
train_dataset length: 130329
val_dataset length: 33395
Label shape: ()
lumped_size:  15
point_size:  26


### MFT256

In [None]:
# ================ Module groups for MFT model =================
TEMPORAL_ENCODER = [
    "rnn",
    "norm_rnn",
]

POINT_ENCODER = [
    "point_enc",
]

CROSS_ATTENTION = [
    "cross_attn",
    "norm_attn", 
]

MID_FUSION = [
    "fuse_gate",
    "norm_fused",
]

TRANSFORMER = [
    "transformer",
]

HEAD = [
    "output",
]
# ================= Fine-tuning strategies for MFT model =================
FINE_TUNE_STRATEGIES_MFT = {
    # Strategy 1: Linear Probing
    "LP": [TEMPORAL_ENCODER + POINT_ENCODER + CROSS_ATTENTION + MID_FUSION + TRANSFORMER],
    # Strategy 2: Fusion-level fine-tuning
    "CM": [TEMPORAL_ENCODER + POINT_ENCODER + TRANSFORMER],
    # Strategy 3: Non-encoder fine-tuning
    "NE": [TEMPORAL_ENCODER + POINT_ENCODER],
    # Strategy 4: Full fine-tuning
    "FF": [[]],
    # Strategy 5: Progressive fine-tuning
    "PG": [
        # Stage 1: heads only
        TEMPORAL_ENCODER + POINT_ENCODER + CROSS_ATTENTION + MID_FUSION + TRANSFORMER,
        # Stage 2: + mid-fusion & cross-attention
        TEMPORAL_ENCODER + POINT_ENCODER + TRANSFORMER,
        # Stage 3: + transformer
        TEMPORAL_ENCODER + POINT_ENCODER,
        # Stage 4: full
        [],
    ],
}

In [None]:
from utils import fine_tune,  MidFusionMMR_Transformer
import os

device = 'cuda' if torch.cuda.is_available() else 'cpu'

for strategy_name, freeze_plan in FINE_TUNE_STRATEGIES_MFT.items():

    print(f"\n==============================")
    print(f"üöÄ Fine-tuning Strategy: {strategy_name}")
    print(f"==============================")

    save_root = rf"model_TL_MFT256_{car_type}/{strategy_name}"
    os.makedirs(save_root, exist_ok=True)

    model = MidFusionMMR_Transformer(
        lumped_embedding_size=lumped_size,
        point_embedding_size=point_size,
        hidden_size=256,
    ).to(device)

    pretrained_path = "model_MFT256/final_model.pth"

    # ---------- Single Strategy ----------
    if len(freeze_plan) == 1:

        print(f"---- Strategy {strategy_name} | Single-stage ----")
        print(f"Frozen modules: {freeze_plan[0]}")

        fine_tune(
            pretrained_path=pretrained_path,
            save_dir=save_root,
            train_dataset=train_dataset,
            val_dataset=val_dataset,
            model=model,
            freeze_layers_list=freeze_plan[0],
            device=device,
        )

    # ---------- PG ----------
    else:
        prev_ckpt = pretrained_path

        for stage_idx, freeze_list in enumerate(freeze_plan):

            print(f"\n---- Strategy {strategy_name} | Stage {stage_idx + 1} ----")
            print(f"Frozen modules: {freeze_list}")

            stage_dir = os.path.join(save_root, f"S{stage_idx + 1}")
            os.makedirs(stage_dir, exist_ok=True)

            fine_tune(
                pretrained_path=prev_ckpt,
                save_dir=stage_dir,
                train_dataset=train_dataset,
                val_dataset=val_dataset,
                model=model,
                freeze_layers_list=freeze_list,
                device=device,
            )

            prev_ckpt = os.path.join(stage_dir, "best_model.pth")
