In [9]:
import pandas as pd
import numpy as np
import yaml
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings

warnings.filterwarnings("ignore")
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (12, 6)

print("=" * 80)
print("RAW DATA VALIDATION & ANALYSIS")
print("=" * 80)

RAW DATA VALIDATION & ANALYSIS


In [10]:
BASE_DIR = Path("..")
config_dir = BASE_DIR / "configs"

with open(config_dir / "optimise_config.yml", "r") as f:
    config = yaml.safe_load(f)

with open(config_dir / "paths.yml", "r") as f:
    paths = yaml.safe_load(f)

data_base = BASE_DIR / paths["data"]["base_dir"]
tau_csv = data_base / paths["data"]["input"]["tau_results_csv"]
parquet_dir = data_base / paths["data"]["input"]["parquet_dir"]
parquet_pattern = paths["data"]["input"]["parquet_pattern"]
images_dir = data_base / paths["data"]["input"]["images_dir"]
image_pattern = paths["data"]["input"]["image_pattern"]

TOTAL_SAMPLES = config["data"]["total_samples"]
PARAMS_PER_SAMPLE = config["data"]["params_per_sample"]
INPUT_FEATURES = config["data"]["input_features"]

print(f"Configuration loaded:")
print(f"  Total samples: {TOTAL_SAMPLES}")
print(f"  Params per sample: {PARAMS_PER_SAMPLE}")
print(f"  Expected total rows: {TOTAL_SAMPLES * PARAMS_PER_SAMPLE}")
print(f"  Input features: {len(INPUT_FEATURES)}")

Configuration loaded:
  Total samples: 100
  Params per sample: 60
  Expected total rows: 6000
  Input features: 10


In [11]:
print("Loading tau_results.csv...")
tau_df = pd.read_csv(tau_csv)

print(f"\nTau results loaded:")
print(f"  Rows: {len(tau_df)}")
print(f"  Columns: {list(tau_df.columns)}")

print("\nFirst 5 rows:")
tau_df.head()

Loading tau_results.csv...

Tau results loaded:
  Rows: 100
  Columns: ['id', 'filename', 'porosity_measured', 'tau_factor', 'D_eff', 'error']

First 5 rows:


Unnamed: 0,id,filename,porosity_measured,tau_factor,D_eff,error
0,0,sample_0000.tif,0.519418,1.558766,0.333224,
1,1,sample_0001.tif,0.576339,2.387288,0.24142,
2,2,sample_0002.tif,0.525229,2.27283,0.23109,
3,3,sample_0003.tif,0.415018,2.483961,0.167079,
4,4,sample_0004.tif,0.438148,3.700079,0.118416,


In [12]:
print("=" * 80)
print("TAU RESULTS STATISTICS")
print("=" * 80)

tau_stats = tau_df[["D_eff", "porosity_measured", "tau_factor"]].describe()
print(tau_stats)

print("\nMissing values in tau_results:")
print(tau_df.isnull().sum())

TAU RESULTS STATISTICS
            D_eff  porosity_measured  tau_factor
count  100.000000         100.000000  100.000000
mean     0.210705           0.487339   27.388088
std      0.184230           0.180468  114.703536
min      0.000154           0.152606    1.179202
25%      0.041231           0.333235    1.801661
50%      0.164922           0.494411    3.085141
75%      0.348979           0.651169    7.657228
max      0.633595           0.815443  989.615723

Missing values in tau_results:
id                     0
filename               0
porosity_measured      0
tau_factor             0
D_eff                  0
error                100
dtype: int64


In [13]:
print("=" * 80)
print("DISCOVERING PARQUET FILES")
print("=" * 80)

parquet_files = []
for sample_id in range(TOTAL_SAMPLES):
    pq_file = parquet_dir / parquet_pattern.format(sample_id)
    if pq_file.exists():
        parquet_files.append(sample_id)

print(f"Found {len(parquet_files)}/{TOTAL_SAMPLES} parquet files")
print(f"Expected total rows: {len(parquet_files) * PARAMS_PER_SAMPLE}")

DISCOVERING PARQUET FILES
Found 100/100 parquet files
Expected total rows: 6000


In [14]:
print("\n" + "=" * 80)
print("LOADING ALL PARQUET FILES (This may take 30-60 seconds...)")
print("=" * 80)

all_data = []
load_errors = []

for sample_id in tqdm(parquet_files, desc="Loading parquet files"):
    try:
        pq_file = parquet_dir / parquet_pattern.format(sample_id)
        df = pd.read_parquet(pq_file)
        df["sample_id"] = sample_id
        all_data.append(df)
    except Exception as e:
        load_errors.append({"sample_id": sample_id, "error": str(e)})

