In [None]:
from pathlib import Path
import numpy as np
from numpy.random import randn, permutation, seed
from numpy.linalg import norm
from scipy.spatial.distance import pdist, squareform
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt
import time
import pandas as pd
from functools import partial
import seaborn as sns
from matplotlib.backends.backend_pdf import PdfPages
from sklearn.cluster import KMeans, MeanShift
from numpy import ndarray

import sys
import json

In [None]:
%matplotlib inline

In [None]:
DATASET_DIR = Path(
    "../out/prec-collected/20230117_200158-aida-ex_EC-src_ALL-r1.5_s0.7-re"
)
DATASET_FILE = DATASET_DIR / "data-precs.csv"

In [None]:
dataset = pd.read_csv(DATASET_FILE)

In [None]:
coords = [np.array(json.loads(x)) for x in dataset.backbone]
unp_id = list(dataset.unp_id)
unp_idx = list(k for k in dataset.unp_idx)
pdb_id = list(dataset.pdb_id)
pdb_idx = list(k for k in dataset.res_id)
res_name = list(k for k in dataset.name)
phi = dataset.phi.to_numpy()
psi = dataset.psi.to_numpy()
omega = dataset.omega.to_numpy()
ss = list(dataset.secondary)

In [None]:
def calc_dihedral2(v1: ndarray, v2: ndarray, v3: ndarray, v4: ndarray):
    """
    Calculates the dihedral angle defined by four 3d points.
    This is the angle between the plane defined by the first three
    points and the plane defined by the last three points.
    Fast approach, based on https://stackoverflow.com/a/34245697/1230403
    """
    b0 = v1 - v2
    b1 = v3 - v2
    b2 = v4 - v3

    # normalize b1 so that it does not influence magnitude of vector
    # rejections that come next
    b1 /= np.linalg.norm(b1)

    # v = projection of b0 onto plane perpendicular to b1
    #   = b0 minus component that aligns with b1
    # w = projection of b2 onto plane perpendicular to b1
    #   = b2 minus component that aligns with b1
    v = b0 - np.dot(b0, b1) * b1
    w = b2 - np.dot(b2, b1) * b1

    # angle between v and w in a plane is the torsion angle
    # v and w may not be normalized but that's fine since tan is y/x
    x = np.dot(v, w)
    y = np.dot(np.cross(b1, v), w)
    return np.arctan2(y, x)

In [None]:
def canonize(coords, more_coords=None):
    """
    Canonizes coordinates.
    """

    X = np.vstack(coords)
    if more_coords:
        X_ = np.vstack(more_coords)

    if more_coords:
        X_ = X_ - X[2, :]
    X = X - X[2, :]

    e1 = X[4, :] - X[2, :]
    e1 = e1 / np.linalg.norm(e1)

    e3 = np.cross(X[2, :] - X[4, :], X[5, :] - X[4, :])
    e3 = e3 / np.linalg.norm(e3)

    e2 = np.cross(e3, e1)
    e2 = e2 / np.linalg.norm(e2)

    U = np.vstack([e1, e2, e3]).T

    if more_coords:
        X_ = X_ @ U
    X = X @ U

    if more_coords:
        return X, X_
    else:
        return X

In [None]:
def cluster(X, k=20, subset=[0, 1, 5, 6], L=None, algo="meanshift"):
    """
    Clusters coordinates into k clusters.
    """

    Z = X[:, subset, :]
    Z = Z.reshape(X.shape[0], -1)

    if algo == "kmeans":
        clust = KMeans(n_clusters=k, random_state=0)
    elif algo == "meanshift":
        clust = MeanShift(bandwidth=float(k), cluster_all=False, max_iter=300)
    else:
        raise Exception(f"Unimplemented {algo=}")

    # Random subsample
    L = L or Z.shape[0]
    L = min(Z.shape[0], L)
    idx = np.random.permutation(Z.shape[0])[0:L]

    clust = clust.fit(Z[idx, :])
    labels = clust.predict(Z)
    C = clust.cluster_centers_

    C_ = np.zeros((C.shape[0], *X.shape[1:]))
    lin_idx = np.array([*range(X.shape[0])])
    ind = []
    for l in range(labels.max() + 1):
        index = labels == l
        z = Z[index, :]
        d2 = ((C[l, :] - z) ** 2).sum(axis=1)
        j = np.argmin(d2)
        C_[l, :, :] = X[j, :, :]
        ind.append(lin_idx[index][j])

    return labels, C_, ind

In [None]:
all_angles = np.vstack(
    [
        np.array([omega[k], phi[k], psi[k], omega[k + 1], phi[k + 1], psi[k + 1]])
        for k in range(len(omega) - 1)
    ]
)

In [None]:
marg = 5
exc_res = []
exc_pdb = ["4N6V"]

