# Syn-motifs Generation Pipeline

This notebook implements an open-source version of the Syn-motifs generation pipeline
used to design Synmask, an intrinsically disordered, hydrophilic and low-immunogenicity
polypeptide used to extend protein half-life without perturbing its intrinsic function.

The pipeline consists of 5 main steps:

1. Screening disordered fragments from natural proteins
2. Removing fragments containing unreasonable motifs
3. Removing fragments containing unreasonable amino acids
4. Removing fragments predicted to contain secondary structures
5. Amino acid modification and re-prediction

Each section of this notebook corresponds to one step in the Methods and can be run
independently, assuming the corresponding input files from the previous step are available.

# Step 1 Screening disordered fragments from natural proteins
## Step 1.1 Find all disordered fragments

In [52]:
## function used in step1.1
import os
import json
import csv
from collections import defaultdict
from tqdm.auto import tqdm
import pathlib

def generate_dssp_shell(pdb_dir, dssp_out_dir, shell_path="run_dssp.sh"):
    """
    Generate a shell script that runs DSSP for each PDB/mmCIF file in `pdb_dir`.

    DSSP command example:
        dssp input.pdb output.dssp

    Parameters
    ----------
    pdb_dir : str
        Path to directory containing PDB files.
    dssp_out_dir : str
        Directory to store DSSP output files.
    shell_path : str
        Output shell script path.
    """
    os.makedirs(dssp_out_dir, exist_ok=True)

    with open(shell_path, "w") as f:
        for file in sorted(os.listdir(pdb_dir)):
            if file.endswith(".pdb") or file.endswith(".cif"):
                in_path = os.path.join(pdb_dir, file)
                out_name = pathlib.Path(file).stem + ".dssp"
                out_path = os.path.join(dssp_out_dir, out_name)
                f.write(f"dssp {in_path} {out_path}\n")

    print(f"[OK] DSSP shell written to: {shell_path}")

def parse_dssp_file(dssp_path):
    """
    Parse a DSSP file and extract (chain, resnum, aa, ss) information.

    Returns
    -------
    dict: mapping
        { chain: [ (resnum, aa, ss), ... ] }
    """
    chain_map = defaultdict(list)
    header_passed = False

    with open(dssp_path) as f:
        for line in f:
            
            if line.strip().startswith("#"):
                header_passed = True
                continue

            if not header_passed:
                continue

            parts = line.rstrip("\n")
            if len(parts) < 120:
                continue

            resnum_str = parts[5:10].strip()
            if not resnum_str.isdigit():
                continue
            resnum = int(resnum_str)
            chain = parts[11].strip() or "_"
            aa = parts[13]
            ss = parts[16] if parts[16] != " " else "-"
            
            chain_map[chain].append((resnum, aa, ss))
    return chain_map

def extract_disordered_fragments(chain_res_list, min_len=20):
    """
    Extract continuous disordered fragments from one chain.

    A residue is considered disordered if:
        ss == '-'

    Breaking conditions:
        - residue index not consecutive
        - secondary structure not '-'

    Returns
    -------
    list of dicts:
        { "start": int, "end": int, "sequence": str }
    """
    frags = []
    cur_seq = []
    cur_start = None
    prev_resnum = None

    for resnum, aa, ss in chain_res_list:
        is_disordered = (ss == "-" or ss == 'T' or ss == 'S')

        if is_disordered:
            # Start a new fragment if needed
            if cur_start is None or prev_resnum is None or resnum != prev_resnum + 1:
                # flush previous fragment
                if cur_seq and len(cur_seq) >= min_len:
                    frags.append({
                        "start": cur_start,
                        "end": prev_resnum,
                        "sequence": "".join(cur_seq)
                    })
                # start new fragment
                cur_start = resnum
                cur_seq = []

            cur_seq.append(aa)
        else:
            # structured → flush fragment
            if cur_seq and len(cur_seq) >= min_len:
                frags.append({
                    "start": cur_start,
                    "end": prev_resnum,
                    "sequence": "".join(cur_seq)
                })
            cur_seq = []
            cur_start = None

        prev_resnum = resnum

    # flush last fragment
    if cur_seq and len(cur_seq) >= min_len:
        frags.append({
            "start": cur_start,
            "end": prev_resnum,
            "sequence": "".join(cur_seq)
        })

    return frags

