In [5]:
def AChar(i, x: str):
    return i, x[0]


def _Atom(i, j, x: str):
    return i, f"%{j-i+1}s" % x


def Character(i, x: str):
    return i, x[0]


def Integer(i, j, x: int):
    return i, f"%{j-i+1}d" % x


def LString(i, j, x: str):
    return i, f"%{j-i+1}s" % x


def Real(i, j, n, m, x: float):
    return i, (f"%{n}.{m}f" % x)[: j - i + 1]


def RecordName(x: str):
    return 1, f"%-6s" % x


def ResidueName(i, j, x: str):
    return i, f"%{j-i+1}s" % x


def Line(parts):
    line = [" " for _ in range(80)]
    for i, s in parts:
        for off in range(len(s)):
            line[i - 1 + off] = s[off]
    return "".join(line)


def MODEL(serial: int):
    return Line([RecordName("MODEL"), Integer(11, 14, serial)])


def ATOM(
    serial: int,
    name: str,
    altLoc: str,
    resName: str,
    chainID: str,
    resSeq: str,
    iCode: str,
    x: float,
    y: float,
    z: float,
    occupancy: float,
    tempFactor: float,
    element: str,
    charge: str,
):
    return Line(
        [
            RecordName("ATOM"),
            Integer(7, 11, serial),
            _Atom(13, 16, name),
            Character(17, altLoc),
            ResidueName(18, 20, resName),
            Character(22, chainID),
            Integer(23, 26, resSeq),
            AChar(27, iCode),
            Real(31, 38, 8, 3, x),
            Real(39, 46, 8, 3, y),
            Real(47, 54, 8, 3, z),
            Real(55, 60, 6, 2, occupancy),
            Real(61, 66, 6, 2, tempFactor),
            LString(77, 78, element),
            LString(79, 80, charge),
        ]
    )


def CRYST1(
    a: float,
    b: float,
    c: float,
    alpha: float,
    beta: float,
    gamma: float,
    sGroup: str,
    z: int,
):
    return Line(
        [
            RecordName("CRYST1"),
            Real(7, 15, 9, 3, a),
            Real(16, 24, 9, 3, b),
            Real(25, 33, 9, 3, c),
            Real(34, 40, 7, 2, alpha),
            Real(41, 47, 7, 2, beta),
            Real(48, 54, 7, 2, gamma),
            LString(56, 66, sGroup),
            Integer(67, 70, z),
        ]
    )


def TER(serial: int, resName: str, chainID: str, resSeq: str, iCode: str):
    return Line(
        [
            RecordName("TER"),
            Integer(7, 11, serial),
            ResidueName(18, 20, resName),
            Character(22, chainID),
            Integer(23, 26, resSeq),
            AChar(27, iCode),
        ]
    )


def ENDMDL():
    return Line([RecordName("ENDMDL")])


def END():
    return Line([RecordName("END")])


In [6]:
import numpy as np
from typing import List, Optional


class Atom:
    def __init__(self):
        self.name: str = ""
        self.pos: np.ndarray = np.array([0.0, 0.0, 0.0])
        self.occupancy: float = 0.0
        self.tempFactor: float = 0.0
        self.charge: str = ' '

    @property
    def element(self):
        return self.name[0]

    def pdb(self, serial: int, chainID: str, resName: str, resSeq: int):
        lines = [
            ATOM(
                serial,
                self.name,
                " ",
                resName,
                chainID,
                resSeq,
                " ",
                *self.pos,
                self.occupancy,
                self.tempFactor,
                self.element,
                self.charge
            )
        ]
        return lines, serial + 1, resSeq


class Residue:
    def __init__(self):
        self.name: str = ""
        self.atoms: List[Atom] = []

    def pdb(self, serial: int, chainID: str, resSeq: int):
        lines = []
        for atom in self.atoms:
            atom_lines, serial, resSeq = atom.pdb(serial, chainID, self.name, resSeq)
            lines.extend(atom_lines)
        return lines, serial, resSeq


