In [None]:
%matplotlib widget

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.metrics import mean_absolute_error, root_mean_squared_error
from pathlib import Path
from tqdm import tqdm
from typing import Dict, Tuple

import optuna

import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger

from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.metrics import RMSE, MAE
from pytorch_forecasting.models import TemporalFusionTransformer



In [None]:
def split_data_into_parts(data: pd.DataFrame, parts: int = 15) -> Dict[str, pd.DataFrame]:
    # Split data into parts
    chunk_size = len(data) // parts
    return {f"{idx+1}": data.iloc[idx * chunk_size:(idx + 1) * chunk_size] for idx in range(parts)}

def load_data(data_dir: str) -> Dict[str, pd.DataFrame]:
    data_path = Path(data_dir)
    all_data = {}

    # Find all parquet files
    parquet_files = list(data_path.glob("**/df*.parquet"))
    print(f"Found {len(parquet_files)} parquet files")

    for file_path in tqdm(parquet_files, desc="Processing cells", unit="cell"):
        # Extract cell number from parent directory name
        file_name = file_path.stem  
        cell_number = file_name.replace('df_', '')  
        cell_name = f'C{cell_number}'  
        tqdm.write(f"Processing {cell_name} ...")
            
        # Load and process data
        data = pd.read_parquet(file_path)
        data['Absolute_Time[yyyy-mm-dd hh:mm:ss]'] = pd.to_datetime(data['Absolute_Time[yyyy-mm-dd hh:mm:ss]'])
        
        # Select relevant columns
        data = data[['Absolute_Time[yyyy-mm-dd hh:mm:ss]', 'Current[A]', 'Voltage[V]', 
                    'Temperature[°C]', 'SOH_ZHU','Q_sum', 'EFC']]
        
        # Resample to hourly
        data.set_index('Absolute_Time[yyyy-mm-dd hh:mm:ss]', inplace=True)
        data_hourly = data.resample('h').mean().reset_index()
        
        # Fill missing values
        data_hourly.interpolate(method='linear', inplace=True)
        data_hourly['SOH_ZHU'] = data_hourly['SOH_ZHU'].fillna(1)
        
        all_data[cell_name] = data_hourly

    return all_data
data_dir = "../01_Datenaufbereitung/Output/Calculated/"
all_data = load_data(data_dir)

In [None]:
def visualize_data(all_data):
    # Combine all data for visualization
    combined_data = pd.concat([data_dict for data_dict in all_data.values()])

    # Visualization
    fig, axs = plt.subplots(6, 1, figsize=(10, 18))

    data_columns = [
        ('Temperature[°C]', 'blue', 'Temperature Data Distribution', 'Temperature (°C)'),
        ('Current[A]', 'orange', 'Current Data Distribution', 'Current (A)'),
        ('Voltage[V]', 'green', 'Voltage Data Distribution', 'Voltage (V)'),
        ('SOH_ZHU', 'purple', 'SOH Data Distribution', 'SOH'),
        ('Q_sum', 'red', 'Q_sum Data Distribution', 'Q_sum'),
        ('EFC', 'cyan', 'EFC Data Distribution', 'EFC')
    ]

    for i, (column, color, title, xlabel) in enumerate(data_columns):
        axs[i].hist(combined_data[column], bins=50, color=color, alpha=0.7)
        axs[i].set_title(title)
        axs[i].set_xlabel(xlabel)
        axs[i].set_ylabel('Frequency')

    # Show plot
    plt.tight_layout()
    plt.show()
    
# visualize_data(all_data)

