In [None]:
import pandas as pd
import os

pred_method = "protenix" ## folder containing predictions
folder_path = f"{pred_method}"
## file hierarchy should be like this
# chai
# ├── 7ar0_B_A_px
# │   ├── 7ar0_B_A_px
# │       ├── seed1
# │           ├── predictions
# │               ├── .pdb
# │               ├── ...
# ├── 7bnv_H_L_A_px
# │   ├── 7bnv_H_L_A_px
# │       ├── seed1
# │           ├── predictions
# │               ├── .pdb
# │               ├── ...

complex_list = [
    f for f in os.listdir(folder_path)
    if os.path.isdir(os.path.join(folder_path, f)) and not f.startswith(".")
]
print(complex_list)  # should print prediction folders ['7ar0_B_A_af2', '...', '...']

original_directory = "path/to/native_PDBs" ## Native PDB directory

## create results folder
result_path = "path/to/results_folder"
dockq_output = f"{pred_method}_dockq_fnat_scores.csv"  ## dockq output
pdockq_output = f"{pred_method}_pdockq2_fit.csv" ## pdockq2 output
combined = f"{pred_method}_combined.csv" ### combined results with dockq_pdock2 and model scores

In [None]:
### cif conversion
import pandas as pd
from Bio import PDB
import os

# Custom MMCIFParser to handle missing B-factors, occupancy, and renaming chains
class CustomMMCIFParser(PDB.MMCIFParser):
    def _build_structure(self, structure_id):
        mmcif_dict = self._mmcif_dict
        try:
            # If B-factors are missing, create a default list of zeros
            b_factor_list = mmcif_dict["_atom_site.B_iso_or_equiv"]
        except KeyError:
            atom_count = len(mmcif_dict["_atom_site.group_PDB"])
            b_factor_list = ["0.0"] * atom_count
            mmcif_dict["_atom_site.B_iso_or_equiv"] = b_factor_list
        
        try:
            # If occupancy is missing, create a default list of zeros
            occupancy_list = mmcif_dict["_atom_site.occupancy"]
        except KeyError:
            atom_count = len(mmcif_dict["_atom_site.group_PDB"])
            occupancy_list = ["0.0"] * atom_count
            mmcif_dict["_atom_site.occupancy"] = occupancy_list
        
        super()._build_structure(structure_id)

# Function to rename chain IDs
def rename_chains(structure):
    for model in structure:
        for chain in model:
            chain.id = chain.id.replace('0', '')  # Replace '0' in chain ID (e.g., A0 → A)
    return structure

for f in complex_list:
    wdir = f'{folder_path}/{f}/{f}/seed_1/predictions/'
    
    # Initialize the custom parser and writer
    parser = CustomMMCIFParser(QUIET=True)  #  reading .cif files
    writer = PDB.PDBIO()                   #  writing .pdb files
    
    # Loop through all .cif files in the directory
    for file_name in os.listdir(wdir):
        if file_name.endswith(".cif"):
            input_path = os.path.join(wdir, file_name)
            output_path = os.path.join(wdir, file_name.replace(".cif", ".pdb"))
            
            try:
                # Parse the .cif file
                structure = parser.get_structure(file_name, input_path)
                
                # Rename chain IDs
                structure = rename_chains(structure)
                
                # Write to .pdb format
                writer.set_structure(structure)
                writer.save(output_path)
                print(f"Converted {file_name} to {output_path}")
            except Exception as e:
                print(f"Error processing {file_name}: {e}")
    
    print(f"Conversion of {f} is complete.")

In [None]:
import os
import sys
import csv
import glob

from DockQ.DockQ import load_PDB, run_on_all_native_interfaces
from statistics import mean

def merge_chains(model, chains_to_merge):
    """
    Merges specified chains in the given model.

    Parameters
    ----------
    model : Bio.PDB.Structure
        The model in which the chains are to be merged.
    chains_to_merge : list of str
        The list of chain IDs to be merged.

    Returns
    -------
    model : Bio.PDB.Structure
        The model with the specified chains merged.
    """
    print(f"Merging chains {chains_to_merge} in model")
    for chain in chains_to_merge[1:]:
        for res in list(model[chain]):
            res.id = (chains_to_merge[0], res.id[1], res.id[2])
            model[chains_to_merge[0]].add(res)
        model.detach_child(chain)
    model[chains_to_merge[0]].id = "".join(chains_to_merge)
    return model

