# Model
will add after accept

In [None]:
import torch
import torch.nn as nn
import math

class MultiFeatureTransformer(nn.Module):
    

# Train

## Data Prepare

In [None]:
import os
import re
import numpy as np
import pandas as pd
import torch
import joblib
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from torch.utils.data import TensorDataset, DataLoader, random_split

def extract_file_key(file_name):
    match = re.search(r'(rest|eis)_(\d+)_cap_([\d.]+)\.pkl', file_name)
    if match:
        file_type = match.group(1)
        num = int(match.group(2))
        cap = float(match.group(3))
        return (num, cap, file_type)
    return None

def train_tensor_build(directory_path, scaler_path):
    scaler_input = StandardScaler()
    scaler_output = StandardScaler()
    rest_len = 14   # change for different case
    eis_len = 107   # change for different case

    input_tensors = []
    output_tensors = []
    input_data_list = []
    output_data_list = []

    for root, dirs, files in os.walk(directory_path):
        file_pairs = {}
        for file in files:
            file_path = os.path.join(root, file)
            file_key = extract_file_key(file)
            if not file_key:
                continue
            num, cap, file_type = file_key
            key = (num, cap)
            if key not in file_pairs:
                file_pairs[key] = {'input': None, 'output': None}
            if file_type == 'rest':
                file_pairs[key]['input'] = file_path
            elif file_type == 'eis':
                file_pairs[key]['output'] = file_path

        for key in sorted(file_pairs.keys(), key=lambda x: (x[0],x[1])):
            pair = file_pairs[key]
            if not pair['input'] or not pair['output']:
                continue

            # input --------------------------------------------------
            df_input = pd.read_pickle(pair['input'])
            df_input.loc[df_input['time'] < 0] = np.nan
            df_input.fillna(0, inplace=True)

            if len(df_input) < rest_len:
                pad_rows = rest_len - len(df_input)
                padding = pd.DataFrame(np.zeros((pad_rows, df_input.shape[1])), 
                                    columns=df_input.columns)
                df_input = pd.concat([df_input, padding], ignore_index=True)
            elif len(df_input) > rest_len:
                df_input = df_input.iloc[:rest_len]

            input_segment = df_input[['Ecell/V', '<I>/mA', 'Temperature/°C', 'time']].values
            input_data_list.append(input_segment)

            # output --------------------------------------------------
            df_output = pd.read_pickle(pair['output'])
            if len(df_output) < eis_len:
                pad_rows = eis_len - len(df_output)
                padding = pd.DataFrame(np.zeros((pad_rows, df_output.shape[1])),
                                    columns=df_output.columns)
                df_output = pd.concat([df_output, padding], ignore_index=True)
            elif len(df_output) > eis_len:
                df_output = df_output.iloc[:eis_len]
            if {'|Z|/Ohm', 'Phase(Z)/deg'}.issubset(df_output.columns):
                output_segment = df_output[['|Z|/Ohm', 'Phase(Z)/deg']].iloc[1:eis_len].values
                output_data_list.append(output_segment)
            else:
                continue

    # normalization --------------------------------------------------
    if input_data_list:
        input_data_combined = np.vstack(input_data_list)
        input_data_normalized = scaler_input.fit_transform(input_data_combined)
        os.makedirs(scaler_path, exist_ok=True)
        joblib.dump(scaler_input, os.path.join(scaler_path, "scaler_input.pth"))

        split_indices = np.cumsum([len(arr) for arr in input_data_list])[:-1]
        input_data_split = np.split(input_data_normalized, split_indices)
        for segment in input_data_split:
            tensor = torch.tensor(segment, dtype=torch.float32).view(-1, 4)  # 4 input features: V I T t
            input_tensors.append(tensor)

    if output_data_list:
        output_data_combined = np.vstack(output_data_list)
        output_data_normalized = scaler_output.fit_transform(output_data_combined)
        joblib.dump(scaler_output, os.path.join(scaler_path, "scaler_output.pth"))

        split_indices = np.cumsum([len(arr) for arr in output_data_list])[:-1]
        output_data_split = np.split(output_data_normalized, split_indices)
        for segment in output_data_split:
            tensor = torch.tensor(segment, dtype=torch.float32).view(-1, 2)  # 2 output features: |Z| θ
            output_tensors.append(tensor)

    if input_tensors and output_tensors:
        input_tensor = torch.stack(input_tensors).permute(1, 0, 2)  # [14, N, 4]
        output_tensor = torch.stack(output_tensors).permute(1, 0, 2)  # [107, N, 2]
        return input_tensor, output_tensor
    else:
        raise ValueError("no valid data for tensor")

def create_dataloaders(input_tensor, output_tensor, batch_size=32, train_ratio=0.8, shuffle=True):
    input_data = input_tensor.permute(1, 0, 2)
    output_data = output_tensor.permute(1, 0, 2)
    dataset = TensorDataset(input_data, output_data)

    total_samples = len(dataset)
    train_size = int(train_ratio * total_samples)
    val_size = total_samples - train_size

    train_dataset, val_dataset = random_split(
        dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False
    )

    return train_loader, val_loader


## Training