In [None]:
def inspect_data_ranges(data_dict: Dict[str, pd.DataFrame]) -> None:
    """
    Inspect time ranges and value ranges for each battery in the data dictionary
    """
    for cell_name, cell_data in data_dict.items():
        print(f"\n=== {cell_name} ===")
        
        # Get time range
        time_range = (cell_data['Absolute_Time[yyyy-mm-dd hh:mm:ss]'].min(), cell_data['Absolute_Time[yyyy-mm-dd hh:mm:ss]'].max())
        print(f"Time Range: {time_range[0]} to {time_range[1]}")
        
        # Get value ranges for each column
        for column in ['SOH_ZHU', 'Current[A]', 'Voltage[V]', 'Temperature[°C]', 'Q_sum', 'EFC']:
            values = cell_data[column]
            print(f"\n{column}:")
            print(f"Value Range: {values.min():.4f} to {values.max():.4f}")
            print(f"Number of Data Points: {len(values)}")

# View all data ranges
# print("All Data Ranges:")
# inspect_data_ranges(all_data)

In [None]:
def split_cell_data(all_data: Dict[str, pd.DataFrame], train=13, val=1, test=1, parts=15) -> Tuple[Dict, Dict, Dict]:
    """Splits the dataset into training, validation, and test sets, then further divides train and val into parts."""
    
    cell_names = list(all_data.keys())
    np.random.seed(773)
    np.random.shuffle(cell_names)

    train_cells = cell_names[:train]
    val_cells = cell_names[train:train + val]
    test_cells = cell_names[train + val:train + val + test]

    print(f"Cell split completed:")
    print(f"Training set: {len(train_cells)} cells")
    print(f"Validation set: {len(val_cells)} cells")
    print(f"Test set: {len(test_cells)} cells")

    train_data = {}
    for cell in train_cells:
        split_data = split_data_into_parts(all_data[cell], parts=parts)
        for part_idx, df_part in split_data.items():
            part_name = f"{cell}_{part_idx}"
            train_data[part_name] = df_part  

    val_data = {cell: all_data[cell] for cell in val_cells}
    test_data = {cell: all_data[cell] for cell in test_cells}

    print(f"Final dataset sizes:")
    print(f"Training set: {len(train_data)} parts")
    print(f"Validation set: {len(val_data)} full cells")
    print(f"Test set: {len(test_data)} full cells")

    return train_data, val_data, test_data

train_data, val_data, test_data = split_cell_data(all_data)

# inspect_data_ranges(train_data)


In [None]:
def plot_dataset_soh(data_dict: dict, title: str, figsize=(10, 7)):
    plt.figure(figsize=figsize)
    
    # Plot each cell's SOH
    for cell_name, cell_data in data_dict.items():
        target = cell_data['SOH_ZHU']
        plt.plot(cell_data['Absolute_Time[yyyy-mm-dd hh:mm:ss]'], target, label=cell_name)
    
    plt.title(f'{title} Set SOH Curves')
    plt.xlabel('Time')
    plt.ylabel('SOH')
    plt.grid(True)
    plt.legend(loc='upper right')
    plt.tight_layout()
    plt.show()

# Plot all three datasets
# plot_dataset_soh(train_data, "Training")
# plot_dataset_soh(val_data, "Validation")
# plot_dataset_soh(test_data, "Test")