def extract_all_disordered_fragments(dssp_dir, output_csv, min_len=20):
    """
    Process all DSSP files in a directory and extract
    contiguous disordered fragments.

    Output CSV columns:
        pdb_id, chain, start, end, sequence
    """
    files = sorted([f for f in os.listdir(dssp_dir) if f.endswith(".dssp")])

    with open(output_csv, "w") as f:
        writer = csv.writer(f)
        writer.writerow(["pdb_id", "chain", "start", "end", "sequence"])

        for file in tqdm(files):
            pdb_id = pathlib.Path(file).stem
            dssp_path = os.path.join(dssp_dir, file)
            chain_map = parse_dssp_file(dssp_path)

            for chain, residues in chain_map.items():
                frags = extract_disordered_fragments(residues, min_len=min_len)
                for frag in frags:
                    writer.writerow([
                        pdb_id,
                        chain,
                        frag["start"],
                        frag["end"],
                        frag["sequence"],
                    ])

    print(f"[OK] Step 1 result written to {output_csv}")

def deduplicate_fragments(input_csv, output_csv, min_len=20):
    """
    Remove duplicate sequences and short fragments.

    Only keep rows with sequence length >= min_len.
    """
    seen = set()
    out = []

    with open(input_csv) as f:
        next(f)  # skip header
        for line in f:
            pdb_id, chain, start, end, seq = line.strip().split(",")
            
            if len(seq) < min_len:
                continue
            if "X" in seq:
                continue
            if seq in seen:
                continue
            seen.add(seq)
            out.append([f"{pdb_id}_{chain}", seq, len(seq)])

    with open(output_csv, "w") as wf:
        writer = csv.writer(wf)
        writer.writerow(["PDB_id", "Sequence", "Length"])
        writer.writerows(out)

    print(f"[OK] Deduplicated fragments written to {output_csv}")

In [None]:
pdb_dir = '/mnt/data/public/PDB_split'
dssp_output = './dssp_results/'
bash_file = './run_dssp.sh'
generate_dssp_shell(pdb_dir, dssp_output, bash_file)

# bash ./run_dssp.sh

In [None]:
output_file = './all_disordered.csv'
extract_all_disordered_fragments(dssp_output, output_file)

du_output_file = './step1_disordered.csv'
deduplicate_fragments(output_file, du_output_file)

## Step 1.2 Find all missing fragments

In [51]:
import os
import csv
from collections import defaultdict
from Bio.PDB import PDBParser
from Bio import pairwise2
from tqdm.auto import tqdm


# === 3-letter → 1-letter mapping ===
three_to_one = {
    'ALA':'A','ARG':'R','ASN':'N','ASP':'D','CYS':'C','GLU':'E','GLN':'Q',
    'GLY':'G','HIS':'H','ILE':'I','LEU':'L','LYS':'K','MET':'M','PHE':'F',
    'PRO':'P','SER':'S','THR':'T','TRP':'W','TYR':'Y','VAL':'V',
    'MSE':'M'  # selenium-methionine
}

VALID_AA_3 = set(three_to_one.keys())


# ==========================================================
#                 Load SEQRES FASTA
# ==========================================================
def load_seqres(seqres_path):
    """
    Load SEQRES FASTA.  
    Return dict: { '3nch_C': 'SEQUENCE', ... }
    """
    seqres = {}
    name = None
    with open(seqres_path) as f:
        for line in f:
            line = line.strip()
            if line.startswith(">"):
                name = line.split()[0][1:]
                seqres[name] = ""
            else:
                seqres[name] += line
    return seqres



# ==========================================================
#          Extract protein sequence from ATOM records
# ==========================================================
def get_observed_sequence(pdb_file):
    """
    Return a sequence string extracted from ATOM coordinates, **protein only**.
    Residues are in the order they appear in the chain.
    """
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("pdb", pdb_file)
    seq = []

    for model in structure:
        for chain in model:
            for res in chain.get_residues():
                hetflag, resseq, icode = res.id
                resname = res.resname.upper()

                # only standard amino acids (exclude DNA/RNA)
                if hetflag == " " and resname in VALID_AA_3 and "CA" in res.child_dict:
                    seq.append(three_to_one[resname])
            break
        break
    return "".join(seq)



