In [2]:
import ampal
from ampal import Polypeptide
import pandas as pd
import numpy as np
# show more rows in pandas
pd.set_option("display.max_rows", 200)
import os
import copy
import pickle

import yaml

import isambard

In [3]:
import isambard.modelling


fixed_backbone = ampal.load_pdb("/home/tadas/code/dreamer/outputs/5u4j/5u4j_fixed_backbone_0.pdb")
sequence = "MFEINPVNNRIQDLTERSDVLRGYLDYDAKKERLEEVNAELEQPDVWNEPERAQALGKERSSLEAVVDTLDQMKQGLEDVSGLLELAVEADDEETFNEAVAELDALEEKLAQLEFRRMFSGEYDSADCYLDIQAGSGGTEAQDWASMLERMYLRWAESRGFKTEIIEESEGEVAGIKSVTIKISGDYAYGWLRTETGVHRLVRKSPFDSGGRRHTSFSSAFVYPEVDDDIDIEINPADLRIDVYRTSGAGGQHVNRTESAVRITHIPTGIVTQCQNDRSQHKNKDQAMKQMKAKLYELEMQKKNAEKQAMEDNKSDIGWGSQIRSYVLDDSRIKDLRTGVETRNTQAVLDGSLDQFIEASLKAGL"
packed_fixed_sctructure=isambard.modelling.scwrl.pack_side_chains_scwrl(fixed_backbone, [sequence])


<Assembly (5u4j_fixed_backbone_0) containing 1 Polypeptide>

In [7]:
x[0][0]

<Residue containing 8 Atoms. Residue code: MET>

In [9]:
fixed_backbone.pdb

'HEADER ISAMBARD Model 5u4j_fixed_backbone_0                                            \nATOM      1  N   MET B   1      37.723  -7.816  -8.013  1.00  1.00              \nATOM      2  CA  MET B   1      36.267  -7.745  -8.044  1.00  1.00              \nATOM      3  C   MET B   1      35.789  -6.680  -9.022  1.00  1.00              \nATOM      4  O   MET B   1      36.236  -5.534  -8.974  1.00  1.00              \nATOM      5  N   PHE B   2      34.967  -7.067  -9.816  1.00  1.00              \nATOM      6  CA  PHE B   2      34.417  -6.166 -10.821  1.00  1.00              \nATOM      7  C   PHE B   2      33.012  -5.712 -10.445  1.00  1.00              \nATOM      8  O   PHE B   2      32.091  -6.523 -10.352  1.00  1.00              \nATOM      9  N   GLU B   3      32.847  -4.526 -10.329  1.00  1.00              \nATOM     10  CA  GLU B   3      31.572  -3.974  -9.886  1.00  1.00              \nATOM     11  C   GLU B   3      30.908  -3.165 -10.993  1.00  1.00              \nATOM    

In [193]:
# user may have any size assembly of polypeptides, ligands, DNA, RNA
# I want to be able to fix all polypeptides in the assembly
# so I need to loop over all polypeptides in the assembly and check if they need fixing. Then in context of assembly fix missing structure. 

# this is hard to do because will likely need different parameters for RFdiffusion, perhaps symmetry, different models, etc.
# easier task is to do if only one chain is present


def check_eligibility(assembly):
    polypeptides = [i for i in assembly if isinstance(i, Polypeptide)]
    assert len(polypeptides) == 1, "Expected only one polypeptide in the assembly. This feature only works for single-chain assemblies."


# def chunk_sequential_numbers(numbers):
#     if not numbers:
#         return []
#     chunks = []
#     current_chunk = [numbers[0]]
#     for i in range(1, len(numbers)):
#         if numbers[i] == numbers[i - 1] + 1:
#             current_chunk.append(numbers[i])
#         else:
#             chunks.append(current_chunk)
#             current_chunk = [numbers[i]]
#     chunks.append(current_chunk)
#     return chunks

def get_residue_ids_chunks(polypeptide):
    monomers = list(polypeptide.get_monomers())
    residue_ids = [int(residue.id) for residue in monomers]
    residue_ids_chunks = chunk_sequential_numbers(residue_ids)
    return residue_ids_chunks

