In [3]:
import gzip
import pickle
import pandas as pd
import numpy as np
import warnings

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, FancyBboxPatch
from decifer_refactored.utility import (
    extract_numeric_property,
    get_unit_cell_volume,
    extract_volume,
)

# Global font settings
plt.rcParams.update({
    "font.size": 12,
    "axes.titlesize": 14,
    "axes.labelsize": 14,
    "xtick.labelsize": 10,
    "ytick.labelsize": 14,
    "legend.fontsize": 12,
})

# Suppress warnings
warnings.filterwarnings("ignore")

def calculate_statistics(df, columns, is_boolean=False):
    """
    Calculate mean, 95% confidence interval, and standard deviation for specified columns.
    """
    if is_boolean:
        df[columns] = df[columns].astype(bool)
        mean_vals = df[columns].mean() * 100
    else:
        mean_vals = df[columns].mean()

    n = len(df)
    z = 1.96  # z-score for 95% confidence
    proportions = df[columns].mean() if is_boolean else None
    ci_vals = (
        z * np.sqrt((proportions * (1 - proportions)) / n) * 100
        if is_boolean
        else None
    )
    std_vals = df[columns].std() * 100 if is_boolean else df[columns].std()

    return pd.DataFrame({
        'mean (%)' if is_boolean else 'mean': mean_vals,
        '95% CI (%)': ci_vals,
        'std (%)' if is_boolean else 'std': std_vals,
    })

def extract_validity_stats(df):
    validity_columns = [
        'formula_validity',
        'spacegroup_validity',
        'bond_length_validity',
        'site_multiplicity_validity',
        'validity',
    ]
    return calculate_statistics(df, validity_columns, is_boolean=True)

def extract_metrics_stats(df):
    metrics_columns = ['rwp', 'wd']
    return calculate_statistics(df, metrics_columns)

def process_file(file_path):
    """Process a single file and extract relevant metrics."""
    with gzip.open(file_path, 'rb') as f:
        df = pd.DataFrame(pickle.load(f))

    metrics_list = []
    for _, entry in df.iterrows():
        try:
            metrics = process_entry(entry)
            metrics_list.append(metrics)
        except Exception:
            continue

    return pd.DataFrame(metrics_list)

def process_entry(entry):
    """Process a single entry to extract metrics and validity."""
    def extract_cell_params(cif):
        a = extract_numeric_property(cif, '_cell_length_a')
        b = extract_numeric_property(cif, '_cell_length_b')
        c = extract_numeric_property(cif, '_cell_length_c')
        alpha = extract_numeric_property(cif, '_cell_angle_alpha')
        beta = extract_numeric_property(cif, '_cell_angle_beta')
        gamma = extract_numeric_property(cif, '_cell_angle_gamma')
        implied_vol = get_unit_cell_volume(a, b, c, alpha, beta, gamma)
        gen_vol = extract_volume(cif)
        return a, b, c, alpha, beta, gamma, implied_vol, gen_vol

    sample_params = extract_cell_params(entry['cif_sample'])
    gen_params = extract_cell_params(entry['cif_gen'])

    validity = {
        'formula_validity': entry['formula_validity'],
        'bond_length_validity': entry['bond_length_validity'],
        'spacegroup_validity': entry['spacegroup_validity'],
        'site_multiplicity_validity': entry['site_multiplicity_validity'],
        'validity': all(
            [
                entry['formula_validity'],
                entry['bond_length_validity'],
                entry['spacegroup_validity'],
                entry['site_multiplicity_validity'],
            ]
        ),
    }

    metrics = {
        **{f'sample_{key}': val for key, val in zip(['a', 'b', 'c', 'alpha', 'beta', 'gamma', 'implied_vol', 'gen_vol'], sample_params)},
        **{f'gen_{key}': val for key, val in zip(['a', 'b', 'c', 'alpha', 'beta', 'gamma', 'implied_vol', 'gen_vol'], gen_params)},
        **validity,
        'rwp': entry['rwp'],
        'wd': entry['wd'],
    }

    return metrics

def process_and_print_stats(file_paths):
    for path in file_paths:
        print(f"Processing: {path}")
        df = process_file(path)
        print(extract_validity_stats(df))
        print(extract_metrics_stats(df))
        print()

# File paths
file_paths = [
    # "../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer-U_(None).pkl.gz",
    # "../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer-U_(Comp).pkl.gz",
    # "../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer-U_(CompSG).pkl.gz",
    
    # "../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer_(None_N-0p00_B-0p05).pkl.gz",
    # "../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer_(Comp_N-0p00_B-0p05).pkl.gz",
    # "../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer_(CompSG_N-0p00_B-0p05).pkl.gz",

    # "../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer_(Comp_N-0p05_B-0p05).pkl.gz",
    # "../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer_(Comp_N-0p00_B-0p10).pkl.gz",
    # "../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer_(Comp_N-0p05_B-0p10).pkl.gz",

    # "../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer_(Comp_N-0p10_B-0p05).pkl.gz",
    # "../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer_(Comp_N-0p00_B-0p20).pkl.gz",
    # "../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer_(Comp_N-0p10_B-0p20).pkl.gz",

    "../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer_chili_(Comp_N-0p00_B-0p05).pkl.gz",
    "../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer_chili_(Comp_N-0p05_B-0p10).pkl.gz",
    "../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer_chili_(Comp_N-0p10_B-0p20).pkl.gz",
]

# Process and print stats
process_and_print_stats(file_paths)


Processing: ../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer_chili_(Comp_N-0p00_B-0p05).pkl.gz
                             mean (%)  95% CI (%)    std (%)
formula_validity            89.769496    0.423233  30.305673
spacegroup_validity         98.547929    0.167065  11.962688
bond_length_validity        43.460601    0.692294  49.571777
site_multiplicity_validity  86.108855    0.483014  34.586289
validity                    42.313160    0.689991  49.406846
         mean 95% CI (%)       std
rwp  0.702010       None  0.401569
wd   0.211406       None  0.211561

Processing: ../experiments/model__conditioned_mlp_augmentation__context_3076__robust/comparison_files_fullXRD__robust/deCIFer_chili_(Comp_N-0p05_B-0p10).pkl.gz
                             mean (%)  95% CI (%)    std (%)
formula_validity            90.113238    0.416890  29.849177
spacegroup_validity         98.517240    0.168807  12.086558
bond_length_validity     

In [16]:
from decifer_refactored.tokenizer import Tokenizer

In [17]:
Tokenizer().vocab_size

372