In [1]:

import pandas as pd
import numpy as np
import json
import re
from time import time
from copy import deepcopy
from pprint import pprint
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.svm import SVR
from sklearn.linear_model import LassoLarsCV
from sklearn.metrics import r2_score, mean_absolute_error
from sklearn.model_selection import cross_validate, RandomizedSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

class ReactionDataLoader:
    def __init__(self, seed=1702913563):
        self.seed = seed
        np.random.seed(seed)
    
    def _extract_ids(self, df):
        """
        Extract unique catalyst and substrate IDs from reaction data
        
        Input DataFrame format:
        - Column 0: reaction_handle (e.g., '99_i_1_A')
        - Column 1: catalyst_id (e.g., '99_i')
        - Column 2: imine_id 
        - Column 3: thiol_id
        - Column 4: product_handle
        """
        # Extract unique catalyst IDs (e.g., '99_i')
        catalyst_ids = df.iloc[:, 1].unique()
        
        # Extract unique imine IDs and thiol IDs
        imine_ids = df.iloc[:, 2].unique()
        thiol_ids = df.iloc[:, 3].unique()
        
        print("\nExtracted IDs Summary:")
        print(f"Catalysts: {sorted(catalyst_ids)}")
        print(f"Imine IDs: {sorted(imine_ids)}")
        print(f"Thiol IDs: {sorted(thiol_ids)}")
        
        # For validation, print example mappings
        print("\nExample reaction mappings:")
        for i in range(min(5, len(df))):
            row = df.iloc[i]
            print(f"Reaction: {row.iloc[0]} -> Catalyst: {row.iloc[1]}, Imine: {row.iloc[2]}, Thiol: {row.iloc[3]}")
        
        return set(catalyst_ids), set(imine_ids), set(thiol_ids)
    
    def load_splits(self):
        """Load and process all CSV files with validation"""
        column_names = ['reaction_handle', 'catalyst_id', 'imine_id', 'thiol_id', 'product_handle']
        usecols = list(range(5))
        
        # Load datasets
        train_df = pd.read_csv('data/ER_DATA_train_set.csv', names=column_names, usecols=usecols)
        cat_test_df = pd.read_csv('data/ER_DATA_cat_test.csv', names=column_names, usecols=usecols)
        sub_test_df = pd.read_csv('data/ER_DATA_sub_test.csv', names=column_names, usecols=usecols)
        subcat_test_df = pd.read_csv('data/ER_DATA_subcat_test.csv', names=column_names, usecols=usecols)
        
        # Extract IDs from each dataset
        print("\nAnalyzing Training Set:")
        train_cats, train_imines, train_thiols = self._extract_ids(train_df)
        
        print("\nAnalyzing Catalyst Test Set:")
        test_cats, test_imines, test_thiols = self._extract_ids(cat_test_df)
        
        print("\nAnalyzing Substrate Test Set:")
        sub_cats, sub_imines, sub_thiols = self._extract_ids(sub_test_df)
        
        print("\nAnalyzing Catalyst and Substrate Test Set:")
        subcat_cats, subcat_imines, subcat_thiols = self._extract_ids(subcat_test_df)
        
        # Check for overlaps
        cat_overlap = train_cats.intersection(test_cats)
        imine_overlap = train_imines.intersection(test_imines)
        thiol_overlap = train_thiols.intersection(test_thiols)
        
        # Print validation results
        print("\nValidating splits:")
        if cat_overlap:
            print(f"Warning: Catalyst overlap between train and test: {cat_overlap}")
        if imine_overlap:
            print(f"Warning: Imine overlap between train and test: {imine_overlap}")
        if thiol_overlap:
            print(f"Warning: Thiol overlap between train and test: {thiol_overlap}")
        
        # Print dataset statistics
        print(f"\nDataset statistics:")
        print(f"Training set: {len(train_df)} reactions, {len(train_cats)} catalysts")
        print(f"Catalyst test set: {len(cat_test_df)} reactions, {len(test_cats)} catalysts")
        print(f"Substrate test set: {len(sub_test_df)} reactions")
        print(f"Subcat test set: {len(subcat_test_df)} reactions")
        
        return {
            'train_handles': train_df['reaction_handle'].tolist(),
            'unseen_cat_handles': cat_test_df['reaction_handle'].tolist(),
            'unseen_subs_handles': sub_test_df['reaction_handle'].tolist(),
            'unseen_cat_and_subs_handles': subcat_test_df['reaction_handle'].tolist(),
            'train_catalysts': train_cats,
            'train_substrates': {
                'imines': train_imines,
                'thiols': train_thiols
            },
            'test_catalysts': test_cats,
            'test_substrates': {
                'imines': test_imines,
                'thiols': test_thiols
            }
        }

def load_embeddings(file_path):
    """Load embeddings from JSON file and strip family prefix"""
    with open(file_path, 'r') as f:
        raw_embeddings = json.load(f)
    
    embeddings = {}
    family_pattern = re.compile(r'^family\d+_')
    for key, value in raw_embeddings.items():
        stripped_key = family_pattern.sub('', key)
        embeddings[stripped_key] = np.array(value)
    
    return embeddings

