In [None]:
import os
import sys
import random
import warnings
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.cluster import KMeans
from sklearn.mixture import BayesianGaussianMixture
from sklearn.mixture import GaussianMixture

from sklearn.metrics import accuracy_score
from sklearn.datasets import make_circles, load_iris
from sklearn import preprocessing
dir_ = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
os.chdir(dir_)
# Local modules
import utils
import prior
import transformer
import main
from sklearn.metrics import rand_score, cluster, adjusted_mutual_info_score 
from sklearn.metrics.cluster import adjusted_rand_score

# Settings
matplotlib.use("TkAgg")
warnings.filterwarnings("ignore")
%matplotlib inline


In [None]:
import math

class Normalize(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = mean
        self.std = std

    def forward(self, x):
        return (x - self.mean) / self.std


class Transformer(nn.Transformer):
    def __init__(self, d_model, nhead, nhid, nlayers, in_features=1, buckets_size=100):
        super(Transformer, self).__init__(d_model=d_model, nhead=nhead, dim_feedforward=nhid, dropout=0,
                                          num_encoder_layers=nlayers)
        self.model_type = 'Transformer'
        self.src_mask = None
        self.d_model = d_model
        self.nhead = nhead
        self.nhid = nhid
        self.nlayers = nlayers
        self.embed_size = d_model
        self.linear_x = nn.Linear(in_features, d_model)
        self.decoder = nn.Linear(d_model, buckets_size)
        self.linear_num_clusters = nn.Linear(in_features, d_model)
        self.embedding = nn.Embedding(buckets_size + 1,self.embed_size)

    def _generate_mask(self, size):
        matrix = torch.zeros((size, size), dtype=torch.float32)
        matrix[:size -1, :size - 1] = 1
        matrix[size -1] = 1
        matrix = matrix.masked_fill(matrix == 0, float('-inf')).masked_fill(matrix == 1, 0)
        return matrix


    def forward(self, X,num_clusters):
        # convert features to higher dimension
        train = (self.linear_x(X)) # Shape S + 1,B, d_model
        cluster_input = torch.full((1,X.shape[1], X.shape[2]), -1, dtype=torch.float) # learns the cluster numbers
        cluster_embedding = self.linear_num_clusters(cluster_input)
        train = torch.cat((train, cluster_embedding) , dim=0)
        cluster_conditional = self.embedding(num_clusters) # 1, B, E
        train = train + cluster_conditional #broadcasting automatically done
        src_mask = self._generate_mask(train.shape[0])
        output = self.encoder(train, mask=src_mask) # S + 1, B, E
        output = self.decoder(output)
        return output[:-1], output[-1:]

In [None]:
from sklearn import preprocessing
import torch
from scipy.stats import invwishart
import random

def sort( x, y,X_true, centers):
    distances = np.linalg.norm(x, axis=1)
    sorted_indices = np.argsort(distances)
    sorted_x = x[sorted_indices]
    sorted_y = y[sorted_indices]
    sorted_X_true = X_true[sorted_indices]
    mapping = {}
    storage = set()
    curr = 0
    for i in range(len(sorted_y)):
        if len(mapping) == centers:
            break
        if sorted_y[i] not in storage:
            mapping[sorted_y[i]] = curr
            storage.add(sorted_y[i])
            curr += 1

    y_mapped = np.array([mapping[number] for number in sorted_y])
    indices = np.random.permutation(len(sorted_x))
    shuffled_x = sorted_x[indices]
    shuffled_y = y_mapped[indices]
    shuffled_X_true = sorted_X_true[indices]
    return shuffled_x, shuffled_y, shuffled_X_true


def generate_bayesian_gmm_data(
        batch_size=32,
        start_seq_len = 100,
        seq_len=500,
        num_features=2,
        min_classes=1,
        num_classes=10,
        weight_concentration_prior=1.0,  # Dirichlet prior for mixture weights
        mean_prior=0.0,  # Mean of prior over cluster means
        mean_precision_prior=0.01,  # Precision (confidence) over means
        degrees_of_freedom_prior=None,  # Degrees of freedom for Wishart prior
        covariance_prior=None,  # Scale matrix for Wishart prior
        seed=None,
        nan_frac = None,
        **kwargs
):
    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)

    #features = random.randint(2, num_features)
    features = num_features
    if covariance_prior is None:
        covariance_prior = np.eye(features)

    if degrees_of_freedom_prior is None:
        degrees_of_freedom_prior = features

    seq_len = random.randint(start_seq_len, seq_len)
    clusters_x = np.zeros((batch_size, seq_len, num_features))
    clusters_y = np.zeros((batch_size, seq_len))
    clusters_x_true = np.zeros((batch_size, seq_len, num_features))
    batch_classes = []
    for i in range(batch_size):
        n_components = np.random.randint(min_classes, num_classes + 1)

        # Sample weights from Dirichlet prior
        weights = np.random.dirichlet(np.ones(n_components) * weight_concentration_prior)

        # Assign points to components proportionally
        counts = np.random.multinomial(seq_len, weights)

        # Sample cluster parameters
        means = []
        covariances = []
        for _ in range(n_components):
            Sigma = invwishart.rvs(df=degrees_of_freedom_prior, scale=covariance_prior)
            mu = np.random.multivariate_normal(np.full(features, mean_prior), Sigma / mean_precision_prior)
            means.append(mu)
            covariances.append(Sigma)
        # Sample points from GMM
        X = []
        y = []
        for k in range(n_components):
            n_k = counts[k]
            X_k = np.random.multivariate_normal(means[k], covariances[k], size=n_k)
            X.append(X_k)
            y.append(np.full(n_k, k))

        X = np.vstack(X)
        y = np.concatenate(y)
        X = np.pad(X, ((0, 0), (0, num_features - features)), mode='constant')

        x = preprocessing.MinMaxScaler().fit_transform(X)
        n_samples, n_features = seq_len, features

        if nan_frac:
            fraction_missing = np.random.uniform(0, nan_frac)
            n_elements = n_samples * n_features
            n_missing = int(fraction_missing * n_elements)

            if n_missing > 0:
                all_indices = np.arange(n_elements)
                np.random.shuffle(all_indices)
                chosen = all_indices[:n_missing]

                rows, cols = np.unravel_index(chosen, (n_samples, n_features))

                # Ensure no row is fully NaN
                full_rows = [r for r in np.unique(rows) if np.sum(rows == r) == n_features]
                to_keep = []
                for r in full_rows:
                    row_mask = rows == r
                    row_positions_in_chosen = np.where(row_mask)[0]
                    keep_pos = np.random.choice(row_positions_in_chosen)
                    to_keep.append(keep_pos)
                # Remove all "to_keep" indices from chosen at once
                chosen = np.delete(chosen, to_keep)
                rows, cols = np.unravel_index(chosen, (n_samples, n_features))
                x[rows, cols] = -2

        x, y, X_true = sort(x, y, X, n_components)
        clusters_x[i] = x
        clusters_y[i] = y
        clusters_x_true[i] = X_true
        batch_classes.append(n_components)

    clusters_x_true = torch.tensor(clusters_x_true, dtype=torch.float32).permute(1, 0, 2)
    clusters_x = torch.tensor(clusters_x, dtype=torch.float32).permute(1, 0, 2)
    clusters_y = torch.tensor(clusters_y, dtype=torch.float32).permute(1, 0)
    batch_classes = torch.tensor(batch_classes).unsqueeze(0)

    return clusters_x, clusters_y, clusters_x_true, batch_classes


