In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import sys
sys.path.append('/home/jupyter-yehlin/Pairformer/boltzdesign')

from boltzdesign_utils import *
from ligandmpnn_utils import *



In [2]:
## ccd_library 
import pickle
import rdkit
filename = '/home/jupyter-yehlin/.boltz/ccd.pkl'
ccd_lib = pickle.load(open(filename, 'rb'))

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
predict_args = {
    "recycling_steps": 0,  # Default value
    "sampling_steps": 200,  # Default value
    "diffusion_samples": 1,  # Default value
    "write_confidence_summary": True,
    "write_full_pae": False,
    "write_full_pde": False,
}

boltz_model = get_boltz_model('/home/jupyter-yehlin/.boltz/boltz1_conf.ckpt',predict_args,device)
boltz_model.train()

## Run BoltzDesign

In [4]:
##default config
config = {
    'mutation_rate': 1,
    'pre_iteration': 30,
    'soft_iteration': 75, 
    'temp_iteration': 45,
    'hard_iteration': 5,
    'semi_greedy_steps': 0,
    'learning_rate_pre': 0.2,
    'learning_rate': 0.1,
    'design_algorithm': '3stages',
    'set_train': True,
    'use_temp': True,
    'disconnect_feats': True,
    'disconnect_pairformer': False,
    'length': 150,
    'distogram_only': True,
    'binder_chain': 'A', ## A or B
    'small_molecule': True,
    'mask_ligand': False,
    'optimize_per_contact_per_binder_pos': True
}


In [5]:
yaml_dir = '/home/jupyter-yehlin/Pairformer/boltz/examples/rfdiffusion_small_molecule'
for yaml_path in os.listdir(yaml_dir):
    if yaml_path.endswith('.yaml'):
        print(yaml_path)

8vhp.yaml
6CZI.yaml
1HXD.yaml
3WC0.yaml


In [None]:
# Set up directories
main_dir = '/home/jupyter-yehlin/Pairformer/boltz/examples/rfdiffusion_small_molecule_designs'
os.makedirs(main_dir, exist_ok=True)

version_name = 'small_molecule'
design_samples = 5

loss_scales = {
    'con_loss': 1.0,
    'i_con_loss': 1.0, 
    'helix_loss': -0.2,
    'plddt_loss': 0.1,
    'pae_loss': 0.4,
    'i_pae_loss': 0.1,
    'rg_loss': 0.4,
}

run_boltz_design(
    main_dir=main_dir,
    yaml_dir=yaml_dir, 
    boltz_model=boltz_model,
    ccd_lib=ccd_lib,
    design_samples=design_samples,
    version_name=version_name,
    config=config, 
    loss_scales=loss_scales
)

## 2. LigandMPNN Redesign

In [2]:
boltzdesign_dir = main_dir + '/results_final'
pdb_save_dir = main_dir + '/pdb'
ligandmpnn_dir = main_dir + '/ligandmpnn'
ligandmpnn_config= '/home/jupyter-yehlin/Pairformer/LigandMPNN/run_ligandmpnn_logits_config.yaml'

os.makedirs(ligandmpnn_dir, exist_ok=True)
convert_cif_files_to_pdb(boltzdesign_dir, pdb_save_dir)
run_ligandmpnn_redesign(ligandmpnn_dir, pdb_save_dir, yaml_dir, ligandmpnn_config, top_k=1, cutoff=6, non_protein_ligand=True, binder_chain='A', target_chain='B')


In [32]:
import json
import glob
import os
import numpy as np

ligandmpnn_dir_boltz = ligandmpnn_dir + '/boltz_predictions_success_lmpnn'
for root in os.listdir(ligandmpnn_dir_boltz):
    root_path = os.path.join(ligandmpnn_dir_boltz, root, 'predictions')
    for subdir in os.listdir(root_path):
        json_path = os.path.join(root_path, subdir, f'confidence_{subdir}_model_0.json')
        pae_path = os.path.join(root_path, subdir, f'pae_{subdir}_model_0.npz')

        length = int(subdir[subdir.find('length')+6:subdir.find('_model')])
        
        if os.path.exists(json_path):
            with open(json_path, 'r') as f:
                data = json.load(f)
            design_name = json_path.split('/')[-2]
            try:
                pae_data = np.load(pae_path)
                pae_matrix = pae_data['pae']
                interface_pae = np.mean(((pae_matrix + pae_matrix.T)/2)[:length,length:])
                print(np.mean(pae_matrix[length:,:length]))
                print(np.mean(pae_matrix[:length,length:]))
                intra_pae = np.mean(pae_matrix[:length,:length])
                print(f"{design_name} length: {length} complex_plddt: {data['complex_plddt']:.2f} iptm: {data['iptm']:.2f} i-pae: {interface_pae:.2f} pae: {intra_pae:.2f}")
            except KeyError:
                print(f"{design_name} length: {length} complex_plddt: {data['complex_plddt']:.2f} iptm: {data['iptm']:.2f} PAE data not found")

20.980047
5.9006677
7v11_results_itr1_length150_model_0_1 length: 150 complex_plddt: 0.82 iptm: 0.61 i-pae: 13.44 pae: 2.93
17.219994
4.119095
5sdv_results_itr1_length150_model_0_1 length: 150 complex_plddt: 0.86 iptm: 0.80 i-pae: 10.67 pae: 3.13
19.964653
7.0432525
7bkc_results_itr1_length150_model_0_1 length: 150 complex_plddt: 0.76 iptm: 0.75 i-pae: 13.50 pae: 7.94