# ==========================================================
#            Align SEQRES ↔ ATOM sequence
# ==========================================================
def align_sequences(seq_full, seq_observed):
    """
    Global alignment to map SEQRES ↔ ATOM.
    Return: aligned_seqres, aligned_observed
    """
    align = pairwise2.align.globalms(seq_full, seq_observed,
                                     2, -1, -10, -1,
                                     one_alignment_only=True)[0]
    return align.seqA, align.seqB



# ==========================================================
#            Identify missing internal segments
# ==========================================================
def find_missing_segments_from_alignment(aln_seqres, aln_observed, pdbid_chain):
    """
    aln_seqres : aligned SEQRES
    aln_observed : aligned ATOM seq
    A missing residue = position where:
        aln_seqres[i] != '-' AND aln_observed[i] == '-'
    """

    missing_segments = []
    n = len(aln_seqres)

    cur = []
    for i in range(n):
        aa_seqres = aln_seqres[i]
        aa_obs = aln_observed[i]

        if aa_seqres != '-' and aa_obs == '-':
            cur.append(aa_seqres)
        else:
            if cur:
                missing_segments.append("".join(cur))
                cur = []
    if cur:
        missing_segments.append("".join(cur))

    # Generate final entries with metadata
    results = []
    seqres_pos = 0  # SEQRES index (1-based)
    obs_pos = 0     # ATOM index (1-based, not used directly)

    cur_start = None
    cur_end = None
    cur_seq = []

    for i in range(n):
        aa_seqres = aln_seqres[i]
        aa_obs = aln_observed[i]

        if aa_seqres != "-":
            seqres_pos += 1

        # missing residue
        if aa_seqres != "-" and aa_obs == "-":
            if cur_start is None:
                cur_start = seqres_pos
            cur_end = seqres_pos
            cur_seq.append(aa_seqres)
        else:
            if cur_seq:
                results.append(
                    (cur_start, cur_end, "".join(cur_seq))
                )
                cur_start = None
                cur_end = None
                cur_seq = []

    if cur_seq:
        results.append(
            (cur_start, cur_end, "".join(cur_seq))
        )

    # Build output dicts
    final = []
    for (s, e, seq) in results:
        at_head = (s == 1)
        at_tail = (e == aln_seqres.replace("-", "").__len__())
        at_headtail = "yes" if (at_head or at_tail) else "no"

        final.append({
            "pdbid": pdbid_chain,
            "seqres_start": s,
            "seqres_end": e,
            "pdb_start": None,   # no need anymore
            "pdb_end": None,
            "missing_seq": seq,
            "seq_len": len(seq),
            "at_headtail": at_headtail
        })
    return final



# ==========================================================
#                   Step 1.2 main function
# ==========================================================
def extract_missing_regions(pdb_dir, seqres_path, output_csv):
    seqres_map = load_seqres(seqres_path)

    with open(output_csv, "w") as f:
        writer = csv.writer(f)
        writer.writerow([
            "pdbid", "seqres_start", "seqres_end",
            "pdb_start", "pdb_end",
            "missing_seq", "seq_len", "at_headtail"
        ])

        for file in tqdm(sorted(os.listdir(pdb_dir))):
            if not file.endswith(".pdb"):
                continue

            pdbid_chain = file.replace(".pdb", "")
            if pdbid_chain not in seqres_map:
                continue

            seq_full = seqres_map[pdbid_chain]
            seq_obs = get_observed_sequence(os.path.join(pdb_dir, file))

            if len(seq_obs) == 0:
                continue  # skip empty chains

            alnA, alnB = align_sequences(seq_full, seq_obs)

            segs = find_missing_segments_from_alignment(alnA, alnB, pdbid_chain)

            for seg in segs:
                writer.writerow([
                    seg["pdbid"],
                    seg["seqres_start"],
                    seg["seqres_end"],
                    seg["pdb_start"],
                    seg["pdb_end"],
                    seg["missing_seq"],
                    seg["seq_len"],
                    seg["at_headtail"],
                ])

    print(f"[OK] Missing regions written to {output_csv}")



