In [4]:
import numpy as np
import pandas as pd
import os
from IPython.display import display, Latex

def extract_validity_stats(df):
    """
    Extract validity statistics from the dataset.

    Parameters:
        df (pd.DataFrame): DataFrame containing validity columns.

    Returns:
        pd.DataFrame: DataFrame with validity statistics.
    """
    validity_columns = ['validity.formula', 'validity.spacegroup', 'validity.bond_length', 'validity.site_multiplicity']
    
    # Ensure the columns are treated as boolean
    df[validity_columns] = df[validity_columns].astype(bool)
    
    # Calculate the percentage of valid entries for each metric
    validity_stats = df[validity_columns].mean() * 100
    
    return validity_stats

def dataset_validity_table(eval_paths):
    """
    Extract validity statistics from multiple datasets and prepare a comparison table.

    Parameters:
        eval_paths (list): List of paths to evaluation files.

    Returns:
        pd.DataFrame: DataFrame containing the validity statistics comparison table.
    """
    results = []

    # Process each dataset
    for path in eval_paths:
        dataset_name = os.path.basename(path).split('.')[0]
        
        # Load dataset and calculate validity stats
        df = pd.read_parquet(path)
        validity_stats = extract_validity_stats(df)
        
        # Add the dataset name to the stats
        validity_stats['Dataset'] = dataset_name
        results.append(validity_stats)

    # Combine all results into a DataFrame
    results_df = pd.DataFrame(results)
    results_df = results_df[['Dataset', 'validity.formula', 'validity.spacegroup', 'validity.bond_length', 'validity.site_multiplicity']]
    
    return results_df

def display_latex_table(results_df):
    """
    Display the validity statistics as a LaTeX table.

    Parameters:
        results_df (pd.DataFrame): DataFrame containing the validity statistics.
    """
    # Create LaTeX-like string to display
    table_str = r"""
\begin{aligned}
& \text{Table: Validity Metrics Comparison between Datasets (Percentage of Valid Entries)}\\
&\begin{array}{|c|c|c|c|c|}
\hline
\text{Dataset} & \text{Formula Validity (%)} & \text{Spacegroup Validity (%)} & \text{Bond Length Validity (%)} & \text{Site Multiplicity Validity (%)} \\
\hline
"""

    # Add rows from DataFrame to the LaTeX string
    for _, row in results_df.iterrows():
        table_str += f"\\text{{{row['Dataset']}}} & {row['validity.formula']:.2f} & {row['validity.spacegroup']:.2f} & {row['validity.bond_length']:.2f} & {row['validity.site_multiplicity']:.2f} \\\\\n"
        table_str += r"\hline" + "\n"

    # Close the table
    table_str += r"\end{array}\end{aligned}"

    # Display the LaTeX-like table
    display(Latex(table_str))

# Example usage (update with your eval file paths)
eval_paths = [
    '../cross-contamination/deciferdataset_experiment/boundarymasking/boundary_masking_100.eval',
    '../cross-contamination/deciferdataset_experiment/no_boundarymasking/no_boundary_masking_100.eval',
    '../nomodel/crystal_1k/nmax8_lmax5/crystal_train_1000.eval'
]

# Generate the validity comparison table
results_df = dataset_validity_table(eval_paths)

# Display the LaTeX table
display_latex_table(results_df)


<IPython.core.display.Latex object>