In [1]:
# Standard library
from collections import defaultdict
from typing import List, Tuple
import importlib.resources as pkg_resources
from multiprocessing.pool import ThreadPool

# Third-party scientific stack
import numpy as np
import pandas as pd
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import eigsh

# RDKit
from rdkit import Chem

# PyTorch core
import torch
import torch.nn as nn
import torch.nn.functional as F

# PyTorch Geometric
from torch_geometric.data import Data, Batch, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader as PyGDataLoader

# Polyatomic complexes
from polyatomic_complexes.src.complexes.abstract_complex import AbstractComplex
from polyatomic_complexes.src.complexes import PolyatomicGeometrySMILE

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def set_seed(seed):
    """
    Fix all random seeds for reproducibility across Python, NumPy, and PyTorch.
    """
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(59)

In [3]:
def load_freesolv_data():
  data_path = pkg_resources.files('polyatomic_complexes.dataset.free_solv') / 'FreeSolv.csv'
  df = pd.read_csv(data_path)
  return df

In [4]:
data = load_freesolv_data()
data.dropna(inplace=True)
data.reset_index(drop=True, inplace=True)
data

Unnamed: 0,iupac,smiles,expt,calc
0,"4-methoxy-N,N-dimethyl-benzamide",CN(C)C(=O)c1ccc(cc1)OC,-11.01,-9.625
1,methanesulfonyl chloride,CS(=O)(=O)Cl,-4.87,-6.219
2,3-methylbut-1-ene,CC(C)C=C,1.83,2.452
3,2-ethylpyrazine,CCc1cnccn1,-5.45,-5.809
4,heptan-1-ol,CCCCCCCO,-4.21,-2.917
...,...,...,...,...
637,methyl octanoate,CCCCCCCC(=O)OC,-2.04,-3.035
638,pyrrolidine,C1CCNC1,-5.48,-4.278
639,4-hydroxybenzaldehyde,c1cc(ccc1C=O)O,-8.83,-10.050
640,1-chloroheptane,CCCCCCCCl,0.29,1.467


In [5]:
data

Unnamed: 0,iupac,smiles,expt,calc
0,"4-methoxy-N,N-dimethyl-benzamide",CN(C)C(=O)c1ccc(cc1)OC,-11.01,-9.625
1,methanesulfonyl chloride,CS(=O)(=O)Cl,-4.87,-6.219
2,3-methylbut-1-ene,CC(C)C=C,1.83,2.452
3,2-ethylpyrazine,CCc1cnccn1,-5.45,-5.809
4,heptan-1-ol,CCCCCCCO,-4.21,-2.917
...,...,...,...,...
637,methyl octanoate,CCCCCCCC(=O)OC,-2.04,-3.035
638,pyrrolidine,C1CCNC1,-5.48,-4.278
639,4-hydroxybenzaldehyde,c1cc(ccc1C=O)O,-8.83,-10.050
640,1-chloroheptane,CCCCCCCCl,0.29,1.467


In [6]:
data.columns

Index(['iupac', 'smiles', 'expt', 'calc'], dtype='object')