In [None]:
import torch
from torch import nn, optim
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def train_model(directory_path, model_path):
    epochs = 1000
    input_dim = 4
    output_dim = 2
    seq_len_in = 14     # change for different case
    seq_len_out = 106   # change for different case
    d_model = 256
    nhead = 4
    num_layers = 6

    model = MultiFeatureTransformer(
        input_dim, output_dim, seq_len_in, seq_len_out, 
        d_model, nhead, num_layers
        ).to(device)
    
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    input_tensor, output_tensor = train_tensor_build(directory_path, model_path)
    train_loader, val_loader = create_dataloaders(input_tensor, output_tensor)

    best_val_loss = float('inf')
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch, (src, tgt) in enumerate(train_loader):
            src = src.to(device)
            tgt = tgt.to(device)
            optimizer.zero_grad()
            output= model(src)
            loss = criterion(output, tgt)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item()
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for src, tgt in val_loader:                
                src = src.to(device)
                tgt = tgt.to(device)
                output = model(src)
                val_loss += criterion(output, tgt).item()
            avg_val_loss = val_loss / len(val_loader)
        
        print(f"Epoch {epoch+1} | Train Loss: {train_loss/len(train_loader):.4f} | Val Loss: {val_loss/len(val_loader):.4f}")

        # save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_path = os.path.join(model_path, "best_model.pth")
            torch.save(model.to('cpu').state_dict(), best_model_path)  # save on CPU
            model.to(device)  # train on GPU
            print(f"✅ Saved best model at epoch {epoch + 1}")
    
    # save final model
    torch.save(model.to('cpu').state_dict(), os.path.join(model_path, "final_model.pth"))


## main_train

In [None]:
directory_path = r"data_120s\train"
model_path = "model"

train_model(directory_path, model_path)

# Test

In [None]:
import os
import torch
import pickle
import re
import pandas as pd
import numpy as np
import joblib
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import warnings
import time

warnings.filterwarnings("ignore", message="enable_nested_tensor is True, but self.use_nested_tensor is False")

def test_result_get_single(model_path, cell_path, input_file):
    # 数据准备
    scaler_path = os.path.join(model_path, "scaler_input.pth")
    scaler = joblib.load(scaler_path)
    rest_len = 14
    input_tensor = None
    file_path = os.path.join(cell_path, input_file)
    df = pd.read_pickle(file_path)
    segment = df[['Ecell/V', '<I>/mA', 'Temperature/°C', 'time']].iloc[:rest_len].values
    input_scaled = scaler.transform(segment)
    input_tensor = torch.tensor(input_scaled, dtype=torch.float32).unsqueeze(0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    input_dim = 4
    output_dim = 2
    seq_len_in = 14     # change for different case
    seq_len_out = 106   # change for different case
    d_model = 256
    nhead = 4
    num_layers = 6

    model = MultiFeatureTransformer(
        input_dim, output_dim, seq_len_in, seq_len_out, 
        d_model, nhead, num_layers
        ).to(device)
    model.load_state_dict(torch.load(os.path.join(model_path, "best_model.pth")))
    model.eval()
    with torch.no_grad():
        src = input_tensor.to(device)
        output = model(src)
    
    scaler_output = joblib.load(os.path.join(model_path, "scaler_output.pth"))
    output = output.cpu().numpy()
    output = scaler_output.inverse_transform(output.reshape(-1, 2)).reshape(output.shape)
    prediction = output
    prediction = output.squeeze(0)
    return prediction

def save_single_eis_results_as_pkl(prediction, result_file, cell_path, output_file):

    eis_path = os.path.join(cell_path, output_file)
    df = pd.read_pickle(eis_path)
    freq_meas = df['freq/Hz'].values
    meas_real = df['|Z|/Ohm'].values
    meas_imag = df['Phase(Z)/deg'].values

    result_df = pd.DataFrame({
        'freq': freq_meas[1:],
        'pred_mag': prediction[:, 0],
        'pred_ph': prediction[:, 1],
        'meas_mag': meas_real[1:],
        'meas_ph': meas_imag[1:]
    })

    directory = os.path.dirname(result_file)
    os.makedirs(directory, exist_ok=True)

    with open(result_file, 'wb') as f:
        pickle.dump(result_df, f)

def extract_number_eis(file_name):
    match = re.search(r'eis_(\d+)_', file_name)
    return int(match.group(1)) if match else -1  # Return -1 if no match found

def extract_number_rest(file_name):
    match = re.search(r'rest_(\d+)_', file_name)
    return int(match.group(1)) if match else -1  # Return -1 if no match found

def extract_info(file_name):
    eis_match = re.search(r'eis_(\d+)_', file_name)
    cap_match = re.search(r'cap_(\d+\.?\d*)\.pkl', file_name)
    
    eis_number = int(eis_match.group(1)) if eis_match else None
    cap_value = float(cap_match.group(1)) if cap_match else None
    
    return eis_number, cap_value


## save predicted EIS

In [None]:
data_path = r"data_120s\test"
model_path = "model"

start_time = time.time()
for cell in os.listdir(data_path):
    cell_path = os.path.join(data_path, cell)

    if os.path.isdir(cell_path):
        print(f"Running test for cell: {cell}")
        cell_path = rf"{data_path}\{cell}"
        
        input_files = [f for f in os.listdir(cell_path) if f.endswith(".pkl") and 'rest_' in f]
        sorted_input_files = sorted(input_files, key=lambda x: extract_number_rest(x))

        output_files = [f for f in os.listdir(cell_path) if f.endswith(".pkl") and 'eis_' in f]
        sorted_output_files = sorted(output_files, key=lambda x: extract_number_eis(x))

        for input_file, output_file in zip(sorted_input_files,sorted_output_files):
            test_preds = test_result_get_single(model_path, cell_path, input_file)

            eis_number, cap_value = extract_info(output_file)
            result_file = rf"result\{cell}\predicted_eis_{eis_number}_cap_{cap_value}.pkl"

            save_single_eis_results_as_pkl(test_preds, result_file, cell_path, output_file)

end_time = time.time()
test_duration = end_time - start_time
print(f"Time: {test_duration:.2f} seconds")