def calculate_dockq(model, native, chain_map):
    """
    Calculates DockQ scores for the given model and native structures based on the chain map.

    Parameters
    ----------
    model : Bio.PDB.Structure
        The model structure.
    native : Bio.PDB.Structure
        The native structure.
    chain_map : dict
        The mapping of chains between model and native structures.

    Returns
    -------
    results : dict
        The results containing various DockQ metrics.
    dockq_score : float
        The DockQ score.
    """
    print(f"Calculating DockQ score with chain_map: {chain_map}")
    results, dockq_score = run_on_all_native_interfaces(model, native, chain_map=chain_map)
    return results, dockq_score

def process_models(models):
    """
    Processes the provided models and calculates DockQ scores.

    Parameters
    ----------
    models : list of tuple
        List of tuples where each tuple contains the model file path and native file path.

    Returns
    -------
    results_list : list of tuple
        List of tuples containing model_id, DockQ, DockQ_F1, and fnat scores.
    """
    results_list = []
    for model_file, native_file in models:
        print(f"Processing model: {model_file}, native: {native_file}")
        model_id = os.path.basename(model_file).split(".")[0]
        model = load_PDB(model_file)
        native = load_PDB(native_file)

        chain_ids = list(model.child_dict.keys())
        print(chain_ids)
        native_chain_ids = list(native.child_dict.keys())

        if len(chain_ids) == 3:
            print(f"Model {model_id} has 3 chains: {chain_ids}")
            
            # Merge A and B chains and recalculate
            model_merged = merge_chains(model, chain_ids[:2])
            print(model_merged)
            native_merged = merge_chains(native, native_chain_ids[:2])
            chain_map_merged = {native_chain_ids[2]: chain_ids[2], "".join(native_chain_ids[:2]): "".join(chain_ids[:2])}
            results_merged, dockq_score_merged = calculate_dockq(model_merged, native_merged, chain_map_merged)
            merged_result = results_merged[list(results_merged.keys())[0]]
            # results_list.append((model_id, merged_result['DockQ'], merged_result['DockQ_F1'], merged_result['fnat']))
            results_list.append((model_id, merged_result['DockQ'], merged_result['fnat'],
                                merged_result['iRMSD'], merged_result['LRMSD'],merged_result['F1']))

        elif len(chain_ids) == 2:
            print(f"Model {model_id} has 2 chains: {chain_ids}")
            # Assume there are only 2 chains
            chain_map = {native_chain_ids[0]: chain_ids[0], native_chain_ids[1]: chain_ids[1]}
            results, dockq_score = calculate_dockq(model, native, chain_map)
            results_list.append((model_id, results[list(results.keys())[0]]['DockQ'], results[list(results.keys())[0]]['fnat'],
                                results[list(results.keys())[0]]['iRMSD'], results[list(results.keys())[0]]['LRMSD'],
                                results[list(results.keys())[0]]['F1']))

        else:
            print(f"Model {model_id} does not have 2 or 3 chains: {chain_ids}, skipping.")
            continue

    return results_list

def save_results_to_csv(results, filename):
    print(f"Saving results to CSV file: {filename}")
    with open(filename, mode='w', newline='') as file:
        writer = csv.writer(file)
        # writer.writerow(['model_id', 'DockQ', 'DockQ_F1', 'fnat'])
        writer.writerow(['model_id', 'DockQ', 'fnat',"iRMSD","LRMSD","F1"])

        for model_id, dockq, fnat, irms, lrms, f1 in results:
            # print(f"Writing row: {model_id}, {dockq}, {dockq_f1}, {fnat}")
            # writer.writerow([model_id, dockq, dockq_f1, fnat])
            print(f"Writing row: {model_id}, {dockq}, {fnat}, {irms},{lrms},{f1}")
            writer.writerow([model_id, dockq, fnat, irms, lrms,f1])

def main(directory, original_directory):
    """
    Main function to find PDB files, process models, and save results to a CSV file.

    Parameters
    ----------
    directory : str
        Path to the directory containing PDB files to be analyzed.
    original_directory : str
        Path to the directory containing original PDB files for comparison.

    Returns
    -------                                                                                                                                                                                                                                                                                                                                    
    None
    """
    pdb_files = glob.glob(f"{directory}/*_seed_1*.pdb")
    if not pdb_files:
        print(f"No PDB files found in directory: {directory}")
        return

    print(f"Found PDB files: {pdb_files}")
    models = []

    for pdb_file in pdb_files:
        pdb_id_chains = os.path.basename(pdb_file).replace("_px",":").split(':')[0] 
        print(f"Searching for original files for: {pdb_id_chains}")
        original_pdb_files = glob.glob(f"{original_directory}/{pdb_id_chains}.pdb")  # Find matching original PDB files
        print(f"Found original PDB files: {original_pdb_files} for {pdb_file}")
        if not original_pdb_files:
            print(f"No matching original PDB file found for {pdb_file}, skipping.")
            continue
        for original_pdb_file in original_pdb_files:
            models.append((pdb_file, original_pdb_file))
            print(f"Adding model-native pair: {pdb_file}, {original_pdb_file}")

    if not models:
        print("No valid model-native pairs found. Exiting.")
        return

    os.chdir(result_path)
    results = process_models(models)
    save_results_to_csv(results, str(pdb_id_chains)+'_dockq_fnat_scores.csv')