def get_clusters_bayesian_gmm(X,batch_classes=None, n_components=10,weight_concentration_prior=1.0,mean_prior=None,
                              mean_precision_prior=None,degrees_of_freedom_prior=None,
                            covariance_prior=None, random_state=42, n_init = None):
    X_curr = X
    features = X_curr.shape[1] 
    curr_mean_prior = np.zeros(features)
    curr_covariance_prior = np.eye(features)
    curr_degrees_of_freedom_prior = features
    best_elbo = float('-inf') 
    best_cluster = -1 
    best_prob = [] 
    best_label = [] 
    time_tot = 0 
    comps = n_components  
    for component in range(1, comps + 1):
        model = BayesianGaussianMixture(n_components=component,
                            weight_concentration_prior_type='dirichlet_distribution',
                            weight_concentration_prior=weight_concentration_prior,
                            mean_prior=curr_mean_prior, mean_precision_prior=mean_precision_prior,
                            n_init = n_init, 
                            degrees_of_freedom_prior=curr_degrees_of_freedom_prior,max_iter=200,
                            covariance_prior=curr_covariance_prior,random_state=42) 
        start_time = time.time() 
        model.fit(X_curr) 
        time_tot += time.time() - start_time
        lower_bound = model.lower_bound_
        if lower_bound - np.log(math.factorial(component)) > best_elbo:
            best_elbo = lower_bound - np.log(math.factorial(component))
            best_cluster = component
            best_label = model.predict(X_curr) 
            best_prob = model.predict_proba(X_curr) 
    return best_cluster, best_label, best_prob, time_tot


