In [None]:
import sys
sys.path.append("../")

import torch
from torch.utils.data import DataLoader
import numpy as np
from model import MSBERT
from sklearn.cluster import Birch
from sklearn.metrics import accuracy_score
from sklearn import metrics

from utils import ModelEmbed
from const import tsne_cluster
from type import TokenizerConfig
from data import Tokenizer, TokenSequenceDataset

In [2]:
def get_label(spectra):
    smiles_seq = np.array([s.get("smiles") for s in spectra])
    unique_smiles = np.unique(smiles_seq)
    labels = np.zeros((len(spectra, )))
    for i, smiles in enumerate(unique_smiles):
        labels[smiles_seq == smiles] = i
    return labels, unique_smiles

In [3]:
def purity_score(y_true, y_pred):
    y_voted_labels = np.zeros(y_true.shape)
    labels = np.unique(y_true)
    ordered_labels = np.arange(labels.shape[0])
    for k in range(labels.shape[0]):
        y_true[y_true == labels[k]] = ordered_labels[k]
    # Update unique labels
    labels = np.unique(y_true)
    bins = np.concatenate((labels, [np.max(labels) + 1]), axis=0)

    for cluster in np.unique(y_pred):
        hist, _ = np.histogram(y_true[y_pred == cluster], bins=bins)
        # Find the most present label in the cluster
        winner = np.argmax(hist)
        y_voted_labels[y_pred == cluster] = winner

    return accuracy_score(y_true, y_voted_labels)

In [4]:
def CalEvaluate(labels_true, labels_pred):
    purity = purity_score(labels_true, labels_pred)
    ari = metrics.adjusted_rand_score(labels_true, labels_pred)
    homogeneity = metrics.homogeneity_score(labels_true, labels_pred)
    completeness = metrics.completeness_score(labels_true, labels_pred)
    v_measure = metrics.v_measure_score(labels_true, labels_pred, beta=0.5)
    result = {
        'ARI': ari,
        "purity": purity,
        'homogeneity': homogeneity,
        'completeness': completeness,
        'v_measure': v_measure
    }
    return result

In [5]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model_state_path = "/data1/xp/data/MSBert/MSBERT.pkl"
model_state = torch.load(model_state_path)

model = MSBERT(
    100002,
    512,
    6,
    16,
    0,
    100,
    3
)
model.load_state_dict(model_state)
model = model.to(device)

show_progress_bar = False
tokenizer_config = TokenizerConfig(
    max_len=100,
    n_decimals=2,
    show_progress_bar=show_progress_bar
)
tokenizer = Tokenizer(**tokenizer_config)

In [6]:
spectra = np.load(tsne_cluster.SPECEMBEDDING_CLUSTER, allow_pickle=True)
labels, unique_smiles = get_label(spectra)

sequences = tokenizer.tokenize_sequence(spectra)
dataset = TokenSequenceDataset(sequences)
loader = DataLoader(
    dataset,
    batch_size=512,
    shuffle=False
)
spectra_embedding = ModelEmbed(model, loader, device)

brc = Birch(threshold=0.5, n_clusters=len(unique_smiles))
pred_labels = brc.fit_predict(spectra_embedding)
CalEvaluate(labels, pred_labels)

{'ARI': 0.2047164225469265,
 'purity': 0.7600961256667252,
 'homogeneity': 0.8797025437129594,
 'completeness': 0.7818874624710648,
 'v_measure': 0.844487086690822}

In [7]:
spectra = np.load(tsne_cluster.MSBERT_CLUSTER, allow_pickle=True)
labels, unique_smiles = get_label(spectra)

sequences = tokenizer.tokenize_sequence(spectra)
dataset = TokenSequenceDataset(sequences)
loader = DataLoader(
    dataset,
    batch_size=512,
    shuffle=False
)
spectra_embedding = ModelEmbed(model, loader, device)
print("embedded end")

brc = Birch(threshold=0.5, n_clusters=len(unique_smiles))
pred_labels = brc.fit_predict(spectra_embedding)
CalEvaluate(labels, pred_labels)

embedded end


{'ARI': 0.23916166662338104,
 'purity': 0.8050994930770977,
 'homogeneity': 0.8850667055388773,
 'completeness': 0.7675688506401166,
 'v_measure': 0.8420978393347991}