def multiplot_and_print(estimator, X_train, Y_train, comb_partitions, title, verbose=1, file_dpi=800):
    """Create performance plots and print metrics"""
    print(f"\n{title}")
    print(f"X_train shape: {X_train.shape}, Y_train shape: {Y_train.shape}")

    Y_train = Y_train.values.ravel()
    predicted_train = estimator.predict(X_train)
    r2_train = r2_score(Y_train, predicted_train)
    mae_train = mean_absolute_error(Y_train, predicted_train)

    if verbose:
        print(f"Train R^2: {r2_train:0.5f}, train MAE: {mae_train:0.5f}")

    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Plot training data
    ax.scatter(
        Y_train,
        predicted_train,
        color="gray",
        alpha=0.5,
        label=f"Train (R^2= {r2_train:.3f}, MAE={mae_train:.3f})"
    )

    # Plot test sets
    for part_name, (X_test, Y_test, color) in comb_partitions.items():
        print(f"\nPart name: {part_name}")
        print(f"X_test shape: {X_test.shape}, Y_test shape: {Y_test.shape}")

        Y_test = Y_test.values.ravel()
        predicted_test = estimator.predict(X_test)
        r2_test = r2_score(Y_test, predicted_test)
        mae_test = mean_absolute_error(Y_test, predicted_test)

        if verbose:
            print(f"Test R^2: {r2_test:0.5f}, test MAE: {mae_test:0.5f}")

        ax.scatter(
            Y_test,
            predicted_test,
            color=color,
            alpha=0.7,
            label=f"{part_name} (R^2= {r2_test:.3f}, MAE={mae_test:.3f})"
        )
    
    ax.set_title(title)
    ax.set_xlabel("Observed $\Delta \Delta G^\u2021$ [kcal/mol]")
    ax.set_ylabel("Predicted $\Delta \Delta G^\u2021$ [kcal/mol]")
    ax.set_xlim(-3, 3)
    ax.set_ylim(-3, 3)
    ax.plot([-3, 3], [-3, 3], 'k--', alpha=0.5)

    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{title}.png", dpi=file_dpi)
    plt.close()

def main():
    # Load embeddings
    embeddings = load_embeddings('/Users/utkarsh/MMLI/equicat/epoch_large/final_molecule_embeddings.json')
    print(f"Loaded embeddings for {len(embeddings)} entities")
    
    # Initialize data loader
    data_loader = ReactionDataLoader()
    splits = data_loader.load_splits()
    
    # Load Y data
    Y_df = pd.read_csv('/Users/utkarsh/MMLI/equicat/src/study 2/data/Y_DATA.csv', dtype={
        'catalyst_id': str,
        'imine_id': str,
        'thiol_id': str,
        'product_id': str
    })
    print(f"Loaded Y data with {len(Y_df)} rows")
    
    # Create X data using embeddings
    X_data = []
    Y_data = []
    valid_reaction_handles = []
    missing_ids = {}

    for _, row in Y_df.iterrows():
        catalyst_id = row['catalyst_id']
        imine_id = row['imine_id']
        thiol_id = row['thiol_id']
        product_id = row['product_id']
        reaction_handle = row['reaction_handle']
        
        # Check if all required embeddings exist
        required_ids = [catalyst_id, imine_id, thiol_id, product_id]
        if all(id in embeddings for id in required_ids):
            combined_embedding = np.concatenate([
                embeddings[catalyst_id],
                embeddings[imine_id],
                embeddings[thiol_id],
                embeddings[product_id]
            ])
            
            X_data.append(combined_embedding)
            Y_data.append(row['selectivity_ddGact_kcal'])
            valid_reaction_handles.append(reaction_handle)
        else:
            missing = [id for id in required_ids if id not in embeddings]
            missing_ids[reaction_handle] = missing
    
    # Print missing IDs summary
    if missing_ids:
        print("\nMissing embeddings summary:")
        for handle, missing in missing_ids.items():
            print(f"{handle}: Missing {missing}")

    # Convert to DataFrame
    X_df = pd.DataFrame(X_data, index=valid_reaction_handles)
    Y_series = pd.Series(Y_data, index=valid_reaction_handles)

    # Create train/test splits
    valid_train_handles = [h for h in splits['train_handles'] if h in X_df.index]
    if not valid_train_handles:
        raise ValueError("No valid training examples found!")
        
    X_train = X_df.loc[valid_train_handles]
    Y_train = Y_df.loc[Y_df['reaction_handle'].isin(valid_train_handles), ["selectivity_ddGact_kcal"]]

    comb_partitions = {}

    # Create test sets
    for split_name, handles, color in [
        ("Unseen substrates", splits['unseen_subs_handles'], "green"),
        ("Unseen catalysts", splits['unseen_cat_handles'], "purple"),
        ("Unseen subs and cats", splits['unseen_cat_and_subs_handles'], "blue")
    ]:
        valid_handles = [h for h in handles if h in X_df.index]
        if valid_handles:
            X_test = X_df.loc[valid_handles]
            Y_test = Y_df.loc[Y_df['reaction_handle'].isin(valid_handles), ["selectivity_ddGact_kcal"]]
            comb_partitions[split_name] = (X_test, Y_test, color)

    # Print dataset sizes
    print("\nDataset sizes:")
    print(f"Training set: {len(X_train)} samples")
    for name, (X_test, _, _) in comb_partitions.items():
        print(f"{name} test set: {len(X_test)} samples")

    # Train models
    base_pipeline = Pipeline([('scaler', StandardScaler())])
    
    models = {
        "GBR": GradientBoostingRegressor(n_estimators=1000, ccp_alpha=1e-3),
        "SVR": SVR(kernel='poly', degree=3, epsilon=0.05),
        "RF": RandomForestRegressor(n_estimators=100, n_jobs=-1, random_state=42),
        "LL": LassoLarsCV(max_iter=5000, cv=5, n_jobs=-1),
    }

    for model_name, model in models.items():
        pipe = deepcopy(base_pipeline)
        pipe.steps.append(('model', model))
        
        print(f"\nTraining {model_name}...")
        t0 = time()
        pipe.fit(X_train, Y_train)
        print(f"{model_name}: Fitting took {time() - t0:.3f}s.")

        scores = cross_validate(
            pipe, X_train, Y_train,
            cv=5, 
            scoring=['neg_mean_absolute_error', 'r2'],
            return_train_score=True
        )
        
        print(f"\n{model_name} cross-validation scores:")
        pprint(scores)
        print(f"{model_name} Q_2: {np.mean(scores['test_r2']):.5f}")

        multiplot_and_print(pipe, X_train, Y_train, comb_partitions, f"BPA_Combinatorial_{model_name}")

    print("\nAll models trained and evaluated successfully.")

