In [None]:
import os
import sys
import random
import warnings
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.cluster import KMeans
from sklearn.mixture import BayesianGaussianMixture
from sklearn.metrics import accuracy_score
from sklearn.datasets import make_circles, load_iris
from sklearn import preprocessing
dir_ = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
os.chdir(dir_)
# Local modules
import utils
import prior
import transformer
import main
from sklearn.metrics import rand_score, cluster, adjusted_mutual_info_score
from sklearn.metrics.cluster import adjusted_rand_score
# Settings
matplotlib.use("TkAgg")
warnings.filterwarnings("ignore")
%matplotlib inline


In [None]:
import math

class Normalize(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = mean
        self.std = std

    def forward(self, x):
        return (x - self.mean) / self.std


class Transformer(nn.Transformer):
    def __init__(self, d_model, nhead, nhid, nlayers, in_features=1, buckets_size=100):
        super(Transformer, self).__init__(d_model=d_model, nhead=nhead, dim_feedforward=nhid, dropout=0,
                                          num_encoder_layers=nlayers)
        self.model_type = 'Transformer'
        self.src_mask = None
        self.d_model = d_model
        self.nhead = nhead
        self.nhid = nhid
        self.nlayers = nlayers
        self.embed_size = d_model
        self.linear_x = nn.Linear(in_features, d_model)
        self.decoder = nn.Linear(d_model, buckets_size)
        self.linear_num_clusters = nn.Linear(in_features, d_model)
        self.embedding = nn.Embedding(buckets_size + 1,self.embed_size)

    def _generate_mask(self, size):
        matrix = torch.zeros((size, size), dtype=torch.float32)
        matrix[:size -1, :size - 1] = 1
        matrix[size -1] = 1
        matrix = matrix.masked_fill(matrix == 0, float('-inf')).masked_fill(matrix == 1, 0)
        return matrix


    def forward(self, X,num_clusters):
        # convert features to higher dimension
        train = (self.linear_x(X)) # Shape S + 1,B, d_model
        cluster_input = torch.full((1,X.shape[1], X.shape[2]), -1, dtype=torch.float) # learns the cluster numbers
        cluster_embedding = self.linear_num_clusters(cluster_input)
        train = torch.cat((train, cluster_embedding) , dim=0)
        cluster_conditional = self.embedding(num_clusters) # 1, B, E
        train = train + cluster_conditional #broadcasting automatically done
        src_mask = self._generate_mask(train.shape[0])
        output = self.encoder(train, mask=src_mask) # S + 1, B, E
        output = self.decoder(output)
        return output[:-1], output[-1:]

In [None]:
from sklearn import preprocessing
import torch
from scipy.stats import invwishart
import random

def sort( x, y,X_true, centers):
    distances = np.linalg.norm(x, axis=1)
    sorted_indices = np.argsort(distances)
    sorted_x = x[sorted_indices]
    sorted_y = y[sorted_indices]
    sorted_X_true = X_true[sorted_indices]
    mapping = {}
    storage = set()
    curr = 0
    for i in range(len(sorted_y)):
        if len(mapping) == centers:
            break
        if sorted_y[i] not in storage:
            mapping[sorted_y[i]] = curr
            storage.add(sorted_y[i])
            curr += 1

    y_mapped = np.array([mapping[number] for number in sorted_y])
    indices = np.random.permutation(len(sorted_x))
    shuffled_x = sorted_x[indices]
    shuffled_y = y_mapped[indices]
    shuffled_X_true = sorted_X_true[indices]
    return shuffled_x, shuffled_y, shuffled_X_true


def generate_bayesian_gmm_data(
        batch_size=32,
        start_seq_len = 100,
        seq_len=500,
        num_features=2,
        min_classes=1,
        num_classes=10,
        weight_concentration_prior=1.0,  # Dirichlet prior for mixture weights
        mean_prior=0.0,  # Mean of prior over cluster means
        mean_precision_prior=0.01,  # Precision (confidence) over means
        degrees_of_freedom_prior=None,  # Degrees of freedom for Wishart prior
        covariance_prior=None,  # Scale matrix for Wishart prior
        seed=None,
        nan_frac = None,
        **kwargs
):
    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)

    #features = random.randint(2, num_features)
    features = num_features
    if covariance_prior is None:
        covariance_prior = np.eye(features)

    if degrees_of_freedom_prior is None:
        degrees_of_freedom_prior = features

    seq_len = random.randint(start_seq_len, seq_len)
    clusters_x = np.zeros((batch_size, seq_len, num_features))
    clusters_y = np.zeros((batch_size, seq_len))
    clusters_x_true = np.zeros((batch_size, seq_len, num_features))
    batch_classes = []
    for i in range(batch_size):
        n_components = np.random.randint(min_classes, num_classes + 1)

        # Sample weights from Dirichlet prior
        weights = np.random.dirichlet(np.ones(n_components) * weight_concentration_prior)

        # Assign points to components proportionally
        counts = np.random.multinomial(seq_len, weights)

        # Sample cluster parameters
        means = []
        covariances = []
        for _ in range(n_components):
            Sigma = invwishart.rvs(df=degrees_of_freedom_prior, scale=covariance_prior)
            mu = np.random.multivariate_normal(np.full(features, mean_prior), Sigma / mean_precision_prior)
            means.append(mu)
            covariances.append(Sigma)
        # Sample points from GMM
        X = []
        y = []
        for k in range(n_components):
            n_k = counts[k]
            X_k = np.random.multivariate_normal(means[k], covariances[k], size=n_k)
            X.append(X_k)
            y.append(np.full(n_k, k))

        X = np.vstack(X)
        y = np.concatenate(y)
        X = np.pad(X, ((0, 0), (0, num_features - features)), mode='constant')

        x = preprocessing.MinMaxScaler().fit_transform(X)
        n_samples, n_features = seq_len, features

        if nan_frac:
            fraction_missing = np.random.uniform(0, nan_frac)
            n_elements = n_samples * n_features
            n_missing = int(fraction_missing * n_elements)

            if n_missing > 0:
                all_indices = np.arange(n_elements)
                np.random.shuffle(all_indices)
                chosen = all_indices[:n_missing]

                rows, cols = np.unravel_index(chosen, (n_samples, n_features))

                # Ensure no row is fully NaN
                full_rows = [r for r in np.unique(rows) if np.sum(rows == r) == n_features]
                to_keep = []
                for r in full_rows:
                    row_mask = rows == r
                    row_positions_in_chosen = np.where(row_mask)[0]
                    keep_pos = np.random.choice(row_positions_in_chosen)
                    to_keep.append(keep_pos)
                # Remove all "to_keep" indices from chosen at once
                chosen = np.delete(chosen, to_keep)
                rows, cols = np.unravel_index(chosen, (n_samples, n_features))
                x[rows, cols] = -2

        x, y, X_true = sort(x, y, X, n_components)
        clusters_x[i] = x
        clusters_y[i] = y
        clusters_x_true[i] = X_true
        batch_classes.append(n_components)

    clusters_x_true = torch.tensor(clusters_x_true, dtype=torch.float32).permute(1, 0, 2)
    clusters_x = torch.tensor(clusters_x, dtype=torch.float32).permute(1, 0, 2)
    clusters_y = torch.tensor(clusters_y, dtype=torch.float32).permute(1, 0)
    batch_classes = torch.tensor(batch_classes).unsqueeze(0)

    return clusters_x, clusters_y, clusters_x_true, batch_classes