# ==========================================================
#             Deduplication & fragment filtering
# ==========================================================
def deduplicate_missing(input_csv, output_csv, min_len=20):

    seen = set()

    with open(output_csv, "w") as wf:
        writer = csv.writer(wf)
        writer.writerow(["PDB_id", "Sequence", "Length"])

        with open(input_csv) as f:
            header = next(f).strip().split(",")
            col = {h: i for i, h in enumerate(header)}

            for line in f:
                parts = line.strip().split(",")

                pdbid = parts[col["pdbid"]]
                seq = parts[col["missing_seq"]]
                length = int(parts[col["seq_len"]])
                at_headtail = parts[col["at_headtail"]]

                # filters
                if at_headtail != "no":     # remove N/C terminal missing
                    continue
                if length < min_len:
                    continue
                if "X" in seq:
                    continue
                if seq in seen:
                    continue

                seen.add(seq)

                writer.writerow([pdbid, seq, length])

    print(f"[OK] Deduplicated file written to {output_csv}")

In [None]:
pdb_dir = '/mnt/data/public/split_PDB'
seq_res_file = '/mnt/data/public/pdb_seqres.txt'
output_file = 'all_missing.csv'
extract_missing_regions(pdb_dir, seq_res_file, output_file)
du_output_file = "./step1_missing.csv"
deduplicate_missing(output_file,du_output_file)

# Step2 Removing fragments containing unreasonable motifs


In [64]:
import csv
import re
from tqdm.auto import tqdm


# -------------------------
# 1. Homopolymer ≥ 4
# -------------------------
def has_homopolymer(seq, run_len=4):
    """
    Return True if the sequence contains a homopolymeric run
    of length >= run_len (e.g., 'SSSS' for run_len=4).
    """
    s = seq.upper()
    if len(s) < run_len:
        return False

    count = 1
    for i in range(1, len(s)):
        if s[i] == s[i - 1]:
            count += 1
            if count >= run_len:
                return True
        else:
            count = 1
    return False


# -------------------------
# 2. Linker motifs:
#    - explicit 'GGGGS'
#    - any 'GGGX' (GGG + any residue)
# -------------------------
def contains_linker_motif(seq):
    """
    Return True if the sequence contains linker-like patterns that
    we want to exclude, including:
      - 'GGGGS'
      - any motif 'GGGX' (GGG followed by any residue).
    """
    s = seq.upper()

    if "GGGGS" in s:
        return True

    # GGG + any amino acid
    if re.search(r"GGG.", s):
        return True

    return False


# -------------------------
# 3. Non-canonical amino acids
# -------------------------
def contains_noncanonical(seq):
    """
    Return True if the sequence contains residues outside the
    20 standard amino acids (ACDEFGHIKLMNPQRSTVWY).
    'X' is automatically treated as non-canonical.
    """
    valid = set("ACDEFGHIKLMNPQRSTVWY")
    s = seq.upper()
    return any(aa not in valid for aa in s)



def contains_short_repeats(seq, min_len=3, max_len=6, min_repeats=3):
    """
    Return True if the sequence contains short repetitive motifs:
      - motif length in [min_len, max_len]
      - repeated >= min_repeats times (overlapping allowed)

    Example: 'GSSGSSGSS' will be flagged because 'GSS' appears >= 3 times.
    """
    s = seq.upper()
    n = len(s)
    if n < min_len:
        return False

    for k in range(min_len, max_len + 1):
        freq = {}
        for i in range(0, n - k + 1):
            mot = s[i:i + k]
            freq[mot] = freq.get(mot, 0) + 1

        if any(count >= min_repeats for count in freq.values()):
            return True

    return False


# -------------------------
# Step 2 main function
# -------------------------

def is_valid_protein_sequence(seq):
    """
    Ensure the sequence contains ONLY the 20 canonical amino acids.
    """
    valid = set("ACDEFGHIKLMNPQRSTVWY")
    return all(aa in valid for aa in seq.upper())