if load_errors:
    print(f"\nFailed to load {len(load_errors)} parquet files")
    print(pd.DataFrame(load_errors))
else:
    print("\nAll parquet files loaded successfully")

df_raw = pd.concat(all_data, ignore_index=True)

print(f"\nCombined raw data:")
print(f"  Total rows: {len(df_raw):,}")
print(f"  Expected: {len(parquet_files) * PARAMS_PER_SAMPLE:,}")
print(f'  Unique samples: {df_raw["sample_id"].nunique()}')
print(f'  Unique params: {df_raw["param_id"].nunique()}')
print(f"  Columns: {len(df_raw.columns)}")
print(f"  Memory usage: {df_raw.memory_usage(deep=True).sum() / 1024**2:.1f} MB")


LOADING ALL PARQUET FILES (This may take 30-60 seconds...)


Loading parquet files: 100%|██████████| 100/100 [00:00<00:00, 181.69it/s]


All parquet files loaded successfully

Combined raw data:
  Total rows: 6,000
  Expected: 6,000
  Unique samples: 100
  Unique params: 60
  Columns: 27
  Memory usage: 6.2 MB





In [15]:
print("=" * 80)
print("RAW DATA SCHEMA")
print("=" * 80)

for i, col in enumerate(df_raw.columns, 1):
    print(f"{i:2d}. {col:40s} ({df_raw[col].dtype})")

EXPECTED_FIELDS = INPUT_FEATURES + [
    "param_id",
    "bruggeman_derived",
    "nominal_capacity_Ah",
    "eol_cycle_measured",
    "capacity_trend_ah",
    "final_RUL",
]

print("\nEXPECTED FIELDS VALIDATION")
for field in EXPECTED_FIELDS:
    print(("✓" if field in df_raw.columns else "✗"), field)

RAW DATA SCHEMA
 1. sample_id                                (int64)
 2. param_id                                 (int64)
 3. filename                                 (object)
 4. bruggeman_derived                        (float64)
 5. input_param_id                           (int64)
 6. input_SEI kinetic rate constant [m.s-1]  (float64)
 7. input_Electrolyte diffusivity [m2.s-1]   (float64)
 8. input_Initial concentration in electrolyte [mol.m-3] (int64)
 9. input_Separator porosity                 (float64)
10. input_Separator Bruggeman coefficient (electrolyte) (float64)
11. input_Separator Bruggeman coefficient    (float64)
12. input_Positive particle radius [m]       (float64)
13. input_Negative particle radius [m]       (float64)
14. input_Positive electrode thickness [m]   (float64)
15. input_Negative electrode thickness [m]   (float64)
16. status                                   (object)
17. runtime_s                                (float64)
18. error_message                   

In [16]:
def calculate_comprehensive_stats(df, features):
    stats = []
    for feat in features:
        if feat in df.columns:
            col = df[feat]
            stats.append(
                {
                    "feature": feat,
                    "count": col.count(),
                    "mean": (
                        col.mean() if pd.api.types.is_numeric_dtype(col) else np.nan
                    ),
                    "std": col.std() if pd.api.types.is_numeric_dtype(col) else np.nan,
                    "min": col.min() if pd.api.types.is_numeric_dtype(col) else np.nan,
                    "25%": (
                        col.quantile(0.25)
                        if pd.api.types.is_numeric_dtype(col)
                        else np.nan
                    ),
                    "50%": (
                        col.quantile(0.5)
                        if pd.api.types.is_numeric_dtype(col)
                        else np.nan
                    ),
                    "75%": (
                        col.quantile(0.75)
                        if pd.api.types.is_numeric_dtype(col)
                        else np.nan
                    ),
                    "max": col.max() if pd.api.types.is_numeric_dtype(col) else np.nan,
                    "missing": col.isnull().sum(),
                    "missing_pct": col.isnull().mean() * 100,
                }
            )
        else:
            stats.append(
                {
                    "feature": feat,
                    "count": 0,
                    "mean": np.nan,
                    "std": np.nan,
                    "min": np.nan,
                    "25%": np.nan,
                    "50%": np.nan,
                    "75%": np.nan,
                    "max": np.nan,
                    "missing": len(df),
                    "missing_pct": 100.0,
                }
            )
    return pd.DataFrame(stats).set_index("feature")


input_stats = calculate_comprehensive_stats(df_raw, INPUT_FEATURES)
print(input_stats.to_string())

                                                      count          mean           std           min           25%           50%           75%           max  missing  missing_pct
