In [56]:
import numpy as np
from numpy.random import randn, permutation, seed
from numpy.linalg import norm
from scipy.spatial.distance import pdist, squareform
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt
import time
import pandas as pd
from functools import partial
import seaborn as sns
from matplotlib.backends.backend_pdf import PdfPages
from sklearn.cluster import KMeans, MeanShift
from numpy import ndarray 

import sys
import json

In [57]:
%matplotlib inline

In [58]:
dataset = pd.read_csv('../data-precs_o-non-redundant.csv')

  dataset = pd.read_csv('../data-precs_o-non-redundant.csv')


In [59]:
coords = [np.array(json.loads(x)) for x in dataset.backbone]
unp_id = list(dataset.unp_id)
unp_idx = list(k for k in dataset.unp_idx)
pdb_id = list(dataset.pdb_id)
pdb_idx = list(k for k in dataset.res_id)
res_name = list(k for k in dataset.name)
phi = dataset.phi.to_numpy()
psi = dataset.psi.to_numpy()
omega = dataset.omega.to_numpy()
ss = list(dataset.secondary)

In [60]:
def calc_dihedral2(v1: ndarray, v2: ndarray, v3: ndarray, v4: ndarray):
    """
    Calculates the dihedral angle defined by four 3d points.
    This is the angle between the plane defined by the first three
    points and the plane defined by the last three points.
    Fast approach, based on https://stackoverflow.com/a/34245697/1230403
    """
    b0 = v1 - v2
    b1 = v3 - v2
    b2 = v4 - v3

    # normalize b1 so that it does not influence magnitude of vector
    # rejections that come next
    b1 /= np.linalg.norm(b1)

    # v = projection of b0 onto plane perpendicular to b1
    #   = b0 minus component that aligns with b1
    # w = projection of b2 onto plane perpendicular to b1
    #   = b2 minus component that aligns with b1
    v = b0 - np.dot(b0, b1) * b1
    w = b2 - np.dot(b2, b1) * b1

    # angle between v and w in a plane is the torsion angle
    # v and w may not be normalized but that's fine since tan is y/x
    x = np.dot(v, w)
    y = np.dot(np.cross(b1, v), w)
    return np.arctan2(y, x)

In [61]:
def cluster(X, k=20, subset=[0,1,5,6], L=None, algo='meanshift'):
    """
    Clusters coordinates into k clusters.
    """
    
    Z = X[:,subset,:]
    Z = Z.reshape(X.shape[0], -1)

    if algo == 'kmeans':
        clust = KMeans(n_clusters=k, random_state=0, verbose=1, max_iter=100, init='k-means++', n_init=1)
    elif algo == 'meanshift':
        clust = MeanShift(bandwidth=float(k), cluster_all=False, max_iter=300)
    else:
        raise Exception(f"Unimplemented {algo=}")
    
    # Random subsample
    L = L or Z.shape[0]
    L = min(Z.shape[0], L)
    idx = np.random.permutation(Z.shape[0])[0:L]
    
    clust = clust.fit(Z[idx,:])
    labels = clust.predict(Z)
    C = clust.cluster_centers_

    C_ = np.zeros((C.shape[0],*X.shape[1:]))
    lin_idx = np.array([*range(X.shape[0])])
    ind = []
    for l in range(labels.max()+1):
        index = (labels == l)
        z = Z[index,:]
        d2 = ((C[l,:] - z)**2).sum(axis=1)
        j = np.argmin(d2)
        C_[l,:,:] = X[j,:,:]
        ind.append(lin_idx[index][j])
        
    return labels, C_, ind

In [62]:
all_angles = np.vstack([
    np.array([omega[k], phi[k], psi[k], omega[k+1], phi[k+1], psi[k+1]])
    for k in range(len(omega)-1)
])

In [63]:
marg = 5
exc_res = []
exc_pdb = ['4N6V']

