In [None]:
from sklearn.decomposition import PCA
import utils
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans, AgglomerativeClustering
from scipy.spatial.distance import cdist
import umap
import torch
import mrc
from lattice import Lattice
from models import HetOnlyVAE
import torch.nn as nn
import numpy as np

In [None]:
def get_nearest_point(data, query):
    """
    Find closest point in @data to @query
    Return datapoint, index
    """
    ind = cdist(query, data).argmin(axis=1)
    return data[ind], ind

# PCA + Kmeans

In [None]:
dirpath =  '/home/bml/storage/mnt/v-5aaaf3e8ff1a43a8/org/huangyue/10049/1112_10049_geom32_256mlp2_heter/'
raw = utils.load_pkl(dirpath + 'z.19.pkl')
z_mu, z_logvar = raw['z_mu'], raw['z_logvar']
pca = PCA(z_mu.shape[1])
pca.fit(z_mu)
print("Explained variance ratio:")
print(pca.explained_variance_ratio_)
pc = pca.transform(z_mu)
K=4
kmeans = KMeans(n_clusters=K, random_state=0, max_iter=10)
labels = kmeans.fit_predict(z_mu)
centers = kmeans.cluster_centers_

centers, centers_ind = get_nearest_point(z_mu, centers)
plt.scatter(pc[:,0], pc[:,1]) 
for ind in centers_ind:
    plt.scatter(pc[ind,0],pc[ind,1],c='k')

# UMAP

In [None]:
reducer = umap.UMAP()
z_embedded = reducer.fit_transform(z_mu)
plt.scatter(z_embedded[:,0], z_embedded[:,1])
for ind in centers_ind:
    plt.scatter(z_embedded[ind,0],z_embedded[ind,1],c='k')

# Agglomerative Cluster

In [None]:
K=6
Agg = AgglomerativeClustering(n_clusters=K, linkage='ward')
labels = Agg.fit_predict(z_mu)
centers = []
for i in range(K):
    centers.append(z_mu[labels==i].mean(0))
centers = np.stack(centers)
centers_near, centers_ind, res = get_nearest_point(z_mu, centers, k=1)
plt.scatter(z_embedded[:,0], z_embedded[:,1], s=.1, c=labels)
it = 0
for ind in centers_ind[:,0]:
    plt.scatter(z_embedded[ind,0],z_embedded[ind,1],c='k')
    plt.annotate('{}'.format(it), (z_embedded[ind,0],z_embedded[ind,1]))
    it +=1

# Volume Generation

In [None]:
use_cuda=torch.cuda.is_available()
device = torch.device('cpu')
in_dim=3
D = 192
Dz=D
extent = 0.5
D_sample = D
lattice = Lattice(D_sample, D_sample, extent, device=device, endpoint=False)
mask = lattice.get_sphere_mask(D//2, soft_edge=0.15*(D//2))

in_dim= D**2
activation=nn.ReLU
qlayers = 3
qdim = 256
players = 2
pdim = 256
model = HetOnlyVAE(lattice, qlayers, qdim, players, pdim, in_dim, zdim=8, enc_type='geom_ft', enc_dim=32, activation=activation)
model.to(device)
Apix=1.23*D/D_sample
epochs=[39]
with torch.no_grad():
    for ind in centers_ind:
        for epoch in epochs:
            print('Generating epoch {}'.format(epoch))
            ckpt = torch.load(dirpath+'/weights.{}.pkl'.format(epoch),map_location=device)
            model.load_state_dict(ckpt["model_state_dict"])
            model.eval()
            vol_recon, _, _ = model(coords=lattice.coords[None], mask=mask[None]>0, z=torch.tensor(z_mu[ind:ind+1],device=device))
            mrc.write(dirpath+'vol_it{:03d}_no{}.{}.mrc'.format(epoch, ind, D_sample),vol_recon.squeeze().cpu().numpy(), Apix, is_vol=True)