def step2_filter_unreasonable_motifs(input_csv, output_csv):
    kept_rows = []

    with open(input_csv) as f:
        reader = csv.DictReader(f)
        for row in reader:
            seq = row["Sequence"].strip().upper()

            # PRE-FILTER: remove nucleic-acid sequences
            if not is_valid_protein_sequence(seq):
                continue

            # Rule 1: homopolymer runs
            if has_homopolymer(seq):
                continue

            # Rule 2: linker-like GGGX patterns
            if contains_linker_motif(seq):
                continue

            # Rule 3: non-canonical AA
            if contains_noncanonical(seq):
                continue

            # Rule 4: short repeat motifs (3–6 aa repeated >=3)
            if contains_short_repeats(seq):
                continue

            kept_rows.append(row)

    with open(output_csv, "w") as f:
        writer = csv.DictWriter(f, fieldnames=["PDB_id", "Sequence", "Length"])
        writer.writeheader()
        writer.writerows(kept_rows)

    print(f"[OK] Step2 complete: kept {len(kept_rows)} → {output_csv}")

In [None]:
step1_missing_file = './step1_missing.csv'
step1_disordered_file = './step1_disordered.csv'

step2_missing_file = './step2_missing.csv'
step2_disordered_file = './step2_disordered.csv'

step2_filter_unreasonable_motifs(input_csv=step1_disordered_file, output_csv=step2_disordered_file)
step2_filter_unreasonable_motifs(input_csv=step1_missing_file, output_csv=step2_missing_file)


# Step 3 Removing fragments containing unreasonable amino acids

In [67]:
import csv
from tqdm.auto import tqdm

# ============================================================
# Amino-acid class definitions
# ============================================================

# AA7: preferred hydrophilic & flexible amino acids
AA7 = set(list("AGTSPED"))

# AA4: tolerated but should not dominate the sequence
AA4 = set(list("HLVY"))

# AA9: undesirable residues (hydrophobic, unstable, or positively charged)
AA9 = set(list("RNCQIKMFW"))

# Full set of allowed amino acids
VALID_AA = AA7 | AA4 | AA9


# ============================================================
# Step 3: Filtering based on amino-acid composition
# ============================================================

def step3_filter_amino_acids(
    input_csv,
    output_csv,
    aa7_min_ratio=0.70,       # Minimum required ratio for AA7 residues
    aa9_max_ratio=0.10,       # Maximum allowed ratio for AA9 residues
    agtspe_max_ratio=0.85     # Maximum allowed ratio for AGTSPE residues
):
    """
    Filter peptide fragments based on amino-acid class composition.

    Conditions:
        - AA7 (AGTSPED) ≥ 70%
        - AA9 (RNCQIKMFW) ≤ 10%
        - AGTSPE ≤ 85% (to avoid overly homogeneous composition)

    Input/Output format:
        CSV with columns: PDB_id, Sequence, Length
    """
    kept_rows = []

    with open(input_csv) as f:
        reader = csv.DictReader(f)
        for row in tqdm(reader):
            seq = row["Sequence"].strip().upper()
            length = len(seq)

            # Skip sequences containing non-standard residues (should rarely happen)
            if not all(aa in VALID_AA for aa in seq):
                continue

            # Count category occurrences
            aa7_count = sum(aa in AA7 for aa in seq)
            aa4_count = sum(aa in AA4 for aa in seq)
            aa9_count = sum(aa in AA9 for aa in seq)

            aa7_ratio = aa7_count / length
            aa9_ratio = aa9_count / length

            # Special filter: prevent AGTSPE from dominating excessively
            agtspe_set = set(list("AGTSPE"))
            agtspe_ratio = sum(aa in agtspe_set for aa in seq) / length

            # ---------------------------------------------------------
            # Apply filtering rules
            # ---------------------------------------------------------

            # AA7 must be the dominant class (≥ 70%)
            if aa7_ratio < aa7_min_ratio:
                continue

            # AA9 (undesirable residues) must remain low (≤ 10%)
            if aa9_ratio > aa9_max_ratio:
                continue

            # Avoid overly homogeneous AGTSPE-only sequences (≤ 85%)
            if agtspe_ratio > agtspe_max_ratio:
                continue

            # Sequence passes all filters → keep it
            kept_rows.append(row)

    # ---------------------------------------------------------
    # Write filtered results
    # ---------------------------------------------------------
    with open(output_csv, "w") as f:
        writer = csv.DictWriter(f, fieldnames=["PDB_id", "Sequence", "Length"])
        writer.writeheader()
        writer.writerows(kept_rows)

    print(f"[OK] Step 3 complete: kept {len(kept_rows)} fragments → {output_csv}")

