In [None]:
# If GCS data mount is not populated perform following steps in vm:
"""
# Stop/Unmount any previous attempts
# Try to unmount the directory in case a previous attempt left it mounted
# If this fails, it's fine, it just means it wasn't mounted.
fusermount -u ~/gcs_pdbbind_mount

# If /gcs_pdbbind_mount not visible in file system create dir for local mount point:
mkdir ~/gcs_pdbbind_mount

# Replace BUCKET_NAME and DIRECTORY_PATH
BUCKET_NAME="cs224w-2025-mae-gnn-bucket"
DIRECTORY_PATH="data/GEMS_pytorch_datasets"

gcsfuse --only-dir "$DIRECTORY_PATH" -o allow_other --implicit-dirs "$BUCKET_NAME" ~/gcs_pdbbind_mount

#Verify the Mount:
ls -l ~/gcs_pdbbind_mount

# You should see your .pt files listed here, proving the mount worked.
"""

In [None]:
!pip install torch_geometric pandas matplotlib
!pip install "numpy<2"

In [None]:
# Imports
import sys
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

# Set style for better visualization
sns.set_theme(style="whitegrid")

In [None]:
print("torch version:", torch.__version__)
print("numpy version:", np.__version__)
print("pandas version:", pd.__version__)

In [None]:
# CONSTANTS

# Need to download Dataset.py from GEMS directory locally onto vm first
# git clone http://github.com/camlab-ethz/GEMS.git
# Define root path of cloned GEMS repo
GEMS_REPO_ROOT = os.path.expanduser('~/GEMS')

# Define the directory path where pre-processed .pt dataset files are located
DATA_DIR = os.path.join(os.path.expanduser('~'), 'gcs_pdbbind_mount')

# Dataset files to analyze
DATASET_FILES_DICT = {
    'Train': '00AEPL_train_cleansplit.pt',
    'Test': '00AEPL_casf2016.pt'
}

ALL_ATOMS = ['B', 'C', 'N', 'O', 'P', 'S', 'Se', 'metal', 'halogen']

ATOM_HYBRIDIZATION_TYPES = ["HybridizationType.S", "HybridizationType.SP", "HybridizationType.SP2", "HybridizationType.SP2D", "HybridizationType.SP3", "HybridizationType.SP3D", "HybridizationType.SP3D2", "HybridizationType.UNSPECIFIED"]
TOTAL_NUM_H_S = ["Num_H.0", "Num_H.1", "Num_H.2", "Num_H.3", "Num_H.4"]
DEGREES = ['Degree.0', 'Degree.1', 'Degree.2', 'Degree.3', 'Degree.4', 'Degree.5', 'Degree.6', 'Degree.7', 'Degree.8', 'Degree.OTHER']
CHIRALITIES = ['Chirality.CHI_UNSPECIFIED', 'Chirality.CHI_TETRAHEDRAL_CW', 'Chirality.CHI_TETRAHEDRAL_CCW', 'Chirality.OTHER']

AMINO_ACIDS = ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE", "LEU",
            "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL"]

LIGAND_FEATURE_MAP = ALL_ATOMS + ["IsInRing"] + ATOM_HYBRIDIZATION_TYPES + \
                    ["FORMAL_CHARGE", "IS_AROMATIC", "MASS/100", ] + TOTAL_NUM_H_S + \
                    DEGREES + CHIRALITIES
FULL_FEATURE_MAP = LIGAND_FEATURE_MAP + AMINO_ACIDS

EDGE_FEATURE_MAP = ["COVALENT_BOND", "SELF_LOOP", "NON-COVALENT_BOND", "EDGE_LENGTH_TO_N/10", "EDGE_LENGTH_TO_CA/10", "EDGE_LENGTH_TO_C/10", "EDGE_LENGTH_TO_CB/10" ] + \
                   ["BOND_TYPE_0", "BOND_TYPE_1.0", "BOND_TYPE_1.5", "BOND_TYPE_2.0", "BOND_TYPE_3.0"] + \
                   ["IS_CONJUGATED", "IS_IN_RING", "BOND_STEREO.NONE", "BOND_STEREO.ANY", "BOND_STEREO.E", "BOND_STEREO.Z", "BOND_STEREO.CIS", "BOND_STEREO.TRANS"]

NEW_EDGE_FEATURE_MAP = ["COVALENT_BOND", "SELF_LOOP", "NON-COVALENT_BOND", ] + \
                   ["BOND_TYPE_0", "BOND_TYPE_1.0", "BOND_TYPE_1.5", "BOND_TYPE_2.0", "BOND_TYPE_3.0"] + \
                   ["IS_CONJUGATED", "IS_IN_RING", "BOND_STEREO.NONE", "BOND_STEREO.ANY", "BOND_STEREO.E", "BOND_STEREO.Z", "BOND_STEREO.CIS", "BOND_STEREO.TRANS"]