class Chain:
    def __init__(self):
        self.id: str = 'A'
        self.residues: List[Residue] = []

    def pdb(self, serial: int):
        lines = []
        resSeq = 1
        for residue in self.residues:
            res_lines, serial, resSeq = residue.pdb(serial, self.id, resSeq)
            lines.extend(res_lines)
            resSeq += 1
        lines.append(TER(serial, self.residues[-1].name, self.id, resSeq, " "))
        serial += 1
        return lines, serial


class Model:
    def __init__(self):
        self.serial: int = 1
        self.cryst1: Optional[np.ndarray] = None
        self.chains: List[Chain] = []

    @property
    def pdb(self) -> str:
        lines = []
        lines.append(MODEL(self.serial))
        if self.cryst1 is not None:
            lines.append(CRYST1(*self.cryst1, 90.0, 90.0, 90.0, "P 1", 1))

        serial = 1
        for chain in self.chains:
            chain_lines, serial = chain.pdb(serial)
            lines.extend(chain_lines)

        lines.extend([ENDMDL(), END()])

        return "\n".join(lines)


In [7]:
from scipy.spatial.transform import Rotation

RESIDUES = [
    "ALA",
    "ARG",
    "ASN",
    "ASP",
    "CYS",
    "GLU",
    "GLN",
    "GLY",
    "HIS",
    "ILE",
    "LEU",
    "LYS",
    "MET",
    "PHE",
    "PRO",
    "SER",
    "THR",
    "TRP",
    "TYR",
    "VAL",
    "PHO",
]


def random_residue():
    return np.random.choice(RESIDUES)


def partition(x: int, n: int):
    pivots = np.random.choice(x, size=n - 1, replace=False)
    pivots.sort()
    pivots = np.array([0, *pivots, x])
    return pivots[1:] - pivots[:-1]


def sample_from_box(box: np.ndarray):
    fs = np.random.rand(*box.shape)
    return box * (fs - 0.5)


def sample_from_sphere():
    v = np.random.randn(3)
    return v / np.linalg.norm(v)


def rotate(v: np.ndarray, by: float):
    while True:
        v_perp = np.cross(v, sample_from_sphere())
        if np.linalg.norm(v_perp) > 0.5 * np.linalg.norm(v):
            break

    R1 = Rotation.from_rotvec(by * v_perp)
    R2 = Rotation.from_rotvec(2.0 * np.pi * v)
    return (R2 * R1).apply(v)


def random_model(
    num_chains: int,
    num_residues: int,
    bond_distance=3.8,
    max_angle=np.pi / 3.0,
    start_box_size=10.0,
    simul_box_density=1e-4,
):
    model = Model()
    model.serial = 1

    if simul_box_density > 0.0:
        box_size = np.cbrt(num_residues / simul_box_density)
        model.cryst1 = np.array([box_size, box_size, box_size])

    init_box = np.array([start_box_size, start_box_size, start_box_size])

    min_residues_per_chain = int(0.75 * num_residues / num_chains)
    residues_per_chain = min_residues_per_chain + partition(
        num_residues - min_residues_per_chain * num_chains, num_chains
    )
    for chain_idx in range(num_chains):
        chain = Chain()
        chain.id = chr(ord("A") + chain_idx)

        cur_pos = sample_from_box(init_box)
        cur_dir = sample_from_sphere()
        for res_idx in range(residues_per_chain[chain_idx]):
            CA_atom = Atom()
            CA_atom.name = "CA"
            CA_atom.pos = cur_pos

            res = Residue()
            res.name = random_residue()
            res.atoms = [CA_atom]

            chain.residues.append(res)

            cur_pos = cur_pos + bond_distance * cur_dir
            rot_angle = np.random.rand() * max_angle
            cur_dir = rotate(cur_dir, rot_angle)

        model.chains.append(chain)

    return model


In [8]:
open('sample.pdb', 'w').write(random_model(5, 1000).pdb)

81728