In [2]:
import matplotlib.pyplot as plt
import polars as pl
import numpy as np

from pymongo import MongoClient


def replace_id(df, replace_name = 'InChI_key'):
    # Load the DataFrame using Polars
    
    # Connect to MongoDB
    client = MongoClient('mongodb://localhost:27017/')
    db = client['lotus_mines']
    
    # Fetch the compounds collection
    compounds_collection = db['compounds']
    
    # Prepare a list to store the results
    results = []
    
    # Iterate over each row in the DataFrame
    for row in df.iter_rows(named=True):
        _id = row['id']
        
        # Find the corresponding compound using _id
        compound = compounds_collection.find_one({'_id': _id})
        
        if compound:
            replace_value = compound.get(replace_name)
            results.append({
                'id': _id,
                replace_name: replace_value
            })
            print(f"Found replace_name: {replace_value} for id: {_id}")
        else:
            print(f"No compound found for id: {_id}")
    
    # Convert the results to a DataFrame
    results_df = pl.DataFrame(results)

    results_df = df.with_columns(results_df[replace_name].alias("id"))
    
    # Display the results
    return results_df


df = pl.read_parquet("../data/MINES/reactions_compounds_list_full.parquet")
df

reaction_id,reactants,products
str,list[struct[2]],list[struct[2]]
"""R5d8539f1d9a5e857189956bad8eb4…","[{""Ce8bc5cd3aa30776ab6d35fdc2bcc4707f4ac2919"",""Starting Compound""}, {""X0eb45233dd43ecacb9fb1e31140450e1dace01c5"",""Coreactant""}, {""X8dc023d8052d83fb6feadf8541387e57c199cad0"",""Coreactant""}]","[{""C878f017efe6de2805a953d0ca9b8491274a29290"",""Predicted""}, {""X73bc8ef21db580aefe4dbc0af17d4013961d9d17"",""Coreactant""}]"
"""R6a89aaf90529aa474f537c081d71f…","[{""Ce8bc5cd3aa30776ab6d35fdc2bcc4707f4ac2919"",""Starting Compound""}, {""X0eb45233dd43ecacb9fb1e31140450e1dace01c5"",""Coreactant""}, {""X8dc023d8052d83fb6feadf8541387e57c199cad0"",""Coreactant""}]","[{""C54130f1c76aaa5380fa631a6a659121284978c5d"",""Predicted""}, {""X73bc8ef21db580aefe4dbc0af17d4013961d9d17"",""Coreactant""}]"
"""Rf0c39549766c89963dbb1a98f8f1d…","[{""Ce8bc5cd3aa30776ab6d35fdc2bcc4707f4ac2919"",""Starting Compound""}, {""X0eb45233dd43ecacb9fb1e31140450e1dace01c5"",""Coreactant""}, {""X8dc023d8052d83fb6feadf8541387e57c199cad0"",""Coreactant""}]","[{""Cfa5e885b86c8c37a465cad5238ed62672498a45d"",""Predicted""}, {""X73bc8ef21db580aefe4dbc0af17d4013961d9d17"",""Coreactant""}]"
"""Rad931d485dae8cbc9aa07c1301163…","[{""Ce8bc5cd3aa30776ab6d35fdc2bcc4707f4ac2919"",""Starting Compound""}, {""X0eb45233dd43ecacb9fb1e31140450e1dace01c5"",""Coreactant""}, {""X8dc023d8052d83fb6feadf8541387e57c199cad0"",""Coreactant""}]","[{""C57a73b796ef9de341670ad4f895779c4ce0d4623"",""Predicted""}, {""X73bc8ef21db580aefe4dbc0af17d4013961d9d17"",""Coreactant""}]"
"""R02b3ad62ed7e42f819c6931a7d392…","[{""Ce8bc5cd3aa30776ab6d35fdc2bcc4707f4ac2919"",""Starting Compound""}, {""X0eb45233dd43ecacb9fb1e31140450e1dace01c5"",""Coreactant""}, {""X8dc023d8052d83fb6feadf8541387e57c199cad0"",""Coreactant""}]","[{""C8e1b680b68eec30be34c6b4857d630e2c245759d"",""Predicted""}, {""X73bc8ef21db580aefe4dbc0af17d4013961d9d17"",""Coreactant""}]"
…,…,…
"""Rc52c2d326695e43c104ecd4a11b20…","[{""C92af25813ba38f4e1f385dc92b0e253dc7a01f6c"",""Starting Compound""}]","[{""C5e8022bf3b0f1ee82b5ec45f5c89f69ea2ab5a29"",""Predicted""}, {""X05622481c18728cacf9317e9a57ce6d421315aac"",""Coreactant""}]"
"""Rb993c3707a42c88a143334e22d4af…","[{""C326b3a90bc4bff0d2cf249e40ce517af77a62d0d"",""Starting Compound""}]","[{""Ca4ef9345ff26e4fec21cc7fb72fa9dca2f86fcc5"",""Predicted""}, {""X05622481c18728cacf9317e9a57ce6d421315aac"",""Coreactant""}]"
"""Rc1fdf76ef8f34b7c3a7538fcb9924…","[{""C62f5460835a65fcb75c6c8da0f7f6f90c060d900"",""Starting Compound""}]","[{""Ce9a6758cc6c58ca618292cde869df568a1c34cb8"",""Predicted""}, {""X05622481c18728cacf9317e9a57ce6d421315aac"",""Coreactant""}]"
"""Rbe1a7403def3e4c0a1dac139d28a0…","[{""C2001018bdb89adc25ce61dd084300beb082370ca"",""Starting Compound""}]","[{""C58fecaa1205bf8dee7e07040978bc26c49e0339c"",""Starting Compound""}, {""X05622481c18728cacf9317e9a57ce6d421315aac"",""Coreactant""}]"


In [None]:
%%time

# Explode reactants and products
reactants_exploded = df.explode("reactants").select(pl.col(["reaction_id", "reactants"])).unnest("reactants").with_columns(pl.lit(True).alias('reactant'))
products_exploded = df.explode("products").select(pl.col(["reaction_id", "products"])).unnest("products").with_columns(pl.lit(False).alias('reactant'))

reactants_exploded, products_exploded

In [None]:
df_comp = pl.concat([reactants_exploded, products_exploded], how="vertical")  

print(df_comp)

In [None]:
# Filter data by type
starting_compounds = df_comp.filter(pl.col('type') == 'Starting Compound')

# Count the occurrences of starting compounds for each starting compound ID
starting_count = starting_compounds.group_by('id').count()
starting_count = starting_count.sort('count', descending=True)

starting_count = replace_id(starting_count[0:20], replace_name="structure_nameTraditional" )

# Plotting
plt.figure(figsize=(10, 6))
plt.bar(starting_count['id'], starting_count['count'], color='blue')
plt.xlabel('InChI_key of Starting Compound')
plt.ylabel('Amount of Predicted Compounds')
plt.title('Starting Compound')
plt.xticks(rotation=90)
plt.show()

print(starting_count)

In [None]:
predicted_compounds = df_comp.filter(pl.col('type') == 'Predicted')

# Count the occurrences of predicted compounds for each starting compound ID
predicted_count = predicted_compounds.group_by('id').count()
predicted_count = predicted_count.sort('count', descending=True)

predicted_count = replace_id(predicted_count[0:20], replace_name="structure_nameTraditional")

# Plotting
plt.figure(figsize=(10, 6))
plt.bar(predicted_count['id'], predicted_count['count'], color='blue')
plt.xlabel('InChI_key of Predicted Compounds')
plt.ylabel('Amount of Predicted Compounds')
plt.title('Predicted Compounds')
plt.xticks(rotation=90)
plt.show()


print(predicted_count)

In [None]:
df = df_comp

# Grouping and aggregating
result = (
    df.group_by("reaction_id")
    .agg([
        pl.when(pl.col("type") == "Coreactant").then(pl.col("id")).otherwise(None).alias("coreactant"),
        pl.when(pl.col("type") == "Starting Compound").then(pl.col("id")).otherwise(None).alias("starting_compound"),
        pl.when(pl.col("type") == "Predicted").then(pl.col("id")).otherwise(None).alias("predicted")
    ])
)

# Function to remove null values from a list
def remove_nulls(lst):
    return [x for x in lst if x is not None]

# Apply the function to each list column with specified return_dtype
df_cleaned = result.with_columns([
    pl.col("coreactant").apply(remove_nulls, return_dtype=pl.List(pl.Utf8)).alias("coreactant"),
    pl.col("starting_compound").apply(remove_nulls, return_dtype=pl.List(pl.Utf8)).alias("starting_compound"),
    pl.col("predicted").apply(remove_nulls, return_dtype=pl.List(pl.Utf8)).alias("predicted")
])

print(df_cleaned)

# Display the result
print(result)