## 1) Let's make sure we have all our imports.

In [25]:
import numpy as np
import torch
from itertools import product
from typing import Sequence, Tuple, Optional
import time
import pandas as pd
import os
import glob
import math
import matplotlib.pyplot as plt
import json
import torch.nn as nn
from pathlib import Path
from typing import List, Tuple, Dict, Optional, Union
import logging
from dataclasses import dataclass


In [26]:
!pip install rdkit
!pip install tqdm



In [27]:
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, Crippen
from tqdm import tqdm


## 2) Utility Functions

In [28]:
################################################################################
# Utility functions
################################################################################

def compute_translation(center: np.ndarray,
                        grid_size: int = 48,
                        voxel_size: float = 1.0) -> np.ndarray:
    """Return the translation vector that places *center* at the grid centre."""
    grid_center = ((grid_size - 1) / 2.0) * voxel_size
    return grid_center - center


def gaussian_kernel(radius: int = 2, sigma: float = 1.0) -> Tuple[np.ndarray, np.ndarray]:
    """Return neighbour offsets and weights for a 3‑D isotropic Gaussian.

    Parameters
    ----------
    radius
        Number of voxels in each direction to include.
    sigma
        Standard deviation (in voxels).

    Returns
    -------
    offsets : (K, 3) int array
        Relative integer offsets.
    weights : (K,) float array
        Corresponding weights ∈ [0, 1].
    """
    offsets = np.array(list(product(range(-radius, radius + 1),
                                    repeat=3)), dtype=np.int8)
    # exclude origin – it is handled separately
    offsets = offsets[(offsets != 0).any(axis=1)]
    distances2 = (offsets ** 2).sum(axis=1).astype(np.float32)
    weights = np.exp(-distances2 / (2 * sigma ** 2))
    return offsets, weights



In [29]:
################################################################################
# Core voxelisation
################################################################################

def voxelise_atoms(atom_coords: np.ndarray,
                   atom_feats: np.ndarray,
                   *,
                   grid_size: int = 48,
                   voxel_size: float = 1.0,
                   centre: Optional[np.ndarray] = None,
                   gaussian_radius: int = 2,
                   gaussian_sigma: float = 1.0
                   ) -> np.ndarray:
    """Convert atomic coordinates + features to a (C, N, N, N) voxel grid.

    Parameters
    ----------
    atom_coords : (N_atoms, 3) float32 array
        Cartesian XYZ coordinates in ångströms.
    atom_feats : (N_atoms, C) float32 array
        Per‑atom feature vectors (same channels *C* for every atom).
    centre
        If *None*, the coordinates’ centroid is used. Otherwise, supply a
        (3,) array – for example, the ligand centroid – that will be mapped
        to the geometric centre of the grid.
    gaussian_radius
        Radius, in voxels, of the isotropic Gaussian used to smooth/distribute
        the signal. Use 0 to disable smoothing.
    gaussian_sigma
        Standard deviation (in voxels) of the Gaussian.

    Returns
    -------
    grid : (C, grid_size, grid_size, grid_size) float32 array
    """
    assert atom_coords.shape[0] == atom_feats.shape[0], "Mismatched atoms/feats"
    C = atom_feats.shape[1]
    grid = np.zeros((C, grid_size, grid_size, grid_size), dtype=np.float32)

    if centre is None:
        centre = atom_coords.mean(axis=0, dtype=np.float32)

    # translate so that *centre* maps to the grid midpoint
    translation = compute_translation(centre, grid_size, voxel_size)
    shifted = atom_coords + translation[None, :]

    # integer voxel indices
    indices = np.floor(shifted / voxel_size).astype(np.int16)  # (N,3)

    # mask atoms that fall outside the grid
    valid = ((indices >= 0) & (indices < grid_size)).all(axis=1)
    indices, feats = indices[valid], atom_feats[valid]

    # pre‑build Gaussian neighbours
    if gaussian_radius > 0:
        neigh_offsets, neigh_weights = gaussian_kernel(gaussian_radius, gaussian_sigma)

    for (x, y, z), f in zip(indices, feats):
        # contribute to the central voxel
        grid[:, x, y, z] += f

        # optionally spread to neighbours
        if gaussian_radius > 0:
            neigh_idx = np.stack([x, y, z]) + neigh_offsets  # (K,3)
            # keep neighbours inside the grid boundaries
            valid_neigh = ((neigh_idx >= 0) & (neigh_idx < grid_size)).all(axis=1)
            neigh_idx = neigh_idx[valid_neigh]
            neigh_w   = neigh_weights[valid_neigh]

            for (nx, ny, nz), w in zip(neigh_idx, neigh_w):
                grid[:, nx, ny, nz] += f * w

    return grid


In [30]:
# 19-channel layout ------------------------------------------------------------
#  0-8 : one-hot element   (B, C, N, O, P, S, Se, halogen, metal)
#    9 : hybridisation     (1, 2, 3  → sp, sp2, sp3; 0 = “other”)
#   10 : heavy-atom bonds  (# neighbours with Z > 1)
#   11 : hetero-atom bonds (# neighbours with Z not in {1, 6})
# 12-16: one-hot structural (hydrophobic, aromatic, acceptor, donor, ring)
#   17 : partial charge    (Gasteiger)
#   18 : molecule type     (–1 protein, +1 ligand)
# ------------------------------------------------------------------------------
FEATURE_DIM = 19
ELEMENTS     = {5:0, 6:1, 7:2, 8:3, 15:4, 16:5, 34:6}       # explicit slots
HALOGENS     = {9, 17, 35, 53}
# every other Z > 1 that isn’t in ELEMENTS ∪ HALOGENS is considered “metal”

