# torch-structure-manipulation

Just a simple notebook to try out functions.


In [3]:
import sys
import time
from pathlib import Path

sys.path.insert(0, str(Path.cwd() / 'src'))

from torch_structure_manipulation.structure_loader import load_structure, StructureLoadOptions
from torch_structure_manipulation.structure_transforms import (
    apply_translation,
    return_atoms_by_radius,
    separate_protein_rna,
)

In [4]:
# 1. Load structure with bonding information
cif_file = "tests/4V6X.cif"

print("Loading structure with bonding information...")
start_time = time.time()

options = StructureLoadOptions(
    center_atoms=True,
    center_atoms_by_mass=False,
    center_point=(0.0, 0.0, 0.0),
    include_hydrogens=True,
    load_bonded_environment=True,
)
df = load_structure(cif_file, options=options)

load_time = time.time() - start_time

print(f"Loaded {len(df)} atoms in {load_time:.3f}s")
print(f"Elements: {sorted(df['element'].unique())}")
print(f"Chains: {len(df['chain'].unique())} chains")
print("\nBonded environments sample:")
print(df[['element', 'bonded_environment', 'molecule_type']].head(10))
print(f"\nMolecule types: {df['molecule_type'].unique()}")
display(df.head())


Loading structure with bonding information...
Loaded 237685 atoms in 7.278s
Elements: ['C', 'N', 'O', 'P', 'S']
Chains: 89 chains

Bonded environments sample:
  element bonded_environment molecule_type
0       N              N(CH)       protein
1       C            C(CCHN)       protein
2       C             C(CNO)       protein
3       O               O(C)       protein
4       C              C(CC)       protein
5       C             C(CNO)       protein
6       O               O(C)       protein
7       N               N(C)       protein
8       N             N(CCH)       protein
9       C            C(CCHN)       protein

Molecule types: ['protein' 'rna']


Unnamed: 0,model,chain,residue,residue_id,atom,element,atomic_number,atomic_weight,covalent_radius,van_der_waals_radius,heteroatom_flag,x,y,z,charge,occupancy,b_isotropic,bonded_environment,molecule_type
0,1,Az,ASN,3,N,N,7,14.0067,0.71,1.55,A,-12.022888,78.351494,-51.032963,0,1.0,10.0,N(CH),protein
1,1,Az,ASN,3,CA,C,6,12.0107,0.73,1.7,A,-12.529888,77.031487,-50.894966,0,1.0,10.0,C(CCHN),protein
2,1,Az,ASN,3,C,C,6,12.0107,0.73,1.7,A,-13.491888,77.359489,-49.845963,0,1.0,10.0,C(CNO),protein
3,1,Az,ASN,3,O,O,8,15.9994,0.66,1.52,A,-14.126888,76.555489,-49.210964,0,1.0,10.0,O(C),protein
4,1,Az,ASN,3,CB,C,6,12.0107,0.73,1.7,A,-11.419888,76.161491,-50.442963,0,1.0,10.0,C(CC),protein


In [5]:
# 2. Analyze bonded environments
print("Analyzing bonded environments...")

# Count unique bonded environments
bonded_env_counts = df['bonded_environment'].value_counts()
print(f"Total unique bonded environments: {len(bonded_env_counts)}")
print("\nTop 15 most common bonded environments:")
display(bonded_env_counts.head(15))

# Show examples by element
print("\nSample environments by element:")
for element in ['C', 'N', 'O', 'P']:
    element_envs = df[df['element'] == element]['bonded_environment'].value_counts().head(5)
    if len(element_envs) > 0:
        print(f"\n{element}:")
        display(element_envs)


Analyzing bonded environments...
Total unique bonded environments: 42

Top 15 most common bonded environments:


bonded_environment
O(C)       20701
C(CC)      18697
C(CCHO)    18283
C(CNO)     17239
N(CCH)     16964
O(P)       13014
C(CCHN)    12400
N(CC)       9969
C(C)        7314
O(CH)       6924
C(CHHO)     6512
O(CP)       6507
P(OOOO)     6498
O(CHP)      6498
C(CNN)      6016
Name: count, dtype: int64


Sample environments by element:

C:


bonded_environment
C(CC)      18697
C(CCHO)    18283
C(CNO)     17239
C(CCHN)    12400
C(C)        7314
Name: count, dtype: int64


N:


bonded_environment
N(CCH)     16964
N(CC)       9969
N(C)        5013
N(CHH)      4785
N(CCCH)     4250
Name: count, dtype: int64


O:


bonded_environment
O(C)      20701
O(P)      13014
O(CH)      6924
O(CP)      6507
O(CHP)     6498
Name: count, dtype: int64


P:


bonded_environment
P(OOOO)    6498
P(OOO)        9
Name: count, dtype: int64

In [6]:
# 3. Analyze molecule types and bonding patterns
print("Analyzing molecule types...")

molecule_type_counts = df['molecule_type'].value_counts()
print("Molecule type distribution:")
display(molecule_type_counts)

# Show bonding patterns by molecule type
print("\nTop bonded environments by molecule type:")
for mol_type in df['molecule_type'].unique():
    mol_df = df[df['molecule_type'] == mol_type]
    top_envs = mol_df['bonded_environment'].value_counts().head(10)
    print(f"\n{mol_type.upper()}:")
    display(top_envs)



