In [None]:
import pandas as pd
import os

pred_method = "chai" ## folder containing predictions
folder_path = f"{pred_method}"
## file hierarchy should be like this
# chai
# ├── 7ar0_B_A_chai
# │   ├── xxx.pdb
# │   ├── ...
# ├── 7bnv_H_L_A_chai
# │   ├── xxx.pdb
# │   ├── ...

complex_list = [
    f for f in os.listdir(folder_path)
    if os.path.isdir(os.path.join(folder_path, f)) and not f.startswith(".")
]
print(complex_list)  # should print prediction folders ['7ar0_B_A_af2', '...', '...']

original_directory = "path/to/native_PDBs" ## Native PDB directory

## create results folder
result_path = "path/to/results_folder"
dockq_output = f"{pred_method}_dockq_fnat_scores.csv"  ## dockq output
pdockq_output = f"{pred_method}_pdockq2_fit.csv" ## pdockq2 output
combined = f"{pred_method}_combined.csv" ### combined results with dockq_pdock2 and model scores

In [None]:
### cif conversion
import pandas as pd
from Bio import PDB
import os
import os
from Bio import PDB

parser = PDB.MMCIFParser(QUIET=True)
writer = PDB.PDBIO()

for f in complex_list:
    fpath = os.path.join(folder_path, f)  # full dir path
    print("Scanning:", fpath)
    if not os.path.isdir(fpath):
        print("! Skipping, not a directory:", fpath)
        continue

    for file_name in os.listdir(fpath):
        if file_name.lower().endswith(".cif"):
            input_path = os.path.join(fpath, file_name)             
            output_path = os.path.join(
                fpath, os.path.splitext(file_name)[0] + ".pdb"
            )

            if not os.path.exists(input_path):
                print("! Not found:", input_path)
                continue

            try:
                structure = parser.get_structure(file_name, input_path)
                writer.set_structure(structure)
                writer.save(output_path)
                print(f"Converted {file_name} -> {output_path}")
            except Exception as e:
                print(f"! Failed on {input_path}: {e}")

In [None]:
## dckq and pdock2 Functions

### rename for unrelaxed
import os
for i in complex_list:
    print(i)
    fpath = f"{folder_path}/{i}/"
    print(fpath)
    prefix = str(i)+"_unrelaxed"  # prefix from af2
    for filename in os.listdir(fpath):
        if filename.endswith(".pdb"):
            old_path = os.path.join(fpath, filename)
            new_filename = prefix + filename
            new_path = os.path.join(fpath, new_filename)
    
            os.rename(old_path, new_path)
            print(f"Renamed: {filename} → {new_filename}")
    
    print("Renaming complete.")

In [None]:
## dockq Functions

import os
import sys
import csv
import glob

from DockQ.DockQ import load_PDB, run_on_all_native_interfaces
from statistics import mean

def merge_chains(model, chains_to_merge):
    print(f"Merging chains {chains_to_merge} in model")
    for chain in chains_to_merge[1:]:
        for res in list(model[chain]):
            res.id = (chains_to_merge[0], res.id[1], res.id[2])
            model[chains_to_merge[0]].add(res)
        model.detach_child(chain)
    model[chains_to_merge[0]].id = "".join(chains_to_merge)
    return model

def calculate_dockq(model, native, chain_map):
    print(f"Calculating DockQ score with chain_map: {chain_map}")
    results, dockq_score = run_on_all_native_interfaces(model, native, chain_map=chain_map)
    return results, dockq_score

