## 2024  09 18 framediff learning

### 1. PDB processing + AlphaFold mmcif object

In [1]:
import Bio
from Bio.PDB import PDBIO, MMCIFParser, PDBParser

import sys
sys.path.append("/home/sirius/PhD/software/se3_diffusion")
from data import errors, mmcif_parsing, parsers
from data import utils as du
import os
import dataclasses
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def process_mmcif_mimic_framediff(mmcif_path,max_resolution,max_len):
    metadata = {}
    mmcif_name = os.path.basename(mmcif_path).replace('.cif', '')
    metadata['pdb_name'] = mmcif_name
    with open(mmcif_path, 'r') as f:
        parsed_mmcif = mmcif_parsing.parse(
            file_id=mmcif_name, mmcif_string=f.read())
    metadata['raw_path'] = mmcif_path
    parsed_mmcif = parsed_mmcif.mmcif_object
    raw_mmcif = parsed_mmcif.raw_string
    raw_olig_count = raw_mmcif['_pdbx_struct_assembly.oligomeric_count']
    oligomeric_count = ','.join(raw_olig_count).lower()
    raw_olig_detail = raw_mmcif['_pdbx_struct_assembly.oligomeric_details']
    oligomeric_detail = ','.join(raw_olig_detail).lower()
    metadata['oligomeric_count'] = oligomeric_count
    metadata['oligomeric_detail'] = oligomeric_detail

    mmcif_header = parsed_mmcif.header
    mmcif_resolution = mmcif_header['resolution']
    metadata['resolution'] = mmcif_resolution
    metadata['structure_method'] = mmcif_header['structure_method']
    if mmcif_resolution >= max_resolution:
        raise errors.ResolutionError(
            f'Too high resolution {mmcif_resolution}')
    if mmcif_resolution == 0.0:
        raise errors.ResolutionError(
            f'Invalid resolution {mmcif_resolution}')

        # Extract all chains
    struct_chains = {
        chain.id.upper(): chain
        for chain in parsed_mmcif.structure.get_chains()}
    metadata['num_chains'] = len(struct_chains)

    # Extract features
    struct_feats = []
    all_seqs = set()
    for chain_id, chain in struct_chains.items():
        # Convert chain id into int
        chain_id = du.chain_str_to_int(chain_id)
        chain_prot = parsers.process_chain(chain, chain_id)
        chain_dict = dataclasses.asdict(chain_prot)
        chain_dict = du.parse_chain_feats(chain_dict)
        all_seqs.add(tuple(chain_dict['aatype']))
        struct_feats.append(chain_dict)
    if len(all_seqs) == 1:
        metadata['quaternary_category'] = 'homomer'
    else:
        metadata['quaternary_category'] = 'heteromer'
    complex_feats = du.concat_np_features(struct_feats, False)

    # Process geometry features
    complex_aatype = complex_feats['aatype']
    modeled_idx = np.where(complex_aatype != 20)[0]
    if np.sum(complex_aatype != 20) == 0:
        raise errors.LengthError('No modeled residues')
    min_modeled_idx = np.min(modeled_idx)
    max_modeled_idx = np.max(modeled_idx)
    metadata['seq_len'] = len(complex_aatype)
    metadata['modeled_seq_len'] = max_modeled_idx - min_modeled_idx + 1
    complex_feats['modeled_idx'] = modeled_idx
    return complex_feats, metadata

## ProteinMPNN chain and PDBTM chain not consistent

In [3]:
mmcif_path = "/home/sirius/Downloads/4a2n.cif"
max_resolution = 8.0
max_len = 2000
complex_feats, metadata = process_mmcif_mimic_framediff(mmcif_path,max_resolution,max_len)

In [8]:
with open(mmcif_path, 'r') as f:
    parsed_mmcif = mmcif_parsing.parse(
        file_id="test", mmcif_string=f.read())

In [14]:
parsed_mmcif.mmcif_object.header

{'structure_method': 'x-ray diffraction',
 'release_date': '2012-01-11',
 'resolution': 3.4}

In [6]:
complex_feats.keys()

dict_keys(['atom_positions', 'aatype', 'atom_mask', 'residue_index', 'chain_index', 'b_factors', 'bb_mask', 'bb_positions', 'modeled_idx'])

In [7]:
complex_feats['chain_index']

array([27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
       27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
       27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
       27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
       27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
       27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
       27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
       27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
       27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
       27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
       27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
       27, 27, 27, 27, 27, 27, 27, 27])

