In [None]:
!pip install torch_geometric pandas numpy matplotlib



In [None]:
import sys
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
# CONSTANTS

# Need to download Dataset.py from GEMS directory locally onto vm first
# git clone http://github.com/camlab-ethz/GEMS.git
# Define root path of cloned GEMS repo
GEMS_REPO_ROOT = os.path.expanduser('~/GEMS')

# Define the directory path where pre-processed .pt dataset files are located
DATA_DIR = os.path.join(os.path.expanduser('~'), 'gcs_pdbbind_mount')

# List of dataset files to analyze
DATASET_FILES = [
    '00AEPL_casf2013.pt',
    '00AEPL_casf2013_indep.pt',
    '00AEPL_casf2016.pt',
    '00AEPL_casf2016_indep.pt',
    '00AEPL_train_cleansplit.pt',
    '00AEPL_train_pdbbind.pt'
]


Added /home/jupyter/GEMS to sys.path
Successfully imported custom Dataset class


In [None]:
# Dynamic Import of Custom Datasets Class
# Add the directory to Python's search path.
print(f"Attempting to load custom Dataset class...")

try:
  # Add GEMS root to sys.path
  if GEMS_REPO_ROOT not in sys.path:
    sys.path.append(GEMS_REPO_ROOT)
    print(f"Added {GEMS_REPO_ROOT} to sys.path")

    # Import custom Dataset class from the cloned repo
    from Dataset import Dataset as GEMS_Dataset
    print("Successfully imported custom Dataset class")

    # Import the necessarty PyG components
    from torch_geometric.data import Data #, Dataset

except ImportError as e:
    print(f"FATAL ERROR importing GEMS Dataset class: {e}")
    sys.exit(1)  # Exit if classes cannot be imported


In [None]:
def extract_metrics(data_list):
   """Iterates through list of GEMS datasets and extracts metrics."""
   metrics = []

   num_features = data_list[0].x.size(1)

   for i, data in enumerate(data_list):

      # Affinity is the y label
      affinity = data.y.item()

      # ID (Assuming is stored in 'id' attribute)
      pdb_id = getattr(data, 'id', f'Complex_{i}')

      # Extract Graph Size Metrics
      # node features (x): shape is [num_atoms, num_features]
      num_nodes = data.x.size(0)

      # Edge index: shape is: [2, num_edges]
      num_edges = data.edge_index.size(1)

      # Atom feature counts
      feature_counts_tensor = torch.sum(data.x, dim=0)

      # Append the aggregated features to the list
      record = {
        'PDB_ID': pdb_id,
        'Affinity_pKi_pKd': affinity,
        'Num_Atoms': num_nodes,
        'Num_Interactions': num_edges,
        'Density': num_edges / num_nodes,
        'Count_Feature_10': num_feat_10
      }

      # Add feature counts dynamically
      for j in range(num_features):
        record[f'Feature_{j}'] = feature_counts_tensor[j].item()

      metrics.append(record)

   return pd.DataFrame(metrics)

DataFrame created successfully
  PDB_ID  Affinity_pKi_pKd  Num_Atoms  Num_Interactions   Density  \
0   3f3c          0.376250         32               278  8.687500   
1   1w3l          0.392500         51               431  8.450980   
2   2hb1          0.237500         33               269  8.151515   
3   2v00          0.228750         32               252  7.875000   
4   1os0          0.376875         60               520  8.666667   

   Count_Feature_10  
0               6.0  
1              18.0  
2               5.0  
3              12.0  
4              18.0  


In [None]:
def run_data_analysi(df, dataset_name):
  """Performs data analysis on dataframe and produces visualizations"""

  print(f"Analyzing {dataset_name} dataset (N={len(df)})...")

  # Affinity Distribution
  print("\n[1] Affinity Distribution")
  print(df['Affinity_pKi_pKd'].describe())

  plt.figure(figsize=(12, 4))
  plt.subplot(1, 2, 1)
  df['Affinity_pKi_pKd'].hist(bins=50)
  plt.title(f'{dataset_name}: Affinity Distribution')
  plt.xlabel('Binding Affinity (pKi/pKd)')
  plt.ylabel('Frequency')


  # Graph Size Metrics
  print("\n[2] Graph Size Metrics")
  print(df[['Num_Atoms', 'Num_Interactions', 'Density']].describe())

  plt.subplot(1, 2, 2)
  df[['Num_Atoms', 'Num_Interactions', 'Density']].hist(bins=50)
  plt.title(f'{dataset_name}: Graph Size Metrics')

  plt.tight_layout()
  plt.show()

  # Protein Diversity
  # Use PDB_ID count to determine protein complexes
  unique_complexes = df['PDB_ID'].nunique()
  print(f"\n[3] Protein Diversity: {unique_proteins} unique proteins")
  print(f"Redundancy Ratio (Total N / Unique N): {len(df) / unique_complexes:.2f}")

  # Node Feature Frequency
  # Summing up total count for the top 5 most frequent features
  feature_cols = [col for col in df.columns if col.startswith('Feature_')]
  top_5_features = df[feature_cols].sum().nlargest(5).index.tolist()

  print("\n[4] Node Feature Frequency")
  print(f"Top 5 Features: {', '.join(top_5_features)}")
  # TODO: map features to actual atom types


In [None]:
# Main loop
if not os.path.exists(DATA_DIR):
    print(f"FATAL ERROR: Path {DATA_DIR} does not exist.")
    sys.exit(1)

for filename in DATASET_FILES:
    data_filepath = os.path.join(DATA_DIR, filename)

    if not os.path.exists(data_filepath):
        print(f"WARNING: File not found: {data_filepath}. Skipping...")
        continue

    print(f"Loading {filename}...")
    try:
        # Load data_list
        data_list = torch.load(data_filepath)
        print(f"Example data object type: {type(data_list[0])}")

        # Run analysis
        df = extract_metrics(data_list)
        run_data_analysis(df, filename)

    except Exception as e:
        print(f"\nFAILED PROCESSING {filename}: {e}")

print("--------------| Data analysis completed! |------------")

Loaded 194 complexes
Example data object type: <class 'torch_geometric.data.data.Data'>
Number of features: 60