In [None]:
#### dockq and fnat calculation running
import pandas as pd
import glob
import os

for i in complex_list:
    fpath = f"{folder_path}/{i}/{i}/seed_1/predictions/"
    directory = os.path.expanduser(fpath)
    original_directory = os.path.expanduser(original_directory)
    main(directory, original_directory)
    
    def combine_csv_files(result_path, output_file=None):
        # Find all CSV files in the folder
        csv_files = glob.glob(os.path.join(result_path, "*dockq_fnat_scores.csv"))
        
        # Read and combine all CSVs
        df_list = [pd.read_csv(file) for file in csv_files]
        combined_df = pd.concat(df_list, ignore_index=True)
        
        # Save to CSV if an output file is provided
        if output_file:
            combined_df.to_csv(output_file, index=False)
            print(f"Combined CSV saved to {output_file}")
        return combined_df
combined_df = combine_csv_files(result_path, dockq_output)

In [None]:
### plddt extraction
import os
import glob
import json
import re

def extract_sample_number(filename):
    """Extract the last number from the filename."""
    match = re.search(r'(\d+)$', filename)
    return match.group(1) if match else None

def update_pdb_with_plddt(folder_path):
    pdb_files = glob.glob(os.path.join(folder_path, "*.pdb"))

    for pdb_file in pdb_files:
        pdb_basename = os.path.basename(pdb_file).replace(".pdb", "")
        sample_number = extract_sample_number(pdb_basename)

        if not sample_number:
            print(f"⚠No sample number found in {pdb_basename}, skipping...")
            continue

        # Locate the corresponding JSON file
        json_pattern = os.path.join(folder_path, f"*full_data_sample_{sample_number}.json")
        json_files = glob.glob(json_pattern)

        if not json_files:
            print(f"No matching JSON found for {pdb_basename}")
            continue

        json_file = json_files[0]  

        # Load atom_plddt values
        with open(json_file, "r") as f:
            data = json.load(f)

        plddt_values = data.get("atom_plddt", [])
        if not plddt_values:
            print(f"⚠No atom_plddt found in {json_file}, skipping...")
            continue

        # Read PDB file and replace B-factor column
        pdb_lines = []
        plddt_index = 0

        with open(pdb_file, "r") as f:
            for line in f:
                if line.startswith(("ATOM", "HETATM")):
                    if plddt_index < len(plddt_values):
                        # Replace B-factor (columns 61-66) with pLDDT value
                        new_line = f"{line[:60]}{plddt_values[plddt_index]:6.2f}{line[66:]}"
                        pdb_lines.append(new_line)
                        plddt_index += 1
                    else:
                        pdb_lines.append(line)
                else:
                    pdb_lines.append(line)

        # Overwrite the original PDB file
        with open(pdb_file, "w") as f:
            f.writelines(pdb_lines)

        print(f"Updated {pdb_basename}.pdb with pLDDT values from {os.path.basename(json_file)}.")

for i in complex_list:
    f = f"{folder_path}/{i}/{i}/seed_1/predictions/"
    update_pdb_with_plddt(f)

In [None]:
import os
import glob
import numpy as np
from Bio.PDB import PDBIO
from Bio.PDB.PDBParser import PDBParser
from Bio.PDB.Selection import unfold_entities

import numpy as np
import sys,os
import argparse
import pickle
import itertools
import pandas as pd
from scipy.optimize import curve_fit

def retrieve_IFplddt(structure, chain1, chain2_lst, max_dist):
    ## generate a dict to save IF_res_id
    chain_lst = list(chain1) + chain2_lst

    ifplddt = []
    contact_chain_lst = []
    for res1 in structure[0][chain1]:
        for chain2 in chain2_lst:
            count = 0
            for res2 in structure[0][chain2]:

                if res1.has_id('CA') and res2.has_id('CA'):
                   dis = abs(res1['CA']-res2['CA'])
                   ## add criteria to filter out disorder res
                   if dis <= max_dist:
                      ifplddt.append(res1['CA'].get_bfactor()*100)
                      count += 1

                elif res1.has_id('CB') and res2.has_id('CB'):
                   dis = abs(res1['CB']-res2['CB'])
                   if dis <= max_dist:
                      ifplddt.append(res1['CB'].get_bfactor()*100)
                      count += 1
            if count > 0:
              contact_chain_lst.append(chain2)
    contact_chain_lst = sorted(list(set(contact_chain_lst)))   

    if len(ifplddt)>0:
       IF_plddt_avg = np.mean(ifplddt)
    else:
       IF_plddt_avg = 0
    return IF_plddt_avg, contact_chain_lst