In [None]:
# Dynamic Import of Custom Datasets Class
# Add the directory to Python's search path.
print(f"Attempting to load custom Dataset class...")

try:
  # Add GEMS root to sys.path
  if GEMS_REPO_ROOT not in sys.path:
    sys.path.append(GEMS_REPO_ROOT)
    print(f"Added {GEMS_REPO_ROOT} to sys.path")

    # Import custom Dataset class from the cloned repo
    from Dataset import Dataset as GEMS_Dataset
    print("Successfully imported custom Dataset class")

    # Import the necessarty PyG components
    from torch_geometric.data import Data #, Dataset

except ImportError as e:
    raise RuntimeError(f"FATAL ERROR importing GEMS Dataset class: {e}")


In [None]:
if not os.path.exists(DATA_DIR):
    raise RuntimeError(f"FATAL ERROR: Path {DATA_DIR} does not exist.")

#filename = "B6AE0L_train_cleansplit.pt"
filename = "00AEPL_casf2016.pt"
example_filename = "example_dataset_no_dist_edge_feat.pt"
data_filepath = os.path.join(DATA_DIR, filename)
if not os.path.exists(data_filepath):
    print(f"WARNING: File not found: {data_filepath}. Skipping...")
    
try:
    # Load data_list
    data_list = torch.load(data_filepath)
    example_data_list = torch.load(os.path.join(DATA_DIR, example_filename))
    print(f"data object type: {type(data_list[0])}")
    print(f"Example data object type: {type(example_data_list[0])}")
except Exception as e:
        print(f"\nFAILED PROCESSING {filename}: {e}")



In [None]:
# create dictionary out of each dataset keyed by id:
data_dict = dict([(d.id, d) for d in data_list])
example_data_dict = dict([(d.id, d) for d in example_data_list])

In [None]:
old_ids = set(data_dict.keys())
new_ids = set(example_data_dict.keys())

# Find the intersection (the overlapping IDs)
overlapping_ids = old_ids.intersection(new_ids)

print(f"Total overlapping IDs found: {len(overlapping_ids)}")

In [None]:
match_results = {
    'match': 0,
    'mismatch': 0,
    'mismatch_ids': [],
    'mismatch_details': {} # Store details about *which* field mismatched
}

FIELDS_TO_COMPARE = ['x', 'edge_index', 'edge_attr', 'y', 'n_nodes']

for data_id in overlapping_ids:
    old_data = data_dict[data_id]
    new_data = example_data_dict[data_id]
    
    mismatch_found = False
    mismatched_fields = []

    # Iterate through the specific fields you defined
    for field in FIELDS_TO_COMPARE:
        # 1. Check if the field exists in both objects
        if hasattr(old_data, field) and hasattr(new_data, field):
            old_tensor = getattr(old_data, field)
            new_tensor = getattr(new_data, field)
            
            # 2. Check for shape consistency
            if old_tensor.shape != new_tensor.shape:
                mismatch_found = True
                mismatched_fields.append(f"{field} (Shape mismatch: {old_tensor.shape} vs {new_tensor.shape})")
                
            # 3. Check for element-wise value equality
            # .all() ensures all elements in the comparison tensor are True
            elif not (old_tensor == new_tensor).all():
                mismatch_found = True
                mismatched_fields.append(f"{field} (Value mismatch)")
                
        # Handle cases where a required field is missing in one or both objects
        elif field in old_data.keys() or field in new_data.keys():
             mismatch_found = True
             mismatched_fields.append(f"{field} (Presence mismatch)")
             
    # Tally results based on the loop outcome
    if mismatch_found:
        match_results['mismatch'] += 1
        match_results['mismatch_ids'].append(data_id)
        match_results['mismatch_details'][data_id] = mismatched_fields
    else:
        match_results['match'] += 1

print("\n--- Custom Comparison Summary ---")
print(f"Total matching objects (on specified fields): {match_results['match']}")
print(f"Total mismatched objects (on specified fields): {match_results['mismatch']}")

if match_results['mismatch'] > 0:
    print("\nSample IDs with Detailed Mismatches:")
    for data_id in match_results['mismatch_ids'][:5]:
        print(f"ID {data_id}: Failed on fields: {', '.join(match_results['mismatch_details'][data_id])}")

In [None]:
key = '1bzc'
(data_dict[key], example_data_dict[key])

In [None]:
mismatched_id = match_results['mismatch_ids'][0]

old_data = data_dict[mismatched_id]
new_data = example_data_dict[mismatched_id]

# Example check for the 'pos' tensor (atom coordinates)
if not (old_data.pos == new_data.pos).all():
    print(f"ID {mismatched_id}: Mismatch found in the 'pos' tensor.")
    
# Example check for the 'x' tensor (node features)
if not (old_data.x == new_data.x).all():
    print(f"ID {mismatched_id}: Mismatch found in the 'x' tensor.")