In [None]:
step2_missing_file = './step2_missing.csv'
step2_disordered_file = './step2_disordered.csv'

step3_missing_file = './step3_missing.csv'
step3_disordered_file = './step3_disordered.csv'

step3_filter_amino_acids(input_csv=step2_missing_file, output_csv=step3_missing_file)
step3_filter_amino_acids(input_csv=step2_disordered_file, output_csv=step3_disordered_file)

# Step 4 Removing fragments predicted to contain secondary structures

In [None]:
# Step 4 — S4PRED: build merged FASTA
import pandas as pd
from tqdm.auto import tqdm

def make_fasta_from_step3(step3_csv, out_fasta):
    """
    Convert Step3 CSV -> FASTA file for S4PRED input.
    CSV format: PDB_id, Sequence, Length
    """
    df = pd.read_csv(step3_csv)
    seqs = df["Sequence"].to_list()

    with open(out_fasta, "w") as f:
        for i, seq in enumerate(seqs):
            f.write(f">{i}\n{seq}\n")

    print(f"[OK] FASTA written → {out_fasta}")
    return len(seqs)

# Example usage:
make_fasta_from_step3("step3_missing.csv","step4_missing_s4pred_input.fasta")

Run S4PRED

Use the repository: https://github.com/psipred/s4pred

After installation, run the following command:

python /home/wangyu/gitlab/ss_predict/s4pred-update/run_model.py --device gpu --outfmt fas ./step4_missing_s4pred_input.fasta > ./step4_missing.fasta

In [None]:
# Parse S4PRED output: each 3 lines → (seq, ss, percent_C)
import pandas as pd
from tqdm.auto import tqdm

def parse_s4pred(s4pred_fasta, out_csv):
    seqs, ss_list, percents = [], [], []

    with open(s4pred_fasta) as f:
        lines = f.readlines()

    for i, line in enumerate(tqdm(lines)):
        if i % 3 == 1:
            seqs.append(line.strip())
        if i % 3 == 2:
            ss = line.strip()
            ss_list.append(ss)
            percents.append(ss.count("C") / len(ss))

    df = pd.DataFrame({
        "id": range(len(seqs)),
        "Sequence": seqs,
        "s4_ss": ss_list,
        "s4_coil_percent": percents,
    })
    df.to_csv(out_csv, index=False)
    print(f"[OK] S4PRED parsed → {out_csv}")

# Filter: keep coils ≥ threshold
import pandas as pd

# Filter: keep coils ≥ threshold
def filter_s4pred_by_coil(s4_csv, fasta_input, fasta_output, coil_th=0.99):
    df = pd.read_csv(s4_csv)
    keep_ids = df[df["s4_coil_percent"] >= coil_th]["id"].to_list()

    seq_lines = open(fasta_input).read().split()
    seq_out = []

    # FASTA is ">id", "SEQ"
    for i, line in enumerate(seq_lines):
        if i % 2 == 0:  # header
            idx = int(line[1:])
            if idx in keep_ids:
                seq_out.append(line)
                seq_out.append(seq_lines[i+1])

    with open(fasta_output, "w") as f:
        f.write("\n".join(seq_out) + "\n")

    print(f"[OK] S4PRED filter done → {fasta_output}")

parse_s4pred('step4_missing.fasta','step4_missing_S4PRED.csv')
filter_s4pred_by_coil('step4_missing_S4PRED.csv','step4_missing.fasta','step4_missing_S4PRED_output.fasta')

In [None]:
# ProtBert secondary structure prediction
from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline
import re
import pandas as pd
from collections import Counter
from tqdm.auto import tqdm