if __name__ == "__main__":
    main()


Loaded embeddings for 835 entities

Analyzing Training Set:

Extracted IDs Summary:
Catalysts: ['144_i', '157_i', '182_i', '202_i', '205_i', '207_i', '230_i', '242_i', '245_vi', '246_vi', '249_i', '251_vi', '253_i', '262_vi', '276_i', '286_vi', '328_i', '365_i', '71_vi', '72_i', '76_vi', '99_i', '99_vi']
Imine IDs: [1, 2, 3, 4]
Thiol IDs: ['A', 'B', 'C', 'D']

Example reaction mappings:
Reaction: 99_i_1_A -> Catalyst: 99_i, Imine: 1, Thiol: A
Reaction: 99_i_1_B -> Catalyst: 99_i, Imine: 1, Thiol: B
Reaction: 99_i_1_C -> Catalyst: 99_i, Imine: 1, Thiol: C
Reaction: 99_i_1_D -> Catalyst: 99_i, Imine: 1, Thiol: D
Reaction: 99_i_2_A -> Catalyst: 99_i, Imine: 2, Thiol: A

Analyzing Catalyst Test Set:

Extracted IDs Summary:
Catalysts: ['145_i', '147_i', '166_i', '1_i', '210_i', '223_i', '229_i', '229_vi', '242_vi', '245_i', '29_i', '365_vi', '371_i', '382_i', '5_i', '61_i', '73_i', '7_i', '87_i']
Imine IDs: [1, 2, 3, 4]
Thiol IDs: ['A', 'B', 'C', 'D']

Example reaction mappings:
Reaction: 1

  y = column_or_1d(y, warn=True)  # TODO: Is this still required?


GBR: Fitting took 11.535s.


  y = column_or_1d(y, warn=True)  # TODO: Is this still required?
  y = column_or_1d(y, warn=True)  # TODO: Is this still required?
  y = column_or_1d(y, warn=True)  # TODO: Is this still required?
  y = column_or_1d(y, warn=True)  # TODO: Is this still required?
  y = column_or_1d(y, warn=True)  # TODO: Is this still required?



GBR cross-validation scores:
{'fit_time': array([9.24051595, 9.23069692, 9.10565996, 8.9919939 , 9.04043126]),
 'score_time': array([0.00565386, 0.00317383, 0.00313592, 0.00306106, 0.00306892]),
 'test_neg_mean_absolute_error': array([-0.31479732, -0.62264991, -0.60872172, -0.64798513, -0.2541745 ]),
 'test_r2': array([ 0.59709079, -1.88427488,  0.24615078, -1.25467921,  0.30732066]),
 'train_neg_mean_absolute_error': array([-0.12515555, -0.12512943, -0.12152705, -0.11743514, -0.12523887]),
 'train_r2': array([0.94164304, 0.94003167, 0.9317802 , 0.94413928, 0.94763589])}
GBR Q_2: -0.39768

BPA_Combinatorial_GBR
X_train shape: (368, 768), Y_train shape: (368, 1)
Train R^2: 0.93440, train MAE: 0.13133

Part name: Unseen substrates
X_test shape: (207, 768), Y_test shape: (207, 1)
Test R^2: 0.86794, test MAE: 0.18634

Part name: Unseen catalysts
X_test shape: (304, 768), Y_test shape: (304, 1)
Test R^2: -0.26881, test MAE: 0.57460

Part name: Unseen subs and cats
X_test shape: (171, 768),

  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)



