In [1]:
# 1. CONFIG 用yaml
# BUILD DATALOADER
# 2. 生成xyz_27
# 3. (注释版) 生成所有需要放入diffusion的东西

In [2]:
import torch
import numpy as np

In [3]:
from for_test import process_target, construct_contig, get_idx0_hotspots, get_init_xyz


In [4]:
INPUT_PDB = '1a2y.pdb'

In [5]:
### Parse input pdb ###
target_feats = process_target(INPUT_PDB, parse_hetatom=True, center=False)
# dict_keys(['xyz_27', 'mask_27', 'seq', 'pdb_idx', 'xyz_het', 'info_het'])

In [6]:
len(target_feats['seq'])

352

In [7]:
# contigmap:
#   contigs: null
#   inpaint_seq: null
#   inpaint_str: null
#   provide_seq: null
#   length: null

CONTIG_CONF = {
    'contigs': ['352-352'],
    'inpaint_seq': None,
    'inpaint_str': None,
    'provide_seq': None,
    'length': None
}


In [8]:
### Generate specific contig ###

# Generate a specific contig from the range of possibilities specified at input


contig_map = construct_contig(target_feats, CONTIG_CONF)
mappings = contig_map.get_mappings()
mask_seq = torch.from_numpy(contig_map.inpaint_seq)[None,:]
mask_str = torch.from_numpy(contig_map.inpaint_str)[None,:]
binderlen =  len(contig_map.inpaint) # binderlen

In [9]:
PPI_CONF = {
    'hotspot_res': None
}

In [10]:
### Get Hotspots ###

hotspot_0idx = get_idx0_hotspots(mappings, PPI_CONF, binderlen) # None

In [11]:
### Initialize other attributes ###

# BASE VAR
xyz_27 = target_feats['xyz_27']
mask_27 = target_feats['mask_27']
seq_orig = target_feats['seq'] # 352
L_mapped = len(contig_map.ref)
contig_map = contig_map

# DIFFUSION VAR 4
diffusion_mask = mask_str #[[False]]
chain_idx = ['A' if i < binderlen else 'B' for i in range(L_mapped)] # ['A']

In [12]:
### Generate initial coordinates ###
# Fully diffusing from points initialised at the origin
# adjust size of input xt according to residue map

xyz_mapped = torch.full((1,1,L_mapped,27,3), np.nan) 
xyz_mapped[:, :, contig_map.hal_idx0, ...] = xyz_27[contig_map.ref_idx0,...] 
xyz_motif_prealign = xyz_mapped.clone() 
motif_prealign_com = xyz_motif_prealign[0,0,:,1].mean(dim=0)
motif_com = xyz_27[contig_map.ref_idx0,1].mean(dim=0)

# DIFFUSION VAR 1
xyz_mapped = get_init_xyz(xyz_mapped).squeeze() # torch.Size([352, 27, 3])

# adjust the size of the input atom map
atom_mask_mapped = torch.full((L_mapped, 27), False)

# DIFFUSION VAR 3
atom_mask_mapped[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0] # torch.Size([352, 27])

# Diffuse the contig-mapped coordinates 
t_step_input = 50 # int(diffuser_conf.T)

# DIFFUSION VAR 5
t_list = np.arange(1, t_step_input+1)

In [13]:
### Generate initial sequence ###

seq_t = torch.full((1,L_mapped), 21).squeeze() # 21 is the mask token
seq_t[contig_map.hal_idx0] = seq_orig[contig_map.ref_idx0]

# # Unmask sequence if desired
# if self._conf.contigmap.provide_seq is not None:
#     seq_t[self.mask_seq.squeeze()] = seq_orig[self.mask_seq.squeeze()] 

seq_t[~mask_seq.squeeze()] = 21

# DIFFUSION VAR 2
seq_t    = torch.nn.functional.one_hot(seq_t, num_classes=22).float() # [L,22]
seq_orig = torch.nn.functional.one_hot(seq_orig, num_classes=22).float() # [L,22]

# fa_stack, xyz_true = self.diffuser.diffuse_pose(
#     xyz_mapped,
#     torch.clone(seq_t),
#     atom_mask_mapped.squeeze(),
#     diffusion_mask=self.diffusion_mask.squeeze(),
#     t_list=t_list)
# xT = fa_stack[-1].squeeze()[:,:14,:]
# xt = torch.clone(xT)

# self.denoiser = self.construct_denoiser(len(self.contig_map.ref), visible=self.mask_seq.squeeze())

#         return xt, seq_t

In [14]:
seq_t

tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.]])

In [15]:
seq_orig

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [16]:
# #####################################
# ### Initialise Potentials Manager ###
# #####################################

# self.potential_manager = PotentialManager(self.potential_conf,
#                                             self.ppi_conf,
#                                             self.diffuser_conf,
#                                             self.inf_conf,
#                                             self.hotspot_0idx,
#                                             self.binderlen)


    
# ####################################
# ### Generate initial coordinates ###
# ####################################