feature                                                                                                                                                                            
input_SEI kinetic rate constant [m.s-1]                6000  1.449557e-13  2.271902e-13  1.080000e-15  5.907500e-15  3.090000e-14  1.735000e-13  9.230000e-13        0          0.0
input_Electrolyte diffusivity [m2.s-1]                 6000  3.001000e-10  5.789803e-11  2.000000e-10  2.515000e-10  3.010000e-10  3.497500e-10  3.980000e-10        0          0.0
input_Initial concentration in electrolyte [mol.m-3]   6000  1.000000e+03  0.000000e+00  1.000000e+03  1.000000e+03  1.000000e+03  1.000000e+03  1.000000e+03        0          0.0
input_Separator porosity                               6000  5.870974e-01  2.892336e-02  5.380924e-0

In [17]:
performance_fields = ["nominal_capacity_Ah", "eol_cycle_measured", "capacity_trend_ah", "final_RUL"]

for field in performance_fields:
    if field in df_raw.columns:
        missing = df_raw[field].isnull().sum()
        print(f"{field}: missing {missing}/{len(df_raw)}")

nominal_capacity_Ah: missing 734/6000
eol_cycle_measured: missing 2670/6000
capacity_trend_ah: missing 734/6000
final_RUL: missing 785/6000


In [18]:
def calculate_retention_raw(row):
    trend = row.get("capacity_trend_ah")
    if trend is None or (isinstance(trend, (list, np.ndarray)) and len(trend) == 0):
        return np.nan
    trend = np.array(trend)
    if trend[0] > 0:
        return (trend[-1] / trend[0]) * 100
    return np.nan


df_raw["capacity_retention_percent_raw"] = df_raw.apply(calculate_retention_raw, axis=1)

In [19]:
over_100 = df_raw[df_raw["capacity_retention_percent_raw"] > 100]
print(f"Retention > 100% rows: {len(over_100)}")
print(over_100[["sample_id", "param_id", "capacity_retention_percent_raw"]].head(20))

Retention > 100% rows: 81
      sample_id  param_id  capacity_retention_percent_raw
422           7         2                      135.304328
433           7        13                      130.269119
437           7        17                      113.460593
452           7        32                      120.163797
459           7        39                      130.719000
461           7        41                      123.482978
462           7        42                      149.947334
466           7        46                      121.961749
471           7        51                      126.605721
2821         47         1                      644.915027
2822         47         2                      202.137891
2825         47         5                      103.457117
2843         47        23                      812.293156
2845         47        25                      876.455893
2848         47        28                      108.799274
2857         47        37                     

In [20]:
# ============================================================================
# IDENTIFY MISSING EOL SAMPLES AND VERIFY IN SOURCE FILES
# ============================================================================

print("=" * 80)
print("IDENTIFYING MISSING EOL_CYCLE_MEASURED SAMPLES")
print("=" * 80)
print()

# Filter rows where eol_cycle_measured is missing
missing_eol = df_raw[df_raw["eol_cycle_measured"].isnull()][
    ["sample_id", "param_id", "eol_cycle_measured"]
].copy()

print(f"Total missing eol_cycle_measured: {len(missing_eol):,} rows")
print(f"Percentage: {len(missing_eol)/len(df_raw)*100:.2f}%")
print()

# Show first 10 rows
print("First 10 missing eol_cycle_measured samples:")
print(missing_eol.head(10).to_string(index=False))

# Save to CSV
output_file = BASE_DIR / "data" / "missing_eol_cycle_samples.csv"
missing_eol.to_csv(output_file, index=False)
print(f"\n✅ Saved to: {output_file}")
print(f"   Total rows: {len(missing_eol):,}")

# ============================================================================
# VERIFY IN SOURCE PARQUET FILES
# ============================================================================

print("\n" + "=" * 80)
print("VERIFYING IN SOURCE PARQUET FILES")
print("=" * 80)
print()

# Sample a few missing samples to verify
sample_to_verify = missing_eol.head(5)

print("Verifying first 5 missing samples in source files...\n")

for idx, row in sample_to_verify.iterrows():
    sample_id = int(row["sample_id"])
    param_id = int(row["param_id"])

    # Load source parquet file
    pq_file = parquet_dir / parquet_pattern.format(sample_id)

    if pq_file.exists():
        df_source = pd.read_parquet(pq_file)

        # Find the specific param
        param_row = df_source[df_source["param_id"] == param_id]

        if len(param_row) > 0:
            eol_value = param_row["eol_cycle_measured"].values[0]

            print(f"Sample {sample_id}, Param {param_id}:")
            print(f"  Source file: {pq_file.name}")
            print(f"  eol_cycle_measured in source: {eol_value}")
            print(f"  Is it NaN/None in source? {pd.isna(eol_value)}")

            # Also check other fields for context
            if "nominal_capacity_Ah" in param_row.columns:
                nom_cap = param_row["nominal_capacity_Ah"].values[0]
                print(f"  nominal_capacity_Ah: {nom_cap}")

            if "capacity_trend_ah" in param_row.columns:
                cap_trend = param_row["capacity_trend_ah"].values[0]
                if isinstance(cap_trend, (list, np.ndarray)):
                    print(f"  capacity_trend_ah length: {len(cap_trend)}")
                else:
                    print(f"  capacity_trend_ah: {cap_trend}")

            print(f"  ✅ CONFIRMED: Missing in source file!\n")
        else:
            print(
                f"Sample {sample_id}, Param {param_id}: ❌ Not found in source file\n"
            )
    else:
        print(
            f"Sample {sample_id}, Param {param_id}: ❌ Source file does not exist: {pq_file}\n"
        )

