In [1]:
from pyrosetta import *
from rosetta import *
init( '-in:file:silent_struct_type binary -mute all' )

┌──────────────────────────────────────────────────────────────────────────────┐
│                                 PyRosetta-4                                  │
│              Created in JHU by Sergey Lyskov and PyRosetta Team              │
│              (C) Copyright Rosetta Commons Member Institutions               │
│                                                                              │
│ NOTE: USE OF PyRosetta FOR COMMERCIAL PURPOSES REQUIRE PURCHASE OF A LICENSE │
│         See LICENSE.PyRosetta.md or email license@uw.edu for details         │
└──────────────────────────────────────────────────────────────────────────────┘
PyRosetta-4 2024 [Rosetta PyRosetta4.conda.linux.cxx11thread.serialization.CentOS.python38.Release 2024.19+release.a34b73c40fe9c61558d566d6a63f803cfb15a4fc 2024-05-02T16:22:03] retrieved from: http://www.pyrosetta.org


  from rosetta import *


In [6]:
tag = "/home/sirius/PhD/scripts/oligomer_1120/TMP_012.pdb"
pose = pose_from_pdb(tag)
chain_splits = pose.split_by_chain()

In [16]:
import sys
import numpy as np
from typing import Tuple
import collections
atom_types = [
    'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
    'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
    'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
    'CZ3', 'NZ', 'OXT'
]
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
atom_type_num = len(atom_types)  # := 37.
def af2_get_atom_positions(pose, tmp_fn) -> Tuple[np.ndarray, np.ndarray]:
    '''
    Given a pose, return the AF2 atom positions array and atom mask array for the protein.
    '''

    # write pose to pdb file
    pose.dump_pdb(tmp_fn)

    with open(tmp_fn, 'r') as pdb_file:
        lines = pdb_file.readlines()

    # Delete the temporary file
    os.remove(tmp_fn)

    # indices of residues observed in the structure
    idx_s = [int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"]
    num_res = len(idx_s)

    all_positions = np.zeros([num_res, atom_type_num, 3])
    all_positions_mask = np.zeros([num_res, atom_type_num],
                                dtype=np.int64)

    residues = collections.defaultdict(list)
    # 4 BB + up to 10 SC atoms
    xyz = np.full((len(idx_s), 14, 3), np.nan, dtype=np.float32)
    for l in lines:
        if l[:4] != "ATOM":
            continue
        resNo, atom, aa = int(l[22:26]), l[12:16], l[17:20]

        residues[ resNo ].append( ( atom.strip(), aa, [float(l[30:38]), float(l[38:46]), float(l[46:54])] ) )

    for resNo in residues:

        pos = np.zeros([atom_type_num, 3], dtype=np.float32)
        mask = np.zeros([atom_type_num], dtype=np.float32)

        for atom in residues[ resNo ]:
            atom_name = atom[0]
            x, y, z = atom[2]
            if atom_name in atom_order.keys():
                pos[atom_order[atom_name]] = [x, y, z]
                mask[atom_order[atom_name]] = 1.0
            elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE':
                # Put the coordinates of the selenium atom in the sulphur column.
                pos[atom_order['SD']] = [x, y, z]
                mask[atom_order['SD']] = 1.0

        idx = idx_s.index(resNo) # This is the order they show up in the pdb
        all_positions[idx] = pos
        all_positions_mask[idx] = mask

    return all_positions, all_positions_mask

In [17]:
all_positions, all_positions_mask = af2_get_atom_positions(pose, "./test.pdb")

In [21]:
list_all_atom_positions = np.split(all_positions, all_positions.shape[0])

In [32]:
seq = pose.sequence()
binderlen = 100
residue_mask = [int(i) > binderlen for i in range(len(seq))]

In [42]:
residue_mask[100]

False

In [35]:
np.sum(residue_mask)

691

In [36]:
len(seq)

792

In [None]:
def parse_initial_guess(all_atom_positions) -> jnp.ndarray:
    '''
    Given a numpy array of all atom positions, return a jax array of the initial guess
    '''

    list_all_atom_positions = np.split(all_atom_positions, all_atom_positions.shape[0])

    templates_all_atom_positions = []

    # Initially fill with zeros
    for _ in list_all_atom_positions:
        templates_all_atom_positions.append(jnp.zeros((residue_constants.atom_type_num, 3)))

    for idx in range(len(list_all_atom_positions)):
        templates_all_atom_positions[idx] = list_all_atom_positions[idx][0] 

    return jnp.array(templates_all_atom_positions)