SVR cross-validation scores:
{'fit_time': array([0.0216589 , 0.01919103, 0.01978302, 0.0206399 , 0.02373672]),
 'score_time': array([0.00657678, 0.00590801, 0.00576997, 0.00599885, 0.00706029]),
 'test_neg_mean_absolute_error': array([-0.35726496, -0.54442166, -0.6767849 , -0.56214006, -0.16707258]),
 'test_r2': array([ 0.42335811, -1.04427626,  0.14086591, -0.66100529,  0.71002448]),
 'train_neg_mean_absolute_error': array([-0.09139003, -0.08563362, -0.0578614 , -0.07326719, -0.08815888]),
 'train_r2': array([0.91435815, 0.91945196, 0.97784189, 0.93510517, 0.9211922 ])}
SVR Q_2: -0.08621

BPA_Combinatorial_SVR
X_train shape: (368, 768), Y_train shape: (368, 1)
Train R^2: 0.92882, train MAE: 0.08046

Part name: Unseen substrates
X_test shape: (207, 768), Y_test shape: (207, 1)
Test R^2: 0.38127, test MAE: 0.42714

Part name: Unseen catalysts
X_test shape: (304, 768), Y_test shape: (304, 1)
Test R^2: 0.15026, test MAE: 0.47640

Part name: Unseen subs and cats
X_test shape: (171, 768), 

  return fit_method(estimator, *args, **kwargs)


RF: Fitting took 0.647s.


  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)



RF cross-validation scores:
{'fit_time': array([0.4804008 , 0.57101226, 0.42070127, 0.44428372, 0.4961772 ]),
 'score_time': array([0.03416491, 0.01754069, 0.01682091, 0.01505995, 0.01723504]),
 'test_neg_mean_absolute_error': array([-0.27137523, -0.56610082, -0.58558363, -0.65503168, -0.21024873]),
 'test_r2': array([ 0.66478876, -1.26622686,  0.26937927, -1.04697396,  0.57719931]),
 'train_neg_mean_absolute_error': array([-0.05434405, -0.05334687, -0.05490797, -0.04508637, -0.0518704 ]),
 'train_r2': array([0.98866319, 0.98870448, 0.98417252, 0.99160429, 0.9906733 ])}
RF Q_2: -0.16037

BPA_Combinatorial_RF
X_train shape: (368, 768), Y_train shape: (368, 1)
Train R^2: 0.99022, train MAE: 0.04974

Part name: Unseen substrates
X_test shape: (207, 768), Y_test shape: (207, 1)
Test R^2: 0.80193, test MAE: 0.23077

Part name: Unseen catalysts
X_test shape: (304, 768), Y_test shape: (304, 1)
Test R^2: -0.71711, test MAE: 0.65875

Part name: Unseen subs and cats
X_test shape: (171, 768), Y_

  y = column_or_1d(y, warn=True)


LL: Fitting took 2.201s.


  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)



LL cross-validation scores:
{'fit_time': array([0.58301973, 0.18602276, 0.53446794, 0.23432088, 0.22567701]),
 'score_time': array([0.00799608, 0.00437307, 0.00457501, 0.00413203, 0.00588179]),
 'test_neg_mean_absolute_error': array([-0.33213152, -0.57443644, -0.70303874, -0.58482683, -0.22057564]),
 'test_r2': array([ 0.46362135, -1.08709801,  0.08616037, -1.00528526,  0.48659229]),
 'train_neg_mean_absolute_error': array([-0.20164729, -0.29181004, -0.32382147, -0.16302418, -0.18564873]),
 'train_r2': array([0.83939166, 0.64712118, 0.53307245, 0.88346243, 0.88184804])}
LL Q_2: -0.21120

BPA_Combinatorial_LL
X_train shape: (368, 768), Y_train shape: (368, 1)
Train R^2: 0.86981, train MAE: 0.17479

Part name: Unseen substrates
X_test shape: (207, 768), Y_test shape: (207, 1)
Test R^2: 0.65548, test MAE: 0.31333

Part name: Unseen catalysts
X_test shape: (304, 768), Y_test shape: (304, 1)
Test R^2: -0.05518, test MAE: 0.60091

Part name: Unseen subs and cats
X_test shape: (171, 768), Y_

The below code is the one where only 5 and E are considered in the test set. 

In [2]:

import pandas as pd
import numpy as np
import json
import re
from time import time
from copy import deepcopy
from pprint import pprint
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.svm import SVR
from sklearn.linear_model import LassoLarsCV
from sklearn.metrics import r2_score, mean_absolute_error
from sklearn.model_selection import cross_validate, RandomizedSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import yaml

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