In [7]:
def build_graph_from_smiles(smile: str, descriptors: List[float], topk_lap: int = 5):
    """
    Builds a PyG Data object with:
      - Node features: chain_0 values (shape [n,1]) + broadcasted descriptors
      - Edge structure and features from get_bonds()
      - Graph-level features: mean/std of chain_1..chain_k + topk eigenvalues of molecule laplacians
    """
    try:
        # 1) Abstract complex
        pg = PolyatomicGeometrySMILE(smile=smile, mode="abstract")
        ac = pg.smiles_to_geom_complex()
        assert isinstance(ac, AbstractComplex)

        # 2) Raw chains
        chains = ac.get_raw_k_chains()
        # chain_0: node-level
        chain0 = chains.get('chain_0')
        if chain0 is None:
            print(f"⚠️ Missing chain_0 for {smile}")
            return None
        n = len(chain0)
        x_node = torch.tensor(chain0, dtype=torch.float32).unsqueeze(1)  # [n,1]
        # graph-level chain stats
        g_stats = []
        for k, arr in chains.items():
            if k == 'chain_0': continue
            a = np.array(arr, dtype=np.float32)
            g_stats.extend([a.mean(), a.std()])
        g_stats = np.array(g_stats, dtype=np.float32)

        # 3) Laplacian eigenvalues (flatten all dims)
        laps = ac.get_laplacians().get('molecule_laplacians', [])
        lap_tuples = []
        for group in laps:
            if isinstance(group, list):
                for item in group:
                    if isinstance(item, tuple) and len(item)==2:
                        lap_tuples.append(item)
            elif isinstance(group, tuple) and len(group)==2:
                lap_tuples.append(group)
        lap_feats = []
        for dim, mat in lap_tuples:
            # convert to sparse if needed
            M = coo_matrix(mat) if not isinstance(mat, coo_matrix) else mat
            k = min(topk_lap, M.shape[0]-1)
            if k <= 0:
                vals = np.zeros(topk_lap, dtype=np.float32)
            else:
                try:
                    vals, _ = eigsh(M, k=k, return_eigenvectors=False)
                except Exception:
                    vals = np.zeros(k, dtype=np.float32)
                if len(vals) < topk_lap:
                    vals = np.pad(vals, (0, topk_lap-len(vals)))
            lap_feats.extend(vals.tolist())
        g_feats = torch.tensor(np.concatenate([g_stats, lap_feats]), dtype=torch.float32)

        # 4) 0-simplices -> node_ids ordering) 0-simplices -> node_ids ordering
        sk = ac.get_skeleta().get('molecule_skeleta', [[]])[0]
        zero = next((lst for dim,lst in sk if dim=='0'), [])
        node_ids = [next(iter(fz))[0] for fz in zero]
        if len(node_ids) != n:
            print(f"⚠️ Node count mismatch for {smile}: {len(node_ids)} vs {n}")
            return None

        # 5) Edges and edge features from bonds
        bonds = ac.get_bonds()  # list of (atom1, atom2, [type, order])
        # Map atom symbols to node indices
        atom_map = defaultdict(list)
        for idx, nid in enumerate(node_ids):
            sym = nid.split('_')[0]
            atom_map[sym].append(idx)

        edge_index_list = []
        edge_attr_list = []
        for a1, a2, (btype, order) in bonds:
            for i in atom_map.get(a1, []):
                for j in atom_map.get(a2, []):
                    edge_index_list.append([i, j])
                    t_int = {'SINGLE': 1, 'DOUBLE': 2, 'TRIPLE': 3}.get(btype, 0)
                    edge_attr_list.append([t_int, float(order)])

        # Fallback: if no bond edges, use 1-simplices from CW-complex
        if not edge_index_list:
            one_list = next((lst for dim, lst in sk if dim == '1'), [])
            for fz in one_list:
                ids = [nid for nid, _ in fz]
                if len(ids) == 2 and ids[0] in node_ids and ids[1] in node_ids:
                    i, j = node_ids[ids[0]], node_ids[ids[1]]
                    edge_index_list.extend([[i, j], [j, i]])
                    # dummy edge_attr for fallback edges
                    edge_attr_list.extend([[0, 0.0], [0, 0.0]])

        edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr_list, dtype=torch.float32)

        # 6) Combine node+graph descriptors
        desc = torch.tensor(descriptors, dtype=torch.float32)
        desc_feats = desc.unsqueeze(0).expand(n,-1)
        x = torch.cat([x_node, desc_feats], dim=1)

        # 7) Build Data
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        data.graph_feats = g_feats
        return data

    except Exception as e:
        print(f"Failed {smile}: {e}")
        return None

In [8]:
def decode_feature_from_frozenset(fzset):
    """
    Expects a frozenset containing a single tuple:
    ('P_3', ('proton', byte_blob))
    Returns a 1D numpy array by decoding the byte blob.
    """
    for _, (_, byte_blob) in fzset:
        try:
            arr = np.frombuffer(byte_blob, dtype=np.float32)
        except (ValueError, TypeError):
            arr = np.frombuffer(byte_blob, dtype=np.float64)
        return arr
    return None

def load_dataset():
    df = load_freesolv_data()
    graphs = []
    for _, row in df.iterrows():
        descriptors = []
        g = build_graph_from_smiles(row['smiles'], descriptors)
        if g is not None:
            g.y = torch.tensor([row['expt']], dtype=torch.float)
            graphs.append(g)
    return graphs