def structure_has_issues(assembly):
    missing_beginning = False
    structure_is_broken = False
    polypeptide = [i for i in assembly if isinstance(i, Polypeptide)][0]
    residues = list(polypeptide.get_monomers())
    residue_ids = [int(residue.id) for residue in residues]
    if residue_ids[0] != 1:
        missing_beginning = True

    residue_chunks = get_residue_ids_chunks(polypeptide)
    gap_number = len(residue_chunks) - 1
    if gap_number > 0:
        structure_is_broken = True
    # report 
    message = ""
    if missing_beginning:
        message += "Missing chain start. "
    if structure_is_broken:
        message += "Structure has breaks. "
    if message:
        print(f"Structure check outcome: {message}")
    
    structure_has_issues = missing_beginning or structure_is_broken
    return structure_has_issues

def find_full_seq_alignment_indices(full_sequence, chain_chunks):
    alignment = []
    start = 0

    for chunk in chain_chunks:
        index = full_sequence.find(chunk, start)
        if index == -1:
            return None  # If a chunk is not found, return None indicating no perfect alignment
        alignment.append(index)
        start = index + len(chunk)

    return alignment

def get_missing_chains(full_sequence, missing_residue_chunks, alignment_indices):

    missing_chains=[]
    for indice, missing_residue_chunk in zip(alignment_indices, missing_residue_chunks):
        start=indice-len(missing_residue_chunk)
        end = indice
        missing_chain=full_sequence[start:end]
        missing_chains.append(missing_chain)
    
    missing_chains.append(full_sequence[end+len(missing_residue_chunk):])

    return missing_chains


def get_chain_chunks(polypeptide, residue_ids_chunks):
    start=0
    stop=0
    chain_chunks = []
    for chunk in residue_ids_chunks:
        stop+=len(chunk)
        chain_chunk = polypeptide.sequence[start:stop]
        start+=len(chunk)
        chain_chunks.append(chain_chunk)
    assert ("".join(chain_chunks)==polypeptide.sequence)
    return chain_chunks

def find_missing_chunks(chunks, start, end):
    missing_chunks = []

    # Add missing chunks from start to the start of the first chunk if needed
    if chunks[0][0] > start:
        missing_chunks.append(list(range(start, chunks[0][0])))

    for i in range(1, len(chunks)):
        previous_end = chunks[i-1][-1]
        current_start = chunks[i][0]
        if previous_end + 1 < current_start:
            missing_chunks.append(list(range(previous_end + 1, current_start)))
    
    # Add missing chunks from the end of the last chunk to end if needed
    if chunks[-1][-1] < end:
        missing_chunks.append(list(range(chunks[-1][-1] + 1, end + 1)))
    
    return missing_chunks







In [202]:
# def find_best_alignment_extend(df, full_sequence):
#     seq_length = len(full_sequence)
#     max_res_id = df.index.max()

#     best_alignment = None
#     best_start_pos = None
#     best_match_count = 0
#     total_non_na = df['mol_letter'].notna().sum()
    
#     # Allow extension from -seq_length to max_res_id + seq_length
#     for start_pos in range(-seq_length, max_res_id + 1):
#         extended_index = range(min(start_pos, 1), max(start_pos + seq_length, max_res_id + 1))
#         temp_df = pd.DataFrame(index=extended_index).join(df)
#         temp_df['complete_mol_letter'] = np.nan
        
#         match_count = 0
#         for i, letter in enumerate(full_sequence):
#             pos = start_pos + i
#             if pos in temp_df.index:
#                 if pd.notna(temp_df.at[pos, 'mol_letter']) and temp_df.at[pos, 'mol_letter'] != letter:
#                     break
#                 temp_df.at[pos, 'complete_mol_letter'] = letter
#                 if pd.notna(temp_df.at[pos, 'mol_letter']):
#                     match_count += 1
#         else:
#             if match_count > best_match_count:
#                 best_alignment = temp_df
#                 best_start_pos = start_pos
#                 best_match_count = match_count

#     if best_alignment is None or best_match_count == 0:
#         raise ValueError("No ideal match found between the provided sequence and the fragmented sequence with known gaps.")
#     elif best_match_count == total_non_na:
#         print("All fragmented residues matched perfectly to input sequence.")
#     else:
#         print(f"Best alignment found at start position: {best_start_pos}, matching {best_match_count} out of {total_non_na} known positions.")
    
#     return best_alignment, best_start_pos


# def clip_residue_ids(residue_df,first,last):
#     return residue_df.loc[first:last]


# def get_contigmap_contigs(residue_df:pd.DataFrame, chain_id:str):
#     assert "complete_mol_letter" in residue_df.columns, "residue_df must contain a 'complete_mol_letter' column."
#     # find first index at which complete_mol_letter is not na
#     start_index = residue_df['complete_mol_letter'].first_valid_index()
#     end_index = residue_df['complete_mol_letter'].last_valid_index()
#     contigs_string = f"{chain_id}{start_index}-{end_index}"
#     return contigs_string#f'contigmap.contigs=[{contigs_string}]'
    