def process_models(models):
    results_list = []
    for model_file, native_file in models:
        print(f"Processing model: {model_file}, native: {native_file}")
        model_id = os.path.basename(model_file).split(".")[0]
        model_id = os.path.basename(model_file)
        model = load_PDB(model_file)
        native = load_PDB(native_file)

        chain_ids = list(model.child_dict.keys())
        print(chain_ids)
        native_chain_ids = list(native.child_dict.keys())

        if len(chain_ids) == 3:
            print(f"Model {model_id} has 3 chains: {chain_ids}")
            
            # Merge A and B chains and recalculate
            model_merged = merge_chains(model, chain_ids[:2])
            print(model_merged)
            native_merged = merge_chains(native, native_chain_ids[:2])
            chain_map_merged = {native_chain_ids[2]: chain_ids[2], "".join(native_chain_ids[:2]): "".join(chain_ids[:2])}
            results_merged, dockq_score_merged = calculate_dockq(model_merged, native_merged, chain_map_merged)
            merged_result = results_merged[list(results_merged.keys())[0]]
            # results_list.append((model_id, merged_result['DockQ'], merged_result['DockQ_F1'], merged_result['fnat']))
            results_list.append((model_id, merged_result['DockQ'], merged_result['fnat'],
                                merged_result['iRMSD'], merged_result['LRMSD'],merged_result['F1']))

        elif len(chain_ids) == 2:
            print(f"Model {model_id} has 2 chains: {chain_ids}")
            # Assume there are only 2 chains
            chain_map = {native_chain_ids[0]: chain_ids[0], native_chain_ids[1]: chain_ids[1]}
            results, dockq_score = calculate_dockq(model, native, chain_map)
            results_list.append((model_id, results[list(results.keys())[0]]['DockQ'], results[list(results.keys())[0]]['fnat'],
                                results[list(results.keys())[0]]['iRMSD'], results[list(results.keys())[0]]['LRMSD'],
                                results[list(results.keys())[0]]['F1']))

        else:
            print(f"Model {model_id} does not have 2 or 3 chains: {chain_ids}, skipping.")
            continue

    return results_list

def save_results_to_csv(results, filename):
    print(f"Saving results to CSV file: {filename}")
    with open(filename, mode='w', newline='') as file:
        writer = csv.writer(file)
        # writer.writerow(['model_id', 'DockQ', 'DockQ_F1', 'fnat'])
        writer.writerow(['model_id', 'DockQ', 'fnat',"iRMSD","LRMSD","F1"])

        for model_id, dockq, fnat, irms, lrms, f1 in results:
            # print(f"Writing row: {model_id}, {dockq}, {dockq_f1}, {fnat}")
            # writer.writerow([model_id, dockq, dockq_f1, fnat])
            print(f"Writing row: {model_id}, {dockq}, {fnat}, {irms},{lrms},{f1}")
            writer.writerow([model_id, dockq, fnat, irms, lrms,f1])  

def main(directory, original_directory):
    pdb_files = glob.glob(f"{directory}/*_unrelaxed*.pdb")
    if not pdb_files:
        print(f"No PDB files found in directory: {directory}")
        return

    print(f"Found PDB files: {pdb_files}")
    models = []

    for pdb_file in pdb_files:
        pdb_id_chains = os.path.basename(directory)
        pdb_id_chains= directory.split("/")[-1].replace("_chai","")
        print(pdb_id_chains)
        # print(pdb_id_chains)
        # pdb_id_chains = "_".join(pdb_id_chains.split("_")[:-1])
        print(f"Searching for original files for: {pdb_id_chains}")
        original_pdb_files = glob.glob(f"{original_directory}/{pdb_id_chains}.pdb")  # Find matching original PDB files
        print(f"Found original PDB files: {original_pdb_files} for {pdb_file}")
        if not original_pdb_files:
            print(f"No matching original PDB file found for {pdb_file}, skipping.")
            continue
        for original_pdb_file in original_pdb_files:
            models.append((pdb_file, original_pdb_file))
            print(f"Adding model-native pair: {pdb_file}, {original_pdb_file}")

    if not models:
        print("No valid model-native pairs found. Exiting.")
        return
    os.chdir(result_path)
    results = process_models(models)
    save_results_to_csv(results, str(pdb_id_chains)+'_dockq_fnat_scores.csv')
    for model_id, dockq, fnat, irms, lrms, f1 in results:
            # print(f"Writing row: {model_id}, {dockq}, {dockq_f1}, {fnat}")
            # writer.writerow([model_id, dockq, dockq_f1, fnat])
        print(f"Writing row: {model_id}, {dockq}, {fnat}, {irms},{lrms},{f1}")

