In [None]:
import numpy as np
from sklearn.cluster import AgglomerativeClustering, KMeans
from scipy.cluster.hierarchy import dendrogram
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
AVAILABLE_DATASETS = ['ST14', 'KLKB1', 'TMPRSS11D', 'TMPRSS6', 'TMPRSS2']

In [None]:
cid_to_fingerprint = {}
for ds in AVAILABLE_DATASETS:
    x = np.load(f'../dumps/{ds}_morgan_fingerprints.npz')
    for cid in x:
        if cid in cid_to_fingerprint:
            assert np.array_equal(x[cid], cid_to_fingerprint[cid])
        cid_to_fingerprint[cid] = x[cid]

In [None]:
X = np.array(list(cid_to_fingerprint.values()))
N = len(X)

In [None]:
def jaccard(X):
    N = X.shape[0]
    X_jaccard = np.empty((N, N))
    for i in range(N):
        for j in range(N):
            X_jaccard[i, j] = _jaccard(X[i], X[j])
    return X_jaccard

def _jaccard(x,y):
    return np.double(np.bitwise_and(x, y).sum()) / np.double(np.bitwise_or(x, y).sum())


In [None]:
X_jaccard = jaccard(X)

In [None]:
plt.matshow(X_jaccard)

In [None]:
kmeans = KMeans(n_clusters=8)

In [None]:
X_kmeans_clusters = kmeans.fit_predict(X)

In [None]:
X_sorted = X[np.argsort(X_kmeans_clusters)]

In [None]:
X_sorted_jaccard = jaccard(X_sorted)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.matshow(X_sorted_jaccard)