# def get_contigmap_inpaint_str(residue_df:pd.DataFrame, chain_id:str):
#     assert "complete_mol_letter" in residue_df.columns and "mol_letter" in residue_df.columns, "residue_df must contain a 'complete_mol_letter' and 'mol_letter' column."
#     # find all row index numbers where mol_letter is na and complete_mol_letter is not na
#     inpaint_indices = residue_df[residue_df['mol_letter'].isna() & residue_df['complete_mol_letter'].notna()].index
#     inpaint_chunks = chunk_sequential_numbers(list(inpaint_indices))
#     inpaint_str = f""
#     for inpaint_chunk in inpaint_chunks:
#         chunk_inpaint_str = f"/{chain_id}{inpaint_chunk[0]}-{inpaint_chunk[-1]}"
#         inpaint_str += chunk_inpaint_str
#     inpaint_str = inpaint_str.strip("/")
#     return inpaint_str # f'contigmap.inpaint_str=[{inpaint_str}]'

# def generate_rfdiffusion_command(
#     residue_df,
#     input_path,
#     output_prefix,
#     rf_diffusion_install_dir,
#     num_designs=1,
#     deterministic=True,
# ):
#     contigmap_contigs_string = get_contigmap_contigs(residue_df, polypeptide.id)
#     contigmap_inpaint_str_string = get_contigmap_inpaint_str(residue_df, polypeptide.id)

#     script = (
#         f"python '{os.path.join(rf_diffusion_install_dir,'scripts', 'run_inference.py')}' "
#         f"inference.output_prefix='{output_prefix}' "
#         f"inference.input_pdb='{input_path}' "
#         f"'contigmap.contigs=[{contigmap_contigs_string}]' "
#         f"'contigmap.inpaint_str=[{contigmap_inpaint_str_string}]' "
#         f"inference.num_designs={int(num_designs)} "
#         + ("inference.deterministic=True " if deterministic else "")
#     )

#     return script

# def create_and_save_residue_presets(monomers):
#     residue_dict = {}
#     keys = set([x.mol_letter for x in monomers])
#     for key in keys:
#         residue_dict[key] = copy.deepcopy([x for x in monomers if x.mol_letter == key][0])

#     residue_dict
#     # save residue_dict as pickle
#     import pickle
#     with open('/home/tadas/code/dreamer/data/ampal_residue_presets', 'wb') as handle:
#         pickle.dump(residue_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

# def load_residue_presets():
#     with open('/home/tadas/code/dreamer/data/ampal_residue_presets', 'rb') as handle:
#         residue_dict = pickle.load(handle)
#     return residue_dict

# def save_commplete_chain_pdb(residue_df, chain_id, outfile_path):
#     residue_presets = load_residue_presets()
#     complete_residues = []
#     for index, row in residue_df.iterrows():
#         if pd.isna(row['mol_letter']):
#             # residue = Residue(atoms = list(residue_dict[row['complete_mol_letter']].get_atoms()),mol_code=row['complete_mol_letter'], monomer_id=index, )
#             residue = copy.deepcopy(residue_presets[row["complete_mol_letter"]])
#             residue.id = int(index)

#         else:
#             # find real monomer with the index
#             # print(row)
#             residue = [mon for mon in monomers if int(mon.id) == index][0]
#         complete_residues.append(residue)

#     complete_polypeptide = Polypeptide(complete_residues,polymer_id=chain_id)
#     with open(outfile_path, "w") as f:
#         f.write(complete_polypeptide.pdb)


In [204]:
input_pdb_path = "/home/tadas/code/dreamer/inputs/example.pdb"
assembly = ampal.load_pdb(input_pdb_path)
polypeptide = assembly[0]
monomers=list(polypeptide.get_monomers())

mol_letters = [residue.mol_letter for residue in monomers]
res_ids = [int(residue.id) for residue in monomers]
# make a dataframe with sequence and res_ids where res_ids will be the index
residue_df = pd.DataFrame({"mol_letter": mol_letters},index=res_ids)
max_res_id = max(res_ids)


# Create a full index from 1 to max_res_id.
full_index = range(1, max_res_id + 1)
residue_df = residue_df.reindex(full_index)