# ----------------------------------------------------------
# 2.  Atom-level feature builder
# ----------------------------------------------------------
def atom_to_feature_vec(atom, mol_type: int) -> np.ndarray:
    """Return a (19,) float32 feature vector for one RDKit atom."""
    vec = np.zeros(FEATURE_DIM, dtype=np.float32)

    # 1) element (9 one-hot channels)
    Z   = atom.GetAtomicNum()
    idx = ELEMENTS.get(Z, 7 if Z in HALOGENS else 8)
    vec[idx] = 1.0                           # channels 0-8

    # 2) hybridisation
    hyb = {Chem.rdchem.HybridizationType.SP: 1,
           Chem.rdchem.HybridizationType.SP2: 2,
           Chem.rdchem.HybridizationType.SP3: 3}.get(atom.GetHybridization(), 0)
    vec[9] = float(hyb)

    # 3) heavy-atom & hetero-atom bond counts
    heavy = hetero = 0
    for b in atom.GetBonds():
        Z_nbr = b.GetOtherAtom(atom).GetAtomicNum()
        if Z_nbr > 1:  heavy += 1
        if Z_nbr not in (1, 6): hetero += 1
    vec[10], vec[11] = float(heavy), float(hetero)

    # 4) structural 5-hot
    vec[12] = 1.0 if Z == 6 and not atom.GetIsAromatic() else 0.0   # hydrophobic
    vec[13] = 1.0 if atom.GetIsAromatic() else 0.0                 # aromatic
    vec[14] = 1.0 if atom.GetTotalNumHs() < atom.GetNumImplicitHs() else 0.0  # acceptor (very rough)
    vec[15] = 1.0 if atom.GetTotalNumHs() > 0 and Z in (7, 8) else 0.0        # donor   (rough)
    vec[16] = 1.0 if atom.IsInRing() else 0.0                                   # ring

    # 5) partial charge (need Gasteiger first)
    charge = 0.0 # Default charge
    if atom.HasProp('_GasteigerCharge'):
        try:
            charge = float(atom.GetProp('_GasteigerCharge'))
        except ValueError:
            print(f"Warning: Could not convert Gasteiger charge to float for atom {atom.GetIdx()} in molecule.")
            charge = 0.0 # Set to 0 if conversion fails

    vec[17] = charge

    # 6) molecule type
    vec[18] = float(mol_type)

    return vec

C/19: This is the number of channels or features for each voxel. These 19 channels correspond to different atomic properties, such as:

  * One-hot encoding for elements (B, C, N, O, P, S, Se, halogen, metal)
  * Hybridization
  * Heavy-atom bonds
  * Hetero-atom bonds
  * One-hot structural properties (hydrophobic, aromatic, acceptor,donor, ring)
  * Partial charge (Gasteiger)
  * Molecule type (protein or ligand)

# 🧬 Ligand and Pocket Processing

Now let's extend our analysis to include ligands and pockets as well. We'll create similar processors for these molecular components to enable comprehensive protein-ligand interaction analysis.

In [None]:
### make sure you're in bindingaffinity
data_dir = "./demo_data_PROCESSED"

In [34]:
# Find all ligand and pocket files
print(" Finding ligand and pocket files...")

all_ligand_files = []
all_pocket_files = []
ligand_ids = set()
pocket_ids = set()

for pdb_id in os.listdir(data_dir):
    complex_path = os.path.join(data_dir, pdb_id)
    if os.path.isdir(complex_path):
        ligand_file = os.path.join(complex_path, f"{pdb_id}_ligand.mol2")
        pocket_file = os.path.join(complex_path, f"{pdb_id}_pocket.mol2")
        
        if os.path.exists(ligand_file):
            all_ligand_files.append(ligand_file)
            ligand_ids.add(pdb_id)
        if os.path.exists(pocket_file):
            all_pocket_files.append(pocket_file)
            pocket_ids.add(pdb_id)

# Count pairs where both ligand and pocket exist
paired_ids = ligand_ids & pocket_ids

print(f" Found {len(all_ligand_files)} ligand files")
print(f" Found {len(all_pocket_files)} pocket files")
print(f" Found {len(paired_ids)} pairs with both ligand and pocket files")

# Show some examples
print(f"\n Example files:")
if all_ligand_files:
    print(f"   Ligand: {all_ligand_files[0]}")
if all_pocket_files:
    print(f"   Pocket: {all_pocket_files[0]}")

 Finding ligand and pocket files...
 Found 229 ligand files
 Found 229 pocket files
 Found 229 pairs with both ligand and pocket files

 Example files:
   Ligand: ./demo_data_PROCESSED/3nw9/3nw9_ligand.mol2
   Pocket: ./demo_data_PROCESSED/3nw9/3nw9_pocket.mol2


In [35]:
@dataclass
class LigandFeatureConfig:
    """Configuration for ligand featurization"""
    include_hydrogens: bool = False
    grid_size: int = 48
    voxel_size: float = 1.0
    channels: int = 19  # Same as protein for consistency

@dataclass
class PocketFeatureConfig:
    """Configuration for pocket featurization"""
    include_hydrogens: bool = False
    grid_size: int = 48
    voxel_size: float = 1.0
    channels: int = 19  # Same as protein for consistency
    cutoff: float = 6.0  # Distance from ligand to define pocket

# Create global configurations
LIGAND_CONFIG = LigandFeatureConfig()
POCKET_CONFIG = PocketFeatureConfig()

print(" Ligand and Pocket configurations initialized")
print(f"   - Ligand config: grid_size={LIGAND_CONFIG.grid_size}, channels={LIGAND_CONFIG.channels}")
print(f"   - Pocket config: grid_size={POCKET_CONFIG.grid_size}, channels={POCKET_CONFIG.channels}, cutoff={POCKET_CONFIG.cutoff}Å")

 Ligand and Pocket configurations initialized
   - Ligand config: grid_size=48, channels=19
   - Pocket config: grid_size=48, channels=19, cutoff=6.0Å