class ReactionDataLoader:
    def __init__(self, seed=1702913563):
        self.seed = seed
        np.random.seed(seed)
    
    def _extract_ids(self, df):
        """
        Extract unique catalyst and substrate IDs from reaction data
        
        Input DataFrame format:
        - Column 0: reaction_handle (e.g., '99_i_1_A')
        - Column 1: catalyst_id (e.g., '99_i')
        - Column 2: imine_id 
        - Column 3: thiol_id
        - Column 4: product_handle
        """
        catalyst_ids = df['catalyst_id'].unique()
        imine_ids = df['imine_id'].unique()
        thiol_ids = df['thiol_id'].unique()
        
        print("\nExtracted IDs Summary:")
        print(f"Catalysts: {sorted(catalyst_ids)}")
        print(f"Imine IDs: {sorted(imine_ids)}")
        print(f"Thiol IDs: {sorted(thiol_ids)}")
        
        print("\nExample reaction mappings:")
        for i in range(min(5, len(df))):
            row = df.iloc[i]
            print(f"Reaction: {row['reaction_handle']} -> "
                  f"Catalyst: {row['catalyst_id']}, "
                  f"Imine: {row['imine_id']}, "
                  f"Thiol: {row['thiol_id']}")
        
        return set(catalyst_ids), set(imine_ids), set(thiol_ids)

    def _create_strict_splits(self, df):
        """
        Create splits ensuring proper separation of catalysts and substrates.
        
        Training:
            - Catalysts: Training set catalysts
            - Substrates: Imines [1,2,3,4], Thiols [A,B,C,D]
        
        Test:
            - New Catalysts: Test set catalysts
            - New Substrates: ONLY Imine [5], ONLY Thiol [E]
        """
        # Define train/test splits
        train_imines = [1, 2, 3, 4]
        train_thiols = ['A', 'B', 'C', 'D']
        test_imines = [5]
        test_thiols = ['E']

        # Split catalysts (using deterministic split from dataset)
        train_catalysts = [
            '144_i', '157_i', '182_i', '202_i', '205_i', '207_i', '230_i', 
            '242_i', '245_vi', '246_vi', '249_i', '251_vi', '253_i', '262_vi', 
            '276_i', '286_vi', '328_i', '365_i', '71_vi', '72_i', '76_vi', 
            '99_i', '99_vi'
        ]
        test_catalysts = [cat for cat in df['catalyst_id'].unique() if cat not in train_catalysts]

        # Create masks for each quadrant
        # Training Set: Training catalysts + Training substrates only
        train_mask = (
            df['catalyst_id'].isin(train_catalysts) & 
            df['imine_id'].isin(train_imines) & 
            df['thiol_id'].isin(train_thiols)
        )
        
        # Catalyst Test Set: Test catalysts + Training substrates only
        cat_test_mask = (
            df['catalyst_id'].isin(test_catalysts) & 
            df['imine_id'].isin(train_imines) & 
            df['thiol_id'].isin(train_thiols)
        )
        
        # Substrate Test Set: Training catalysts + ONLY test substrates
        sub_test_mask = (
            df['catalyst_id'].isin(train_catalysts) & 
            df['imine_id'].isin(test_imines) &  # Only imine 5
            df['thiol_id'].isin(test_thiols)    # Only thiol E
        )
        
        # Combined Test Set: Test catalysts + ONLY test substrates
        subcat_test_mask = (
            df['catalyst_id'].isin(test_catalysts) & 
            df['imine_id'].isin(test_imines) &  # Only imine 5
            df['thiol_id'].isin(test_thiols)    # Only thiol E
        )

        splits = {
            'train': df[train_mask].copy(),
            'cat_test': df[cat_test_mask].copy(),
            'sub_test': df[sub_test_mask].copy(),
            'subcat_test': df[subcat_test_mask].copy()
        }

        # Verify splits
        print("\nDetailed Split Verification:")
        for name, split_df in splits.items():
            print(f"\n{name} split:")
            cats = sorted(split_df['catalyst_id'].unique())
            imines = sorted(split_df['imine_id'].unique())
            thiols = sorted(split_df['thiol_id'].unique())
            
            print(f"Number of reactions: {len(split_df)}")
            print(f"Catalysts: {cats}")
            print(f"Imines: {imines}")
            print(f"Thiols: {thiols}")
            
            if name in ['sub_test', 'subcat_test']:
                if set(imines) != set(test_imines) or set(thiols) != set(test_thiols):
                    raise ValueError(f"{name} contains non-test substrates!")
                    
                reactions = split_df.apply(
                    lambda x: f"{x['catalyst_id']}_{x['imine_id']}_{x['thiol_id']}", 
                    axis=1
                ).tolist()
                print("\nExample reactions in this split:")
                for r in reactions[:25]:
                    print(f"  {r}")

        return splits

    def load_splits(self):
        """Load and process all data with proper validation"""
        # Load full dataset
        df = pd.read_csv('data/Y_DATA.csv')
        
        # Create strict splits
        splits = self._create_strict_splits(df)
        
        # Extract and validate IDs for each split
        print("\nAnalyzing Training Set:")
        train_cats, train_imines, train_thiols = self._extract_ids(splits['train'])
        
        print("\nAnalyzing Catalyst Test Set:")
        test_cats, test_imines, test_thiols = self._extract_ids(splits['cat_test'])
        
        print("\nAnalyzing Substrate Test Set:")
        sub_cats, sub_imines, sub_thiols = self._extract_ids(splits['sub_test'])
        
        print("\nAnalyzing Catalyst and Substrate Test Set:")
        subcat_cats, subcat_imines, subcat_thiols = self._extract_ids(splits['subcat_test'])
        
        return {
            'train_handles': splits['train']['reaction_handle'].tolist(),
            'unseen_cat_handles': splits['cat_test']['reaction_handle'].tolist(),
            'unseen_subs_handles': splits['sub_test']['reaction_handle'].tolist(),
            'unseen_cat_and_subs_handles': splits['subcat_test']['reaction_handle'].tolist(),
            'train_catalysts': train_cats,
            'train_substrates': {
                'imines': train_imines,
                'thiols': train_thiols
            },
            'test_catalysts': test_cats,
            'test_substrates': {
                'imines': sub_imines,
                'thiols': sub_thiols
            }
        }