## End check

In [16]:
mmcif_path = "/home/sirius/PhD/software/se3_diffusion/data/1h2s.cif"
max_resolution = 8.0
max_len = 2000
complex_feats, metadata = process_mmcif_mimic_framediff(mmcif_path,max_resolution,max_len)

modeled_idx is which residue is not X

bb_mask is whether this residue has CA atom

a modeled_idx can also be bb_mask

In [21]:
for key in complex_feats.keys():
    print(key, complex_feats[key].shape)

atom_positions (327, 37, 3)
aatype (327,)
atom_mask (327, 37)
residue_index (327,)
chain_index (327,)
b_factors (327, 37)
bb_mask (327,)
bb_positions (327, 3)
modeled_idx (285,)


In [22]:
metadata

{'pdb_name': '1h2s',
 'raw_path': '/home/sirius/PhD/software/se3_diffusion/data/1h2s.cif',
 'oligomeric_count': '4',
 'oligomeric_detail': 'tetrameric',
 'resolution': 1.93,
 'structure_method': 'x-ray diffraction',
 'num_chains': 2,
 'quaternary_category': 'heteromer',
 'seq_len': 327,
 'modeled_seq_len': 326}

## 2. from framediff to obtain different chain (modeled residue)

In [1]:
token2seq_dict = {0:'A', 1:'C', 2:'D', 3:'E', 4:'F', 5:'G', 6:'H', 7:'I', 8:'K', 9:'L', 10:'M', 11:'N', 12:'P', 
                13:'Q', 14:'R', 15:'S', 16:'T', 17:'V', 18:'W', 19:'Y', 20:'X', 21:'O', 22:'U', 23:'B', 24:'Z', 25:'-', 26:'.', 
                27:'<mask>', 28: '<pad>',}


In [20]:
import numpy as np
random_length = 200
random_seq = ''.join([token2seq_dict[np.random.randint(0, 20)] for i in range(random_length)])
print("random_seq", random_seq)

random_seq GGLEPEPLIKNVMGMDMFAGQMGDTWRIMAHPQSFITIHDFLKFQQMVRDLNHTIFSYNKFERSLTLTLRLELFGEEESTNVKAHPLRTETMADGCNMLNFKRLMLLVANCYDHRFYYAQHHLTTAREKMGCVNAWGRWKQGVYMTIWWDENFKQRMDTFLESTLCSYCRTLLVVPAWHFCWCPKKAMYYYYDRTGRPSQ


In [33]:
def select_positions(n_mutations, seq, plddt, select_positions='plddt'):
    mutable_positions = []
    if select_positions == 'random':
        # Choose positions randomly.
        mutable_positions= np.random.choice(range(len(seq)), size=n_mutations, replace=False)


    elif select_positions == 'plddt':
        # Choose positions based on lowest plddt (taken across all states/oligomer for each protomer).
        # First/last three positions of each protomers are choice frequency adjusted to avoid picking N/C term every time (they tend to score much lower).


        mutate_plddt_quantile = 0.25 # default worst pLDDT quantile to mutate.


        proto_L = len(seq)

        # Weights associated with each position in the protomer.
        # to account for termini systematically scoring worse in pLDDT.
        weights = np.array([0.25, 0.5, 0.75] + [1] * (proto_L - 6) + [0.75, 0.5, 0.25])

        # Sub-select lowest % quantile of plddt positions.
        n_potential = round(proto_L * mutate_plddt_quantile)
        print("n_potential", n_potential)
        consensus_min = np.min(plddt, axis=0)
        print("consensus_min", consensus_min)
        potential_sites = np.argsort(consensus_min)[:n_potential]
        print("potential_sites", potential_sites)

        # Select mutable sites
        sub_w = weights[potential_sites]
        print("sub_w", sub_w)
        sub_w = [w/np.sum(sub_w) for w in sub_w]
        print("sub_w", sub_w)
        sites = np.random.choice(potential_sites, size=n_mutations, replace=False, p=sub_w)

        mutable_positions = sites
    return mutable_positions

In [34]:
consensus_min = np.min(plddt, axis=0)

In [35]:
n_potential = round(len(seq) * 0.25)
potential_sites = np.argsort(consensus_min)[:n_potential]

In [36]:
np.argsort(consensus_min)