def process_row(args):
    idx, row = args
    try:
        descriptors = []
        graph = build_graph_from_smiles(row['smiles'], descriptors)
        if graph is not None:
            graph.y = torch.tensor([row['expt']], dtype=torch.float)
        print(f"Success for smile {row['smiles']} at index {idx}")
        return idx, graph
    except Exception as e:
        print(f"Failed on row {idx}: {e}")
        return idx, None

In [9]:
def load_dataset_parallel(df: pd.DataFrame, num_workers=4):
    with ThreadPool(num_workers) as pool:
        results = pool.map(process_row, [(i, row) for i, row in df.iterrows()])
    results = sorted([r for r in results if r[1] is not None], key=lambda x: x[0])
    return [g for _, g in results]

In [10]:
loading = False
if not loading:
    all_data = load_dataset_parallel(data, num_workers=8)
    torch.save(all_data, 'all_data_freesolv.pt')

Success for smile CCCC at index 84
Success for smile CCl at index 85
Success for smile CCc1ccccc1O at index 147
Success for smile [C@@H](C(F)(F)F)(OC(F)F)Cl at index 105
Success for smile CC(C)CBr at index 86
Success for smile COC(=O)c1ccc(cc1)O at index 63
Success for smile c1ccc2cc(ccc2c1)O at index 42
Success for smile CN(C)C(=O)c1ccc(cc1)OC at index 0
Success for smile CC(C)(C)Cl at index 148
Success for smile c1ccc2c(c1)ccc3c2cccc3 at index 126
Success for smile C=CCCC=C at index 106
Success for smile CC(=C)C=C at index 149
Success for smile CC(C)SC(C)C at index 87
Success for smile CI at index 127
Success for smile CS(=O)(=O)Cl at index 1
Success for smile CCCCCc1ccccc1 at index 64
Success for smile c1cc(c(cc1Cl)Cl)Cl at index 43
Success for smile Cc1cccc(c1)C at index 107
Success for smile CCCCCCC at index 88
Success for smile CC(F)F at index 65
Success for smile CC(C)C=C at index 2
Success for smile Cc1ccc(cc1)C(C)C at index 150
Success for smile CC(=O)OC at index 108
Success f

  g_stats.extend([a.mean(), a.std()])
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


Success for smile CC(=O)OCCOC(=O)C at index 20
Success for smile CC(C)CO at index 143
Success for smile CC(=O)NC at index 40
Success for smile Cn1cnc2c1c(=O)n(c(=O)n2C)C at index 121
Success for smile CCc1ccccc1 at index 166
Success for smile CCCCCCCCBr at index 81
Success for smile CNC at index 122
Success for smile CCN(CC)CC at index 168
Success for smile CCCOC(=O)CC at index 144
Success for smile CCCCCCCC=C at index 41
Success for smile C(=C(Cl)Cl)Cl at index 167
Success for smile CN(CC(F)(F)F)c1ccccc1 at index 104
Success for smile c1ccc(cc1)CO at index 82
Success for smile c1ccc(c(c1)C(F)(F)F)C(F)(F)F at index 62
Success for smile C(=C(F)F)(C(F)(F)F)F at index 123
Success for smile CCCOCCO at index 210
Success for smile CC(C)NC(C)C at index 231
Success for smile C(C(Cl)(Cl)Cl)(Cl)(Cl)Cl at index 145
Success for smile c1cc(c(c(c1)Cl)Cl)Cl at index 189
Success for smile c1cc(ccc1O)Cl at index 124
Success for smile CCCCCCI at index 252
Success for smile c1c(c(=O)[nH]c(=O)[nH]1)Br at 

  g_stats.extend([a.mean(), a.std()])
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