# if self.diffuser_conf.partial_T: # None
#     assert xyz_27.shape[0] == L_mapped, f"there must be a coordinate in the input PDB for \
#             each residue implied by the contig string for partial diffusion.  length of \
#             input PDB != length of contig string: {xyz_27.shape[0]} != {L_mapped}"
#     assert contig_map.hal_idx0 == contig_map.ref_idx0, f'for partial diffusion there can \
#             be no offset between the index of a residue in the input and the index of the \
#             residue in the output, {contig_map.hal_idx0} != {contig_map.ref_idx0}'
#     # Partially diffusing from a known structure
#     xyz_mapped=xyz_27
#     atom_mask_mapped = mask_27
# else:
#     # Fully diffusing from points initialised at the origin
#     # adjust size of input xt according to residue map
#     xyz_mapped = torch.full((1,1,L_mapped,27,3), np.nan)
#     xyz_mapped[:, :, contig_map.hal_idx0, ...] = xyz_27[contig_map.ref_idx0,...]
#     xyz_motif_prealign = xyz_mapped.clone()
#     motif_prealign_com = xyz_motif_prealign[0,0,:,1].mean(dim=0)
#     self.motif_com = xyz_27[contig_map.ref_idx0,1].mean(dim=0)
#     xyz_mapped = get_init_xyz(xyz_mapped).squeeze()
#     # adjust the size of the input atom map
#     atom_mask_mapped = torch.full((L_mapped, 27), False)
#     atom_mask_mapped[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0]

    # # Diffuse the contig-mapped coordinates 
    # if self.diffuser_conf.partial_T:
    #     assert self.diffuser_conf.partial_T <= self.diffuser_conf.T, "Partial_T must be less than T"
    #     self.t_step_input = int(self.diffuser_conf.partial_T)
    # else:
    #     self.t_step_input = int(self.diffuser_conf.T)
    # t_list = np.arange(1, self.t_step_input+1)

    # #################################
    # ### Generate initial sequence ###
    # #################################

    # seq_t = torch.full((1,L_mapped), 21).squeeze() # 21 is the mask token
    # seq_t[contig_map.hal_idx0] = seq_orig[contig_map.ref_idx0]
    
    # # Unmask sequence if desired
    # if self._conf.contigmap.provide_seq is not None:
    #     seq_t[self.mask_seq.squeeze()] = seq_orig[self.mask_seq.squeeze()] 

    # seq_t[~self.mask_seq.squeeze()] = 21
    # seq_t    = torch.nn.functional.one_hot(seq_t, num_classes=22).float() # [L,22]
    # seq_orig = torch.nn.functional.one_hot(seq_orig, num_classes=22).float() # [L,22]

    # fa_stack, xyz_true = self.diffuser.diffuse_pose(
    #     xyz_mapped,
    #     torch.clone(seq_t),
    #     atom_mask_mapped.squeeze(),
    #     diffusion_mask=self.diffusion_mask.squeeze(),
    #     t_list=t_list)
    # xT = fa_stack[-1].squeeze()[:,:14,:]
    # xt = torch.clone(xT)

    # self.denoiser = self.construct_denoiser(len(self.contig_map.ref), visible=self.mask_seq.squeeze())

    # ######################
    # ### Apply Symmetry ###
    # ######################

    # if self.symmetry is not None:
    #     xt, seq_t = self.symmetry.apply_symmetry(xt, seq_t)
    # self._log.info(f'Sequence init: {seq2chars(torch.argmax(seq_t, dim=-1))}')
    
    # self.msa_prev = None
    # self.pair_prev = None
    # self.state_prev = None

# #########################################
# ### Parse ligand for ligand potential ###
# #########################################

# if self.potential_conf.guiding_potentials is not None:
#     if any(list(filter(lambda x: "substrate_contacts" in x, self.potential_conf.guiding_potentials))):
#         assert len(self.target_feats['xyz_het']) > 0, "If you're using the Substrate Contact potential, \
#                 you need to make sure there's a ligand in the input_pdb file!"
#         het_names = np.array([i['name'].strip() for i in self.target_feats['info_het']])
#         xyz_het = self.target_feats['xyz_het'][het_names == self._conf.potentials.substrate]
#         xyz_het = torch.from_numpy(xyz_het)
#         assert xyz_het.shape[0] > 0, f'expected >0 heteroatoms from ligand with name {self._conf.potentials.substrate}'
#         xyz_motif_prealign = xyz_motif_prealign[0,0][self.diffusion_mask.squeeze()]
#         motif_prealign_com = xyz_motif_prealign[:,1].mean(dim=0)
#         xyz_het_com = xyz_het.mean(dim=0)
#         for pot in self.potential_manager.potentials_to_apply:
#             pot.motif_substrate_atoms = xyz_het
#             pot.diffusion_mask = self.diffusion_mask.squeeze()
#             pot.xyz_motif = xyz_motif_prealign
#             pot.diffuser = self.diffuser