def match_mask(train_X, X_true):
    # Copy to avoid modifying original
    masked_X = X_true.clone()
    
    # Wherever train_X == -2, set X_true to NaN
    masked_X[train_X == -2] = float('nan')
    
    return masked_X

def apply_mask(train_X, miss_level, rng):
    n_features = train_X.shape[1] 
    n_samples = train_X.shape[0]
    corrupted = np.zeros_like(train_X, dtype=bool)
    for f in range(n_features):
        # Pick random fraction for this feature
        n_mask = int(miss_level * n_samples)

        # Randomly choose indices for this feature
        idx = rng.choice(n_samples, size=n_mask, replace=False)
        corrupted[idx, f] = True

    fully_corrupted = np.where(np.all(corrupted, axis=1))[0]
    for f in fully_corrupted:
        # randomly "unmask" one feature so not all are -2
        f_restore = rng.integers(0, n_features)
        corrupted[f, f_restore] = False
    train_X[corrupted] = -2

    return train_X


def generate_progressive_masks(n_samples, n_features, miss_levels, rng):
    """
    Generates progressive masks where the *total* fraction of masked entries
    matches each miss_level (not per feature).
    
    Returns a dict: {miss_level: boolean mask of shape (n_samples, n_features)}
    """
    masks = {}
    base_mask = np.zeros((n_samples, n_features), dtype=bool)  # start with no missing

    total_entries = n_samples * n_features

    for miss_level in miss_levels:
        # total number of masked entries desired
        n_to_mask_total = int(miss_level * total_entries)
        
        # start from previous mask
        mask = base_mask.copy()
        
        # how many new entries we still need to add
        current_masked = mask.sum()
        n_new_to_mask = max(0, n_to_mask_total - current_masked)
        
        if n_new_to_mask > 0:
            available_indices = np.argwhere(~mask)  # coordinates of available entries
            selected = rng.choice(len(available_indices), size=n_new_to_mask, replace=False)
            chosen_coords = available_indices[selected]
            
            for r, c in chosen_coords:
                mask[r, c] = True

        # ensure no row is fully masked
        fully_masked_rows = np.where(np.all(mask, axis=1))[0]
        for row in fully_masked_rows:
            restore_feature = rng.integers(0, n_features)
            mask[row, restore_feature] = False

        masks[miss_level] = mask
        base_mask = mask.copy()  # progressively build

    return masks
    

def purity_score(y_true, y_pred):
    """Compute purity score given true and predicted labels."""
    matrix = cluster.contingency_matrix(y_true, y_pred)
    return (np.sum(np.amax(matrix, axis=0)) / np.sum(matrix)) 