Success for smile CC(OC)(OC)OC at index 179
Success for smile CC(C)(C)OC at index 221
Success for smile CCCc1ccc(cc1)O at index 300
Success for smile CCOCC at index 196
Success for smile CC(C)(C)C(=O)OC at index 241
Success for smile c1cc(ccc1C#N)O at index 323
Success for smile CC=C(C)C at index 301
Success for smile CC(=O)c1cccnc1 at index 280
Success for smile CCCCCI at index 261
Success for smile CC#C at index 281
Success for smile CCCCc1ccccc1 at index 180
Success for smile C(CCl)Cl at index 302
Success for smile CS(=O)(=O)C at index 324
Success for smile COC(OC)OC at index 262
Success for smile C([N+](=O)[O-])(Cl)(Cl)Cl at index 242
Success for smile CCC(C)(C)CC at index 303
Success for smile CCCCCCCCC=O at index 282
Success for smile CN(C)c1ccccc1 at index 181
Success for smile CCNc1nc(nc(n1)SC)NC(C)C at index 197
Success for smile CCc1cccc(c1)O at index 325
Success for smile CCC(=O)O at index 283
Success for smile CC(C)OC at index 182
Success for smile CCCCCCCCCC at index 263
S

  g_stats.extend([a.mean(), a.std()])
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


Success for smile COC(=O)C(F)(F)F at index 224
Success for smile Cc1cccc(n1)C at index 305
Success for smile Cc1cccc(c1)O at index 245
Success for smile CCCCCCOC(=O)C at index 199
Success for smile c1ccc(cc1)CCl at index 287
Success for smile c1cc(ccc1Br)Br at index 327
Success for smile COC(C(Cl)Cl)(F)F at index 306
Success for smile Cc1c[nH]c2c1cccc2 at index 266
Success for smile c1ccc2ccccc2c1 at index 225
Success for smile C1CCC(=O)C1 at index 200
Success for smile CC1CCCCC1 at index 288
Success for smile c12c(c(c(c(c1Cl)Cl)Cl)Cl)Oc3c(c(c(c(c3Cl)Cl)Cl)Cl)O2 at index 183
Success for smile CCOCCOC(=O)C at index 307
Success for smile CCCCC(=O)O at index 201
Success for smile Cc1cccs1 at index 289
Success for smile COc1c(ccc(c1C(=O)O)Cl)Cl at index 328
Success for smile CCBr at index 202
Success for smile c1ccc2c(c1)C(=O)c3c(ccc(c3C2=O)O)N at index 246
Success for smile CC/C=C\C at index 329
Success for smile COc1cccc(c1)N at index 308
Success for smile COP(=O)([C@H](C(Cl)(Cl)Cl)O)OC 

In [11]:
class PolyatomicNet(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim=64, output_dim=1):
        super().__init__()
        self.embed   = nn.Linear(input_dim, hidden_dim)
        self.conv1   = GCNConv(hidden_dim,   hidden_dim)
        self.conv2   = GCNConv(hidden_dim,   hidden_dim)
        self.readout = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch):
        if edge_index.numel() == 0 or edge_index.size(1) == 0:
            n = x.size(0)
            edge_index = torch.stack([torch.arange(n), torch.arange(n)], dim=0).to(x.device)
        x = F.relu(self.embed(x))
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.readout(x).view(-1)

In [12]:
def collate_with_graph_feats(batch_list):
    graph_feats = torch.stack([data.graph_feats for data in batch_list], dim=0)
    for data in batch_list:
        del data.graph_feats
    batched = Batch.from_data_list(batch_list)
    batched.graph_feats = graph_feats
    return batched

In [13]:
def train(model, loader, opt, loss_fn):
    model.train()
    total = 0.0
    for batch in loader:
        opt.zero_grad()
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = loss_fn(out, batch.y.view(-1))
        loss.backward()
        opt.step()
        total += loss.item() * batch.num_graphs
    return total / len(loader.dataset)

import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error
def compute_metrics_with_ci(trues, preds, n_boot=2000, alpha=0.05, seed=42):
    """
    Compute MAE and RMSE with bootstrap 95% confidence intervals.
    """
    trues = np.array(trues)
    preds = np.array(preds)
    mae = mean_absolute_error(trues, preds)
    rmse = np.sqrt(mean_squared_error(trues, preds))

    rng = np.random.RandomState(seed)
    mae_samples = []
    rmse_samples = []
    n = len(trues)
    for _ in range(n_boot):
        idx = rng.randint(0, n, n)
        t = trues[idx]
        p = preds[idx]
        mae_samples.append(mean_absolute_error(t, p))
        rmse_samples.append(np.sqrt(mean_squared_error(t, p)))

    lower = 100 * (alpha/2)
    upper = 100 * (1 - alpha/2)
    mae_ci = (np.percentile(mae_samples, lower), np.percentile(mae_samples, upper))
    rmse_ci = (np.percentile(rmse_samples, lower), np.percentile(rmse_samples, upper))
    return {'mae': mae, 'mae_ci': mae_ci, 'rmse': rmse, 'rmse_ci': rmse_ci}