full_sequence = "HHHHHHSAALEVLFQGPGMFEINPVNNRIQDLTERSDVLRGYLDYDAKKERLEEVNAELEQPDVWNEPERAQALGKERSSLEAVVDTLDQMKQGLEDVSGLLELAVEADDEETFNEAVAELDALEEKLAQLEFRRMFSGEYDSADCYLDIQAGSGGTEAQDWASMLERMYLRWAESRGFKTEIIEESEGEVAGIKSVTIKISGDYAYGWLRTETGVHRLVRKSPFDSGGRRHTSFSSAFVYPEVDDDIDIEINPADLRIDVYRTSGAGGQHVNRTESAVRITHIPTGIVTQCQNDRSQHKNKDQAMKQMKAKLYELEMQKKNAEKQAMEDNKSDIGWGSQIRSYVLDDSRIKDLRTGVETRNTQAVLDGSLDQFIEASLKAGL"
first = 1
last = max_res_id + 2

residue_df, start_pos = find_best_alignment_extend(residue_df, full_sequence)
residue_df = clip_residue_ids(residue_df,first,last)


all_res_pdb_path= "/home/tadas/code/dreamer/inputs/example_complete.pdb"
save_commplete_chain_pdb(residue_df, polypeptide.id, all_res_pdb_path)

output_dir = "/home/tadas/code/dreamer/outputs"
rf_diffusion_install_dir = "/home/tadas/code/RFdiffusion"
output_prefix = f"{output_dir}/{os.path.splitext(os.path.basename(all_res_pdb_path))[0]}"

generate_rfdiffusion_command(
    residue_df,
    all_res_pdb_path,
    output_prefix,
    rf_diffusion_install_dir,
    num_designs=1,
    deterministic=True,
)


All fragmented residues matched perfectly to input sequence.


"python '/home/tadas/code/RFdiffusion/scripts/run_inference.py' inference.output_prefix='/home/tadas/code/dreamer/outputs/example_complete' inference.input_pdb='/home/tadas/code/dreamer/inputs/example_complete.pdb' 'contigmap.contigs=[v1-329]' 'contigmap.inpaint_str=[v1-126/v227-295/v328-329]' inference.num_designs=1 inference.deterministic=True "

In [2]:
assembly = ampal.load_pdb("/home/tadas/code/dreamer/inputs/example.pdb")

check_eligibility(assembly)

if not structure_has_issues(assembly):
    print("Structure looks good.")

print("structure looks bad. Provide a protein sequence to fix it.")

# prompt user for sequence 
full_sequence = "HHHHHHSAALEVLFQGPGMFEINPVNNRIQDLTERSDVLRGYLDYDAKKERLEEVNAELEQPDVWNEPERAQALGKERSSLEAVVDTLDQMKQGLEDVSGLLELAVEADDEETFNEAVAELDALEEKLAQLEFRRMFSGEYDSADCYLDIQAGSGGTEAQDWASMLERMYLRWAESRGFKTEIIEESEGEVAGIKSVTIKISGDYAYGWLRTETGVHRLVRKSPFDSGGRRHTSFSSAFVYPEVDDDIDIEINPADLRIDVYRTSGAGGQHVNRTESAVRITHIPTGIVTQCQNDRSQHKNKDQAMKQMKAKLYELEMQKKNAEKQAMEDNKSDIGWGSQIRSYVLDDSRIKDLRTGVETRNTQAVLDGSLDQFIEASLKAGL"


NameError: name 'check_eligibility' is not defined

In [179]:
polypeptide = assembly[0]


residue_ids_chunks = get_residue_ids_chunks(polypeptide)


print("Residue id chunks",[(i[0], i[-1]) for i in residue_ids_chunks])
missing_residue_ids_chunks = find_missing_chunks(residue_ids_chunks, start=0, end=len(full_sequence))
print("Missing residue id chunks",[(i[0], i[-1]) for i in missing_residue_ids_chunks])

chain_chunks = get_chain_chunks(polypeptide, residue_ids_chunks)

display(chain_chunks)

alignment_indices = find_full_seq_alignment_indices(full_sequence, chain_chunks)

missing_chains = get_missing_chains(full_sequence, missing_residue_ids_chunks, alignment_indices)

assert len(missing_chains) == len(residue_ids_chunks)+1

display(missing_chains)

Residue id chunks [(127, 226), (296, 327)]
Missing residue id chunks [(0, 126), (227, 295), (328, 383)]


['DCYLDIQAGSGGTEAQDWASMLERMYLRWAESRGFKTEIIEESEGEVAGIKSVTIKISGDYAYGWLRTETGVHRLVRKSPFDSGGRRHTSFSSAFVYPEV',
 'YELEMQKKNAEKQAMEDNKSDIGWGSQIRSYV']

