In [3]:
import os, sys, time
import numpy as np
import pandas as pd
import scipy.sparse
import sklearn.preprocessing as sk_prep
import sklearn.model_selection as sk_model
import torch
from torch.utils.data import DataLoader, Dataset

import pandas as pd
import numpy as np
import ast

def load_ecotox_data(
    adore_path,
    chemicals_path,
    # -- Optional concentration choice --
    use_molar=True,
    mol_col="result_conc1_mean_mol_log",
    mg_col="result_conc1_mean_log",
    # -- Selfies options --
    use_selfies=False,
    selfies_path=None,
    selfies_col="selfies_embed",
    # -- Mol2Vec options --
    use_mol2vec=False,
    mol2vec_path=None,
    mol2vec_cols=None,
    # -- Fingerprint options --
    use_fingerprint=False,
    fingerprint_path=None,
    fp_col="morgan_fp",
    # -- Misc --
    shuffle=True,
    random_state=42
):
    # -------------------------------------------------------------------------
    # 1. Read base CSVs and standard merges
    # -------------------------------------------------------------------------
    adore = pd.read_csv(adore_path, low_memory=False)
    chemicals = pd.read_csv(chemicals_path, low_memory=False)

    df = adore.merge(chemicals, on='test_cas', how='left')
    df.rename(columns={
        'tax_gs': 'species',
        'test_cas': 'CAS',
        'result_obs_duration_mean': 'duration'
    }, inplace=True)

    target_col = mol_col if use_molar else mg_col
    if target_col not in df.columns:
        raise ValueError(
            f"Chosen target_col='{target_col}' not found in the merged DataFrame. "
            "Check your ADORE/chemicals CSV for the correct column name."
        )
    df.rename(columns={target_col: 'conc'}, inplace=True)

    # -------------------------------------------------------------------------
    # 2. Optionally merge SELFIES
    # -------------------------------------------------------------------------
    if use_selfies:
        if not selfies_path:
            raise ValueError("`use_selfies=True` but `selfies_path` was not provided.")
        df_selfies = pd.read_csv(selfies_path)
        df = df.merge(df_selfies, on='CAS', how='left')
        def parse_selfies(x):
            return ast.literal_eval(x) if isinstance(x, str) else x
        df[selfies_col] = df[selfies_col].apply(parse_selfies)

    # -------------------------------------------------------------------------
    # 3. Optionally merge Mol2Vec
    # -------------------------------------------------------------------------
    if use_mol2vec:
        if mol2vec_path is None:
            # Assume mol2vec columns are already present in the merged dataframe.
            if not mol2vec_cols:
                raise ValueError("When use_mol2vec=True and mol2vec_path is None, you must provide mol2vec_cols.")
            missing_cols = [col for col in mol2vec_cols if col not in df.columns]
            if missing_cols:
                raise ValueError(f"Mol2Vec columns missing in chemicals file: {missing_cols}")
        else:
            if not mol2vec_cols:
                raise ValueError("`use_mol2vec=True` requires mol2vec_cols.")
            df_m2v = pd.read_csv(mol2vec_path)
            columns_to_merge = ['CAS'] + mol2vec_cols
            missing_cols = [c for c in columns_to_merge if c not in df_m2v.columns]
            if missing_cols:
                raise ValueError(f"Mol2Vec file is missing columns: {missing_cols}")
            df = df.merge(df_m2v[columns_to_merge], on='CAS', how='left')

    # -------------------------------------------------------------------------
    # 4. Optionally merge Fingerprints
    # -------------------------------------------------------------------------
    if use_fingerprint:
        if not fingerprint_path:
            raise ValueError("`use_fingerprint=True` but `fingerprint_path` was not provided.")
        df_fp = pd.read_csv(fingerprint_path)
        if 'CAS' not in df_fp.columns or fp_col not in df_fp.columns:
            raise ValueError(f"Fingerprint file must contain 'CAS' and '{fp_col}' columns. Found: {list(df_fp.columns)}")
        df = df.merge(df_fp[['CAS', fp_col]], on='CAS', how='left')
        def parse_fp(x):
            return ast.literal_eval(x) if isinstance(x, str) else x
        df[fp_col] = df[fp_col].apply(parse_fp)

    # -------------------------------------------------------------------------
    # 5. Clean up DataFrame
    # -------------------------------------------------------------------------
    df['species'] = pd.Categorical(df['species'])
    df['CAS'] = pd.Categorical(df['CAS'])
    df['duration'] = df['duration'].astype(float)

    if shuffle:
        df = df.sample(frac=1, random_state=random_state).reset_index(drop=True)

    # -------------------------------------------------------------------------
    # 6. Center the target
    # -------------------------------------------------------------------------
    y_mean = df['conc'].mean()
    df['conc_centered'] = df['conc'] - y_mean

    return df, df['conc_centered'].values