def plot_scores(scores, metric, title, ylabel):
    plt.figure(figsize=(10, 6))
    score_means = {m: [] for m in scores.keys()}
    score_stderrs = {m: [] for m in scores.keys()}

    miss_levels_sorted = sorted({ml for v in scores.values() for ml in v[metric].keys()})

    for m in scores.keys():
        for miss_level in miss_levels_sorted:
            vals = np.array(scores[m][metric].get(miss_level, []))
            if len(vals) > 0:
                score_means[m].append(vals.mean())
                score_stderrs[m].append(vals.std(ddof=1) / np.sqrt(len(vals)))
            else:
                score_means[m].append(np.nan)
                score_stderrs[m].append(np.nan)

    for m in scores.keys():
        plt.errorbar(
            miss_levels_sorted,
            score_means[m],
            yerr=score_stderrs[m],
            marker='o',
            markersize=6,
            capsize=5,
            linewidth=2.2,
            elinewidth=1.5,
            capthick=1.5,
            label=m
        )

    plt.xlabel("Missingness level", fontsize=12)
    plt.ylabel(ylabel, fontsize=12)
    plt.title(title, fontsize=14)
    plt.legend(fontsize=11)
    plt.grid(True, alpha=0.4, linestyle="--")
    plt.tight_layout()
    plt.show()
from sklearn.decomposition import PCA

def apply_pca_5d(X):
    """Reduce dataset to 5 principal components."""
    pca = PCA(n_components=5)
    return pca.fit_transform(X)

def impute(X, strategy="mean"):
    X_imputed = X.clone()
    
    for col in range(X.shape[1]):
        col_data = X[:, col]
        mask = torch.isnan(col_data)
        
        if mask.any():
            if strategy == "mean":
                fill_value = torch.nanmean(col_data)
            elif strategy == "median":
                fill_value = torch.nanmedian(col_data)
            else:
                raise ValueError("strategy must be 'mean' or 'median'")
            
            X_imputed[mask, col] = fill_value
    
    return X_imputed

def get_labels_pfn(model, train_X):
    start_time = time.time()
    with torch.no_grad():
        logits, cluster_output = model(train_X,torch.full((1,1), 0, dtype=torch.long))
        cluster_input = torch.argmax(cluster_output, dim=-1) + 1
        with torch.no_grad():
            logits, cluster_output = model(train_X, torch.full((1,1), cluster_input.item(), dtype=torch.long))
        predictions = torch.argmax(logits, dim=-1).cpu().numpy()
        predictions_cluster_count = (torch.argmax(cluster_output, dim=-1) + 1).squeeze(0) 
        end_time = time.time() - start_time
        return predictions, predictions_cluster_count, end_time


def get_labels_gmm(X,batch_class,n_init = None, random_state = 42):
    best_aic, best_bic = float('inf') , float('inf')
    aic_labels, bic_labels = [], [] 
    X_curr = X
    time_tot = 0 
    for component in range(1, 11):
        model = GaussianMixture(n_components = component,n_init=n_init, reg_covar=1e-04, random_state=random_state)
        start_time = time.time() 
        model.fit(X_curr)
        time_tot += time.time() - start_time
        predictions = model.predict(X_curr)
        current_bic  = model.bic(X_curr)
        current_aic  = model.aic(X_curr)

        if current_aic < best_aic:
            best_aic = current_aic 
            aic_labels = predictions

        if current_bic < best_bic:
            best_bic = current_bic 
            bic_labels = predictions

    return aic_labels, bic_labels, time_tot

In [None]:
d_model, nhead, nhid, nlayers = 256, 4, 512, 4
n_init = 10
num_features = 5
num_outputs = 10
mean_precision_prior = 0.1 
weight_concentration_prior = 1.0
miss_levels=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,1] 
num_iterations = 20
start_seed = 3000
exp = f"gls2_pfn_{start_seed}"

model_path_10 = "models\models_missingness\pfn_easy_5D.pt"
model_path_20 = "models\models_missingness\pfn_hard_5D.pt"