def load_embeddings(file_path):
    """Load embeddings from JSON file and strip family prefix"""
    with open(file_path, 'r') as f:
        raw_embeddings = json.load(f)
    
    embeddings = {}
    family_pattern = re.compile(r'^family\d+_')
    for key, value in raw_embeddings.items():
        stripped_key = family_pattern.sub('', key)
        embeddings[stripped_key] = np.array(value)
    
    return embeddings

def multiplot_and_print(estimator, X_train, Y_train, comb_partitions, title, verbose=1, file_dpi=800):
    """Create performance plots and print metrics"""
    print(f"\n{title}")
    print(f"X_train shape: {X_train.shape}, Y_train shape: {Y_train.shape}")

    Y_train = Y_train.values.ravel()
    predicted_train = estimator.predict(X_train)
    r2_train = r2_score(Y_train, predicted_train)
    mae_train = mean_absolute_error(Y_train, predicted_train)

    if verbose:
        print(f"Train R^2: {r2_train:0.5f}, train MAE: {mae_train:0.5f}")

    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Plot training data
    ax.scatter(
        Y_train,
        predicted_train,
        color="gray",
        alpha=0.5,
        label=f"Train (R^2= {r2_train:.3f}, MAE={mae_train:.3f})"
    )

    # Plot test sets
    for part_name, (X_test, Y_test, color) in comb_partitions.items():
        print(f"\nPart name: {part_name}")
        print(f"X_test shape: {X_test.shape}, Y_test shape: {Y_test.shape}")

        Y_test = Y_test.values.ravel()
        predicted_test = estimator.predict(X_test)
        r2_test = r2_score(Y_test, predicted_test)
        mae_test = mean_absolute_error(Y_test, predicted_test)

        if verbose:
            print(f"Test R^2: {r2_test:0.5f}, test MAE: {mae_test:0.5f}")

        ax.scatter(
            Y_test,
            predicted_test,
            color=color,
            alpha=0.7,
            label=f"{part_name} (R^2= {r2_test:.3f}, MAE={mae_test:.3f})"
        )
    
    ax.set_title(title)
    ax.set_xlabel("Observed $\Delta \Delta G^\u2021$ [kcal/mol]")
    ax.set_ylabel("Predicted $\Delta \Delta G^\u2021$ [kcal/mol]")
    ax.set_xlim(-3, 3)
    ax.set_ylim(-3, 3)
    ax.plot([-3, 3], [-3, 3], 'k--', alpha=0.5)

    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{title}.png", dpi=file_dpi)
    plt.close()