# ============================================================================
# STATISTICS BY SAMPLE
# ============================================================================

print("=" * 80)
print("MISSING EOL STATISTICS BY SAMPLE")
print("=" * 80)
print()

# Count missing eol per sample
missing_by_sample = (
    missing_eol.groupby("sample_id").size().reset_index(name="missing_count")
)
missing_by_sample = missing_by_sample.sort_values("missing_count", ascending=False)

print(
    f'Samples with missing eol_cycle_measured: {len(missing_by_sample)} out of {df_raw["sample_id"].nunique()}'
)
print(f"\nTop 10 samples with most missing eol:")
print(missing_by_sample.head(10).to_string(index=False))

# Show distribution
print(f"\nMissing eol distribution per sample:")
print(missing_by_sample["missing_count"].describe())

# Check if any samples have ALL params missing eol
samples_all_missing = missing_by_sample[
    missing_by_sample["missing_count"] == PARAMS_PER_SAMPLE
]
if len(samples_all_missing) > 0:
    print(
        f"\n⚠️  Samples with ALL {PARAMS_PER_SAMPLE} params missing eol: {len(samples_all_missing)}"
    )
    print("Sample IDs:", samples_all_missing["sample_id"].tolist())
else:
    print(f"\n✅ No samples have ALL params missing eol")

# ============================================================================
# CROSS-CHECK: Do samples with missing eol have capacity data?
# ============================================================================

print("\n" + "=" * 80)
print("CROSS-CHECK: CAPACITY DATA FOR MISSING EOL SAMPLES")
print("=" * 80)
print()

# Check if missing eol samples have capacity_trend_ah
missing_eol_full = df_raw[df_raw["eol_cycle_measured"].isnull()].copy()

if "capacity_trend_ah" in missing_eol_full.columns:
    # Check how many have empty capacity_trend_ah
    missing_eol_full["has_cap_trend"] = missing_eol_full["capacity_trend_ah"].apply(
        lambda x: (
            False
            if (x is None or (isinstance(x, (list, np.ndarray)) and len(x) == 0))
            else True
        )
    )

    print(
        f'Missing eol samples with capacity_trend_ah data: {missing_eol_full["has_cap_trend"].sum():,}'
    )
    print(
        f'Missing eol samples WITHOUT capacity_trend_ah: {(~missing_eol_full["has_cap_trend"]).sum():,}'
    )

    # Show a few examples
    print("\nSample rows (missing eol, but HAS capacity data):")
    has_both_issues = missing_eol_full[missing_eol_full["has_cap_trend"] == True].head(
        3
    )
    if len(has_both_issues) > 0:
        for idx, row in has_both_issues.iterrows():
            cap_trend = row["capacity_trend_ah"]
            print(
                f'  Sample {row["sample_id"]}, Param {row["param_id"]}: capacity_trend length = {len(cap_trend)}'
            )

print("\n" + "=" * 80)
print("VERIFICATION COMPLETE")
print("=" * 80)

IDENTIFYING MISSING EOL_CYCLE_MEASURED SAMPLES

Total missing eol_cycle_measured: 2,670 rows
Percentage: 44.50%

First 10 missing eol_cycle_measured samples:
 sample_id  param_id  eol_cycle_measured
         0         7                 NaN
         0         8                 NaN
         0         9                 NaN
         0        10                 NaN
         0        13                 NaN
         0        14                 NaN
         0        15                 NaN
         0        21                 NaN
         0        25                 NaN
         0        34                 NaN

✅ Saved to: ../data/missing_eol_cycle_samples.csv
   Total rows: 2,670

VERIFYING IN SOURCE PARQUET FILES

Verifying first 5 missing samples in source files...

Sample 0, Param 7:
  Source file: results_rank_0.parquet
  eol_cycle_measured in source: nan
  Is it NaN/None in source? True
  nominal_capacity_Ah: 5.0
  capacity_trend_ah length: 26
  ✅ CONFIRMED: Missing in source file!

Sampl