# torch-structure-manipulation

Just a simple notebook to try out functions.


In [1]:
import sys
import pandas as pd
import torch
import time
from pathlib import Path

sys.path.insert(0, str(Path.cwd() / 'src'))

import mmdf
from torch_structure_manipulation import (
    FastCIFBondParser,
    FastAtomEnvironmentMapper,
    StructureTransforms
)

In [2]:
# 1. standard reading with mmdf
cif_file = "tests/4V6X.cif"

print("Reading CIF file...")
start_time = time.time()
mmdf_df = mmdf.read(cif_file)
read_time = time.time() - start_time

print(f"Read {len(mmdf_df)} atoms in {read_time:.3f}s")
print(f"Elements: {sorted(list(mmdf_df['element'].unique()))}")
print(f"Chains: {len(mmdf_df['chain'].unique())} chains")
display(mmdf_df.head()) 


Reading CIF file...
Read 237685 atoms in 2.558s
Elements: ['C', 'N', 'O', 'P', 'S']
Chains: 89 chains


Unnamed: 0,model,chain,residue,residue_id,atom,element,atomic_number,atomic_weight,covalent_radius,van_der_waals_radius,heteroatom_flag,x,y,z,charge,occupancy,b_isotropic
0,1,Az,ASN,3,N,N,7,14.0067,0.71,1.55,A,-12.173,76.285,-54.829,0,1.0,10.0
1,1,Az,ASN,3,CA,C,6,12.0107,0.73,1.7,A,-12.68,74.965,-54.691,0,1.0,10.0
2,1,Az,ASN,3,C,C,6,12.0107,0.73,1.7,A,-13.642,75.293,-53.642,0,1.0,10.0
3,1,Az,ASN,3,O,O,8,15.9994,0.66,1.52,A,-14.277,74.489,-53.007,0,1.0,10.0
4,1,Az,ASN,3,CB,C,6,12.0107,0.73,1.7,A,-11.57,74.095,-54.239,0,1.0,10.0


In [3]:
# 2. Extract ALL covalent bonds from CIF using proper mmCIF tables
print("Extracting covalent bonds from CIF...")
start_time = time.time()

parser = FastCIFBondParser()
bonds_df = parser.extract_bonds_for_mmdf(cif_file, mmdf_df)

bond_time = time.time() - start_time
print(f"Extracted {len(bonds_df)} covalent bonds in {bond_time:.3f}s")

if len(bonds_df) > 0:
    print(f"\nBond types:")
    bond_type_counts = bonds_df.groupby(['atom1_element', 'atom2_element']).size().sort_values(ascending=False)
    display(bond_type_counts.head(10))
    
    print(f"\nSample bonds:")
    display(bonds_df[['atom1_idx', 'atom2_idx', 'bond_order', 'atom1_element', 'atom2_element']].head(10))
else:
    print("No bonds found")


Extracting covalent bonds from CIF...
Mapped 234702 bonds to mmdf indices
Extracted 234702 covalent bonds in 17.175s

Bond types:


atom1_element  atom2_element
C              C                89764
               O                42433
N              C                39554
C              N                30247
P              O                19521
O              C                12388
C              S                  502
S              C                  293
dtype: int64


Sample bonds:


Unnamed: 0,atom1_idx,atom2_idx,bond_order,atom1_element,atom2_element
0,146910,146911,2.0,P,O
1,146910,146912,1.0,P,O
2,146910,146913,1.0,P,O
3,146913,146914,1.0,O,C
4,146914,146915,1.0,C,C
5,146915,146916,1.0,C,O
6,146915,146917,1.0,C,C
7,146916,146921,1.0,O,C
8,146917,146918,1.0,C,O
9,146917,146919,1.0,C,C


In [4]:
# 3. Generate atom environments based on covalent connectivity
print("Mapping atom environments...")
start_time = time.time()

env_mapper = FastAtomEnvironmentMapper()
mmdf_with_envs = env_mapper.map_environments(mmdf_df, bonds_df)

env_time = time.time() - start_time
print(f"Mapped environments for {len(mmdf_with_envs)} atoms in {env_time:.3f}s")

# Show sample environments
print(f"\nSample atom environments:")
display(mmdf_with_envs[['element', 'environment_id', 'coordination_number', 'bonded_elements']].head(10))

