In [None]:
# If GCS data mount is not populated perform following steps in vm:
"""
# Stop/Unmount any previous attempts
# Try to unmount the directory in case a previous attempt left it mounted
# If this fails, it's fine, it just means it wasn't mounted.
fusermount -u ~/gcs_pdbbind_mount

# If /gcs_pdbbind_mount not visible in file system create dir for local mount point:
mkdir ~/gcs_pdbbind_mount

# Replace BUCKET_NAME and DIRECTORY_PATH
BUCKET_NAME="cs224w-2025-mae-gnn-bucket"
DIRECTORY_PATH="data/GEMS_pytorch_datasets"

gcsfuse --only-dir "$DIRECTORY_PATH" -o allow_other --implicit-dirs "$BUCKET_NAME" ~/gcs_pdbbind_mount

#Verify the Mount:
ls -l ~/gcs_pdbbind_mount

# You should see your .pt files listed here, proving the mount worked.
"""

In [1]:
!pip install torch_geometric pandas matplotlib
!pip install "numpy<2"



In [2]:
# Imports
import sys
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

# Set style for better visualization
sns.set_theme(style="whitegrid")

In [3]:
print("torch version:", torch.__version__)
print("numpy version:", np.__version__)
print("pandas version:", pd.__version__)

torch version: 2.7.1+cu118
numpy version: 1.26.4
pandas version: 2.3.3


In [4]:
# CONSTANTS

# Need to download Dataset.py from GEMS directory locally onto vm first
# git clone http://github.com/camlab-ethz/GEMS.git
# Define root path of cloned GEMS repo
GEMS_REPO_ROOT = os.path.expanduser('~/GEMS')

# Define the directory path where pre-processed .pt dataset files are located
DATA_DIR = os.path.join(os.path.expanduser('~'), 'gcs_pdbbind_mount')

ALL_ATOMS = ['B', 'C', 'N', 'O', 'P', 'S', 'Se', 'metal', 'halogen']

ATOM_HYBRIDIZATION_TYPES = ["HybridizationType.S", "HybridizationType.SP", "HybridizationType.SP2", "HybridizationType.SP2D", "HybridizationType.SP3", "HybridizationType.SP3D", "HybridizationType.SP3D2", "HybridizationType.UNSPECIFIED"]
TOTAL_NUM_H_S = ["Num_H.0", "Num_H.1", "Num_H.2", "Num_H.3", "Num_H.4"]
DEGREES = ['Degree.0', 'Degree.1', 'Degree.2', 'Degree.3', 'Degree.4', 'Degree.5', 'Degree.6', 'Degree.7', 'Degree.8', 'Degree.OTHER']
CHIRALITIES = ['Chirality.CHI_UNSPECIFIED', 'Chirality.CHI_TETRAHEDRAL_CW', 'Chirality.CHI_TETRAHEDRAL_CCW', 'Chirality.OTHER']

AMINO_ACIDS = ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE", "LEU",
            "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL"]

LIGAND_FEATURE_MAP = ALL_ATOMS + ["IsInRing"] + ATOM_HYBRIDIZATION_TYPES + \
                    ["FORMAL_CHARGE", "IS_AROMATIC", "MASS/100", ] + TOTAL_NUM_H_S + \
                    DEGREES + CHIRALITIES
FULL_FEATURE_MAP = LIGAND_FEATURE_MAP + AMINO_ACIDS

EDGE_FEATURE_MAP = ["COVALENT_BOND", "SELF_LOOP", "NON-COVALENT_BOND", "EDGE_LENGTH_TO_N/10", "EDGE_LENGTH_TO_CA/10", "EDGE_LENGTH_TO_C/10", "EDGE_LENGTH_TO_CB/10" ] + \
                   ["BOND_TYPE_0", "BOND_TYPE_1.0", "BOND_TYPE_1.5", "BOND_TYPE_2.0", "BOND_TYPE_3.0"] + \
                   ["IS_CONJUGATED", "IS_IN_RING", "BOND_STEREO.NONE", "BOND_STEREO.ANY", "BOND_STEREO.E", "BOND_STEREO.Z", "BOND_STEREO.CIS", "BOND_STEREO.TRANS"]

NEW_EDGE_FEATURE_MAP = ["COVALENT_BOND", "SELF_LOOP", "NON-COVALENT_BOND", ] + \
                   ["BOND_TYPE_0", "BOND_TYPE_1.0", "BOND_TYPE_1.5", "BOND_TYPE_2.0", "BOND_TYPE_3.0"] + \
                   ["IS_CONJUGATED", "IS_IN_RING", "BOND_STEREO.NONE", "BOND_STEREO.ANY", "BOND_STEREO.E", "BOND_STEREO.Z", "BOND_STEREO.CIS", "BOND_STEREO.TRANS"]