In [64]:
IDX = [
    k
    for k in range(marg+1, len(coords)-marg-1)
    if all(
        tuple(c.shape) == (4,3)
        for c in coords[k-marg-1:k+marg+1]
    ) and all(
        unp_id[k] == unp_id[k+m] == unp_id[k-1] for m in range(1, marg+1)
    ) and all(
        unp_idx[k] == unp_idx[k+m] - m and unp_idx[k-m] + m for m in range(1, marg+1)
    )
    and not np.isnan(all_angles[k-marg:k+marg,:]).any()
    and res_name[k] not in exc_res and res_name[k+1] not in exc_res
    and pdb_id[k].split(':')[0] not in exc_pdb
]

In [65]:
res_triplet = [
    res_name[k-1] + res_name[k] + res_name[k+1]
    for k in IDX
]

ss_triplet = [
    ss[k-1] + ss[k] + ss[k+1]
    for k in IDX
]

In [66]:
res_seq = [
    ''.join(res_name[k+m] for m in range(-marg-1, marg+1))
    for k in IDX
]

In [67]:
unp_id = [unp_id[k] for k in IDX]
unp_idx = [unp_idx[k] for k in IDX]
pdb_id = [pdb_id[k] for k in IDX]
pdb_idx = [pdb_idx[k] for k in IDX]
res_name = [res_name[k] for k in IDX]

In [68]:
all_angles = all_angles[IDX,:]
angles = all_angles[:,[1,2,4,5]]

In [69]:
res = np.array(res_name)

In [70]:
ss = np.array([s[0:2] for s in ss_triplet])

In [71]:
def canonize_batch(X: np.ndarray):
    #X = np.vstack(coords)
    X = X - X[:,2,:][:,None,:]
    
    e1 = X[:,4,:]-X[:,2,:]
    e1 = e1/np.linalg.norm(e1, axis=1)[:,None]
    
    e3 = np.cross(X[:,2,:]-X[:,4,:], X[:,5,:]-X[:,4,:], axis=1)
    e3 = e3/np.linalg.norm(e3, axis=1)[:,None]
    
    e2 = np.cross(e3, e1, axis=1)
    e2 = e2/np.linalg.norm(e2, axis=1)[:,None]
    
    U = np.stack([e1, e2, e3], axis=2)
    return np.einsum('nij,njk->nik', X, U) #X

In [72]:
inputs = []
for k in IDX:
    x = coords[k:k+2]
    inputs.append(np.vstack(x))
inputs = np.stack(inputs, axis=0)

In [73]:
canonical_coords = canonize_batch(inputs)

In [74]:
X = canonical_coords

In [75]:
k=4000
subset=[0,1,2,3,4,5,6,7]
labels, C_, ind = cluster(X, k=k, L=None, algo='kmeans', subset=subset)

Initialization complete
Iteration 0, inertia 102157.80687986978.
Iteration 1, inertia 91064.48215469395.
Iteration 2, inertia 89411.83428869775.
Iteration 3, inertia 88635.84818425271.
Iteration 4, inertia 88150.0409253068.
Iteration 5, inertia 87809.52577237817.
Iteration 6, inertia 87552.12928459814.
Iteration 7, inertia 87355.03394745704.
Iteration 8, inertia 87195.80983338397.
Iteration 9, inertia 87061.05404277309.
Iteration 10, inertia 86947.6467754198.
Iteration 11, inertia 86849.70732868912.
Iteration 12, inertia 86762.98773932224.
Iteration 13, inertia 86687.7331593255.
Iteration 14, inertia 86621.51556961105.
Iteration 15, inertia 86563.35860743266.
Iteration 16, inertia 86510.99020329703.
Iteration 17, inertia 86463.75950846732.
Iteration 18, inertia 86420.94932850328.
Iteration 19, inertia 86382.80026697584.
Iteration 20, inertia 86348.39798702061.
Iteration 21, inertia 86317.21814115722.
Iteration 22, inertia 86288.59840131554.
Iteration 23, inertia 86261.68993656748.
Iter