array([144,  84, 142, 145, 136,  37,  56, 120, 167,  68, 173,  59,  49,
       151, 195, 106,   2,  73, 140,  46, 110,  95,  65,  12,  48,   8,
        82,  38,  17, 156, 131, 191,   3,  88, 166, 180, 114, 163, 168,
        93, 176, 158,  29,  99,  90,  97,  23,  63,  34,  33,  36,  28,
       181,  39,  60, 100,  24,   7,  89,  25, 192, 124, 121, 115,  35,
       105,  26,  52,  13, 102, 170,  10,  27, 116,  31, 113, 177,  58,
       109, 159,  44,   5,  30, 147,  94, 152,  32,  18,  45, 172, 150,
       174,  41, 185,  40, 108, 111, 197,  86,  43,  42, 133, 127,  14,
       162, 126, 129, 189, 123,  53,  79,  81,  47,  22, 143,  72,  78,
       137,  67, 190,  85,   6, 118, 117,  16,  61, 188, 139, 149,  77,
       178, 153,  98, 101, 164,  80,  15,  83,   1,  66, 187,  92, 169,
       165, 175,  21,  57, 119, 183,  75, 160, 148,  91, 196, 193, 199,
        74,  96, 130,  51,  87, 186,  69, 182,  64,  70, 128,  71, 122,
       112,   9, 171, 125,   4, 154, 104, 146, 157, 135, 161,  5

In [37]:
n_mutation = 3
seq = random_seq    
plddt = np.random.rand(1, len(seq))
select_positions(n_mutation, seq, plddt, select_positions='plddt')

n_potential 50
consensus_min [0.67991993 0.26378496 0.18109359 0.39947967 0.7163121  0.13783264
 0.8614685  0.08578992 0.77003964 0.54391734 0.64565287 0.82070195
 0.9552196  0.23292149 0.39489896 0.1182073  0.46144875 0.81987643
 0.98049278 0.25808346 0.73930927 0.7708222  0.41862547 0.17380925
 0.70748985 0.54950291 0.37264936 0.57239608 0.64097485 0.08048683
 0.90783048 0.71416334 0.02783704 0.93322387 0.67972782 0.06035076
 0.83755585 0.02235071 0.97216163 0.93606308 0.14000451 0.78488458
 0.34574262 0.88226199 0.64664325 0.79654754 0.4133937  0.21479152
 0.27035165 0.40062728 0.52380658 0.66069971 0.25267376 0.52290115
 0.21954544 0.58495492 0.7988917  0.32635508 0.43369932 0.27049779
 0.87199179 0.65525485 0.85770219 0.08791165 0.60426651 0.96960753
 0.70136095 0.970794   0.36076204 0.99772002 0.59892164 0.83775233
 0.920749   0.01257469 0.01009929 0.91493031 0.1061782  0.48316049
 0.49281014 0.82409957 0.12322792 0.42623725 0.41373013 0.43101214
 0.32987575 0.72117002 0.50032504

array([132,  32, 194])

In [31]:
plddt.shape

(200,)

In [19]:
np.min(plddt, axis=0)

0.004324417699140826

In [23]:
plddt

array([0.94466765, 0.35411377, 0.13133574, 0.62673324, 0.84319948,
       0.64793025, 0.6168724 , 0.59099241, 0.65483817, 0.45938927,
       0.19604578, 0.83483383, 0.81739394, 0.85594978, 0.34495433,
       0.71568908, 0.35878587, 0.569031  , 0.74240711, 0.98657093,
       0.73409419, 0.57650353, 0.89320624, 0.82503796, 0.77674735,
       0.13603442, 0.22548158, 0.13620494, 0.54673154, 0.02608368,
       0.5851039 , 0.42936689, 0.2882605 , 0.59071571, 0.7325343 ,
       0.26741469, 0.26729874, 0.99642178, 0.49246782, 0.60826715,
       0.2006435 , 0.4083844 , 0.06862138, 0.10747409, 0.58152769,
       0.92430111, 0.47993575, 0.21817594, 0.31861625, 0.97095885,
       0.47925291, 0.49145359, 0.74871333, 0.96650504, 0.94489188,
       0.86572996, 0.84917946, 0.0219251 , 0.63692467, 0.04678527,
       0.48505089, 0.21614411, 0.11926675, 0.46895851, 0.18889543,
       0.65679202, 0.95285333, 0.81676777, 0.33813803, 0.49673921,
       0.50772039, 0.06400917, 0.51053618, 0.94408299, 0.52686