def protbert_predict(fasta, out_csv):
    # Read sequences
    seqs = []
    with open(fasta) as f:
        tokens = f.read().split()
        for line in tokens:
            if not line.startswith(">"):
                seqs.append(line)

    # format for ProtBert (add space between AA)
    spaced = [" ".join(list(s)) for s in seqs]

    pipeline = TokenClassificationPipeline(
        model=AutoModelForTokenClassification.from_pretrained("Rostlab/prot_bert_bfd_ss3"),
        tokenizer=AutoTokenizer.from_pretrained("Rostlab/prot_bert_bfd_ss3"),
        device=0
    )

    # run model
    seqs_clean = [re.sub(r"[UZOB]", "X", s) for s in spaced]
    results = pipeline(seqs_clean)

    # count coil ratios
    coil_ratio = []
    ss_strings = []

    for res in results:
        ss = "".join([x["entity"] for x in res])
        ss_strings.append(ss)
        coil_ratio.append(ss.count("C") / len(ss))

    df = pd.DataFrame({
        "id": range(len(seqs)),
        "Sequence": seqs,
        "protbert_ss": ss_strings,
        "protbert_coil_percent": coil_ratio
    })
    df.to_csv(out_csv, index=False)
    print(f"[OK] ProtBert results → {out_csv}")

protbert_predict('step4_missing_S4PRED_output.fasta','step4_missing_ProtBert.csv')

In [None]:
def filter_protbert(s4_filtered_fasta, protbert_csv, out_fasta, th=0.97):
    df = pd.read_csv(protbert_csv)
    keep_ids = df[df["protbert_coil_percent"] >= th]["id"].to_list()

    lines = open(s4_filtered_fasta).read().split()
    out = []

    for i, line in enumerate(lines):
        if i % 2 == 0:
            idx = int(line[1:])
            if idx in keep_ids:
                out.append(line)
                out.append(lines[i+1])

    with open(out_fasta, "w") as f:
        f.write("\n".join(out) + "\n")

    print(f"[OK] ProtBert filtered → {out_fasta}")

filter_protbert("step4_missing_S4PRED_output.fasta","step4_missing_ProtBert.csv","step4_missing_ProtBert_output.fasta")

In [None]:
# Create AF2 input directory: each seq → one FASTA
import os
import shutil

def split_fasta_for_af2(fasta, out_dir):
    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)
    os.makedirs(out_dir)

    lines = open(fasta).read().split()
    for i in range(0, len(lines), 2):
        hdr = lines[i]
        seq = lines[i+1]
        idx = int(hdr[1:])
        with open(os.path.join(out_dir, f"seq{idx}.fasta"), "w") as f:
            f.write(f">seq{idx}\n{seq}\n")

    print(f"[OK] AF2 input FASTA files created → {out_dir}")

split_fasta_for_af2("step4_missing_ProtBert_output.fasta",'./step4_missing_af2_input')

Run AF2

XLA_PYTHON_CLIENT_PREALLOCATE=false CUDA_VISIBLE_DEVICES=3 python3 /home/wangyu/gitlab/alphafold-dev/run_single_without_msa.py --input_type=dir --input=./step4_missing_af2_input --output_dir=./step4_missing_af2_output

In [None]:
def extract_rank0(raw_dir, out_dir):
    import shutil
    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)
    os.makedirs(out_dir)

    files = os.listdir(raw_dir)
    for f in files:
        sub = os.path.join(raw_dir, f)
        try:
            shutil.copy(
                os.path.join(sub, "ranked_0.pdb"),
                os.path.join(out_dir, f"{int(f[3:])}.pdb")
            )
        except:
            pass

    print(f"[OK] AF2 rank0 PDB extracted → {out_dir}")
extract_rank0("./step4_missing_af2_output","./step4_missing_af2_output_analysis")

In [None]:
# Generate DSSP bash
def build_dssp_sh(pdb_dir, ss_dir, sh_file):
    if not os.path.exists(ss_dir):
        os.makedirs(ss_dir)
    with open(sh_file, "w") as f:
        for p in os.listdir(pdb_dir):
            src = os.path.join(pdb_dir, p)
            tgt = os.path.join(ss_dir, p.replace(".pdb",".txt"))
            f.write(f"dssp {src} {tgt}\n")
    print(f"[OK] DSSP run script → {sh_file}")


# Parse DSSP output
def parse_dssp_txt(dssp_txt):
    aa, ss = "", ""
    below = False
    with open(dssp_txt) as f:
        for line in f:
            if line.startswith("#"):
                below = True
                continue
            if below and line.split():
                try:
                    if line[13].strip():
                        aa += line[13]
                        ss += (line[16] if line[16] != " " else "-")
                except:
                    continue
    return aa, ss