In [5]:
# Dynamic Import of Custom Datasets Class
# Add the directory to Python's search path.
print(f"Attempting to load custom Dataset class...")

try:
  # Add GEMS root to sys.path
  if GEMS_REPO_ROOT not in sys.path:
    sys.path.append(GEMS_REPO_ROOT)
    print(f"Added {GEMS_REPO_ROOT} to sys.path")

    # Import custom Dataset class from the cloned repo
    from Dataset import Dataset as GEMS_Dataset
    print("Successfully imported custom Dataset class")

    # Import the necessarty PyG components
    from torch_geometric.data import Data #, Dataset

except ImportError as e:
    raise RuntimeError(f"FATAL ERROR importing GEMS Dataset class: {e}")


Attempting to load custom Dataset class...
Added /home/jupyter/GEMS to sys.path
Successfully imported custom Dataset class


In [54]:
#filename = "B6AE0L_train_cleansplit.pt"
#filename = "00AEPL_casf2016.pt"
filename = "00AEPL_train_cleansplit.pt"

example_filename = "example_dataset_no_dist_edge_feat.pt"
#example_filename = "example_dataset.pt"

In [21]:
if not os.path.exists(DATA_DIR):
    raise RuntimeError(f"FATAL ERROR: Path {DATA_DIR} does not exist.")

data_filepath = os.path.join(DATA_DIR, filename)
if not os.path.exists(data_filepath):
    print(f"WARNING: File not found: {data_filepath}. Skipping...")
    
try:
    # Load data_list
    data_list = torch.load(data_filepath, weights_only=False)
    example_data_list = torch.load(os.path.join(DATA_DIR, example_filename), weights_only=False)
    print(f"data object type: {type(data_list[0])}")
    print(f"Example data object type: {type(example_data_list[0])}")
except Exception as e:
        print(f"\nFAILED PROCESSING {filename}: {e}")



data object type: <class 'torch_geometric.data.data.Data'>
Example data object type: <class 'torch_geometric.data.data.Data'>


In [25]:
# create dictionary out of each dataset keyed by id:
data_dict = dict([(d.id, d) for d in data_list])
example_data_dict = dict([(d.id, d) for d in example_data_list])

In [26]:
old_ids = set(data_dict.keys())
new_ids = set(example_data_dict.keys())

# Find the intersection (the overlapping IDs)
overlapping_ids = old_ids.intersection(new_ids)

print(f"Total overlapping IDs found: {len(overlapping_ids)}")

Total overlapping IDs found: 77


In [27]:
match_results = {
    'match': 0,
    'mismatch': 0,
    'mismatch_ids': [],
    'mismatch_details': {} # Store details about *which* field mismatched
}

FIELDS_TO_COMPARE = ['x', 'edge_index', 'edge_attr', 'y', 'n_nodes']

for data_id in overlapping_ids:
    old_data = data_dict[data_id]
    new_data = example_data_dict[data_id]
    
    mismatch_found = False
    mismatched_fields = []

    # Iterate through the specific fields you defined
    for field in FIELDS_TO_COMPARE:
        # 1. Check if the field exists in both objects
        if hasattr(old_data, field) and hasattr(new_data, field):
            old_tensor = getattr(old_data, field)
            new_tensor = getattr(new_data, field)
            
            # 2. Check for shape consistency
            if old_tensor.shape != new_tensor.shape:
                mismatch_found = True
                mismatched_fields.append(f"{field} (Shape mismatch: {old_tensor.shape} vs {new_tensor.shape})")
                
            # 3. Check for element-wise value equality
            # .all() ensures all elements in the comparison tensor are True
            elif not (old_tensor == new_tensor).all():
                mismatch_found = True
                mismatched_fields.append(f"{field} (Value mismatch)")
                
        # Handle cases where a required field is missing in one or both objects
        elif field in old_data.keys() or field in new_data.keys():
             mismatch_found = True
             mismatched_fields.append(f"{field} (Presence mismatch)")
             
    # Tally results based on the loop outcome
    if mismatch_found:
        match_results['mismatch'] += 1
        match_results['mismatch_ids'].append(data_id)
        match_results['mismatch_details'][data_id] = mismatched_fields
    else:
        match_results['match'] += 1

print("\n--- Custom Comparison Summary ---")
print(f"Total matching objects (on specified fields): {match_results['match']}")
print(f"Total mismatched objects (on specified fields): {match_results['mismatch']}")