['GMFEINPVNNRIQDLTERSDVLRGYLDYDAKKERLEEVNAELEQPDVWNEPERAQALGKERSSLEAVVDTLDQMKQGLEDVSGLLELAVEADDEETFNEAVAELDALEEKLAQLEFRRMFSGEYDSA',
 'DDDIDIEINPADLRIDVYRTSGAGGQHVNRTESAVRITHIPTGIVTQCQNDRSQHKNKDQAMKQMKAKL',
 'L']

[[3, 4, 5], [7, 8], [11]]


In [164]:
outcome_chain = ""
for i in range(len(missing_chains)-1):
    outcome_chain = outcome_chain + missing_chains[i]
    outcome_chain = outcome_chain+ chain_chunks[i]

outcome_chain = outcome_chain + missing_chains[-1]

assert outcome_chain in full_sequence


AssertionError: 

In [156]:
chain_chunks

['DCYLDIQAGSGGTEAQDWASMLERMYLRWAESRGFKTEIIEESEGEVAGIKSVTIKISGDYAYGWLRTETGVHRLVRKSPFDSGGRRHTSFSSAFVYPEV',
 'YELEMQKKNAEKQAMEDNKSDIGWGSQIRSYV']

In [157]:
outcome_chain

'YDAKKERLEEVNAELEQPDVWNEPERAQALGKERSSLEAVVDTLDQMKQGLEDVSGLLELAVEADDEETFNEAVAELDALEEKLAQLEFRRMFSGEYDSADCYLDIQAGSGGTEAQDWASMLERMYLRWAESRGFKTEIIEESEGEVAGIKSVTIKISGDYAYGWLRTETGVHRLVRKSPFDSGGRRHTSFSSAFVYPEVTHIPTGIVTQCQNDRSQHKNKDQAMKQMKAKLYELEMQKKNAEKQAMEDNKSDIGWGSQIRSYVLDDSRIKDLRTGVETRNTQAVLDGSLDQFIEASLKAGL'

In [158]:
full_sequence

'HHHHHHSAALEVLFQGPGMFEINPVNNRIQDLTERSDVLRGYLDYDAKKERLEEVNAELEQPDVWNEPERAQALGKERSSLEAVVDTLDQMKQGLEDVSGLLELAVEADDEETFNEAVAELDALEEKLAQLEFRRMFSGEYDSADCYLDIQAGSGGTEAQDWASMLERMYLRWAESRGFKTEIIEESEGEVAGIKSVTIKISGDYAYGWLRTETGVHRLVRKSPFDSGGRRHTSFSSAFVYPEVDDDIDIEINPADLRIDVYRTSGAGGQHVNRTESAVRITHIPTGIVTQCQNDRSQHKNKDQAMKQMKAKLYELEMQKKNAEKQAMEDNKSDIGWGSQIRSYVLDDSRIKDLRTGVETRNTQAVLDGSLDQFIEASLKAGL'

In [None]:
                                            YDAKKERLEEVNAELEQPDVWNEPERAQALGKERSSLEAVVDTLDQMKQGLEDVSGLLELAVEADDEETFNEAVAELDALEEKLAQLEFRRMFSGEYDSADCYLDIQAGSGGTEAQDWASMLERMYLRWAESRGFKTEIIEESEGEVAGIKSVTIKISGDYAYGWLRTETGVHRLVRKSPFDSGGRRHTSFSSAFVYPEVTHIPTGIVTQCQNDRSQHKNKDQAMKQMKAKLYELEMQKKNAEKQAMEDNKSDIGWGSQIRSYVLDDSRIKDLRTGVETRNTQAVLDGSLDQFIEASLKAGL
HHHHHHSAALEVLFQGPGMFEINPVNNRIQDLTERSDVLRGYLDYDAKKERLEEVNAELEQPDVWNEPERAQALGKERSSLEAVVDTLDQMKQGLEDVSGLLELAVEADDEETFNEAVAELDALEEKLAQLEFRRMFSGEYDSADCYLDIQAGSGGTEAQDWASMLERMYLRWAESRGFKTEIIEESEGEVAGIKSVTIKISGDYAYGWLRTETGVHRLVRKSPFDSGGRRHTSFSSAFVYPEVDDDIDIEINPADLRIDVYRTSGAGGQHVNRTESAVRITHIPTGIVTQCQNDRSQHKNKDQAMKQMKAKLYELEMQKKNAEKQAMEDNKSDIGWGSQIRSYVLDDSRIKDLRTGVETRNTQAVLDGSLDQFIEASLKAGL