In [3]:
import csv
import copy
import numpy as np
import pandas as pd

from scipy.stats import pearsonr, spearmanr

In [4]:
def calculate_crds(datasets_ini, alpha=1, beta_up=1.3, beta_down=0.3, gamma_up=1.6, gamma_down=0.6, T_g=120, T_n=5000):
    # Extract reactants, reagents, and reactions from the datasets to find max values
    datasets = copy.deepcopy(datasets_ini)
    reactants = [d[0] for d in datasets]
    reagents = [d[1] for d in datasets]
    reactions = [d[2] for d in datasets]
    
    # Get the max values
    max_reactants = max(reactants)
    max_reagents = max(reagents)
    max_reactions = max(reactions)
    
    # Calculate the score for each dataset and insert it into the respective sublist
    for dataset in datasets:
        R_t, R_g, R_n = dataset[0], dataset[1], dataset[2]
        
        # R_t contribution
        R_t_contrib = (R_t / max_reactants) ** alpha
        
        # R_g contribution with modified weighting
        if R_g <= T_g:
            R_g_contrib = (R_g / max_reagents) ** beta_up
        else:
            R_g_contrib = (T_g / max_reagents) ** beta_up * (R_g / T_g) ** beta_down
        
        # R_n contribution with modified weighting for reactions
        if R_n <= T_n:
            R_n_contrib = (R_n / max_reactions) ** gamma_up
        else:
            R_n_contrib = (T_n / max_reactions) ** gamma_up * (R_n / T_n) ** gamma_down
        
        # Total score
        score = R_t_contrib * R_g_contrib * R_n_contrib
        
        # Insert the score into the dataset
        dataset.insert(3, score)
    
    return datasets

def summarize_sources_with_combined(df):
    # List to hold summary information for each source
    summary_list = []
    
    # Grouping the DataFrame by 'Source'
    grouped = df.groupby('Source')
    
    # Iterating over each group
    for source, group in grouped:
        # Number of unique pairs of 'Aryl Halide SMILES', 'Amine SMILES'
        unique_pair_count = group[['Aryl Halide SMILES', 'Amine SMILES']].drop_duplicates().shape[0]
        
        # Number of unique combinations of 'Catalyst SMILES', 'Solvent SMILES', 'Base SMILES'
        unique_combinations_count = group[['Catalyst SMILES', 'Solvent SMILES', 'Base SMILES']].drop_duplicates().shape[0]
        
        # Total number of rows for the current source
        total_rows_count = group.shape[0]
        
        # Append the summary information as a list
        summary_list.append([unique_pair_count, unique_combinations_count, total_rows_count, source])

    # DataFrame excluding "JNJ HTE 2024"
    df_excluding_jnj = df[df['Source'] != "JNJ HTE 2024"]
    
    # Summary for all sources combined except "JNJ HTE 2024"
    unique_pairs_excl_jnj = df_excluding_jnj[['Aryl Halide SMILES', 'Amine SMILES']].drop_duplicates().shape[0]
    unique_combinations_excl_jnj = df_excluding_jnj[['Catalyst SMILES', 'Solvent SMILES', 'Base SMILES']].drop_duplicates().shape[0]
    total_rows_excl_jnj = df_excluding_jnj.shape[0]
    combined_excl_jnj_summary = [unique_pairs_excl_jnj, unique_combinations_excl_jnj, total_rows_excl_jnj, "All except JNJ HTE 2024"]
    
    # Summary for all sources combined
    unique_pairs_all = df[['Aryl Halide SMILES', 'Amine SMILES']].drop_duplicates().shape[0]
    unique_combinations_all = df[['Catalyst SMILES', 'Solvent SMILES', 'Base SMILES']].drop_duplicates().shape[0]
    total_rows_all = df.shape[0]
    combined_all_summary = [unique_pairs_all, unique_combinations_all, total_rows_all, "All sources"]

    # Add the combined summaries to the summary list
    summary_list.append(combined_excl_jnj_summary)
    summary_list.append(combined_all_summary)

    return summary_list

def calculate_correlations(dataframe, reference_column, columns_to_compare):
    """
    Calculate Pearson and Spearman correlations between a reference column 
    and a list of other columns in a dataframe.

    Parameters:
        dataframe (pd.DataFrame): The dataframe containing the data.
        reference_column (str): The column name for the reference column.
        columns_to_compare (list of str): A list of column names to compare against the reference column.

    Returns:
        pd.DataFrame: A dataframe containing Pearson and Spearman correlations for each column in the list.
    """
    results = []
    for column in columns_to_compare:
        if column != reference_column:
            # Calculate Pearson and Spearman correlations
            pearson_corr, _ = pearsonr(dataframe[reference_column], dataframe[column])
            spearman_corr, _ = spearmanr(dataframe[reference_column], dataframe[column])
            results.append({"Column": column, "Pearson Correlation": pearson_corr, "Spearman Correlation": spearman_corr})
    
    # Convert results to a dataframe
    results_df = pd.DataFrame(results)
    return results_df