def main():
    # Load embeddings
    embeddings = load_embeddings('/Users/utkarsh/MMLI/equicat/epoch_large/final_molecule_embeddings.json')
    print(f"Loaded embeddings for {len(embeddings)} entities")
    
    # Initialize data loader
    data_loader = ReactionDataLoader()
    splits = data_loader.load_splits()
    
    # Load Y data
    Y_df = pd.read_csv('/Users/utkarsh/MMLI/equicat/src/study 2/data/Y_DATA.csv', dtype={
        'catalyst_id': str,
        'imine_id': str,
        'thiol_id': str,
        'product_id': str
    })
    print(f"Loaded Y data with {len(Y_df)} rows")
    
    # Create X data using embeddings
    X_data = []
    Y_data = []
    valid_reaction_handles = []
    missing_ids = {}

    for _, row in Y_df.iterrows():
        catalyst_id = row['catalyst_id']
        imine_id = row['imine_id']
        thiol_id = row['thiol_id']
        product_id = row['product_id']
        reaction_handle = row['reaction_handle']
        
        # Check if all required embeddings exist
        required_ids = [catalyst_id, imine_id, thiol_id, product_id]
        if all(id in embeddings for id in required_ids):
            combined_embedding = np.concatenate([
                embeddings[catalyst_id],
                embeddings[imine_id],
                embeddings[thiol_id],
                embeddings[product_id]
            ])
            
            X_data.append(combined_embedding)
            Y_data.append(row['selectivity_ddGact_kcal'])
            valid_reaction_handles.append(reaction_handle)
        else:
            missing = [id for id in required_ids if id not in embeddings]
            missing_ids[reaction_handle] = missing
    
    # Print missing IDs summary
    if missing_ids:
        print("\nMissing embeddings summary:")
        for handle, missing in missing_ids.items():
            print(f"{handle}: Missing {missing}")

    # Convert to DataFrame
    X_df = pd.DataFrame(X_data, index=valid_reaction_handles)
    Y_series = pd.Series(Y_data, index=valid_reaction_handles)

    # Create train/test splits
    valid_train_handles = [h for h in splits['train_handles'] if h in X_df.index]
    if not valid_train_handles:
        raise ValueError("No valid training examples found!")
        
    X_train = X_df.loc[valid_train_handles]
    Y_train = Y_df.loc[Y_df['reaction_handle'].isin(valid_train_handles), ["selectivity_ddGact_kcal"]]

    comb_partitions = {}

    # Create test sets
    for split_name, handles, color in [
        ("Unseen substrates", splits['unseen_subs_handles'], "green"),
        ("Unseen catalysts", splits['unseen_cat_handles'], "purple"),
        ("Unseen subs and cats", splits['unseen_cat_and_subs_handles'], "blue")
    ]:
        valid_handles = [h for h in handles if h in X_df.index]
        if valid_handles:
            X_test = X_df.loc[valid_handles]
            Y_test = Y_df.loc[Y_df['reaction_handle'].isin(valid_handles), ["selectivity_ddGact_kcal"]]
            comb_partitions[split_name] = (X_test, Y_test, color)

    # Print dataset sizes
    print("\nDataset sizes:")
    print(f"Training set: {len(X_train)} samples")
    for name, (X_test, _, _) in comb_partitions.items():
        print(f"{name} test set: {len(X_test)} samples")

    # Train models
    base_pipeline = Pipeline([('scaler', StandardScaler())])
    
    models = {
        "GBR": GradientBoostingRegressor(n_estimators=500, ccp_alpha=1e-3),
        "SVR": SVR(kernel='poly', degree=3, epsilon=0.05),
        "RF": RandomForestRegressor(n_estimators=500, n_jobs=-1, random_state=7),
        "LL": LassoLarsCV(max_iter=5000, cv=5, n_jobs=-1),
    }

    for model_name, model in models.items():
        pipe = deepcopy(base_pipeline)
        pipe.steps.append(('model', model))
        
        print(f"\nTraining {model_name}...")
        t0 = time()
        pipe.fit(X_train, Y_train)
        print(f"{model_name}: Fitting took {time() - t0:.3f}s.")

        scores = cross_validate(
            pipe, X_train, Y_train,
            cv=5, 
            scoring=['neg_mean_absolute_error', 'r2'],
            return_train_score=True
        )
        
        print(f"\n{model_name} cross-validation scores:")
        pprint(scores)
        print(f"{model_name} Q_2: {np.mean(scores['test_r2']):.5f}")

        multiplot_and_print(pipe, X_train, Y_train, comb_partitions, f"BPA_Combinatorial_{model_name}")

    print("\nAll models trained and evaluated successfully.")

if __name__ == "__main__":
    main()

Loaded embeddings for 835 entities

Detailed Split Verification:

train split:
Number of reactions: 368
Catalysts: ['144_i', '157_i', '182_i', '202_i', '205_i', '207_i', '230_i', '242_i', '245_vi', '246_vi', '249_i', '251_vi', '253_i', '262_vi', '276_i', '286_vi', '328_i', '365_i', '71_vi', '72_i', '76_vi', '99_i', '99_vi']
Imines: [1, 2, 3, 4]
Thiols: ['A', 'B', 'C', 'D']

cat_test split:
Number of reactions: 304
Catalysts: ['145_i', '147_i', '166_i', '1_i', '210_i', '223_i', '229_i', '229_vi', '242_vi', '245_i', '29_i', '365_vi', '371_i', '382_i', '5_i', '61_i', '73_i', '7_i', '87_i']
Imines: [1, 2, 3, 4]
Thiols: ['A', 'B', 'C', 'D']

sub_test split:
Number of reactions: 23
Catalysts: ['144_i', '157_i', '182_i', '202_i', '205_i', '207_i', '230_i', '242_i', '245_vi', '246_vi', '249_i', '251_vi', '253_i', '262_vi', '276_i', '286_vi', '328_i', '365_i', '71_vi', '72_i', '76_vi', '99_i', '99_vi']
Imines: [5]
Thiols: ['E']

Example reactions in this split:
  144_i_5_E
  157_i_5_E
  182_i_5

  y = column_or_1d(y, warn=True)  # TODO: Is this still required?


GBR: Fitting took 5.816s.


  y = column_or_1d(y, warn=True)  # TODO: Is this still required?
  y = column_or_1d(y, warn=True)  # TODO: Is this still required?
  y = column_or_1d(y, warn=True)  # TODO: Is this still required?
  y = column_or_1d(y, warn=True)  # TODO: Is this still required?
  y = column_or_1d(y, warn=True)  # TODO: Is this still required?



GBR cross-validation scores:
{'fit_time': array([4.63873219, 5.04385495, 4.61571193, 4.4706459 , 4.5038023 ]),
 'score_time': array([0.00356388, 0.00384521, 0.00338197, 0.00370812, 0.00318694]),
 'test_neg_mean_absolute_error': array([-0.6153812 , -0.59782428, -0.71044225, -0.63896888, -0.17463152]),
 'test_r2': array([-0.5424633 , -1.71449897, -0.12026193, -0.76512685,  0.71643678]),
 'train_neg_mean_absolute_error': array([-0.12217467, -0.12178404, -0.1222905 , -0.11563356, -0.12225058]),
 'train_r2': array([0.94560402, 0.94368367, 0.92947555, 0.94740727, 0.94995083])}