In [36]:
class LigandDataProcessor:
    """Process ligand MOL2 files into 3D grids"""
    
    def __init__(self, config: LigandFeatureConfig):
        self.config = config
        self.channels = config.channels
        self.grid_size = config.grid_size
        self.voxel_size = config.voxel_size
        
    def load_mol2(self, file_path):
        """Load MOL2 file and return molecule and conformer"""
        metadata = {
            'file': file_path,
            'errors': [],
            'loading_success': False,
            'sanitization_success': False
        }
        
        if not os.path.exists(file_path):
            metadata['errors'].append("File does not exist")
            return None, None, metadata
            
        try:
            # First try MOL2 format
            mol = Chem.MolFromMol2File(file_path, sanitize=False, removeHs=not self.config.include_hydrogens)
            
            # If MOL2 fails, try SDF format (common issue with .mol2 extensions)
            if mol is None:
                mol = Chem.MolFromMolFile(file_path, sanitize=False, removeHs=not self.config.include_hydrogens)
            
            if mol is None:
                metadata['errors'].append("Failed to load molecule from both MOL2 and SDF formats")
                return None, None, metadata
            
            metadata['loading_success'] = True
            metadata['num_atoms'] = mol.GetNumAtoms()
            metadata['num_heavy_atoms'] = mol.GetNumHeavyAtoms()
            
            # Attempt sanitization
            try:
                sanitize_flags = (Chem.SanitizeFlags.SANITIZE_ALL ^ 
                                Chem.SanitizeFlags.SANITIZE_KEKULIZE ^
                                Chem.SanitizeFlags.SANITIZE_SETAROMATICITY)
                Chem.SanitizeMol(mol, sanitizeOps=sanitize_flags)
                metadata['sanitization_success'] = True
            except Exception as e:
                print(f"Warning: Sanitization failed for {file_path}: {e}")
                metadata['errors'].append(f"Sanitization failed: {e}")
            
            # Get conformer
            conf = mol.GetConformer()
            return mol, conf, metadata
            
        except Exception as e:
            metadata['errors'].append(f"Exception during loading: {e}")
            return None, None, metadata
    
    def extract_atoms_from_mol(self, mol, conf):
        """Extract atomic coordinates and features from molecule"""
        coords, feats = [], []
        
        try:
            for i, atom in enumerate(mol.GetAtoms()):
                try:
                    pos = conf.GetAtomPosition(i)
                    p = np.array([pos.x, pos.y, pos.z], dtype=np.float32)
                    coords.append(p)
                    # Use molecule type +1 for ligand
                    feats.append(atom_to_feature_vec(atom, +1))
                except Exception as e:
                    print(f"Error extracting ligand atom {i}: {e}")
                    continue
                    
        except Exception as e:
            print(f"Error processing ligand molecule: {e}")
            return np.empty((0, 3), dtype=np.float32), np.empty((0, self.channels), dtype=np.float32)
        
        return np.asarray(coords, np.float32), np.asarray(feats, np.float32)
    
    def coords_to_grid(self, coords, feats, center=None):
        """Convert atomic coordinates and features to 3D grid"""
        if len(coords) == 0:
            return np.zeros((self.channels, self.grid_size, self.grid_size, self.grid_size), dtype=np.float32)
        
        # Use ligand centroid as center if not provided
        if center is None:
            center = coords.mean(axis=0)
        
        # Create grid
        grid = np.zeros((self.channels, self.grid_size, self.grid_size, self.grid_size), dtype=np.float32)
        
        # Grid bounds
        grid_min = center - (self.grid_size * self.voxel_size) / 2
        grid_max = center + (self.grid_size * self.voxel_size) / 2
        
        for coord, feat in zip(coords, feats):
            # Convert coordinates to grid indices
            grid_coord = (coord - grid_min) / self.voxel_size
            
            # Check if within grid bounds
            if np.all(grid_coord >= 0) and np.all(grid_coord < self.grid_size):
                x, y, z = grid_coord.astype(int)
                if 0 <= x < self.grid_size and 0 <= y < self.grid_size and 0 <= z < self.grid_size:
                    grid[:, x, y, z] += feat
        
        return grid
    
    def process_ligand_file(self, file_path):
        """Process a single ligand MOL2 file"""
        start_time = time.time()
        
        # Load molecule
        mol, conf, metadata = self.load_mol2(file_path)
        
        if mol is None or conf is None:
            metadata['success'] = False
            metadata['processing_time'] = time.time() - start_time
            return None, metadata
        
        # Extract atomic information
        coords, feats = self.extract_atoms_from_mol(mol, conf)
        
        if len(coords) == 0:
            metadata['errors'].append("No valid atoms extracted")
            metadata['success'] = False
            metadata['processing_time'] = time.time() - start_time
            return None, metadata
        
        # Convert to grid
        grid = self.coords_to_grid(coords, feats)
        
        # Update metadata
        metadata['grid_shape'] = list(grid.shape)
        metadata['grid_occupancy'] = np.count_nonzero(grid) / grid.size
        metadata['processing_time'] = time.time() - start_time
        metadata['success'] = True
        metadata['ligand_id'] = os.path.basename(os.path.dirname(file_path))
        
        return grid, metadata