In [None]:
IDX = [
    k
    for k in range(marg + 1, len(coords) - marg - 1)
    if all(tuple(c.shape) == (4, 3) for c in coords[k - marg - 1 : k + marg + 1])
    and all(unp_id[k] == unp_id[k + m] == unp_id[k - 1] for m in range(1, marg + 1))
    and all(
        unp_idx[k] == unp_idx[k + m] - m and unp_idx[k - m] + m
        for m in range(1, marg + 1)
    )
    and not np.isnan(all_angles[k - marg : k + marg, :]).any()
    and res_name[k] not in exc_res
    and res_name[k + 1] not in exc_res
    and pdb_id[k].split(":")[0] not in exc_pdb
]

In [None]:
res_triplet = [res_name[k - 1] + res_name[k] + res_name[k + 1] for k in IDX]

ss_triplet = [ss[k - 1] + ss[k] + ss[k + 1] for k in IDX]

In [None]:
res_seq = ["".join(res_name[k + m] for m in range(-marg - 1, marg + 1)) for k in IDX]

In [None]:
unp_id = [unp_id[k] for k in IDX]
unp_idx = [unp_idx[k] for k in IDX]
pdb_id = [pdb_id[k] for k in IDX]
pdb_idx = [pdb_idx[k] for k in IDX]
res_name = [res_name[k] for k in IDX]

In [None]:
all_angles = all_angles[IDX, :]
angles = all_angles[:, [1, 2, 4, 5]]

In [None]:
res = np.array(res_name)

In [None]:
ss = np.array([s[0:2] for s in ss_triplet])

In [None]:
canonical_coords = []
canonical_coords_full = []

for k in IDX:
    if k % 100000 == 0:
        print(k)
    x, x_ = canonize(coords[k : k + 2], coords[k - marg - 1 : k + marg + 1])
    canonical_coords.append(x)
    canonical_coords_full.append(x_)

In [None]:
X = np.array(canonical_coords)

In [None]:
k = 20
labels, C_, ind = cluster(X, k=k, L=5000, algo="kmeans")

In [None]:
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.manifold import MDS

Z = C_[:, [0, 1, 5, 6], :]
Z = Z.reshape(C_.shape[0], -1)
mds = MDS(dissimilarity="euclidean", n_components=1, random_state=0)
x = mds.fit_transform(Z)[:, 0]
idx_sort = np.argsort(x)
idx_inv = np.zeros_like(idx_sort)
for k in range(len(idx_sort)):
    idx_inv[idx_sort[k]] = k

In [None]:
C_ = C_[idx_sort, :, :]
ind = list(np.array(ind)[idx_sort])
labels = idx_inv[labels]

In [None]:
# subclusters
lin_idx = np.array([*range(X.shape[0])])
ind_sub = []
for l in range(labels.max() + 1):
    idx = labels == l
    # labels_, C__, ind_ = cluster(X[idx,:,:], k=10, algo='kmeans')
    labels_, C__, ind_ = cluster(X[idx, :, :], k=15, L=1000, algo="kmeans")

    counts = np.bincount(labels_)
    freqs = counts / sum(counts)
    mask = freqs >= 0.5 / len(counts)
    ind_ = np.array(ind_)[mask]

    ind_ = [(lin_idx[idx])[i] for i in ind_]
    assert all(lab == l for lab in labels[ind_])
    ind_sub.append(ind_)

    print(f"{l:4d}\t{len(ind_)}")

In [None]:
clusters = {
    "num": [],
    "sub": [],
    "pdb_id": [],
    "pdb_idx": [],
    "res_prev": [],
    "res": [],
    "res_next": [],
    "ss_prev": [],
    "ss": [],
    "ss_next": [],
    "phi0": [],
    "psi0": [],
    "phi1": [],
    "psi1": [],
}
for l, idx_ in enumerate(ind):
    for s, idx in enumerate([idx_, *ind_sub[l]]):
        clusters["num"].append(l)
        clusters["sub"].append(s)
        clusters["pdb_id"].append(pdb_id[idx])
        clusters["pdb_idx"].append(pdb_idx[idx])
        clusters["res_prev"].append(res_triplet[idx][0])
        clusters["res"].append(res_triplet[idx][1])
        clusters["res_next"].append(res_triplet[idx][2])
        clusters["ss_prev"].append(ss_triplet[idx][0])
        clusters["ss"].append(ss_triplet[idx][1])
        clusters["ss_next"].append(ss_triplet[idx][2])
        clusters["phi0"].append(all_angles[idx, 1])
        clusters["psi0"].append(all_angles[idx, 2])
        clusters["phi1"].append(all_angles[idx, 4])
        clusters["psi1"].append(all_angles[idx, 5])

In [None]:
df = pd.DataFrame(clusters)
df.to_csv(f"clusters_{max(labels)+1}.csv", index=False)

In [None]:
df