pfn_10 = Transformer(d_model, nhead, nhid, nlayers, in_features=num_features,buckets_size=num_outputs)
pfn_20 = Transformer(d_model, nhead, nhid, nlayers, in_features=num_features,buckets_size=num_outputs)

checkpoint_10 = torch.load(model_path_10, weights_only=True, map_location=torch.device('cpu'))
checkpoint_20 = torch.load(model_path_20, weights_only=True, map_location=torch.device('cpu'))
pfn_10.load_state_dict(checkpoint_10['model_state_dict'])
pfn_20.load_state_dict(checkpoint_20['model_state_dict'])
pfn_10.eval() 
pfn_20.eval()
start_time = time.time() 
print('')


In [None]:
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import MinMaxScaler
##################################################### data RNA #####################################
data_rna = (pd.read_csv("datasets/data_top_20_RNA.csv")).to_numpy()
data_rna = apply_pca_5d(data_rna)

labels_rna = (pd.read_csv("datasets/labels_RNA.csv")).to_numpy()
labels_str = labels_rna[:, 1]  # extract the second column (labels)
le = LabelEncoder()
labels_rna = le.fit_transform(labels_str)


scaler = MinMaxScaler()
train_X = torch.from_numpy(scaler.fit_transform(data_rna)).float() 
X_true = torch.from_numpy(data_rna).float()
batch_class = 5 
train_Y = labels_rna


######################################## Data GLS  #################################################
# data_gls = pd.read_csv("datasets/GWAS_cl_data1.csv")
# data_gls = data_gls.drop(columns=['entropy', 'Unnamed: 0'])
# le = LabelEncoder()
# labels = le.fit_transform(data_gls['Cluster'].values)

# X_data_gls  = (data_gls.drop(columns=['Cluster']).to_numpy())
# train_Y = torch.tensor(labels, dtype=torch.long)


# scaler = MinMaxScaler()
# train_X = torch.tensor(scaler.fit_transform(X_data_gls), dtype = torch.float)
# X_true = torch.tensor(X_data_gls)
# batch_class= 3
######################################## Data GLS 2 #################################################
# data_gls2 = pd.read_csv("datasets/GWAS_cl_data2.csv")
# data_gls2 = data_gls2.drop(columns=['entropy', 'Unnamed: 0'])
# le = LabelEncoder()
# labels = le.fit_transform(data_gls2['Cluster'].values)

# X_data_gls2  = (data_gls2.drop(columns=['Cluster']).to_numpy())
# train_Y = torch.tensor(labels, dtype=torch.long)


# scaler = MinMaxScaler()
# train_X = torch.tensor(scaler.fit_transform(X_data_gls2), dtype = torch.float)
# X_true = torch.tensor(X_data_gls2) 
# batch_class = 3 

#################################################### Data GLS3 ####################################
# data_gls3 = pd.read_csv("datasets/GWAS_cl_data3.csv")
# data_gls3 = data_gls3.drop(columns=['entropy', 'Unnamed: 0'])
# le = LabelEncoder()
# labels = le.fit_transform(data_gls3['Cluster'].values)

# X_data_gls3  = (data_gls3.drop(columns=['Cluster']).to_numpy())
# train_Y = torch.tensor(labels, dtype=torch.long)


# scaler = MinMaxScaler()
# train_X = torch.tensor(scaler.fit_transform(X_data_gls3), dtype = torch.float)
# X_true = torch.tensor(X_data_gls3) 
# batch_class = 3 


In [None]:
from collections import defaultdict

methods = [
    "pfn_easy", "pfn_hard", 
   "bgmm_mean", "bgmm_median", 
   "aic_mean", "aic_median", 
   "bic_mean", "bic_median",
]

methods2 = ["pfn_easy", "pfn_hard", "bgmm","gmm"]
scores = {m: {"ari": defaultdict(list),
              "ami": defaultdict(list),
              "purity": defaultdict(list)} 
          for m in methods}