def retrieve_IFPAEinter(structure, paeMat, contact_lst, max_dist):
    ## contact_lst:the chain list that have an interface with each chain. For eg, a tetramer with A,B,C,D chains and A/B A/C B/D C/D interfaces,
    ##             contact_lst would be [['B','C'],['A','D'],['A','D'],['B','C']]

    chain_lst = [x.id for x in structure[0]]
    seqlen = [len(x) for x in structure[0]]
    ifch1_col=[]
    ifch2_col=[]
    ch1_lst=[]
    ch2_lst=[]
    ifpae_avg = []
    d=10
    for ch1_idx in range(len(chain_lst)):
      ## extract x axis range from the PAE matrix
      idx = chain_lst.index(chain_lst[ch1_idx])
      ch1_sta=sum(seqlen[:idx])
      ch1_end=ch1_sta+seqlen[idx]
      ifpae_col = []   
      ## for each chain that shares an interface with chain1, retrieve the PAE matrix for the specific part.
      for contact_ch in contact_lst[ch1_idx]:
        index = chain_lst.index(contact_ch)
        ch_sta = sum(seqlen[:index])
        ch_end = ch_sta+seqlen[index]
        paeMat = np.array(paeMat)
        remain_paeMatrix = paeMat[ch1_sta:ch1_end,ch_sta:ch_end]   

        ## get avg PAE values for the interfaces for chain 1
        mat_x = -1
        for res1 in structure[0][chain_lst[ch1_idx]]:
          mat_x += 1
          mat_y = -1
          for res2 in structure[0][contact_ch]:
              mat_y+=1
              if res1['CA'] - res2['CA'] <=max_dist:
                 ifpae_col.append(remain_paeMatrix[mat_x,mat_y])
      ## normalize by d(10A) first and then get the average
      if not ifpae_col:
        ifpae_avg.append(0)
      else:
        norm_if_interpae=np.mean(1/(1+(np.array(ifpae_col)/d)**2))
        ifpae_avg.append(norm_if_interpae)
        # print(ifpae_avg)
    return ifpae_avg

def calc_pmidockq(ifpae_norm, ifplddt):
    df = pd.DataFrame()
    df['ifpae_norm'] = ifpae_norm
    df['ifplddt'] = ifplddt

    df['prot'] = df.ifpae_norm*df.ifplddt
    fitpopt = [1.31034849e+00, 8.47326239e+01, 7.47157696e-02, 5.01886443e-03] ## from orignal fit function  
    df['pmidockq'] = sigmoid(df.prot.values, *fitpopt)
    return df

def sigmoid(x, L ,x0, k, b):
    y = L / (1 + np.exp(-k*(x-x0)))+b
    return (y)

def find_matching_json(pdb_file, directory):
    """Find the corresponding JSON file by matching the last number in the filename."""
    pdb_file_basename = os.path.basename(pdb_file)
    
    last_digit = ''.join([char for char in pdb_file_basename if char.isdigit()])[-2:]  
    
    json_pattern = f"*full_data_sample_{last_digit}.json"  
    json_files = glob.glob(os.path.join(directory, json_pattern))

    return json_files[0] if json_files else None 