In [76]:
counts = np.array([(labels==i).sum() for i in range(k)])
rms = np.array([
    np.sqrt(((X[labels==i,:,:]-C_[i,:,:][None,:,:])**2).sum(2).mean(1).mean())
     for i in range(k)
])

In [77]:
clusters = {
    'num': [],
    'samples': [],
    'rms': [],
    'pdb_id': [],
    'pdb_idx': [],
    'res_prev': [],
    'res': [],
    'res_next': [],
    'ss_prev': [],
    'ss': [],
    'ss_next': [],
    'phi0': [],
    'psi0': [],
    'phi1': [],
    'psi1': [],
    'X': [],
}
for l, idx in enumerate(ind):
        clusters['num'].append(l)
        clusters['samples'].append((labels==l).sum())
        clusters['pdb_id'].append(pdb_id[idx])
        clusters['pdb_idx'].append(pdb_idx[idx])
        clusters['res_prev'].append(res_triplet[idx][0])
        clusters['res'].append(res_triplet[idx][1])
        clusters['res_next'].append(res_triplet[idx][2])
        clusters['ss_prev'].append(ss_triplet[idx][0])
        clusters['ss'].append(ss_triplet[idx][1])
        clusters['ss_next'].append(ss_triplet[idx][2])
        clusters['phi0'].append(all_angles[idx, 1])
        clusters['psi0'].append(all_angles[idx, 2])
        clusters['phi1'].append(all_angles[idx, 4])
        clusters['psi1'].append(all_angles[idx, 5])
        clusters['rms'].append(rms[l])
        clusters['X'].append(X[idx,:])

In [78]:
df = pd.DataFrame(clusters)
df.to_csv(f'clusters-{k//1000}K.csv', index=False)
df

Unnamed: 0,num,samples,rms,pdb_id,pdb_idx,res_prev,res,res_next,ss_prev,ss,ss_next,phi0,psi0,phi1,psi1,X
0,0,4690,1.412437,2HZL:A,312,E,A,A,H,H,H,-65.37,-42.63,-63.13,-42.72,"[[-0.020095734553920004, 2.2641227768423957, 0..."
1,1,735,1.759498,6G4J:A,53,V,A,L,E,E,E,-94.39,128.60,-115.32,119.68,"[[-1.6241520389069006, 1.3133620801362533, -1...."
2,2,556,1.280369,5EXP:A,346,F,N,S,E,-,H,-84.53,173.65,-68.24,-32.78,"[[-2.124916433684262, 1.152695995324641, -0.32..."
3,3,333,1.506222,4IA6:A,336,C,E,N,H,S,S,-111.19,144.83,64.28,33.48,"[[-1.8854035048908218, 1.4943351325819934, -0...."
4,4,227,1.195918,7CYW:A,150,T,T,T,E,S,S,-101.18,5.56,-107.70,-178.27,"[[0.2517040228824545, 2.4808253600048285, -0.1..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3995,3995,551,1.491680,3UG3:A,408,I,A,V,E,E,E,-86.54,130.05,-133.59,124.28,"[[-1.6888381542054245, 1.286453708259629, -1.1..."
3996,3996,280,1.382126,5JOD:A,107,T,D,A,E,E,-,-139.72,107.88,-115.95,21.60,"[[-1.281977182898812, 1.6326275333353828, -1.2..."
3997,3997,456,1.028167,6V42:A,309,D,L,K,S,-,T,-97.06,123.20,-88.54,157.25,"[[-1.5664116559581382, 1.5800994056285744, -0...."
3998,3998,190,1.292819,3T5T:A,396,V,N,E,H,-,S,-56.62,142.92,-116.09,20.53,"[[-1.8874519986201617, 1.3074337950955368, -0...."


In [79]:
import pickle
with open(f'clusters-{k//1000}K.pkl', 'wb') as f:
    pickle.dump(clusters, f)

In [80]:
import pickle
with open(f'labels-{k//1000}K.pkl', 'wb') as f:
    pickle.dump({
        'X': X,
        'angles': angles,
        'labels': labels,
    }, f)