time_dict = {m: defaultdict(list) for m in methods2}

for iteration in range(num_iterations): 
    rng_mask = np.random.default_rng(start_seed + iteration)
    masks = generate_progressive_masks(X_true.shape[0],X_true.shape[1], miss_levels, rng_mask)
    for miss_level in miss_levels: 
        train_X_original = train_X.clone() 
        X_true_original = X_true.clone()    
        train_X_masked = train_X_original.clone()
        train_X_masked[masks[miss_level]] = -2  # apply mask
        X_true_masked = match_mask(train_X_masked.clone(), X_true_original.clone())
        
        # zeros_pad = torch.zeros((X_data_gls.shape[0], 2), dtype=train_X.dtype) # for gls1 only
        # train_X_masked = torch.cat([train_X_masked, zeros_pad], dim=1)
        
        # imputations
        X_true_mean = impute(X_true_masked, strategy='mean')
        X_true_median = impute(X_true_masked, strategy='median')
        X_true_np = X_true_masked.detach().cpu().numpy() 


        # clustering with masked + imputed data
        labels_pfn_10, cluster_pfn_10, time_tot_pfn_easy = get_labels_pfn(pfn_10, train_X_masked.clone().unsqueeze(1)) 
        labels_pfn_10 = labels_pfn_10.squeeze()
        labels_pfn_20, cluster_pfn_20, time_tot_pfn_hard = get_labels_pfn(pfn_20, train_X_masked.clone().unsqueeze(1))
        labels_pfn_20 = labels_pfn_20.squeeze()
    
        cluster_bgmm_mean, labels_bgmm_mean, probs_bgmm_mean, time_tot_bgmm_mean = get_clusters_bayesian_gmm(
            X=X_true_mean, batch_classes=None, n_init=n_init, 
            mean_precision_prior=mean_precision_prior, weight_concentration_prior=weight_concentration_prior)
    
        cluster_bgmm_median, labels_bgmm_median, probs_bgmm_median,time_tot_bgmm = get_clusters_bayesian_gmm(
            X=X_true_median, batch_classes=None, n_init=n_init, 
            mean_precision_prior=mean_precision_prior, weight_concentration_prior=weight_concentration_prior)
    
        labels_aic_mean, labels_bic_mean, time_tot_gmm = get_labels_gmm(X_true_mean, batch_class, n_init=n_init)        
        labels_aic_median, labels_bic_median,time_tot_bic_mean = get_labels_gmm(X_true_median, batch_class, n_init=n_init)  

        time_dict['pfn_easy'][miss_level].append(time_tot_pfn_easy)
        time_dict['pfn_hard'][miss_level].append(time_tot_pfn_hard)
        time_dict['bgmm'][miss_level].append(time_tot_bgmm)
        time_dict['gmm'][miss_level].append(time_tot_gmm)
        # evaluation
        for name, labels in {
           "pfn_easy": labels_pfn_10,
           "pfn_hard": labels_pfn_20,
            "bgmm_mean": labels_bgmm_mean,
            "bgmm_median": labels_bgmm_median,
            "aic_mean": labels_aic_mean,
            "aic_median": labels_aic_median,
            "bic_mean": labels_bic_mean,
            "bic_median": labels_bic_median,
        }.items():
            scores[name]['ari'][miss_level].append(adjusted_rand_score(train_Y, labels))
            scores[name]['ami'][miss_level].append(adjusted_mutual_info_score(train_Y, labels, average_method="arithmetic"))
            scores[name]['purity'][miss_level].append(purity_score(train_Y, labels))

In [None]:
plot_scores(scores, "ari", "Clustering performance vs. Missingness (ARI)", "Adjusted Rand Index (ARI)")
plot_scores(scores, "ami", "Clustering performance vs. Missingness (AMI)", "Adjusted Mutual Information (AMI)")
plot_scores(scores, "purity", "Clustering performance vs. Missingness (Purity)", "Purity")