Analyzing molecule types...
Molecule type distribution:


molecule_type
rna        130839
protein    106846
Name: count, dtype: int64


Top bonded environments by molecule type:

PROTEIN:


bonded_environment
C(CC)             18697
O(C)              14960
C(CNO)            14220
N(CCH)            13945
C(CCHN)           12400
C(C)               7314
N(C)               5013
C(CN)              4751
C(CCCH)            2849
O(C, carboxyl)     2439
Name: count, dtype: int64


RNA:


bonded_environment
C(CCHO)    17628
O(P)       13014
N(CC)       9175
C(CHHO)     6512
O(CP)       6507
O(CHP)      6498
P(OOOO)     6498
C(CNN)      6016
O(CH)       5890
O(CC)       5876
Name: count, dtype: int64

In [7]:
# 4. Apply structure transformations
print("Applying structure transformations...")

# Note: Structure is already centered at origin from loading
print("\n1. Checking structure center...")
current_center = df[['x', 'y', 'z']].mean()
print(f"   Current center: ({current_center['x']:.2f}, {current_center['y']:.2f}, {current_center['z']:.2f})")

# Apply translation
print("\n2. Applying translation...")
translation = (2.0, -5.0, 10.0)  # (dz, dy, dx) order for structure_transforms
translated_df = apply_translation(df, translation)
print(f"   Translated by {translation} (dz, dy, dx)")

# Filter by radius from center
print("\n3. Filtering atoms by radius...")
center_point = (0.0, 0.0, 0.0)  # (z, y, x) order for structure_transforms
radius = 20.0
atoms_inside, atoms_outside = return_atoms_by_radius(df, center_point, radius=radius)
print(f"   Atoms within {radius}Å of {center_point}: {len(atoms_inside)}")
print(f"   Atoms outside {radius}Å: {len(atoms_outside)}")
print(f"   Total: {len(df)} → {len(atoms_inside)} inside")

# Structure analysis
print("\n4. Structure analysis...")
molecule_type_counts = df['molecule_type'].value_counts()
print("Molecule type distribution:")
for mol_type, count in molecule_type_counts.items():
    print(f"     {mol_type}: {count} atoms ({count/len(df)*100:.1f}%)")

# Separate protein/RNA
print("\n5. Separating protein and RNA...")
protein_df, rna_df = separate_protein_rna(df)
print(f"   Separated: {len(protein_df)} protein + {len(rna_df)} RNA atoms")

print(f"\nTotal processing time: {load_time:.3f}s for {len(df)} atoms")



Applying structure transformations...

1. Checking structure center...
   Current center: (0.00, 0.00, 0.00)

2. Applying translation...
   Translated by (2.0, -5.0, 10.0) (dz, dy, dx)

3. Filtering atoms by radius...
   Atoms within 20.0Å of (0.0, 0.0, 0.0): 1108
   Atoms outside 20.0Å: 236577
   Total: 237685 → 1108 inside

4. Structure analysis...
Molecule type distribution:
     rna: 130839 atoms (55.0%)
     protein: 106846 atoms (45.0%)

5. Separating protein and RNA...
   Separated: 106846 protein + 130839 RNA atoms

Total processing time: 7.278s for 237685 atoms


In [8]:
# 5. Additional examples: Working with coordinates
print("Working with coordinates and tensors...")

from torch_structure_manipulation.structure_loader import get_zyx_coords, df_params_to_tensors

# Extract coordinates as tensor
coords = get_zyx_coords(df)
print(f"Coordinates tensor shape: {coords.shape}")
print(f"Coordinate range: z=[{coords[:, 0].min():.2f}, {coords[:, 0].max():.2f}], "
      f"y=[{coords[:, 1].min():.2f}, {coords[:, 1].max():.2f}], "
      f"x=[{coords[:, 2].min():.2f}, {coords[:, 2].max():.2f}]")

# Extract all parameters as tensors/lists
atom_zyx, atom_id, atom_b_factor, atom_bonded_id, molecule_type = df_params_to_tensors(df)
print("\nExtracted parameters:")
print(f"  Coordinates: {atom_zyx.shape}")
print(f"  Atom IDs: {len(atom_id)} elements")
print(f"  B factors: {atom_b_factor.shape}")
print(f"  Bonded IDs: {len(atom_bonded_id)}")
print(f"  Molecule types: {len(molecule_type)}")
print(f"\nSample atom IDs: {atom_id[:10]}")
print(f"Sample bonded IDs: {atom_bonded_id[:10]}")


Working with coordinates and tensors...
Coordinates tensor shape: torch.Size([237685, 3])
Coordinate range: z=[-177.81, 140.94], y=[-137.96, 186.70], x=[-157.38, 128.06]

Extracted parameters:
  Coordinates: torch.Size([237685, 3])
  Atom IDs: 237685 elements
  B factors: torch.Size([237685])
  Bonded IDs: 237685
  Molecule types: 237685

Sample atom IDs: ['N', 'C', 'C', 'O', 'C', 'C', 'O', 'N', 'N', 'C']
Sample bonded IDs: ['N(CH)', 'C(CCHN)', 'C(CNO)', 'O(C)', 'C(CC)', 'C(CNO)', 'O(C)', 'N(C)', 'N(CCH)', 'C(CCHN)']