In [37]:
class PocketDataProcessor:
    """Process pocket MOL2 files into 3D grids"""
    
    def __init__(self, config: PocketFeatureConfig):
        self.config = config
        self.channels = config.channels
        self.grid_size = config.grid_size
        self.voxel_size = config.voxel_size
        self.cutoff = config.cutoff
        
    def load_mol2(self, file_path):
        """Load MOL2 file and return molecule and conformer"""
        metadata = {
            'file': file_path,
            'errors': [],
            'loading_success': False,
            'sanitization_success': False
        }
        
        if not os.path.exists(file_path):
            metadata['errors'].append("File does not exist")
            return None, None, metadata
            
        try:
            # First try MOL2 format
            mol = Chem.MolFromMol2File(file_path, sanitize=False, removeHs=not self.config.include_hydrogens)
            
            # If MOL2 fails, try SDF format (common issue with .mol2 extensions)
            if mol is None:
                mol = Chem.MolFromMolFile(file_path, sanitize=False, removeHs=not self.config.include_hydrogens)
            
            if mol is None:
                metadata['errors'].append("Failed to load molecule from both MOL2 and SDF formats")
                return None, None, metadata
            
            metadata['loading_success'] = True
            metadata['num_atoms'] = mol.GetNumAtoms()
            metadata['num_heavy_atoms'] = mol.GetNumHeavyAtoms()
            
            # Attempt sanitization
            try:
                sanitize_flags = (Chem.SanitizeFlags.SANITIZE_ALL ^ 
                                Chem.SanitizeFlags.SANITIZE_KEKULIZE ^
                                Chem.SanitizeFlags.SANITIZE_SETAROMATICITY)
                Chem.SanitizeMol(mol, sanitizeOps=sanitize_flags)
                metadata['sanitization_success'] = True
            except Exception as e:
                print(f"Warning: Sanitization failed for {file_path}: {e}")
                metadata['errors'].append(f"Sanitization failed: {e}")
            
            # Get conformer
            conf = mol.GetConformer()
            return mol, conf, metadata
            
        except Exception as e:
            metadata['errors'].append(f"Exception during loading: {e}")
            return None, None, metadata
    
    def extract_atoms_from_mol(self, mol, conf, ligand_center=None):
        """Extract atomic coordinates and features from pocket molecule"""
        coords, feats = [], []
        
        try:
            for i, atom in enumerate(mol.GetAtoms()):
                try:
                    pos = conf.GetAtomPosition(i)
                    p = np.array([pos.x, pos.y, pos.z], dtype=np.float32)
                    
                    # If ligand center is provided, only include atoms within cutoff
                    if ligand_center is not None:
                        distance = np.linalg.norm(p - ligand_center)
                        if distance > self.cutoff:
                            continue
                    
                    coords.append(p)
                    # Use molecule type 0 for pocket (different from protein -1 and ligand +1)
                    feats.append(atom_to_feature_vec(atom, 0))
                except Exception as e:
                    print(f"Error extracting pocket atom {i}: {e}")
                    continue
                    
        except Exception as e:
            print(f"Error processing pocket molecule: {e}")
            return np.empty((0, 3), dtype=np.float32), np.empty((0, self.channels), dtype=np.float32)
        
        return np.asarray(coords, np.float32), np.asarray(feats, np.float32)
    
    def coords_to_grid(self, coords, feats, center):
        """Convert atomic coordinates and features to 3D grid"""
        if len(coords) == 0:
            return np.zeros((self.channels, self.grid_size, self.grid_size, self.grid_size), dtype=np.float32)
        
        # Create grid
        grid = np.zeros((self.channels, self.grid_size, self.grid_size, self.grid_size), dtype=np.float32)
        
        # Grid bounds
        grid_min = center - (self.grid_size * self.voxel_size) / 2
        grid_max = center + (self.grid_size * self.voxel_size) / 2
        
        for coord, feat in zip(coords, feats):
            # Convert coordinates to grid indices
            grid_coord = (coord - grid_min) / self.voxel_size
            
            # Check if within grid bounds
            if np.all(grid_coord >= 0) and np.all(grid_coord < self.grid_size):
                x, y, z = grid_coord.astype(int)
                if 0 <= x < self.grid_size and 0 <= y < self.grid_size and 0 <= z < self.grid_size:
                    grid[:, x, y, z] += feat
        
        return grid
    
    def process_pocket_file(self, pocket_path, ligand_path=None):
        """Process a single pocket MOL2 file, optionally using ligand for centering"""
        start_time = time.time()
        
        # Load pocket molecule
        mol, conf, metadata = self.load_mol2(pocket_path)
        
        if mol is None or conf is None:
            metadata['success'] = False
            metadata['processing_time'] = time.time() - start_time
            return None, metadata
        
        # Determine center (ligand center if available, otherwise pocket centroid)
        ligand_center = None
        if ligand_path and os.path.exists(ligand_path):
            try:
                ligand_processor = LigandDataProcessor(LIGAND_CONFIG)
                lig_mol, lig_conf, _ = ligand_processor.load_mol2(ligand_path)
                if lig_mol and lig_conf:
                    lig_coords = np.array([lig_conf.GetAtomPosition(i) for i in range(lig_mol.GetNumAtoms())])
                    ligand_center = lig_coords.mean(axis=0)
            except Exception as e:
                print(f"Warning: Could not load ligand for centering: {e}")
        
        # Extract atomic information
        coords, feats = self.extract_atoms_from_mol(mol, conf, ligand_center)
        
        if len(coords) == 0:
            metadata['errors'].append("No valid atoms extracted")
            metadata['success'] = False
            metadata['processing_time'] = time.time() - start_time
            return None, metadata
        
        # Use ligand center if available, otherwise pocket centroid
        center = ligand_center if ligand_center is not None else coords.mean(axis=0)
        
        # Convert to grid
        grid = self.coords_to_grid(coords, feats, center)
        
        # Update metadata
        metadata['grid_shape'] = list(grid.shape)
        metadata['grid_occupancy'] = np.count_nonzero(grid) / grid.size
        metadata['processing_time'] = time.time() - start_time
        metadata['success'] = True
        metadata['pocket_id'] = os.path.basename(os.path.dirname(pocket_path))
        metadata['used_ligand_center'] = ligand_center is not None
        
        return grid, metadata

In [38]:
# Process sample ligand files
print(" Testing ligand processing with sample files...")

ligand_processor = LigandDataProcessor(LIGAND_CONFIG)
sample_ligand_grids = []
sample_ligand_metadata = []

# Process first 5 ligand files as a test
sample_ligand_files = all_ligand_files[:5]
print(f"Processing {len(sample_ligand_files)} sample ligand files...")

