In [1]:
import sys
sys.path.append('../../../')
from sklearn.metrics import silhouette_score
from glob import glob
import torch
from tifffile import imread
import numpy as np

In [2]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
patch_size = 64
n_channel = 32
hierarchy_level = 3
centre_size = 4
model = torch.load("/group/jug/Sheida/HVAE/Hyperparameter_search/fine-pond-501-cl1e-3_kl1e-1/model/Contrastive_MAE_best_vae.net")
model.mode_pred=True
model.eval()
model.to(device)

LadderVAE(
  (first_bottom_up): Sequential(
    (0): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ELU(alpha=1.0)
    (2): BottomUpDeterministicResBlock(
      (res): ResidualBlock(
        (block): Sequential(
          (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ELU(alpha=1.0)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (3): Dropout2d(p=0.2, inplace=False)
          (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ELU(alpha=1.0)
          (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (7): Dropout2d(p=0.2, inplace=False)
        )
      )
    )
  )
  (top_down_layers): ModuleList(
    (0-1): 2 x TopDownLayer(
      (deterministic_block): Sequential(
        (0): TopDownDeterministicResBlock(
          (pre_conv): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), 

In [3]:
def get_normalized_tensor(img, model, device):
    test_images = torch.from_numpy(img.copy()).to(device)
    data_mean = model.data_mean
    data_std = model.data_std
    test_images = (test_images - data_mean) / data_std
    return test_images
def load_data(dir):
    return imread(dir)
def get_mus(model, z):
    n_features = n_channel * hierarchy_level
    data = np.zeros((n_features,))
    model.mode_pred = True
    model.eval()
    with torch.no_grad():
        model.to(device)
        z = z.to(device=device, dtype=torch.float)
        z = z.reshape(1, 1, patch_size, patch_size)
        with torch.no_grad():
            sample = model(z, z, z, model_layers=[0, 1, 2])
            mu = sample["mu"]
            for i in range(hierarchy_level):
                data[i * n_channel : (i + 1) * n_channel] = get_mean_centre(mu, i)
            data = data.T.reshape(-1, n_features)
    return data
def get_mean_centre(x, i):

    if i == 3:
        return x[i][0].cpu().numpy().reshape(n_channel, -1).mean(-1)
    elif i == 4:
        return x[i][0].cpu().numpy().reshape(n_channel, -1).mean(-1)
    else:
        lower_bound = 2 ** (5 - 1 - i) - int(centre_size / 2)
        upper_bound = 2 ** (5 - 1 - i) + int(centre_size / 2)
        return (
            x[i][0]
            .cpu()
            .numpy()[:, lower_bound:upper_bound, lower_bound:upper_bound]
            .reshape(n_channel, -1)
            .mean(-1)
        )

In [4]:
data_dir = "/localscratch/testing/img/"
golgi = get_normalized_tensor(load_data(sorted(glob(data_dir+'class1/*.tif'))), model, device)
mitochondria = get_normalized_tensor(load_data(sorted(glob(data_dir+'class2/*.tif'))), model, device)
granule = get_normalized_tensor(load_data(sorted(glob(data_dir+'class3/*.tif'))), model, device)
mask_dir = "/localscratch/testing/mask/"
golgi_mask = load_data(sorted(glob(mask_dir+'class1/*.tif')))
mitochondria_mask = load_data(sorted(glob(mask_dir+'class2/*.tif')))
granule_mask = load_data(sorted(glob((mask_dir+'class3/*.tif'))))
class_type = [golgi, mitochondria, granule]
masks = [golgi_mask, mitochondria_mask, granule_mask]

In [5]:
mu = []
mus = np.array([])
for class_t in range(len(class_type)):
    for i in range(len(class_type[class_t])):
        mu.extend(get_mus(model, class_type[class_t][i]))

    mus = np.append(mus, mu).reshape(-1, 96)
    mu = []
for i in range(len(mus)):
    mus[i] = np.asarray(mus[i])

# silhouette
labels = np.array([0] * 19 + [1] * 161 + [2] * 363)
silhouette = silhouette_score(mus, labels)

In [6]:
silhouette

0.8789143979912478