if match_results['mismatch'] > 0:
    print(f"\nSample IDs with Detailed Mismatches: {len(match_results['mismatch_ids'])}/{len(overlapping_ids)}")
    for data_id in match_results['mismatch_ids'][:5]:
        print(f"ID {data_id}: Failed on fields: {', '.join(match_results['mismatch_details'][data_id])}")


--- Custom Comparison Summary ---
Total matching objects (on specified fields): 0
Total mismatched objects (on specified fields): 77

Sample IDs with Detailed Mismatches: 77/77
ID 1bp0: Failed on fields: x (Value mismatch), edge_index (Shape mismatch: torch.Size([2, 313]) vs torch.Size([2, 309])), edge_attr (Shape mismatch: torch.Size([313, 20]) vs torch.Size([309, 16]))
ID 1a4w: Failed on fields: x (Value mismatch), edge_index (Shape mismatch: torch.Size([2, 538]) vs torch.Size([2, 530])), edge_attr (Shape mismatch: torch.Size([538, 20]) vs torch.Size([530, 16]))
ID 1atl: Failed on fields: x (Value mismatch), edge_index (Shape mismatch: torch.Size([2, 375]) vs torch.Size([2, 389])), edge_attr (Shape mismatch: torch.Size([375, 20]) vs torch.Size([389, 16]))
ID 1bnq: Failed on fields: x (Shape mismatch: torch.Size([45, 60]) vs torch.Size([46, 60])), edge_index (Shape mismatch: torch.Size([2, 333]) vs torch.Size([2, 356])), edge_attr (Shape mismatch: torch.Size([333, 20]) vs torch.Size(

In [30]:
key = '1bp0'
(data_dict[key], example_data_dict[key])

(Data(x=[37, 60], edge_index=[2, 313], edge_attr=[313, 20], y=0.3375000059604645, n_nodes=[3], lig_emb=[1, 384], id='1bp0'),
 Data(x=[37, 60], edge_index=[2, 309], edge_attr=[309, 16], y=0.3375000059604645, pos=[37, 3], n_nodes=[3], id='1bp0'))

In [None]:
mismatched_id = match_results['mismatch_ids'][0]

old_data = data_dict[mismatched_id]
new_data = example_data_dict[mismatched_id]

# Example check for the 'pos' tensor (atom coordinates)
if not (old_data.pos == new_data.pos).all():
    print(f"ID {mismatched_id}: Mismatch found in the 'pos' tensor.")
    
# Example check for the 'x' tensor (node features)
if not (old_data.x == new_data.x).all():
    print(f"ID {mismatched_id}: Mismatch found in the 'x' tensor.")

In [52]:
def find_mismatch(a, b, col_keys=None):
    # 1. Perform element-wise comparison
    comparison_mask = (a == b)

    # 2. Invert the mask to highlight Mismatches (where the value is False)
    mismatch_mask = ~comparison_mask

    # Using the mismatch_mask from the previous example:
    # mismatch_mask is torch.tensor([[False, True], [False, False]])

    # Use torch.nonzero() to get the row/column indices where the value is True (mismatch)
    mismatch_indices = torch.nonzero(mismatch_mask)

    # Iterate through each mismatch index (e.g., [0, 1])
    for index in mismatch_indices:
        # Use tuple indexing to access the specific element
        index_tuple = tuple(index)
        old_value = a[index_tuple]
        new_value = b[index_tuple]

        # Format the index for printing (e.g., [0, 1])
        index_str = str(tuple(index.tolist()))
        
        if col_keys:
            index_str = f"({str(index_tuple[0].item())}, {str(index_tuple[1].item())}({col_keys[index_tuple[1]]}))"

        print(f"Index {index_str}: Old Value = {old_value.item():.4f}, New Value = {new_value.item():.4f}")
        # Output: Index (0, 1): Old Value = 2.0000, New Value = 5.0000

In [33]:
key = '1bp0'
find_mismatch(data_dict[key].x, example_data_dict[key].x)

Index (0, 2): Old Value = 0.0000, New Value = 1.0000
Index (0, 4): Old Value = 1.0000, New Value = 0.0000
Index (0, 9): Old Value = 0.0000, New Value = 1.0000
Index (0, 12): Old Value = 0.0000, New Value = 1.0000
Index (0, 14): Old Value = 1.0000, New Value = 0.0000
Index (0, 19): Old Value = 0.0000, New Value = 1.0000
Index (0, 20): Old Value = 0.3097, New Value = 0.1401
Index (0, 29): Old Value = 0.0000, New Value = 1.0000
Index (0, 30): Old Value = 1.0000, New Value = 0.0000
Index (1, 1): Old Value = 0.0000, New Value = 1.0000
Index (1, 3): Old Value = 1.0000, New Value = 0.0000
Index (1, 9): Old Value = 0.0000, New Value = 1.0000
Index (1, 19): Old Value = 0.0000, New Value = 1.0000
Index (1, 20): Old Value = 0.1600, New Value = 0.1201
Index (1, 27): Old Value = 1.0000, New Value = 0.0000
Index (1, 29): Old Value = 0.0000, New Value = 1.0000
Index (2, 2): Old Value = 0.0000, New Value = 1.0000
Index (2, 3): Old Value = 1.0000, New Value = 0.0000
Index (2, 9): Old Value = 0.0000, Ne

In [46]:
import torch

# Assume tensor_a and tensor_b are two PyTorch tensors (e.g., node features 'x' or 'pos')

def are_tensors_row_permutation_equal(tensor_a, tensor_b, atol=1e-5, verbose=False, col_keys=None):
    """
    Checks if two tensors are equal up to row permutation using lexicographical sort.
    """
    if verbose:
        print(f"size a: {tensor_a.shape}, size b: {tensor_b.shape}")
        
    if tensor_a.shape != tensor_b.shape:
        # Tensors must have the same shape to be permutations of each other
        return False

    # 1. Sort the rows of the first tensor
    # torch.sort can only sort by one column. We use torch.lexsort for stable sorting, 
    
    # Convert to NumPy for a reliable, multi-column stable sort
    a_np = tensor_a.cpu().numpy()
    b_np = tensor_b.cpu().numpy()
    
    # Find the indices that would sort the rows lexicographically
    # numpy.lexsort sorts based on the sequence of keys (columns) provided.
    # We sort by the last column first, then the second-to-last, etc.
    sort_keys = [a_np[:, i] for i in range(a_np.shape[1] - 1, -1, -1)]
    sorted_a_indices = np.lexsort(sort_keys)
    
    sort_keys = [b_np[:, i] for i in range(b_np.shape[1] - 1, -1, -1)]
    sorted_b_indices = np.lexsort(sort_keys)
    
    # Apply the sorting indices to the original tensors
    sorted_a = torch.from_numpy(a_np[sorted_a_indices]).to(tensor_a.device)
    sorted_b = torch.from_numpy(b_np[sorted_b_indices]).to(tensor_b.device)
    
    # 2. Compare the sorted tensors with tolerance
    # Use torch.allclose for floating-point comparison
    allclose = torch.allclose(sorted_a, sorted_b, atol=atol)
    if not allclose:
        find_mismatch(sorted_a, sorted_b, col_keys=col_keys)

In [53]:
key = '1bp0'
old_data_tensor = data_dict[key].x
new_data_tensor = example_data_dict[key].x
col_keys = FULL_FEATURE_MAP
if are_tensors_row_permutation_equal(old_data_tensor, new_data_tensor, verbose=True, col_keys=col_keys):
     print("The positional coordinates match up to row ordering.")
else: 
    print("not permutation equal")

size a: torch.Size([37, 60]), size b: torch.Size([37, 60])
Index (17, 14(HybridizationType.SP3)): Old Value = 1.0000, New Value = 0.0000
Index (17, 15(HybridizationType.SP3D)): Old Value = 0.0000, New Value = 1.0000
Index (17, 21(Num_H.0)): Old Value = 1.0000, New Value = 0.0000
Index (17, 22(Num_H.1)): Old Value = 0.0000, New Value = 1.0000
Index (18, 18(FORMAL_CHARGE)): Old Value = -1.0000, New Value = 0.0000
Index (18, 21(Num_H.0)): Old Value = 1.0000, New Value = 0.0000
Index (18, 22(Num_H.1)): Old Value = 0.0000, New Value = 1.0000
Index (19, 18(FORMAL_CHARGE)): Old Value = -1.0000, New Value = 0.0000
Index (19, 21(Num_H.0)): Old Value = 1.0000, New Value = 0.0000
Index (19, 22(Num_H.1)): Old Value = 0.0000, New Value = 1.0000
Index (21, 21(Num_H.0)): Old Value = 1.0000, New Value = 0.0000
Index (21, 22(Num_H.1)): Old Value = 0.0000, New Value = 1.0000
Index (21, 27(Degree.1)): Old Value = 0.0000, New Value = 1.0000
Index (21, 28(Degree.2)): Old Value = 1.0000, New Value = 0.0000