GBR Q_2: -0.48518

BPA_Combinatorial_GBR
X_train shape: (368, 768), Y_train shape: (368, 1)
Train R^2: 0.94142, train MAE: 0.12383

Part name: Unseen substrates
X_test shape: (23, 768), Y_test shape: (23, 1)
Test R^2: 0.90516, test MAE: 0.15958

Part name: Unseen catalysts
X_test shape: (304, 768), Y_test shape: (304, 1)
Test R^2: -0.18898, test MAE: 0.59198

Part name: Unseen subs and cats
X_test shape: (19, 768), Y_

  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)



SVR cross-validation scores:
{'fit_time': array([0.02000809, 0.02057695, 0.01967502, 0.02121425, 0.02110505]),
 'score_time': array([0.00643492, 0.00637388, 0.00600505, 0.00581098, 0.00675797]),
 'test_neg_mean_absolute_error': array([-0.40736169, -0.61583882, -0.61748674, -0.60082689, -0.15776011]),
 'test_r2': array([ 0.22324295, -1.47389108,  0.23356486, -0.52376831,  0.71324786]),
 'train_neg_mean_absolute_error': array([-0.08933459, -0.08464323, -0.07278762, -0.08061528, -0.09560427]),
 'train_r2': array([0.92749925, 0.95041027, 0.94837473, 0.92687885, 0.92393719])}
SVR Q_2: -0.16552

BPA_Combinatorial_SVR
X_train shape: (368, 768), Y_train shape: (368, 1)
Train R^2: 0.93667, train MAE: 0.08395

Part name: Unseen substrates
X_test shape: (23, 768), Y_test shape: (23, 1)
Test R^2: 0.29357, test MAE: 0.44740

Part name: Unseen catalysts
X_test shape: (304, 768), Y_test shape: (304, 1)
Test R^2: 0.05840, test MAE: 0.49889

Part name: Unseen subs and cats
X_test shape: (19, 768), Y_t

  return fit_method(estimator, *args, **kwargs)


RF: Fitting took 2.794s.


  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)



RF cross-validation scores:
{'fit_time': array([2.29848814, 2.59468699, 2.61975384, 2.08915424, 1.9942348 ]),
 'score_time': array([0.02935982, 0.02769995, 0.02965808, 0.02953482, 0.02982497]),
 'test_neg_mean_absolute_error': array([-0.67493749, -0.6138876 , -0.75506086, -0.62120238, -0.2060198 ]),
 'test_r2': array([-0.82021452, -2.08282323, -0.2603671 , -0.58580517,  0.60277473]),
 'train_neg_mean_absolute_error': array([-0.05345899, -0.05427587, -0.05091813, -0.0438752 , -0.05107932]),
 'train_r2': array([0.98936609, 0.98823555, 0.98617939, 0.99175201, 0.99063841])}
RF Q_2: -0.62929

BPA_Combinatorial_RF
X_train shape: (368, 768), Y_train shape: (368, 1)
Train R^2: 0.99053, train MAE: 0.04826

Part name: Unseen substrates
X_test shape: (23, 768), Y_test shape: (23, 1)
Test R^2: 0.88799, test MAE: 0.17059

Part name: Unseen catalysts
X_test shape: (304, 768), Y_test shape: (304, 1)
Test R^2: -0.10128, test MAE: 0.53501

Part name: Unseen subs and cats
X_test shape: (19, 768), Y_tes

  y = column_or_1d(y, warn=True)


LL: Fitting took 1.804s.


  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)



LL cross-validation scores:
{'fit_time': array([0.51868105, 0.48132825, 0.24281001, 0.22717595, 0.1388948 ]),
 'score_time': array([0.00515485, 0.00534487, 0.00302696, 0.002985  , 0.00297427]),
 'test_neg_mean_absolute_error': array([-0.50714383, -0.6399915 , -0.5768531 , -0.66618261, -0.36561611]),
 'test_r2': array([-0.41316624, -1.42800753,  0.34473119, -0.67334051, -0.01997947]),
 'train_neg_mean_absolute_error': array([-0.1870848 , -0.39168412, -0.24037052, -0.21852668, -0.58181791]),
 'train_r2': array([0.87409413, 0.41198834, 0.7253533 , 0.79005288, 0.        ])}
LL Q_2: -0.43795

BPA_Combinatorial_LL
X_train shape: (368, 768), Y_train shape: (368, 1)
Train R^2: 0.64060, train MAE: 0.30586

Part name: Unseen substrates
X_test shape: (23, 768), Y_test shape: (23, 1)
Test R^2: 0.55817, test MAE: 0.34283

Part name: Unseen catalysts
X_test shape: (304, 768), Y_test shape: (304, 1)
Test R^2: -0.13710, test MAE: 0.54839

Part name: Unseen subs and cats
X_test shape: (19, 768), Y_tes