def process_pdb_file(pdb_file, json_file, distance, file_id, chains_part=""):
    """Process PDB and JSON files to compute interface metrics."""
    pdbp = PDBParser(QUIET=True)
    structure = pdbp.get_structure('', pdb_file)
    chains = [chain.id for chain in structure[0]]

    remain_contact_lst = []
    plddt_lst = []

    for idx in range(len(chains)):
        chain2_lst = list(set(chains) - {chains[idx]})
        IF_plddt, contact_lst = retrieve_IFplddt(structure, chains[idx], chain2_lst, distance)
        plddt_lst.append(IF_plddt)
        remain_contact_lst.append(contact_lst)

    pae_data = pd.read_json(json_file, lines=True)
    if "token_pair_pae" in pae_data:
        pae_matrix = pae_data["token_pair_pae"][0]  # Extract PAE matrix
    else:
        print(f"Warning: 'token_pair_pae' not found in {json_file}")
        return None  

    avgif_pae = retrieve_IFPAEinter(structure, pae_matrix, remain_contact_lst, distance)
    res = calc_pmidockq(avgif_pae, plddt_lst)

    pdb_id = os.path.basename(pdb_file).split('_')[0]
    result = {
        "model_id" : pdb_file,
        "pdb_id": file_id,
        "pdb_id_with_chains": '{0}_{1}'.format(pdb_id, chains_part),
        "ipae_norm_ag": res['ifpae_norm'].tolist()[-1],
        "ipae_norm_avg": np.mean(res['ifpae_norm']), 
        "iplddt_ag": res['ifplddt'].tolist()[-1],
        "iplddt_avg": np.mean(res['ifplddt']), 
        "pDockQ2_ag": res['pmidockq'].tolist()[-1],
        "pDockQ2_avg": np.mean(res['pmidockq'])}
    return result

def run_processing(directory):
    """Process all PDB files in the directory and match with JSON files."""
    pdb_files = glob.glob(os.path.join(directory, "*_px_seed_1*.pdb"))
    results = []

    for pdb_file in pdb_files:
        json_file_name = find_matching_json(pdb_file, directory)

        if json_file_name and os.path.exists(json_file_name):
            base_name = pdb_file.rsplit('_', 4)[0]
            pdb_file_basename = os.path.basename(pdb_file)
            parts = pdb_file_basename.split('_')
            # print(f"Filena/me parts: {parts}")

            pdb_id = parts[0]
            chains_part = "_".join(parts[1:parts.index('px')])
            print(f"Processing: {pdb_id}_{chains_part}")
            file_id = f"{pdb_id}_{chains_part}"
            result = process_pdb_file(pdb_file, json_file_name, 8, pdb_id, chains_part)
            if result:
                results.append(result)
        else:
            print(f"No matching JSON file found for {pdb_file}")

    if results:
        df = pd.DataFrame(results)
        csv_file_path = f"{file_id}_pdockq2_fit.csv"
        df.to_csv(csv_file_path, index=False)
        print(f"Data saved to {csv_file_path}")
    else:
        print("No results to save.")


In [None]:
# Set directory paths
import pandas as pd
import glob
import os
# pdockq2 functions 
for i in complex_list:
    directory = f"{folder_path}/{i}/{i}/seed_1/predictions/"
    run_processing(directory)   

def combine_csv_files(result_path, output_file=None):
    csv_files = glob.glob(os.path.join(result_path, "*_pdockq2_fit.csv"))
    df_list = [pd.read_csv(file) for file in csv_files]
    combined_df = pd.concat(df_list, ignore_index=True)
    if output_file:
        combined_df.to_csv(output_file, index=False)
        print(f"Combined CSV saved to {output_file}")
    return combined_df
    
combined_df = combine_csv_files(result_path, pdockq_output)

In [None]:
## proteinx metrics
import os
import json
import pandas as pd

def extract_data_from_json(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
    return {
        "filename": os.path.basename(file_path),
        "plddt": data.get("plddt"),
        "gpde": data.get("gpde"),
        "ptm": data.get("ptm"),
        "iptm": data.get("iptm"),
    }

x=[]
for i in complex_list:
    for f in range(25):
        file_path = f"{folder_path}/{i}/{i}/seed_1/predictions/{i}_seed_1_summary_confidence_sample_{f}.json"
        data = extract_data_from_json(file_path)
        df = pd.DataFrame([data])
        x.append(df)

dfy = pd.concat(x)
# print(list(dfy))
dfy['model_id'] = dfy['filename'].str.replace('_summary_confidence', "").str.replace('.json', "")
dfy=dfy[['model_id','plddt', 'gpde', 'ptm', 'iptm']]

df = pd.read_csv(f"{result_path}/{pdockq_output}")
df['model_id']=df['model_id'].str.replace(".pdb","")
df['model_id']=df['model_id'].str.split("/").str[-1]
# print(df.head())

df2 = pd.read_csv(f"{result_path}/{dockq_output}")
df2['model_id']=df2['model_id'].str.replace(".pdb","")
# print(df2.head())
dfx = pd.merge(df,df2,on="model_id")

df_fin = pd.merge(dfx,dfy,on="model_id")
df_fin=df_fin.drop_duplicates()

# AntiConf = 0.3pDockQ2_ag + 0.7pTM
df_fin["AntiConf"] =(0.3 * df_fin["pDockQ2_ag"]) + (0.7 * df_fin["ptm"])
df_fin.to_csv(f"{result_path}/{combined}",index=False)