In [None]:
#### dockq and fnat calculation running
import pandas as pd
import glob
import os

for i in complex_list:
    fpath = f"{folder_path}/{i}"
    directory = os.path.expanduser(fpath)
    original_directory = os.path.expanduser(original_directory)
    main(directory, original_directory)

def combine_csv_files(result_path, output_file=None):
    # Find all CSV files in the folder
    csv_files = glob.glob(os.path.join(result_path, "*dockq_fnat_scores.csv"))
    
    # Read and combine all CSVs
    df_list = [pd.read_csv(file) for file in csv_files]
    combined_df = pd.concat(df_list, ignore_index=True)
    
    # Save to CSV if an output file is provided
    if output_file:
        combined_df.to_csv(output_file, index=False)
        print(f"Combined CSV saved to {output_file}")
    return combined_df
combined_df = combine_csv_files(result_path, dockq_output)

In [None]:
#pdockq2 funcs
from Bio.PDB import PDBIO
from Bio.PDB.PDBParser import PDBParser
from Bio.PDB.Selection import unfold_entities

import numpy as np
import sys,os
import argparse
import pickle
import itertools
import pandas as pd
from scipy.optimize import curve_fit

def retrieve_IFplddt(structure, chain1, chain2_lst, max_dist):
    ## generate a dict to save IF_res_id
    chain_lst = list(chain1) + chain2_lst

    ifplddt = []
    contact_chain_lst = []
    for res1 in structure[0][chain1]:
        for chain2 in chain2_lst:
            count = 0
            for res2 in structure[0][chain2]:

                if res1.has_id('CA') and res2.has_id('CA'):
                   dis = abs(res1['CA']-res2['CA'])
                   ## add criteria to filter out disorder res
                   if dis <= max_dist:
                      ifplddt.append(res1['CA'].get_bfactor())
                      count += 1

                elif res1.has_id('CB') and res2.has_id('CB'):
                   dis = abs(res1['CB']-res2['CB'])
                   if dis <= max_dist:
                      ifplddt.append(res1['CB'].get_bfactor())
                      count += 1
            if count > 0:
              contact_chain_lst.append(chain2)
    contact_chain_lst = sorted(list(set(contact_chain_lst)))   

    if len(ifplddt)>0:
       IF_plddt_avg = np.mean(ifplddt)
    else:
       IF_plddt_avg = 0
    # print(IF_plddt_avg)

    return IF_plddt_avg, contact_chain_lst


def retrieve_IFPAEinter(structure, paeMat, contact_lst, max_dist):
    ## contact_lst:the chain list that have an interface with each chain. For eg, a tetramer with A,B,C,D chains and A/B A/C B/D C/D interfaces,
    ##             contact_lst would be [['B','C'],['A','D'],['A','D'],['B','C']]

    chain_lst = [x.id for x in structure[0]]
    seqlen = [len(x) for x in structure[0]]
    ifch1_col=[]
    ifch2_col=[]
    ch1_lst=[]
    ch2_lst=[]
    ifpae_avg = []
    d=10
    for ch1_idx in range(len(chain_lst)):
      ## extract x axis range from the PAE matrix
      idx = chain_lst.index(chain_lst[ch1_idx])
      ch1_sta=sum(seqlen[:idx])
      ch1_end=ch1_sta+seqlen[idx]
      ifpae_col = []   
      ## for each chain that shares an interface with chain1, retrieve the PAE matrix for the specific part.
      for contact_ch in contact_lst[ch1_idx]:
        index = chain_lst.index(contact_ch)
        ch_sta = sum(seqlen[:index])
        ch_end = ch_sta+seqlen[index]
        paeMat = np.array(paeMat)
        remain_paeMatrix = paeMat[ch1_sta:ch1_end,ch_sta:ch_end]   

        ## get avg PAE values for the interfaces for chain 1
        mat_x = -1
        for res1 in structure[0][chain_lst[ch1_idx]]:
          mat_x += 1
          mat_y = -1
          for res2 in structure[0][contact_ch]:
              mat_y+=1
              if res1['CA'] - res2['CA'] <=max_dist:
                 ifpae_col.append(remain_paeMatrix[mat_x,mat_y])
      ## normalize by d(10A) first and then get the average
      if not ifpae_col:
        ifpae_avg.append(0)
      else:
        norm_if_interpae=np.mean(1/(1+(np.array(ifpae_col)/d)**2))
        ifpae_avg.append(norm_if_interpae)
        # print(ifpae_avg)
    return ifpae_avg