# Environment statistics
env_stats = env_mapper.get_environment_statistics(mmdf_with_envs)
print(f"\nTop 15 atom environments:")
display(env_stats.head(15))



Mapping atom environments...
Mapped environments for 237685 atoms in 1.569s

Sample atom environments:


Unnamed: 0,element,environment_id,coordination_number,bonded_elements
0,N,N(C),1,[C]
1,C,C(CCN),3,"[C, C, N]"
2,C,C(CO),2,"[C, O]"
3,O,O(C),1,[C]
4,C,C(CC),2,"[C, C]"
5,C,C(CNO),3,"[C, N, O]"
6,O,O(C),1,[C]
7,N,N(C),1,[C]
8,N,N(C),1,[C]
9,C,C(CCN),3,"[C, C, N]"



Top 15 atom environments:


Unnamed: 0,environment_id,count,percentage
0,O(C),36562,15.38
1,N(C),22573,9.5
2,C(CC),22469,9.45
3,C(CO),21080,8.87
4,C(CCO),18684,7.86
5,C(CCN),16013,6.74
6,N(CC),14800,6.23
7,O(P),13014,5.48
8,C(CNO),9862,4.15
9,C(CN),7468,3.14


In [5]:
# This part is not really useful:
#  I just put the basic functions and how I'd use them, also as sort of sanity check they kind of work but I need to add some test in here.

# 4. Apply structure transformations
print("Applying structure transformations...")

transforms = StructureTransforms()

# Center structure at origin
print("\n1. Centering structure...")
centered_df = transforms.center_structure(mmdf_df)
original_center = mmdf_df[['x', 'y', 'z']].mean()
new_center = centered_df[['x', 'y', 'z']].mean()
print(f"   Original center: ({original_center['x']:.2f}, {original_center['y']:.2f}, {original_center['z']:.2f})")
print(f"   New center: ({new_center['x']:.2f}, {new_center['y']:.2f}, {new_center['z']:.2f})")

# Apply translation
print("\n2. Applying translation...")
translation = (10.0, -5.0, 2.0)
translated_df = transforms.apply_translation(mmdf_df, translation)
print(f"   Translated by {translation}")

# Filter by radius from center
print("\n3. Filtering atoms by radius...")
center_point = (0, 0, 0)
radius = 20.0
filtered_df = transforms.remove_atoms_by_radius(centered_df, center_point, radius=radius, keep_inside=True)
print(f"   Kept atoms within {radius}Å of {center_point}: {len(mmdf_df)} → {len(filtered_df)} atoms")

# Remove sidechains (if protein atoms present)
print("\n4. Structure analysis...")
protein_residues = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLU', 'GLN', 'GLY', 'HIS', 'ILE', 
                   'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL']
rna_residues = ['A', 'U', 'G', 'C', 'DA', 'DT', 'DG', 'DC']

protein_atoms = mmdf_df[mmdf_df['residue'].isin(protein_residues)]
rna_atoms = mmdf_df[mmdf_df['residue'].isin(rna_residues)]

print(f"   Protein atoms: {len(protein_atoms)} ({len(protein_atoms)/len(mmdf_df)*100:.1f}%)")
print(f"   RNA/DNA atoms: {len(rna_atoms)} ({len(rna_atoms)/len(mmdf_df)*100:.1f}%)")
print(f"   Other atoms: {len(mmdf_df) - len(protein_atoms) - len(rna_atoms)}")

total_time = read_time + bond_time + env_time
print(f"\nTotal processing time: {total_time:.3f}s for {len(mmdf_df)} atoms")

# Separate protein/RNA
protein_df, rna_df = transforms.separate_protein_rna(mmdf_df)
print(f"Separated: {len(protein_df)} protein + {len(rna_df)} RNA")



Applying structure transformations...

1. Centering structure...
   Original center: (-0.15, -2.07, -3.80)
   New center: (0.00, 0.00, 0.00)

2. Applying translation...
   Translated by (10.0, -5.0, 2.0)

3. Filtering atoms by radius...
   Kept atoms within 20.0Å of (0, 0, 0): 237685 → 1108 atoms

4. Structure analysis...
   Protein atoms: 106846 (45.0%)
   RNA/DNA atoms: 130839 (55.0%)
   Other atoms: 0

Total processing time: 21.301s for 237685 atoms
Separated: 106846 protein + 130839 RNA