def get_clusters_bayesian_gmm(X,batch_classes=None, n_components=10,weight_concentration_prior=1.0,
                            mean_prior=None, mean_precision_prior=None,degrees_of_freedom_prior=None,
        covariance_prior=None, random_state=42, n_init = None):
    batch_size = X.shape[1]
    clusters = [] 
    time_arr = []
    labels = [] 
    probs = [] 
    for batch in range(batch_size):
        time_tot = 0 
        X_curr = X[:,batch,:]
        mask = (X_curr != 0).any(axis=0)  # Boolean mask for columns that are not all zero
        X_curr = X_curr[:, mask]
        features = X_curr.shape[1] 
        curr_mean_prior = np.zeros(features)
        curr_covariance_prior = np.eye(features)
        curr_degrees_of_freedom_prior = features
        best_elbo = float('-inf') 
        best_cluster = -1 
        best_prob = [] 
        best_label = [] 
        comps = n_components  
        for component in range(1, comps + 1):
            model = BayesianGaussianMixture(n_components=component,
                                weight_concentration_prior_type='dirichlet_distribution',
                                weight_concentration_prior=weight_concentration_prior,
                                mean_prior=curr_mean_prior, mean_precision_prior=mean_precision_prior,
                                n_init = n_init, 
                                degrees_of_freedom_prior=curr_degrees_of_freedom_prior,max_iter=5000,
                                covariance_prior=curr_covariance_prior,random_state=42) 
            start_time = time.time() 
            model.fit(X_curr) 
            end_time = time.time() 
            time_tot += end_time - start_time
            lower_bound = model.lower_bound_
            if lower_bound - np.log(math.factorial(component)) > best_elbo:
                best_elbo = lower_bound - np.log(math.factorial(component))
                best_cluster = component
                best_label = model.predict(X_curr) 
                best_prob = model.predict_proba(X_curr) 
        time_arr.append(time_tot)
        clusters.append(best_cluster)
        labels.append(best_label) 
        probs.append(best_prob)
    return clusters, time_arr, labels, probs