def calc_pmidockq(ifpae_norm, ifplddt):
    df = pd.DataFrame()
    df['ifpae_norm'] = ifpae_norm
    df['ifplddt'] = ifplddt

    df['prot'] = df.ifpae_norm*df.ifplddt
    fitpopt = [1.31034849e+00, 8.47326239e+01, 7.47157696e-02, 5.01886443e-03] ## from orignal pdcokq2 fit  
    df['pmidockq'] = sigmoid(df.prot.values, *fitpopt)
    return df

def sigmoid(x, L ,x0, k, b):
    y = L / (1 + np.exp(-k*(x-x0)))+b
    return (y)

def process_pdb_file(pdb_file, json_file, distance, file_id, chains_part="", pae=None):
    
    # Parse the PDB file
    pdbp = PDBParser(QUIET=True)
    structure = pdbp.get_structure('', pdb_file)
    chains = [chain.id for chain in structure[0]]

    remain_contact_lst = []
    plddt_lst = []

    # Assuming these functions are defined in the same file or imported
    for idx in range(len(chains)):
        chain2_lst = list(set(chains)-set(chains[idx]))
        IF_plddt, contact_lst = retrieve_IFplddt(structure, chains[idx], chain2_lst, distance)
        plddt_lst.append(IF_plddt)
        remain_contact_lst.append(contact_lst)
    
    # Use the passed 'pae' directly, no need to read the JSON file again
    avgif_pae = retrieve_IFPAEinter(structure, pae, remain_contact_lst, distance)

    # Calculate the results using the processed data
    res = calc_pmidockq(avgif_pae, plddt_lst)

    # Extract the PDB ID from the filename
    pdb_id = os.path.basename(pdb_file).split('_')[0]
    
    # Prepare the result dictionary
    result = {
        "model_id" : pdb_file,
        "pdb_id": file_id,
        "pdb_id_with_chains": '{0}_{1}'.format(pdb_id, chains_part),
        "ipae_norm_ag": res['ifpae_norm'].tolist()[-1],
        "ipae_norm_avg": np.mean(res['ifpae_norm']), 
        "iplddt_ag": res['ifplddt'].tolist()[-1],
        "iplddt_avg": np.mean(res['ifplddt']), 
        "pDockQ2_ag": res['pmidockq'].tolist()[-1],
        "pDockQ2_avg": np.mean(res['pmidockq'])}
    return result

In [None]:
import os
import glob
import json
import pandas as pd

def find_matching_json(pdb_file, directory):
    pdb_file_basename = os.path.basename(pdb_file)
    
    # Use a regular expression to extract the model index from the filename
    import re
    match = re.search(r'model_idx_(\d+)', pdb_file_basename)
    
    if match:
        model_idx = match.group(1)  # Extracted model index
        print(f"Extracted model index: {model_idx}")
    else:
        print(f"Could not extract model index from: {pdb_file_basename}")
        return None
    
    # Extract pdb_id and chains_part from the PDB file name
    parts = pdb_file_basename.split('_')
    pdb_id = parts[0]
    chains_part = "_".join(parts[1:parts.index("chai")])  # Use "chai" as identifier
    
    json_pattern = f"scores.model_idx_{model_idx}.json"  
    print(f"Looking for JSON file with pattern: {json_pattern}")

    json_files = glob.glob(os.path.join(directory, json_pattern))
    
    print(f"Found JSON files: {json_files}")
    
    return json_files[0] if json_files else None 
    