for i, ligand_file in enumerate(sample_ligand_files):
    print(f"  Processing ligand {i+1}/{len(sample_ligand_files)}: {os.path.basename(os.path.dirname(ligand_file))}")
    
    try:
        grid, metadata = ligand_processor.process_ligand_file(ligand_file)
        
        if grid is not None:
            sample_ligand_grids.append(grid)
            sample_ligand_metadata.append(metadata)
            print(f"     Success: shape={grid.shape}, occupancy={metadata['grid_occupancy']:.4f}")
        else:
            sample_ligand_metadata.append(metadata)
            print(f"     Failed: {metadata.get('errors', ['Unknown error'])}")
            
    except Exception as e:
        print(f"     Exception: {e}")
        error_metadata = {
            'file': ligand_file,
            'error': str(e),
            'success': False
        }
        sample_ligand_metadata.append(error_metadata)

print(f"\n Ligand Processing Results:")
print(f"   - Successfully processed: {len(sample_ligand_grids)} ligands")
print(f"   - Failed: {len(sample_ligand_metadata) - len(sample_ligand_grids)} ligands")

if sample_ligand_grids:
    sample_grid = sample_ligand_grids[0]
    print(f"   - Sample grid shape: {sample_grid.shape}")
    print(f"   - Sample grid occupancy: {np.count_nonzero(sample_grid) / sample_grid.size:.4f}")
    print(f"   - Sample grid memory: {sample_grid.nbytes / 1024**2:.2f} MB")

 Testing ligand processing with sample files...
Processing 5 sample ligand files...
  Processing ligand 1/5: 3nw9
     Success: shape=(19, 48, 48, 48), occupancy=0.0001
  Processing ligand 2/5: 2v7a
     Success: shape=(19, 48, 48, 48), occupancy=0.0001
  Processing ligand 3/5: 2qnq
     Success: shape=(19, 48, 48, 48), occupancy=0.0002
  Processing ligand 4/5: 3u5j
     Success: shape=(19, 48, 48, 48), occupancy=0.0001
  Processing ligand 5/5: 4eor
     Success: shape=(19, 48, 48, 48), occupancy=0.0001

 Ligand Processing Results:
   - Successfully processed: 5 ligands
   - Failed: 0 ligands
   - Sample grid shape: (19, 48, 48, 48)
   - Sample grid occupancy: 0.0001
   - Sample grid memory: 8.02 MB


In [39]:
# Process sample pocket files
print(" Testing pocket processing with sample files...")

pocket_processor = PocketDataProcessor(POCKET_CONFIG)
sample_pocket_grids = []
sample_pocket_metadata = []

# Process first 5 pocket files as a test, with corresponding ligands for centering
sample_pocket_files = all_pocket_files[:5]
print(f"Processing {len(sample_pocket_files)} sample pocket files...")

for i, pocket_file in enumerate(sample_pocket_files):
    protein_id = os.path.basename(os.path.dirname(pocket_file))
    ligand_file = os.path.join(os.path.dirname(pocket_file), f"{protein_id}_ligand.mol2")
    
    print(f"  Processing pocket {i+1}/{len(sample_pocket_files)}: {protein_id}")
    
    try:
        # Use ligand file for centering if it exists
        ligand_path = ligand_file if os.path.exists(ligand_file) else None
        grid, metadata = pocket_processor.process_pocket_file(pocket_file, ligand_path)
        
        if grid is not None:
            sample_pocket_grids.append(grid)
            sample_pocket_metadata.append(metadata)
            used_ligand = "with ligand center" if metadata.get('used_ligand_center', False) else "self-centered"
            print(f"     Success: shape={grid.shape}, occupancy={metadata['grid_occupancy']:.4f} ({used_ligand})")
        else:
            sample_pocket_metadata.append(metadata)
            print(f"     Failed: {metadata.get('errors', ['Unknown error'])}")
            
    except Exception as e:
        print(f"     Exception: {e}")
        error_metadata = {
            'file': pocket_file,
            'error': str(e),
            'success': False
        }
        sample_pocket_metadata.append(error_metadata)

print(f"\n Pocket Processing Results:")
print(f"   - Successfully processed: {len(sample_pocket_grids)} pockets")
print(f"   - Failed: {len(sample_pocket_metadata) - len(sample_pocket_grids)} pockets")

if sample_pocket_grids:
    sample_grid = sample_pocket_grids[0]
    print(f"   - Sample grid shape: {sample_grid.shape}")
    print(f"   - Sample grid occupancy: {np.count_nonzero(sample_grid) / sample_grid.size:.4f}")
    print(f"   - Sample grid memory: {sample_grid.nbytes / 1024**2:.2f} MB")

 Testing pocket processing with sample files...
Processing 5 sample pocket files...
  Processing pocket 1/5: 3nw9
     Failed: ['Failed to load molecule from both MOL2 and SDF formats']
  Processing pocket 2/5: 2v7a
     Failed: ['Failed to load molecule from both MOL2 and SDF formats']
  Processing pocket 3/5: 2qnq
     Success: shape=(19, 48, 48, 48), occupancy=0.0001 (with ligand center)
  Processing pocket 4/5: 3u5j
     Success: shape=(19, 48, 48, 48), occupancy=0.0001 (with ligand center)
  Processing pocket 5/5: 4eor
     Success: shape=(19, 48, 48, 48), occupancy=0.0001 (with ligand center)

 Pocket Processing Results:
   - Successfully processed: 3 pockets
   - Failed: 2 pockets
   - Sample grid shape: (19, 48, 48, 48)
   - Sample grid occupancy: 0.0001
   - Sample grid memory: 8.02 MB


[20:04:03] Cannot convert 'PRO' to unsigned int on line 4
[20:04:03] Cannot convert 'PRO' to unsigned int on line 4


In [40]:
# Process all ligand files
print(" Processing ALL ligand files...")
print(f"Total ligand files to process: {len(all_ligand_files)}")

all_ligand_grids = []
all_ligand_metadata = []

batch_size = 10
total_batches = (len(all_ligand_files) + batch_size - 1) // batch_size