def collect_af2_dssp(ss_dir, out_csv):
    rows = []
    for txt in os.listdir(ss_dir):
        if not txt.endswith(".txt"):
            continue
        aa, ss = parse_dssp_txt(os.path.join(ss_dir, txt))
        length = len(ss)
        non_coil = ss.count("H") + ss.count("E") + ss.count("G") + ss.count("B") + ss.count("I")
        coil_percent = (length - non_coil) / length
        idx = int(txt.split(".")[0])
        rows.append([idx, aa, ss, coil_percent])

    df = pd.DataFrame(rows, columns=["id","Sequence","ss","af2_coil_percent"])
    df.to_csv(out_csv, index=False)
    print(f"[OK] AF2-DSSP collected → {out_csv}")

build_dssp_sh('step4_missing_af2_output_analysis','step4_missing_ss','step4_missing_run.sh')
# bash step4_missing_run.sh
collect_af2_dssp("step4_missing_ss","step4_missing_af2.csv")

In [None]:
# Merge three secondary-structure predictions and filter
def merge_and_filter_final(
    step3_csv, s4_csv, prot_csv, af2_csv,
    output_csv,
    th_s4=0.99, th_prot=0.97, th_af2=0.96
):
    df0 = pd.read_csv(step3_csv)
    df0["id"] = range(len(df0))

    df_s4 = pd.read_csv(s4_csv)
    df_prot = pd.read_csv(prot_csv)
    df_af2 = pd.read_csv(af2_csv)

    df = df0.merge(df_s4, on="id", how="inner")
    df = df.merge(df_prot, on="id", how="inner")
    df = df.merge(df_af2, on="id", how="inner")

    df_final = df[
        (df["s4_coil_percent"] >= th_s4) &
        (df["protbert_coil_percent"] >= th_prot) &
        (df["af2_coil_percent"] >= th_af2)
    ]

    df_final.to_csv(output_csv, index=False)
    print(f"[OK] Step 4 final filtered motifs → {output_csv}")

step3_csv = 'step3_missing.csv'
s4_csv = 'step4_missing_S4PRED.csv'
prot_csv = 'step4_missing_ProtBert.csv'
af2_csv = 'step4_missing_af2.csv'
output_csv = 'step4_missing.csv'
merge_and_filter_final(step3_csv, s4_csv, prot_csv, af2_csv,output_csv)

# Step 5 Amino acid modification and re-prediction

In [None]:
import pandas as pd
import random

# --- Amino acid classes ---
CLASS1 = set(["R","N","C","Q","I","K","M","F","W"])  # undesirable
CLASS2 = ["A","G","T","S","P","E","D"]               # preferred


def remove_class1(seq):
    """Remove Class1 amino acids."""
    return "".join([aa for aa in seq if aa not in CLASS1])


def replace_class1(seq):
    """Replace Class1 amino acids with random Class2 amino acids."""
    new_seq = []
    for aa in seq:
        if aa in CLASS1:
            new_seq.append(random.choice(CLASS2))
        else:
            new_seq.append(aa)
    return "".join(new_seq)


def step5_generate_modified_sequences(input_csv, output_csv):
    """
    Step 5: For each sequence, generate:
       - remove version: delete all Class1 AAs
       - replace version: replace Class1 AAs with random AA from Class2

    Output CSV:
        id,orig_seq,remove_seq,replace_seq
    """
    df = pd.read_csv(input_csv)

    rows = []
    for i, row in df.iterrows():
        seq = row["Sequence"]
        seq_remove = remove_class1(seq)
        seq_replace = replace_class1(seq)

        rows.append([
            row["id"] if "id" in row else i,
            seq,
            seq_remove,
            seq_replace
        ])

    outdf = pd.DataFrame(rows, columns=["id","orig_seq","remove_seq","replace_seq"])
    outdf.to_csv(output_csv, index=False)
    print(f"[OK] Step5 finished → {output_csv}")



In [None]:
step4_missing_file = './step4_missing.csv'
step5_missing_candidates_file = './step4_missing_candidates.csv'
step5_generate_modified_sequences(step4_missing_file, step5_missing_candidates_file)

Secondary Filtering Based on Step 4 Results

Based on the filters applied in Step 4, we perform an additional round of screening on the resulting candidate sequences. 