# Executive Summary 
We'll demonstrate how to use the class GaussianAsymmetricSBM for soft decision clustering of items which have
no attributes, but for which there is a contigency table whose values are dependent upon pairwise group membership $(r,s)$.

The class GaussianAsymmetricSBM assumes the value is sampled from a normal distribution parameterises by $(r,s)$.

We'll make an example with 90 elements from 3 groups with each group containing 30 elements.
Let's see if we can't cluster elements correctly.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from bicluster import GaussianAsymmetricSBM
import warnings
from tqdm.autonotebook import tqdm

warnings.filterwarnings('ignore')
np.set_printoptions(linewidth=200)

Here is where we generate our contigency table where we know the correct membership of the items so we can compare.

In [None]:
# --- Example Usage ---

# 1. Generate Synthetic Data with an Asymmetric Block Structure (K=3)
K = 3
N = 90
groups = np.repeat(range(K), N // K)
np.random.shuffle(groups)

A_synth = np.zeros((N, N))
mu_matrix = np.array([
    [10, 20, 5],  
    [30, 10, 40], 
    [5, 40, 10]
])
sigma_matrix = np.array([
    [1, 1, 1],
    [1, 1, 1],
    [1, 1, 1]
])

for i in range(N):
    for j in range(N):
        if i == j:
            A_synth[i, j] = 0 # Ensure diagonal is excluded / set to a dummy value
            continue

        r = groups[i]
        s = groups[j]
        A_synth[i, j] = np.random.normal(loc=mu_matrix[r, s], scale=np.sqrt(sigma_matrix[r, s]))


## Default Initialization
We have 3 different means of initializing the hill climb:

* **None** -- we use summary statistics of the contigency matrix $A$ to generate initial estimates of the parameters. We do not recommend this method.
* "kmeans" -- use Kmeans for initial hard decision initialization.
* "spectral"  -- use Scikit-Learn's SpectralBiclustering class for initial hard decision clustering to generate initial estimates of the parameters.

  
We'll start with initializing from summary statistics. What we're looking for is a split in assignment of the items into 3 groups of 30 each. Anything else is
either wrong, or **very** wrong.

In [None]:
from collections import Counter
print(f"Original group counts: {Counter(groups)}")
print("\nAccuracy check (against synthetic truth):")


for t in tqdm(range(20),desc='Attempt #'):

    # 2. Instantiate and Fit the Model
    # Start with K=3 groups
    sbm = GaussianAsymmetricSBM(K=3, max_iter=50, tol=1e-5)
    sbm.fit(A_synth)
        
    # Get hard cluster assignments from the soft assignments
    hard_clusters = np.argmax(sbm.tau_i, axis=1)
    
    # Check if the recovered structure resembles the ground truth    
    print(f"Predicted group counts: {Counter(hard_clusters)}")

## Using Kmeans initializer:
Let's see how well the Kmeans initializer works:

In [None]:
from collections import Counter
print(f"Original group counts: {Counter(groups)}")
print("\nAccuracy check (against synthetic truth):")


for t in tqdm(range(20),desc='Attempt #'):

    # 2. Instantiate and Fit the Model
    # Start with K=3 groups
    sbm = GaussianAsymmetricSBM(K=3, max_iter=50, tol=1e-5)
    sbm.fit(A_synth,init='kmeans')
        
    # Get hard cluster assignments from the soft assignments
    hard_clusters = np.argmax(sbm.tau_i, axis=1)
    
    # Check if the recovered structure resembles the ground truth    
    print(f"Predicted group counts: {Counter(hard_clusters)}")

## Using Spectral Means:

Let's compare with Scikit-Learn's Spectral Biclustering algorithm.

In [None]:
from collections import Counter
print(f"Original group counts: {Counter(groups)}")
print("\nAccuracy check (against synthetic truth):")


for t in tqdm(range(20),desc='Attempt #'):

    # 2. Instantiate and Fit the Model
    # Start with K=3 groups
    sbm = GaussianAsymmetricSBM(K=3, max_iter=50, tol=1e-5)
    sbm.fit(A_synth,init='spectral')
        
    # Get hard cluster assignments from the soft assignments
    hard_clusters = np.argmax(sbm.tau_i, axis=1)
    
    # Check if the recovered structure resembles the ground truth    
    print(f"Predicted group counts: {Counter(hard_clusters)}")

## Before and After Clustering

I want to visually demonstrate how well the biclustering worked so I'll plot a heat map of the
contigency table $A$ and the plot the same table $A$ but with rows and columns permuted so 
that the elements from group 0 are listed first, then from group 1, and then lastly from group 2.

If the clustering worked correctly you should see a heat map of an array of blocks parameterized by 
group indices $(r,s)$ that appear relatively constant within each block. If the clustering is incorrect you
should see artifacts in the blocks that are a tell-tale the clustering is not quite right.

In [None]:
# 2. Instantiate and Fit the Model
# Start with K=3 groups
sbm = GaussianAsymmetricSBM(K=3, max_iter=50, tol=1e-5)
sbm.fit(A_synth,init='spectral')

# 3. Analyze Results
print("\n--- Model Results ---")
print(f"Final Block Means (mu_rs):\n{sbm.mu.round(2)}")

# Get hard cluster assignments from the soft assignments
hard_clusters = np.argmax(sbm.tau_i, axis=1)

# Check if the recovered structure resembles the ground truth
print("\nAccuracy check (against synthetic truth):")
from collections import Counter
print(f"Original group counts: {Counter(groups)}")
print(f"Predicted group counts: {Counter(hard_clusters)}")

In [None]:
import matplotlib.pyplot as plt

This is what the original heat map of $A$ looks like prior to clustering elements. Some structure is visible, but it is not the optimal way
to cluster elements as you'll see below.

In [None]:
plt.title("Original matrix $A$")
plt.imshow(A_synth,interpolation='none'); plt.colorbar();

Now we'll resort the elements to generate the box like visual artifacts:

In [None]:
sorted_rows = np.argsort(hard_clusters)
sorted_mesh  = np.ix_(sorted_rows,sorted_rows)
A_sorted = A_synth[sorted_mesh]

And here's the before vs. after clustering plots. 

The diagonal black line is due to the fact that when we cluster we deliberately ignore the effect of $A[i,i]$ as uninformative for clustering purposes,
so we've zeroed out the diagonal entries.

In [None]:
plt.figure(figsize=(14,6))
plt.subplot(1,2,1)
plt.title("Original matrix $A$")
plt.imshow(A_synth,interpolation='none');

plt.subplot(1,2,2)
plt.title('"A" permuted by group assignment. Should be \napproximately constant per block pairs (r,s)')
plt.imshow(A_sorted,interpolation='none');
plt.xlabel("Group S")
plt.ylabel("Group R");