for batch_idx in range(total_batches):
    start_idx = batch_idx * batch_size
    end_idx = min((batch_idx + 1) * batch_size, len(all_ligand_files))
    batch_files = all_ligand_files[start_idx:end_idx]
    
    print(f"Processing batch {batch_idx + 1}/{total_batches} ({len(batch_files)} files)")
    
    for i, ligand_file in enumerate(batch_files):
        protein_id = os.path.basename(os.path.dirname(ligand_file))
        
        if (start_idx + i) % 20 == 0:  # Progress update every 20 files
            print(f"  Progress: {start_idx + i + 1}/{len(all_ligand_files)} files processed")
        
        try:
            grid, metadata = ligand_processor.process_ligand_file(ligand_file)
            
            if grid is not None:
                all_ligand_grids.append(grid)
                all_ligand_metadata.append(metadata)
            else:
                all_ligand_metadata.append(metadata)
                print(f"     Failed {protein_id}: {metadata.get('errors', ['Unknown'])}")
                
        except Exception as e:
            print(f"     Exception {protein_id}: {e}")
            error_metadata = {
                'file': ligand_file,
                'ligand_id': protein_id,
                'error': str(e),
                'success': False
            }
            all_ligand_metadata.append(error_metadata)

print(f"\n LIGAND PROCESSING COMPLETE!")
print(f"   - Total processed: {len(all_ligand_grids)} ligands")
print(f"   - Total failed: {len(all_ligand_metadata) - len(all_ligand_grids)} ligands")
print(f"   - Success rate: {len(all_ligand_grids) / len(all_ligand_files) * 100:.1f}%")

if all_ligand_grids:
    # Check memory requirements before conversion
    estimated_memory_gb = len(all_ligand_grids) * 19 * 64 * 64 * 64 * 4 / (1024**3)
    print(f"   - Estimated memory needed: {estimated_memory_gb:.2f} GB")
    
    if estimated_memory_gb > 3.0:  # If > 3GB, keep as list to save memory
        print(f"   - Keeping as list to avoid memory issues")
        print(f"   - Number of ligand grids: {len(all_ligand_grids)}")
        print(f"   - Individual grid shape: {all_ligand_grids[0].shape}")
    else:
        # Convert to numpy array only if memory allows
        try:
            all_ligand_grids = np.array(all_ligand_grids)
            print(f"   - Final array shape: {all_ligand_grids.shape}")
            print(f"   - Memory usage: {all_ligand_grids.nbytes / 1024**3:.2f} GB")
        except MemoryError:
            print(f"   - Memory error: keeping as list instead")
            print(f"   - Number of ligand grids: {len(all_ligand_grids)}")

 Processing ALL ligand files...
Total ligand files to process: 229
Processing batch 1/23 (10 files)
  Progress: 1/229 files processed
Processing batch 2/23 (10 files)
Processing batch 3/23 (10 files)
  Progress: 21/229 files processed
Processing batch 4/23 (10 files)
Processing batch 5/23 (10 files)
  Progress: 41/229 files processed
Processing batch 6/23 (10 files)
Processing batch 7/23 (10 files)
  Progress: 61/229 files processed
Processing batch 8/23 (10 files)
Processing batch 9/23 (10 files)
  Progress: 81/229 files processed
Processing batch 10/23 (10 files)
Processing batch 11/23 (10 files)
  Progress: 101/229 files processed
Processing batch 12/23 (10 files)
Processing batch 13/23 (10 files)
  Progress: 121/229 files processed
Processing batch 14/23 (10 files)
Processing batch 15/23 (10 files)
  Progress: 141/229 files processed
Processing batch 16/23 (10 files)
Processing batch 17/23 (10 files)
  Progress: 161/229 files processed
Processing batch 18/23 (10 files)
Processing b

In [41]:
# Process all pocket files
print(" Processing ALL pocket files...")
print(f"Total pocket files to process: {len(all_pocket_files)}")

all_pocket_grids = []
all_pocket_metadata = []

batch_size = 10
total_batches = (len(all_pocket_files) + batch_size - 1) // batch_size

for batch_idx in range(total_batches):
    start_idx = batch_idx * batch_size
    end_idx = min((batch_idx + 1) * batch_size, len(all_pocket_files))
    batch_files = all_pocket_files[start_idx:end_idx]
    
    print(f" Processing batch {batch_idx + 1}/{total_batches} ({len(batch_files)} files)")
    
    for i, pocket_file in enumerate(batch_files):
        protein_id = os.path.basename(os.path.dirname(pocket_file))
        ligand_file = os.path.join(os.path.dirname(pocket_file), f"{protein_id}_ligand.mol2")
        
        if (start_idx + i) % 20 == 0:  # Progress update every 20 files
            print(f"  Progress: {start_idx + i + 1}/{len(all_pocket_files)} files processed")
        
        try:
            # Use ligand file for centering if it exists
            ligand_path = ligand_file if os.path.exists(ligand_file) else None
            grid, metadata = pocket_processor.process_pocket_file(pocket_file, ligand_path)
            
            if grid is not None:
                all_pocket_grids.append(grid)
                all_pocket_metadata.append(metadata)
            else:
                all_pocket_metadata.append(metadata)
                print(f"     Failed {protein_id}: {metadata.get('errors', ['Unknown'])}")
                
        except Exception as e:
            print(f"     Exception {protein_id}: {e}")
            error_metadata = {
                'file': pocket_file,
                'pocket_id': protein_id,
                'error': str(e),
                'success': False
            }
            all_pocket_metadata.append(error_metadata)

print(f"\n POCKET PROCESSING COMPLETE!")
print(f"   - Total processed: {len(all_pocket_grids)} pockets")
print(f"   - Total failed: {len(all_pocket_metadata) - len(all_pocket_grids)} pockets")
print(f"   - Success rate: {len(all_pocket_grids) / len(all_pocket_files) * 100:.1f}%")

if all_pocket_grids:
    # Convert to numpy array
    all_pocket_grids = np.array(all_pocket_grids)
    print(f"   - Final array shape: {all_pocket_grids.shape}")
    print(f"   - Memory usage: {all_pocket_grids.nbytes / 1024**3:.2f} GB")

 Processing ALL pocket files...