def load_dataset_summary(csv_file):
    """
    Reads a dataset_summary.csv file and returns a list of lists
    of the form: [train_samples, validation_samples, test_samples, dataset_name].
    """
    datasets = []
    
    with open(csv_file, mode="r", encoding="utf-8") as file:
        reader = csv.reader(file)
        header = next(reader)  # Skip header row
        
        for row in reader:
            # Convert numeric values from strings to integers
            train = int(row[0])
            validation = int(row[1])
            test = int(row[2])
            name = row[3]
            
            datasets.append([train, validation, test, name])
    
    return datasets


### Correlation: CRDS vs OOD performance?

In [22]:
dataset_summary = load_dataset_summary("data/dataset_summary.csv")

datasets_w_scores = calculate_crds(copy.deepcopy(dataset_summary), alpha=1, beta_up=0.2, beta_down=0.1, gamma_up=0.9, gamma_down=0.2, T_g=4000, T_n=5000)

best_performance = pd.read_csv("results/Data_Source_best.csv", index_col = 0)

dic = {"Train Data":[], "CRDS": []}
for scores in datasets_w_scores:
    dic["Train Data"].append(scores[4])
    dic["CRDS"].append(scores[3])

df = pd.DataFrame(dic)
final = best_performance.merge(df, on = ["Train Data"])

correlations_df = calculate_correlations(final, "CRDS", ['ROC AUC Avg', 'Balanced Accuracy Avg', 'F1 Score Avg', 'AU-PR-C Avg'])
print(correlations_df)

                  Column  Pearson Correlation  Spearman Correlation
0            ROC AUC Avg             0.789956              0.452381
1  Balanced Accuracy Avg             0.770371              0.428571
2           F1 Score Avg             0.308120              0.309524
3            AU-PR-C Avg             0.170448              0.285714


### Correlation: Simplified CRDS vs OOD performance?

In [21]:
dataset_summary = load_dataset_summary("data/dataset_summary.csv")

datasets_w_scores = calculate_crds(copy.deepcopy(dataset_summary), alpha=1, beta_up=1, beta_down=1, gamma_up=1.0, gamma_down=1, T_g=4000000, T_n=500000)

best_performance = pd.read_csv("results/Data_Source_best.csv", index_col = 0)

dic = {"Train Data":[], "CRDS": []}
for scores in datasets_w_scores:
    dic["Train Data"].append(scores[4])
    dic["CRDS"].append(scores[3])

df = pd.DataFrame(dic)
final = best_performance.merge(df, on = ["Train Data"])

correlations_df = calculate_correlations(final, "CRDS", ['ROC AUC Avg', 'Balanced Accuracy Avg', 'F1 Score Avg', 'AU-PR-C Avg'])
print(correlations_df)

                  Column  Pearson Correlation  Spearman Correlation
0            ROC AUC Avg             0.534217              0.261905
1  Balanced Accuracy Avg             0.558231              0.214286
2           F1 Score Avg             0.263111              0.428571
3            AU-PR-C Avg             0.381183              0.404762


### Correlation: Dataset Size vs OOD performance?

In [8]:
best_performance = pd.read_csv("results/Data_Source_best.csv")

dic = {"Train Data":['Chem.S.23', 'Sci.15', 'Chem.S.16', 'Sci.18', 'Nat.C.24', "JACS.25",
                     'JnJ25', 'Sci.23'],
        "dataset_size":[750, 768, 144, 4312, 2632, 4204, 11328, 3359]}

df = pd.DataFrame(dic)
final = best_performance.merge(df, on = ["Train Data"])
correlations_df = calculate_correlations(final, "dataset_size", ['ROC AUC Avg', 'Balanced Accuracy Avg', 'F1 Score Avg', 'AU-PR-C Avg'])
print(correlations_df)

                  Column  Pearson Correlation  Spearman Correlation
0            ROC AUC Avg             0.603495              0.476190
1  Balanced Accuracy Avg             0.540669              0.404762
2           F1 Score Avg             0.286920             -0.214286
3            AU-PR-C Avg             0.062482             -0.309524