In [None]:
def scale_data_dicts(
    train_dict: Dict[str, pd.DataFrame],
    val_dict: Dict[str, pd.DataFrame],
    test_dict: Dict[str, pd.DataFrame]
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
    
    # Concatenate all training data
    train_concat = pd.concat(train_dict.values(), axis=0)
    
    # Initialize scalers
    standard_scaler = StandardScaler()  
    minmax_scaler   = MinMaxScaler(feature_range=(0, 1), clip=True) 
    
    # Fit scalers
    standard_scaler.fit(train_concat[['Current[A]']])
    minmax_scaler.fit(train_concat[['Temperature[°C]', 'Voltage[V]', 'Q_sum', 'EFC']])
    
    def transform(df: pd.DataFrame) -> pd.DataFrame:
        df_copy = df.copy()
        df_copy[['Current[A]']] = standard_scaler.transform(df_copy[['Current[A]']])
        df_copy[['Temperature[°C]', 'Voltage[V]', 'Q_sum', 'EFC']] = minmax_scaler.transform(df_copy[['Temperature[°C]', 'Voltage[V]', 'Q_sum', 'EFC']])
        return df_copy
    
    # Scale all datasets
    train_scaled = {name: transform(df) for name, df in train_dict.items()}
    val_scaled   = {name: transform(df) for name, df in val_dict.items()}
    test_scaled  = {name: transform(df) for name, df in test_dict.items()}
    
    return train_scaled, val_scaled, test_scaled

train_scaled, val_scaled, test_scaled = scale_data_dicts(train_data, val_data, test_data)
# inspect_data_ranges(val_scaled)


In [None]:
def prepare_dataset(
    data_dict: Dict[str, pd.DataFrame],
    max_encoder_length: int = 24,   
    max_prediction_length: int = 1  
) -> TimeSeriesDataSet:
    all_list = []
    for part_name, df in data_dict.items():
        df_copy = df.copy()
        df_copy['Absolute_Time[yyyy-mm-dd hh:mm:ss]'] = pd.to_datetime(df_copy['Absolute_Time[yyyy-mm-dd hh:mm:ss]'])
        df_copy.sort_values('Absolute_Time[yyyy-mm-dd hh:mm:ss]', inplace=True)

        df_copy['group_id'] = part_name
        df_copy['time_idx'] = np.arange(len(df_copy))
        all_list.append(df_copy)

    big_df = pd.concat(all_list, ignore_index=True)
    
    dataset = TimeSeriesDataSet(
        big_df,
        time_idx="time_idx",
        group_ids=["group_id"],
        target="SOH_ZHU",
        time_varying_unknown_reals=[],  
        # time_varying_known_reals=["EFC"], 
        time_varying_known_reals=["Current[A]", "Voltage[V]", "Temperature[°C]"],
        # time_varying_known_reals=["Current[A]", "Voltage[V]", "Temperature[°C]", "Q_sum", "EFC"],
        max_encoder_length=max_encoder_length,
        max_prediction_length=max_prediction_length,
        scalers={}
    )
    return dataset

In [None]:
def objective(trial: optuna.Trial):
    hidden_size = trial.suggest_int("hidden_size", 8, 64, step=8)
    lstm_layers = trial.suggest_int("lstm_layers", 1, 4)
    dropout = trial.suggest_float("dropout", 0, 0.5, step=0.1)
    attention_head_size = trial.suggest_int("attention_head_size", 2, 8, step=2)
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True)
    
    max_encoder_length = trial.suggest_int("max_encoder_length", 12, 48, step=12)
    max_prediction_length = trial.suggest_int("max_prediction_length", 1, 12, step=1)
    batch_size = trial.suggest_int("batch_size", 16, 64, step=16)
    
    # Prepare dataset
    train_dataset = prepare_dataset(train_data, max_encoder_length=max_encoder_length, max_prediction_length=max_prediction_length)
    val_dataset = prepare_dataset(val_data, max_encoder_length=max_encoder_length, max_prediction_length=max_prediction_length)

    train_dataloader = train_dataset.to_dataloader(train=True, batch_size=batch_size, num_workers=4, persistent_workers=True)
    val_dataloader = val_dataset.to_dataloader(train=False, batch_size=batch_size, num_workers=4, persistent_workers=True)

    # Create model with sampled hyperparameters
    tft = TemporalFusionTransformer.from_dataset(
        train_dataset,
        learning_rate=learning_rate,
        hidden_size=hidden_size,
        lstm_layers=lstm_layers,
        dropout=dropout,
        attention_head_size=attention_head_size,
        loss=RMSE(),
        logging_metrics=[MAE(), RMSE()],
        optimizer="adam",
        
    )

    trainer = Trainer(
        max_epochs=100,  # Reduce for faster tuning
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
        gradient_clip_val=0.1,
        callbacks=[EarlyStopping(monitor="val_loss", patience=20)],
        logger=TensorBoardLogger("tft_optuna_logs")
    )

    trainer.fit(tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
    val_loss = trainer.callback_metrics["val_loss"].item()
    return val_loss


def tune_hyperparameters(n_trials=100):
    study = optuna.create_study(direction="minimize")
    study.optimize(objective, n_trials=n_trials)
    
    print("Best hyperparameters:", study.best_params)
    return study.best_params

# Run hyperparameter tuning
best_params = tune_hyperparameters(n_trials=100)
print(best_params)

In [None]:
max_encoder_length = best_params['max_encoder_length']
max_prediction_length = best_params['max_prediction_length']
train_dataset = prepare_dataset(train_data, max_encoder_length=max_encoder_length, max_prediction_length=max_prediction_length)
val_dataset = prepare_dataset(val_data, max_encoder_length=max_encoder_length, max_prediction_length=max_prediction_length)

batch_size = 32
train_dataloader = train_dataset.to_dataloader(train=True, batch_size=batch_size, num_workers=4, persistent_workers=True)
val_dataloader = val_dataset.to_dataloader(train=False, batch_size=batch_size, num_workers=4, persistent_workers=True)

# Train final model with best hyperparameters
tft_best = TemporalFusionTransformer.from_dataset(
    train_dataset,
    learning_rate=best_params["learning_rate"],
    hidden_size=best_params["hidden_size"],
    lstm_layers=best_params["lstm_layers"],
    dropout=best_params["dropout"],
    attention_head_size=best_params["attention_head_size"],
    loss=RMSE(),
    logging_metrics=[MAE(), RMSE()],
    optimizer="adam",
)

trainer = Trainer(
    max_epochs=100,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=2,
    gradient_clip_val=0.1,
    callbacks=[EarlyStopping(monitor="val_loss", patience=10)],
    logger=TensorBoardLogger("tft_final_logs")
)

trainer.fit(tft_best, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
trainer.save_checkpoint("TFT_model_best.ckpt")

In [None]:
# model_path = r'E:\00_Thesis\TFT\results\00basic\TFT_model.ckpt'
# tft = TemporalFusionTransformer.load_from_checkpoint(model_path)

test_dataset = prepare_dataset(test_scaled, max_encoder_length=24, max_prediction_length=1)
test_dataloader = test_dataset.to_dataloader(train=False, batch_size=32, num_workers=4, persistent_workers=True, drop_last=False)
test_predictions = tft.predict(test_dataloader, mode="prediction")


y_pred = test_predictions.cpu().numpy().flatten()
y_true = np.array(test_dataset.data["target"]).flatten() 
print(f"y_pred shape: {len(y_pred)}")
print(f"y_true shape: {len(y_true)}")

min_len = min(len(y_pred), len(y_true))
y_pred = y_pred[:min_len]
y_true = y_true[:min_len]

# Evaluate the model
mae = mean_absolute_error(y_true, y_pred)
rmse = root_mean_squared_error(y_true, y_pred)
print(f" MAE = {mae:.3e}, RMSE = {rmse:.3e}")

# visualize the prediction
time_index = np.array(test_dataset.data["time"]).flatten()[:min_len]  

pred_df = pd.DataFrame({"time_idx": time_index, "true_SOH_ZHU": y_true, "predicted_SOH_ZHU": y_pred})
pred_df.sort_values("time_idx", inplace=True)


plt.figure(figsize=(12, 6))
plt.plot(pred_df["time_idx"], pred_df["true_SOH_ZHU"], label="True SOH_ZHU")
plt.plot(pred_df["time_idx"], pred_df["predicted_SOH_ZHU"], label="Predicted SOH_ZHU")
plt.xlabel("Time Index")
plt.ylabel("SOH_ZHU")
plt.title("TFT Model - SOH_ZHU Test Prediction")
plt.text(0.80, 0.85, f"MAE = {mae:.3e}\nRMSE = {rmse:.3e}", 
         transform=plt.gca().transAxes, fontsize=12, 
         verticalalignment='top', bbox=dict(facecolor='white', alpha=0.7))
plt.legend()
plt.grid()
plt.show()