def evaluate(model, loader):
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for batch in loader:
            out = model(batch.x, batch.edge_index, batch.batch)
            preds.append(out)
            trues.append(batch.y.view(-1))
    preds = torch.cat(preds)
    trues = torch.cat(trues)
    return torch.sqrt(torch.mean((preds - trues)**2)).item()


def evaluate_with_ci(model, loader):
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for batch in loader:
            out = model(batch.x, batch.edge_index, batch.batch)
            y_true = batch.y.view(-1).cpu().tolist()
            y_pred = out.view(-1).cpu().tolist()
            trues.extend(y_true)
            preds.extend(y_pred)
    metrics = compute_metrics_with_ci(trues, preds)
    return metrics


In [14]:
loading = True
if loading:
    all_data = torch.load('all_data_freesolv.pt', weights_only=False)
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    ys = np.array([d.y.item() for d in all_data]).reshape(-1, 1)
    ys_scaled = scaler.fit_transform(ys)
    for i, d in enumerate(all_data):
        d.y = torch.tensor([ys_scaled[i][0]], dtype=torch.float32)

In [15]:
data_list = all_data
train_n = int(0.8*len(data_list))
train_ds, test_ds = torch.utils.data.random_split(data_list, [train_n,len(data_list)-train_n])
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=32)



input_dim = all_data[0].x.size(1)
model     = PolyatomicNet(input_dim=input_dim, hidden_dim=64, output_dim=1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn   = nn.MSELoss()

history = {}

for epoch in range(1, 21):
    tr_loss = train(model, train_loader, optimizer, loss_fn)
    metrics = evaluate_with_ci(model, test_loader)
    print(f"Epoch {epoch:02d} | Train Loss: {tr_loss:.4f} | "
          f"Test MAE: {metrics['mae']:.4f} (95% CI [{metrics['mae_ci'][0]:.4f}, {metrics['mae_ci'][1]:.4f}]) | "
          f"Test RMSE: {metrics['rmse']:.4f} (95% CI [{metrics['rmse_ci'][0]:.4f}, {metrics['rmse_ci'][1]:.4f}])")
    history[epoch] = metrics
    



Epoch 01 | Train Loss: 1.0336 | Test MAE: 0.7253 (95% CI [0.6252, 0.8301]) | Test RMSE: 0.9391 (95% CI [0.7919, 1.0960])
Epoch 02 | Train Loss: 1.0342 | Test MAE: 0.7086 (95% CI [0.6081, 0.8115]) | Test RMSE: 0.9269 (95% CI [0.7741, 1.0910])
Epoch 03 | Train Loss: 1.0265 | Test MAE: 0.7049 (95% CI [0.6039, 0.8072]) | Test RMSE: 0.9239 (95% CI [0.7693, 1.0890])
Epoch 04 | Train Loss: 1.0246 | Test MAE: 0.7067 (95% CI [0.6053, 0.8093]) | Test RMSE: 0.9240 (95% CI [0.7726, 1.0873])
Epoch 05 | Train Loss: 1.0179 | Test MAE: 0.7054 (95% CI [0.6032, 0.8078]) | Test RMSE: 0.9223 (95% CI [0.7720, 1.0838])
Epoch 06 | Train Loss: 1.0186 | Test MAE: 0.6991 (95% CI [0.5976, 0.8012]) | Test RMSE: 0.9175 (95% CI [0.7637, 1.0809])
Epoch 07 | Train Loss: 1.0110 | Test MAE: 0.6939 (95% CI [0.5918, 0.7969]) | Test RMSE: 0.9141 (95% CI [0.7598, 1.0776])
Epoch 08 | Train Loss: 1.0083 | Test MAE: 0.6836 (95% CI [0.5802, 0.7886]) | Test RMSE: 0.9090 (95% CI [0.7515, 1.0761])
Epoch 09 | Train Loss: 1.0107 | 

In [16]:
final = history[20]
print("*"*20)
print(f"Test MAE: {final['mae']:.4f} (95% CI [{final['mae_ci'][0]:.4f}, {final['mae_ci'][1]:.4f}])")
print(f"Test RMSE: {final['rmse']:.4f} (95% CI [{final['rmse_ci'][0]:.4f}, {final['rmse_ci'][1]:.4f}])")
print("*"*20)

********************
Test MAE: 0.6749 (95% CI [0.5740, 0.7770])
Test RMSE: 0.8974 (95% CI [0.7448, 1.0564])
********************