In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -----------------------------------------------------------------------
# 1. Load + minimal extra preprocessing
# -----------------------------------------------------------------------
adore_path = '../data_files/ecotox_mortality_processed.csv'
chemicals_path = '../data_files/ecotox_properties_with-oecd-function.csv'

full_data, y_centered = load_ecotox_data(
    adore_path=adore_path,
    chemicals_path = chemicals_path,
    use_selfies=False,
    use_mol2vec=False,
    use_fingerprint=False,
    shuffle=True,
    random_state=42,
)

Using device: cpu


In [5]:
full_data

Unnamed: 0,test_id,reference_number,CAS,test_location,test_exposure_type,test_control_type,test_media_type,test_application_freq_unit,test_organism_lifestage,result_id,...,chem_mordred_WPol,chem_mordred_Zagreb1,chem_mordred_Zagreb2,chem_mordred_mZagreb2,DTXSID,oecd_function_prevalent,oecd_function_all,funcuse_prevalent,funcuse_all,conc_centered
0,1076392,5940,552-89-6,LAB,F,I,FW,CON,NR,115475,...,14,50.0,56.0,2.638889,,,,,,1.294115
1,1147015,12427,2702-72-9,LAB,S,C,FW,X,NR,101573,...,15,60.0,65.0,2.944444,,,,,,1.643887
2,2291766,863,123-86-4,LAB,S,C,SW,X,NR,2680607,...,5,28.0,26.0,2.083333,DTXSID3021982,solvent - surfactant,fragrance - solvent - surfactant,solubilizer - solvent,cosmetic - flavoring - fragrance - odorant - s...,2.355409
3,1001083,182,94-09-7,LAB,S,I,FW,X,NR,112388,...,14,54.0,59.0,2.861111,DTXSID8021804,UV stabilizer - other,UV stabilizer - other,active - oral pain reliever - analgesic,active oral pain reliever - active - oral pai...,1.413871
4,2236363,7199,13171-21-6,LAB,S,V,FW,X,NR,2507183,...,27,80.0,90.0,4.458333,DTXSID7021156,biocide,biocide,insecticide,insecticide,2.454773
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
70665,1035679,2188,1114-71-2,LAB,S,NR,SW,X,JV,185392,...,14,50.0,52.0,3.444444,,,,,,0.640951
70666,2059172,344,95-06-7,LAB,S,K,FW,X,NC,2100640,...,13,48.0,50.0,3.027778,,,,,,-0.135830
70667,1304436,115741,63-25-2,LAB,R,M,SW,DLY,NR,761019,...,21,74.0,85.0,3.472222,DTXSID9020247,biocide,biocide,insecticide,insecticide - pesticide,-1.786213
70668,1178837,15570,563-12-2,LAB,NR,NR,FW,NR,NR,59617,...,24,82.0,88.0,4.750000,DTXSID2024086,biocide,biocide,insecticide,insecticide,-0.390165