Total pocket files to process: 229
 Processing batch 1/23 (10 files)
  Progress: 1/229 files processed
     Failed 3nw9: ['Failed to load molecule from both MOL2 and SDF formats']
     Failed 2v7a: ['Failed to load molecule from both MOL2 and SDF formats']
     Failed 3d4z: ['Failed to load molecule from both MOL2 and SDF formats']
     Failed 3ehy: ['Failed to load molecule from both MOL2 and SDF formats']
 Processing batch 2/23 (10 files)


[20:04:04] Cannot convert 'PRO' to unsigned int on line 4
[20:04:04] Cannot convert 'PRO' to unsigned int on line 4
[20:04:04] Cannot convert 'PRO' to unsigned int on line 4
[20:04:04] Cannot convert 'PRO' to unsigned int on line 4
[20:04:04] Cannot convert 'PRO' to unsigned int on line 4
[20:04:04] Cannot convert 'PRO' to unsigned int on line 4


     Failed 3dx2: ['Failed to load molecule from both MOL2 and SDF formats']
 Processing batch 3/23 (10 files)
  Progress: 21/229 files processed
     Failed 1z9g: ['Failed to load molecule from both MOL2 and SDF formats']
 Processing batch 4/23 (10 files)
 Processing batch 5/23 (10 files)
  Progress: 41/229 files processed
 Processing batch 6/23 (10 files)
     Failed 4gr0: ['Failed to load molecule from both MOL2 and SDF formats']
     Failed 3tsk: ['Failed to load molecule from both MOL2 and SDF formats']


[20:04:06] Cannot convert 'PRO' to unsigned int on line 4
[20:04:06] Cannot convert 'PRO' to unsigned int on line 4


 Processing batch 7/23 (10 files)
  Progress: 61/229 files processed
 Processing batch 8/23 (10 files)


[20:04:06] Cannot convert 'PRO' to unsigned int on line 4


     Failed 3oe5: ['Failed to load molecule from both MOL2 and SDF formats']
 Processing batch 9/23 (10 files)
  Progress: 81/229 files processed
 Processing batch 10/23 (10 files)
     Failed 3oe4: ['Failed to load molecule from both MOL2 and SDF formats']


[20:04:07] Cannot convert 'PRO' to unsigned int on line 4


 Processing batch 11/23 (10 files)
  Progress: 101/229 files processed
 Processing batch 12/23 (10 files)
 Processing batch 13/23 (10 files)
  Progress: 121/229 files processed
 Processing batch 14/23 (10 files)
     Failed 3nx7: ['Failed to load molecule from both MOL2 and SDF formats']
     Failed 1qf1: ['Failed to load molecule from both MOL2 and SDF formats']
     Failed 1ps3: ['Failed to load molecule from both MOL2 and SDF formats']


[20:04:08] Cannot convert 'PRO' to unsigned int on line 4
[20:04:08] Cannot convert 'PRO' to unsigned int on line 4
[20:04:08] Cannot convert 'PRO' to unsigned int on line 4


 Processing batch 15/23 (10 files)
  Progress: 141/229 files processed
     Failed 3uuo: ['Failed to load molecule from both MOL2 and SDF formats']


[20:04:08] Cannot convert 'PRO' to unsigned int on line 4


 Processing batch 16/23 (10 files)


[20:04:09] Cannot convert 'PRO' to unsigned int on line 4


     Failed 3fcq: ['Failed to load molecule from both MOL2 and SDF formats']
 Processing batch 17/23 (10 files)
  Progress: 161/229 files processed
 Processing batch 18/23 (10 files)
     Failed 3ozs: ['Failed to load molecule from both MOL2 and SDF formats']
     Failed 3ozt: ['Failed to load molecule from both MOL2 and SDF formats']
     Failed 3ui7: ['Failed to load molecule from both MOL2 and SDF formats']


[20:04:09] Cannot convert 'PRO' to unsigned int on line 4
[20:04:09] Cannot convert 'PRO' to unsigned int on line 4
[20:04:10] Cannot convert 'PRO' to unsigned int on line 4


 Processing batch 19/23 (10 files)
  Progress: 181/229 files processed
 Processing batch 20/23 (10 files)


[20:04:10] Cannot convert 'PRO' to unsigned int on line 4


     Failed 3lka: ['Failed to load molecule from both MOL2 and SDF formats']
 Processing batch 21/23 (10 files)
  Progress: 201/229 files processed
 Processing batch 22/23 (10 files)
 Processing batch 23/23 (9 files)
  Progress: 221/229 files processed

 POCKET PROCESSING COMPLETE!
   - Total processed: 210 pockets
   - Total failed: 19 pockets
   - Success rate: 91.7%
   - Final array shape: (210, 19, 48, 48, 48)
   - Memory usage: 1.64 GB


In [42]:
# Check pocket processing results
print(" Pocket processing completed!")
if 'all_pocket_grids' in globals():
    print(f" Successfully processed: {len(all_pocket_grids)} pockets")
    if len(all_pocket_grids) > 0:
        print(f"   Grid shape: {all_pocket_grids.shape}")
        print(f"   Memory usage: {all_pocket_grids.nbytes / 1024**3:.2f} GB")

if 'all_pocket_metadata' in globals():
    successful = [m for m in all_pocket_metadata if m.get('success', False)]
    failed = [m for m in all_pocket_metadata if not m.get('success', False)]
    print(f"   Successful: {len(successful)}, Failed: {len(failed)}")
    print(f"   Success rate: {len(successful) / len(all_pocket_metadata) * 100:.1f}%")

 Pocket processing completed!
 Successfully processed: 210 pockets
   Grid shape: (210, 19, 48, 48, 48)
   Memory usage: 1.64 GB
   Successful: 210, Failed: 19
   Success rate: 91.7%