def run_processing(directory):
    pdb_files = glob.glob(os.path.join(directory, "*_unrelaxed*.pdb"))
    print(f"Found PDB files: {pdb_files}")  
    
    results = []

    for pdb_file in pdb_files:
        pdb_file_basename = os.path.basename(pdb_file)
        parts = pdb_file_basename.split('_')
        pdb_id = parts[0]
        chains_part = "_".join(parts[1:parts.index("chai")])  

        # Find the matching JSON file
        json_file_name = find_matching_json(pdb_file, directory)
        
        if json_file_name and os.path.exists(json_file_name):
  
            with open(json_file_name, 'r') as json_file:
                json_data = json.load(json_file)
                pae = json_data.get('pae', None)  
            
            if pae is not None:
                file_id = f"{pdb_id}_{chains_part}"
                result = process_pdb_file(pdb_file, json_file_name, 8, pdb_id, chains_part, pae)  
                print(f"Result for {pdb_file}: {result}")
                results.append(result)
            else:
                print(f"PAE not found in JSON file for {pdb_file}.")
        else:
            print(f"JSON file for {pdb_file} not found.")

    if results:
        df = pd.DataFrame(results)
        csv_file_path = f"{file_id}_pdockq2_fit.csv"  
        df.to_csv(csv_file_path, index=False)
        print(f"Data saved to {csv_file_path}")
    else:
        print("No results to save.")

### run for all predictions
for i in complex_list:
    directory = f"{folder_path}/{i}"
    run_processing(directory)   

def combine_csv_files(result_path, output_file=None):
    csv_files = glob.glob(os.path.join(result_path, "*_pdockq2_fit.csv"))
    df_list = [pd.read_csv(file) for file in csv_files]
    combined_df = pd.concat(df_list, ignore_index=True)
    if output_file:
        combined_df.to_csv(output_file, index=False)
        print(f"Combined CSV saved to {output_file}")
    return combined_df
    
combined_df = combine_csv_files(result_path, pdockq_output)

In [None]:
import os
import pandas as pd

x = []

for i in complex_list:
    os.chdir(f"{folder_path}/{i}")    
    df = pd.read_csv(f"{folder_path}/{i}/metrics.csv")    
    df['model_id'] = df['filename'].str.replace("scores.",f"{f}_").str.replace("_idx","")
    df['model_id'] = df['model_id'].str.replace(".json","")
    df = df[["model_id",'aggregate_score', 'ptm', 'iptm']]    
    x.append(df)
  
dfy = pd.concat(x, axis=0)
# print(dfy.head())

df = pd.read_csv(f"{result_path}/{pdockq_output}")
df['model_id']=df['model_id'].str.replace(".pdb","").str.replace("unrelaxedpred.","").str.replace("_idx","")
df['model_id']=df['model_id'].str.split("/").str[-1]
# print(df.head())

df2 = pd.read_csv(f"{result_path}/{dockq_output}")
df2['model_id']=df2['model_id'].str.replace(".pdb","").str.replace("unrelaxedpred.","").str.replace("_idx","")
print(df2.head())
dfx = pd.merge(df,df2,on="model_id")

df_fin = pd.merge(dfx,dfy,on="model_id")
df_fin=df_fin.drop_duplicates()

# AntiConf = 0.3pDockQ2_ag + 0.7pTM
df_fin["AntiConf"] =(0.3 * df_fin["pDockQ2_ag"]) + (0.7 * df_fin["ptm"])
df_fin.to_csv(f"{result_path}/{combined}",index=False)