def get_nll_transformer(model, X, true_labels,batch_classes, n_components=5, condition=False):
    batch_size = true_labels.shape[1]
    criterion = nn.NLLLoss()
    losses = []
    for batch_index in range(batch_size):
        train_x = X[:, batch_index].unsqueeze(1)
        train_y = true_labels[:, batch_index]
        train_y = train_y.long()
        batch = batch_classes[batch_index].unsqueeze(-1)
        logits, cluster_output = model(train_x, torch.full(batch.shape,0, dtype=torch.long))
        cluster_output = cluster_output.view(-1, cluster_output.shape[2])
        predicted_index = cluster_output.argmax(dim=1).long().item()
        logits, cluster_output = model(train_x, torch.full(batch.shape, predicted_index + 1, dtype=torch.long))
        logits = logits.squeeze(1)

        predicted_probs = F.log_softmax(logits, dim=-1)
        num_true_classes = len(torch.unique(train_y))
        best_loss = float("inf")
        all_perms = list(itertools.permutations(range(n_components), num_true_classes))
        for perm in all_perms:
            permuted_part = torch.as_tensor(predicted_probs[:, perm])
            remaining_indices = [i for i in range(n_components) if i not in perm]
            remaining_part = torch.as_tensor(predicted_probs[:, remaining_indices])
            new_tensor = torch.cat([permuted_part, remaining_part], dim=1)
            loss_val = criterion(new_tensor, train_y).item()
            best_loss = min(best_loss, loss_val)
        losses.append(best_loss)
    return losses

In [None]:
d_model, nhead, nhid, nlayers = 256, 4, 512, 4
in_features = 2
num_features = in_features
num_outputs = 10
num_classes =8 
batch_size = 1
mean_precision_prior = 0.1
weight_concentration_prior = 1.0
start_seq_len = 100
seq_len = 500
iterations = 20
start_seed  = 0
nr = 1
n_init = 1
seed = 0
model_path = "models/models_original/pfn_hard_2D.pt"

model = Transformer(d_model, nhead, nhid, nlayers, in_features=in_features,
                                buckets_size=num_outputs)