In [43]:
for m in all_ligand_metadata:
    print(f" Ligand ID: {m.get('ligand_id', 'Unknown')}, Success: {m.get('success', False)}, Errors: {m.get('errors', [])}")

 Ligand ID: 3nw9, Success: True, Errors: []
 Ligand ID: 2v7a, Success: True, Errors: []
 Ligand ID: 2qnq, Success: True, Errors: []
 Ligand ID: 3u5j, Success: True, Errors: []
 Ligand ID: 4eor, Success: True, Errors: []
 Ligand ID: 2p4y, Success: True, Errors: []
 Ligand ID: 3d4z, Success: True, Errors: []
 Ligand ID: 1z6e, Success: True, Errors: []
 Ligand ID: 1owh, Success: True, Errors: []
 Ligand ID: 3ehy, Success: True, Errors: []
 Ligand ID: 4cig, Success: True, Errors: []
 Ligand ID: 4j3l, Success: True, Errors: []
 Ligand ID: 3n76, Success: True, Errors: []
 Ligand ID: 4eo8, Success: True, Errors: []
 Ligand ID: 3gnw, Success: True, Errors: []
 Ligand ID: 2qbr, Success: True, Errors: []
 Ligand ID: 4gfm, Success: True, Errors: []
 Ligand ID: 3rlr, Success: True, Errors: []
 Ligand ID: 3b27, Success: True, Errors: []
 Ligand ID: 3dx2, Success: True, Errors: []
 Ligand ID: 3n7a, Success: True, Errors: []
 Ligand ID: 1z9g, Success: True, Errors: []
 Ligand ID: 4f09, Success: True,

In [44]:
for m in all_pocket_metadata:
    print(f" Pocket ID: {m.get('pocket_id', 'Unknown')}, Success: {m.get('success', False)}, Errors: {m.get('errors', [])}")

 Pocket ID: Unknown, Success: False, Errors: ['Failed to load molecule from both MOL2 and SDF formats']
 Pocket ID: Unknown, Success: False, Errors: ['Failed to load molecule from both MOL2 and SDF formats']
 Pocket ID: 2qnq, Success: True, Errors: []
 Pocket ID: 3u5j, Success: True, Errors: []
 Pocket ID: 4eor, Success: True, Errors: []
 Pocket ID: 2p4y, Success: True, Errors: []
 Pocket ID: Unknown, Success: False, Errors: ['Failed to load molecule from both MOL2 and SDF formats']
 Pocket ID: 1z6e, Success: True, Errors: []
 Pocket ID: 1owh, Success: True, Errors: []
 Pocket ID: Unknown, Success: False, Errors: ['Failed to load molecule from both MOL2 and SDF formats']
 Pocket ID: 4cig, Success: True, Errors: []
 Pocket ID: 4j3l, Success: True, Errors: []
 Pocket ID: 3n76, Success: True, Errors: []
 Pocket ID: 4eo8, Success: True, Errors: []
 Pocket ID: 3gnw, Success: True, Errors: []
 Pocket ID: 2qbr, Success: True, Errors: []
 Pocket ID: 4gfm, Success: True, Errors: []
 Pocket ID: 

In [None]:
# os.chdir('step4_outputs')
os.getcwd()

'/Users/sanskriti/Documents/GitHub/bindingaffinity/step4_outputs'

In [None]:
print("Saving matched ligand-pocket pairs...")

# Ensure output directories exist
os.makedirs("model_ready_data/ligand_data", exist_ok=True)
os.makedirs("model_ready_data/pocket_data", exist_ok=True)

# Create a map from ligand and pocket metadata
ligand_map = {meta.get('ligand_id', "Unknown"): (grid, meta) for grid, meta in zip(all_ligand_grids, all_ligand_metadata)}
pocket_map = {meta.get('pocket_id', "Unknown"): (grid, meta) for grid, meta in zip(all_pocket_grids, all_pocket_metadata)}

# Intersect names
common_keys = sorted(set(ligand_map.keys()) & set(pocket_map.keys()))

print(f"Found {len(common_keys)} matched pairs out of {len(all_ligand_grids)} ligands and {len(all_pocket_grids)} pockets")

# Prepare filtered lists
filtered_ligand_grids = []
filtered_pocket_grids = []
filtered_ligand_metadata = []
filtered_pocket_metadata = []

for key in common_keys:
    l_grid, l_meta = ligand_map[key]
    p_grid, p_meta = pocket_map[key]

    filtered_ligand_grids.append(l_grid)
    filtered_ligand_metadata.append(l_meta)
    filtered_pocket_grids.append(p_grid)
    filtered_pocket_metadata.append(p_meta)

# Save filtered ligand data
ligand_array = np.array(filtered_ligand_grids)
np.save("model_ready_data/ligand_data/ligand_data.npz", ligand_array)
with open("model_ready_data/ligand_data/ligand_metadata.json", "w") as f:
    json.dump(filtered_ligand_metadata, f, indent=2, default=str)
print(f"Saved {len(filtered_ligand_grids)} ligand entries")

# Save filtered pocket data
pocket_array = np.array(filtered_pocket_grids)
np.save("model_ready_data/pocket_data/pocket_grids.npz", pocket_array)
with open("model_ready_data/pocket_data/pocket_metadata.json", "w") as f:
    json.dump(filtered_pocket_metadata, f, indent=2, default=str)
print(f"Saved {len(filtered_pocket_grids)} pocket entries")

print("\nFINAL SUMMARY:")
print(f"   Ligand-pocket pairs saved: {len(common_keys)}")


Saving matched ligand-pocket pairs...
Found 191 matched pairs out of 229 ligands and 210 pockets
Saved 191 ligand entries
Saved 191 pocket entries

FINAL SUMMARY:
   Ligand-pocket pairs saved: 191


In [47]:

# Save common_keys (matched PDB IDs) to a txt file
with open("matched_pdb_ids.txt", "w") as f:
    for pdb_id in common_keys:
        f.write(f"{pdb_id}\n")

print(f"Saved {len(common_keys)} matched PDB IDs to matched_pdb_ids.txt")

Saved 191 matched PDB IDs to matched_pdb_ids.txt


In [51]:
os.getcwd()

'/Users/sanskriti/Documents/GitHub/bindingaffinity/step4_outputs'