checkpoint = torch.load(model_path, weights_only=True, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print('')

In [None]:
%matplotlib inline
import itertools 

def compute_correct_clusters(batch_size = 100, iterations = 1000,start_seq_len = 100, 
                             seq_len=500,num_features=2, min_classes=1, num_classes=10, weight_concentration_prior=1,
                             mean_precision_prior = None, n_init = None, seed = 0):
    bayes = [] 
    clusters_gmm = [] 
    ami_gmm = [] 
    ari_gmm = [] 
    purity_gmm = [] 
    r_model = [] 
    a_model = [] 
    p_model = [] 
    clusters_model = [] 
    nll_gmm = [] 
    nll_transformer = [] 
    for iteration in range(iterations):
        train_X, train_Y, X_true, batch_classes = generate_bayesian_gmm_data(batch_size=batch_size,start_seq_len = start_seq_len,
                                                                               seq_len=seq_len, num_features=num_features,min_classes=min_classes, num_classes=num_classes,
                                                                               mean_precision_prior=mean_precision_prior, weight_concentration_prior=weight_concentration_prior,
                                                                               seed = seed + iteration)
        train_Y = train_Y.cpu() 

        # with torch.no_grad():
        #     logits, cluster_output = model(train_X,torch.full(batch_classes.shape, 0, dtype=torch.long))
        # cluster_input = torch.argmax(cluster_output, dim=-1) + 1
        with torch.no_grad():
            logits, cluster_output = model(train_X, batch_classes)
        predictions = torch.argmax(logits, dim=-1).cpu().numpy()
        predictions_cluster_count = (torch.argmax(cluster_output, dim=-1) + 1).squeeze(0) 
        
        # Model metrics
        rand_index_score = utils.compute_external_metrics(train_Y, predictions, 'rand_index')
        ami_score = utils.compute_external_metrics(train_Y, predictions, 'ami')
        purity_score = utils.compute_external_metrics(train_Y, predictions, 'purity')

        r_model.extend(rand_index_score)
        a_model.extend(ami_score)
        p_model.extend(purity_score)
        batch_classes = batch_classes.squeeze(0).detach()
        bayes_cluster_prediction,time_arr,labels,probabilities = get_clusters_bayesian_gmm(X_true,batch_classes, mean_precision_prior=mean_precision_prior,
                                                           weight_concentration_prior=weight_concentration_prior, n_init = n_init)
        
        nll_t = get_nll_transformer(model, train_X, train_Y,batch_classes=batch_classes, n_components = num_classes, condition=True)
        nll_transformer.extend(nll_t) 
        true_labels = [train_Y[:, i].to(torch.int).tolist() for i in range(train_Y.shape[1])]
        for i in range(len(bayes_cluster_prediction)):
            curr_true_label = true_labels[i]
            curr_true_cluster = batch_classes[i] 
            curr_prediction_cluster = bayes_cluster_prediction[i]
            curr_prob_prediction = probabilities[i] 
            curr_label_prediction = labels[i]
            clusters_gmm.append(curr_true_cluster == curr_prediction_cluster)
            clusters_model.append(predictions_cluster_count[i] == curr_true_cluster)
            ari_gmm.append(adjusted_rand_score(curr_true_label, curr_label_prediction))
            ami_gmm.append(adjusted_mutual_info_score(curr_true_label, curr_label_prediction))
            matrix = cluster.contingency_matrix(curr_true_label, curr_label_prediction)
            purity_gmm.append(np.sum(np.amax(matrix, axis=0)) / np.sum(matrix))
            target = torch.tensor(curr_true_label, dtype=torch.long)
            criterion = nn.NLLLoss()
            if curr_prediction_cluster < curr_true_cluster:
                nll_gmm.append(float("inf"))
            else:
                best_loss = float("inf")
                all_perms = list(itertools.permutations(range(curr_prediction_cluster), int(curr_true_cluster)))
                for perm in all_perms:
                    permuted = torch.as_tensor(curr_prob_prediction[:, perm])
                    log_t = torch.log(permuted)
                    loss = criterion(log_t, target).item()
                    if loss < best_loss:
                        best_loss = loss
                nll_gmm.append(best_loss)
    return {
        'model': {
            'clusters': np.array(clusters_model),
            'rand_index': np.array(r_model),
            'ami': np.array(a_model),
            'purity': np.array(p_model),
            'nll': np.array(nll_transformer)
        },
        'gmm': {
            'clusters': np.array(clusters_gmm), 
            'rand_index': np.array(ari_gmm),
            'ami': np.array(ami_gmm),
            'purity': np.array(purity_gmm),
            'nll' : np.array(nll_gmm)
        }
    }

In [None]:
results = compute_correct_clusters(batch_size=batch_size,iterations=iterations,start_seq_len = start_seq_len, 
                                  seq_len=seq_len,num_features=num_features,num_classes = num_classes,
                                  weight_concentration_prior = weight_concentration_prior,
                                  mean_precision_prior=mean_precision_prior, n_init=n_init,seed = seed)

# Cluster-PFN results
clusters_pfn = results['model']['clusters']
rand_index_pfn = results['model']['rand_index']
ami_pfn = results['model']['ami']
purity_pfn = results['model']['purity']
nll_pfn = results['model']['nll']

# GMM results
clusters_gmm = results['gmm']['clusters']
rand_index_gmm = results['gmm']['rand_index']
ami_gmm = results['gmm']['ami']
purity_gmm = results['gmm']['purity']
nll_gmm = results['gmm']['nll']


In [None]:
def compute_violin_plots_external_metrics(ami_pfn, ami_bayes, ari_pfn, ari_bayes, purity_pfn, purity_bayes):
    df = pd.DataFrame({
        'AMI': np.concatenate([ami_pfn, ami_bayes]),
        'ARI': np.concatenate([ari_pfn, ari_bayes]),
        'Purity': np.concatenate([purity_pfn, purity_bayes]),
        'Model': ['PFN'] * len(ami_pfn) + ['Bayesian GMM VI'] * len(ami_pfn)
    })

    # Melt into long form for Seaborn
    df_melted = df.melt(id_vars='Model', var_name='Metric', value_name='Score')

    # Summary (optional)
    summary = df_melted.groupby(['Model', 'Metric'])['Score'].agg(['mean', 'median', 'std']).reset_index()
    # print(summary)

    # Plot violin plots
    plt.figure(figsize=(10, 6))
    sns.violinplot(
        data=df_melted, 
        x='Metric', 
        y='Score', 
        hue='Model', 
        split=True, 
        inner=None, 
        cut = 0,
        palette = 'tab10'
    )

    plt.title(f"Cluster-PFN vs Bayesian GMM VI external metrics", fontsize=16)
    plt.ylim(0.4, 1.01)
    plt.xticks(fontsize=16)
    plt.xlabel("Metric", fontsize=16)
    plt.ylabel("Score", fontsize=16)
    plt.legend(loc='lower right')
    plt.tight_layout()
    plt.show()
    return None

In [None]:
print(f"pfn cluster acccuracy: {np.sum(clusters_pfn) / len(clusters_pfn)}")
print(f"VI cluster acccuracy: {np.sum(clusters_gmm) / len(clusters_pfn)}")


In [None]:
compute_violin_plots_external_metrics(ami_pfn, ami_gmm, rand_index_pfn, rand_index_gmm, purity_pfn, purity_gmm)