# Weakly-supervised end-to-end PCa grading with Attention-guided Kernel Density Matrices (WiSDOM)

In [None]:
!nvidia-smi

## Imports, libs, storage, wandb

In [None]:
import numpy as np
import tensorflow as tf
import wandb
from keras import optimizers
from keras.layers import Input, Dense
import keras
import os
import pandas as pd
from collections import OrderedDict
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import classification_report, ConfusionMatrixDisplay, mean_absolute_error, cohen_kappa_score, pairwise_distances
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt


## Data

#### Load Dataframes

Do some cleaning as well

In [None]:
df_train = pd.read_csv("/root/data/KDM/data/wsi_train.csv") 
df_val = pd.read_csv("/root/data/KDM/data/wsi_val.csv")
df_test = pd.read_csv("/root/data/KDM/data/wsi_test.csv")
train_dict = df_train.set_index('image_id')['isup_grade'].to_dict(into=OrderedDict)
val_dict = df_val.set_index('image_id')['isup_grade'].to_dict(into=OrderedDict)
test_dict = df_test.set_index('image_id')['isup_grade'].to_dict(into=OrderedDict)
train_dict = {k: v for k, v in train_dict.items() if os.path.isfile(os.path.join('/root/data/KDM/data/tile_mosaics',k+'.jpeg'))}
val_dict = {k: v for k, v in val_dict.items() if os.path.isfile(os.path.join('/root/data/KDM/data/tile_mosaics',k+'.jpeg'))}
test_dict = {k: v for k, v in test_dict.items() if os.path.isfile(os.path.join('/root/data/KDM/data/tile_mosaics',k+'.jpeg'))}
del train_dict['1c36b3db47d83f1436bd260288c5723f']
del train_dict['50203fbd5de280144cbb16749814a3fe']
del test_dict['ecae863e7c478594aa4c84ce132b3825']

train_features = list(train_dict.keys())
train_labels = np.array(list(train_dict.values()))
val_features = list(val_dict.keys())
val_labels = np.array(list(val_dict.values()))
test_features = list(test_dict.keys())
test_labels = np.array(list(test_dict.values()))
train_paths = [os.path.join('/root/data/KDM/data/tile_mosaics',train_path+'.jpeg') for train_path in train_features]
val_paths = [os.path.join('/root/data/KDM/data/tile_mosaics',val_path+'.jpeg') for val_path in val_features]
test_paths = [os.path.join('/root/data/KDM/data/tile_mosaics',test_path+'.jpeg') for test_path in test_features]

In [None]:
for k, v, i in zip(train_features, train_labels, train_dict.items()):
  assert k == i[0] and v == i[1]


### Preprocessing

Helper functions to load a sample from storage

In [None]:
def decode_img(img):
  # Convert the compressed string to a 3D uint8 tensor
  img = tf.io.decode_jpeg(img, channels=3)
  # Resize the image to the desired size
  return img

def process_path(file_path, label):
  # Load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  #img = tf.cast(img, tf.float32) / 255.0
  return img, label

def process_sample(file_path):
  # Load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  img = tf.cast(img, tf.float32) / 255.0
  return img


#### One hot encoding of the labels

In [None]:
encoder = OneHotEncoder()
y_train_onehot = encoder.fit_transform(train_labels.reshape(-1,1))
y_train_one_hot = y_train_onehot.toarray()
y_val_onehot = encoder.fit_transform(val_labels.reshape(-1,1))
y_val_one_hot = y_val_onehot.toarray()
y_test_onehot = encoder.fit_transform(test_labels.reshape(-1,1))
y_test_one_hot = y_test_onehot.toarray()

#### Regression Labels

In [None]:
train_reg_labels = train_labels / 5
val_reg_labels = val_labels / 5
test_reg_labels = test_labels / 5

In [None]:
np.unique(train_reg_labels, return_counts=True)

## Datasets

#### Batch size

For A100 largest batch size that fits is 8

In [None]:
batch_size = 8

#### Dataset class that returns a WSI mosaic


In [None]:
def create_dataset(paths,labels,batch_size,shuffle=False):
  num_samples = len(labels)
  paths = tf.data.Dataset.from_tensor_slices(paths)
  labels = tf.data.Dataset.from_tensor_slices(labels)
  ds = tf.data.Dataset.zip((paths,labels))
  if shuffle:
    ds = ds.shuffle(buffer_size=num_samples)
  ds = ds.map(process_path)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
  return ds


In [None]:
train_dataset_one_hot = create_dataset(train_paths, y_train_one_hot, batch_size=4, shuffle=True)
train_dataset = create_dataset(train_paths, train_reg_labels, batch_size=4, shuffle=True)
val_dataset = create_dataset(val_paths, val_reg_labels, shuffle=False, batch_size=4)
test_dataset = create_dataset(test_paths, test_reg_labels, shuffle=False, batch_size=4)

In [None]:
def get_samples_from_each_class(dataset, n_samples):
    samples_per_class = {class_idx: [] for class_idx in range(6)}
    samples_found_per_class = {class_idx: 0 for class_idx in range(6)}
    samples_collected = 0

    for batch_samples, batch_labels in dataset:
        for sample, label in zip(batch_samples, batch_labels):
            class_idx = np.argmax(label)
            if samples_found_per_class[class_idx] < n_samples:
                samples_per_class[class_idx].append(sample)
                samples_found_per_class[class_idx] += 1
                samples_collected += 1

            if samples_collected == n_samples * 6:
                break

        if samples_collected == n_samples * 6:
            break

    stacked_samples = tf.stack([sample for samples_list in samples_per_class.values() for sample in samples_list])
    stacked_labels = tf.stack([tf.one_hot(class_idx, depth=6) for class_idx, samples_list in samples_per_class.items() for _ in range(n_samples)])

    return stacked_samples, stacked_labels

In [None]:
prototypes, prototype_labels = get_samples_from_each_class(train_dataset_one_hot, 36)

In [None]:
prototypes.shape

## KDM

### Patches

In [None]:
class Patches(tf.keras.layers.Layer):
    def __init__(self, patch_size, image_size, strides):
        super(Patches, self).__init__()
        self.patch_size = patch_size
        self.strides = strides
        self.num_patches = (image_size - patch_size) // strides + 1

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.strides, self.strides, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, self.num_patches ** 2, patch_dims])
        return patches

### KDM Functions and Kernels

In [None]:
from keras.backend import dtype
def dm2comp(dm):
    '''
    Extract vectors and weights from a factorized density matrix representation
    Arguments:
     dm: tensor of shape (bs, n, d + 1)
    Returns:
     w: tensor of shape (bs, n)
     v: tensor of shape (bs, n, d)
    '''
    return dm[:, :, 0], dm[:, :, 1:]


def comp2dm(w, v):
    '''
    Construct a factorized density matrix from vectors and weights
    Arguments:
     w: tensor of shape (bs, n)
     v: tensor of shape (bs, n, d)
    Returns:
     dm: tensor of shape (bs, n, d + 1)
    '''
    return tf.concat((w[:, :, tf.newaxis], v), axis=2)

def samples2dm(samples):
    '''
    Construct a factorized density matrix from a batch of samples
    each sample will have the same weight. Samples that are all
    zero will be ignored.
    Arguments:
        samples: tensor of shape (bs, n, d)
    Returns:
        dm: tensor of shape (bs, n, d + 1)
    '''
    w = tf.reduce_any(samples, axis=-1)
    w = w / tf.reduce_sum(w, axis=-1, keepdims=True)
    return comp2dm(w, samples)

def pure2dm(psi):
    '''
    Construct a factorized density matrix to represent a pure state
    Arguments:
     psi: tensor of shape (bs, d)
    Returns:
     dm: tensor of shape (bs, 1, d + 1)
    '''
    ones = tf.ones_like(psi[:, 0:1])
    dm = tf.concat((ones[:,tf.newaxis, :],
                    psi[:,tf.newaxis, :]),
                   axis=2)
    return dm


def dm2discrete(dm):
    '''
    Creates a discrete distribution from the components of a density matrix
    Arguments:
     dm: tensor of shape (bs, n, d + 1)
    Returns:
     prob: vector of probabilities (bs, d)
    '''
    w, v = dm2comp(dm)
    w = w / tf.reduce_sum(w, axis=-1, keepdims=True)
    norms_v = tf.expand_dims(tf.linalg.norm(v, axis=-1), axis=-1)
    v = v / norms_v
    probs = tf.einsum('...j,...ji->...i', w, v ** 2, optimize="optimal")
    return probs


def dm2distrib(dm, sigma):
    '''
    Creates a Gaussian mixture distribution from the components of a density
    matrix with an RBF kernel
    Arguments:
     dm: tensor of shape (bs, n, d + 1)
     sigma: sigma parameter of the RBF kernel
    Returns:
     gm: mixture of Gaussian distribution with shape (bs, )
    '''
    w, v = dm2comp(dm)
    v = tf.cast(v, tf.float32)
    sigma = tf.cast(sigma, tf.float32)
    gm = tfd.MixtureSameFamily(reparameterize=True,
            mixture_distribution=tfd.Categorical(
                                    probs=w),
            components_distribution=tfd.Independent( tfd.Normal(
                    loc=v,  # component 2
                    scale=sigma / np.sqrt(2.)),
                    reinterpreted_batch_ndims=1))
    return gm


def pure_dm_overlap(x, dm, kernel):
    '''
    Calculates the overlap of a state  \phi(x) with a density
    matrix in a RKHS defined by a kernel
    Arguments:
      x: tensor of shape (bs, d)
     dm: tensor of shape (bs, n, d + 1)
     kernel: kernel function
              k: (bs, d) x (bs, n, d) -> (bs, n)
    Returns:
     overlap: tensor with shape (bs, )
    '''
    w, v = dm2comp(dm)
    overlap = tf.einsum('...i,...i->...', w, kernel(x, v) ** 2)
    return overlap

## Kernels

class CompTransKernelLayer(tf.keras.layers.Layer):
    def __init__(self, transform, kernel):
        '''
        Composes a transformation and a kernel to create a new
        kernel.
        Arguments:
            transform: a function f that transform the input before feeding it to the
                    kernel
                    f:(bs, d) -> (bs, D)
            kernel: a kernel function
                    k:(bs, n, D)x(m, D) -> (bs, n, m)
        '''
        super(CompTransKernelLayer, self).__init__()
        self.transform = transform
        self.kernel = kernel

    def call(self, A, B):
        '''
        Input:
            A: tensor of shape (bs, n, d)
            B: tensor of shape (m, d)
        Result:
            K: tensor of shape (bs, n, m)
        '''
        shape = tf.shape(A) # (bs, n, d)
        A = tf.reshape(A, [shape[0] * shape[1], shape[2]])
        A = self.transform(A)
        dim_out = tf.shape(A)[1]
        A = tf.reshape(A, [shape[0], shape[1], dim_out])
        B = self.transform(B)
        return self.kernel(A, B)

    def log_weight(self):
        return self.kernel.log_weight()

class RBFKernelLayer(tf.keras.layers.Layer):
    def __init__(self, sigma, dim, trainable=True, min_sigma=1e-3):
        '''
        Builds a layer that calculates the rbf kernel between two set of vectors
        Arguments:
            sigma: RBF scale parameter. If it is a tf.Variable it will be used as is.
                     Otherwise it will create a trainable variable with the given value.
        '''
        super(RBFKernelLayer, self).__init__()
        if type(sigma) is tf.Variable:
            self.sigma = sigma
        else:
            self.sigma = tf.Variable(sigma, dtype=tf.float32, trainable=trainable)
        self.dim = dim
        self.min_sigma = min_sigma

    def call(self, A, B):
        '''
        Input:
            A: tensor of shape (bs, n, d)
            B: tensor of shape (m, d)
        Result:
            K: tensor of shape (bs, n, m)
        '''
        shape_A = tf.shape(A)
        shape_B = tf.shape(B)
        A_norm = tf.norm(A, axis=-1)[..., tf.newaxis] ** 2
        B_norm = tf.norm(B, axis=-1)[tf.newaxis, tf.newaxis, :] ** 2
        A_reshaped = tf.reshape(A, [-1, shape_A[2]])
        AB = tf.matmul(A_reshaped, B, transpose_b=True)
        AB = tf.reshape(AB, [shape_A[0], shape_A[1], shape_B[0]])
        dist2 = A_norm + B_norm - 2. * AB
        dist2 = tf.clip_by_value(dist2, 0., np.inf)
        sigma = tf.clip_by_value(self.sigma, self.min_sigma, np.inf)
        K = tf.exp(-dist2 / (2. * sigma ** 2.))
        return K

    def log_weight(self):
        sigma = tf.clip_by_value(self.sigma, self.min_sigma, np.inf)
        return - self.dim * tf.math.log(sigma + 1e-12) - self.dim * np.log(4 * np.pi)



'''
Keras layer version of CosineKernel
'''
class CosineKernelLayer(tf.keras.layers.Layer):
    def __init__(self):
        '''
        Builds a layer that calculates the cosine kernel between two set of vectors
        '''
        super(CosineKernelLayer, self).__init__()
        self.eps = 1e-6

    def call(self, A, B):
        '''
        Input:
            A: tensor of shape (bs, n, d)
            B: tensor of shape (m, d)
        Result:
            K: tensor of shape (bs, n, m)
        '''
        A = tf.math.divide_no_nan(A,
                                  tf.expand_dims(tf.norm(A, axis=-1), axis=-1))
        B = tf.math.divide_no_nan(B,
                                  tf.expand_dims(tf.norm(B, axis=-1), axis=-1))
        K = tf.einsum("...nd,md->...nm", A, B)
        return K

    def log_weight(self):
        return 0


class CrossProductKernelLayer(tf.keras.layers.Layer):

    def __init__(self, dim1, kernel1, kernel2):
        '''
        Create a layer that calculates the cross product kernel of two input
        kernels. The input vector are divided into two parts, the first of dimension
        dim1 and the second of dimension d - dim1. Each input kernel is applied to
        one of the parts of the input.
        Arguments:
            dim1: the dimension of the first part of the input vector
            kernel1: a kernel function
                    k1:(bs, n, dim1)x(m, dim1) -> (bs, n, m)
            kernel2: a kernel function
                    k2:(bs, n, d - dim1)x(m, d - dim1) -> (bs, n, m)
        '''

        super(CrossProductKernelLayer, self).__init__()
        self.dim1 = dim1
        self.kernel1 = kernel1
        self.kernel2 = kernel2

    def call(self, A, B):
        '''
        Input:
            A: tensor of shape (bs, n, d)
            B: tensor of shape (m, d)
        Result:
            K: tensor of shape (bs, n, m)
        '''
        A1 = A[:, :, :self.dim1]
        A2 = A[:, :, self.dim1:]
        B1 = B[:, :self.dim1]
        B2 = B[:, self.dim1:]
        return self.kernel1(A1, B1) * self.kernel2(A2, B2)

    def log_weight(self):
        return self.kernel1.log_weight() + self.kernel2.log_weight()

## Layers and models

def l1_loss(vals):
    '''
    Calculate the l1 loss for a batch of vectors
    Arguments:
        vals: tensor with shape (b_size, n)
    '''
    b_size = tf.cast(tf.shape(vals)[0], dtype=tf.float32)
    vals = vals / tf.norm(vals, axis=1)[:, tf.newaxis]
    loss = tf.reduce_sum(tf.abs(vals)) / b_size
    return loss

class KDMUnit(tf.keras.layers.Layer):
    """Kernel Density Matrix Unit
    Receives as input a factored density matrix represented by a set of vectors
    and weight values.
    Returns a resulting factored density matrix.
    Input shape:
        (batch_size, n_comp_in, dim_x + 1)
        where dim_x is the dimension of the input state
        and n_comp_in is the number of components of the input factorization.
        The weights of the input factorization of sample i are [i, :, 0],
        and the vectors are [i, :, 1:dim_x + 1].
    Output shape:
        (batch_size, n_comp, dim_y)
        where dim_y is the dimension of the output state
        and n_comp is the number of components used to represent the train
        density matrix. The weights of the
        output factorization for sample i are [i, :, 0], and the vectors
        are [i, :, 1:dim_y + 1].
    Arguments:
        dim_x: int. the dimension of the input state
        dim_y: int. the dimension of the output state
        x_train: bool. Whether to train or not the x compoments of the train
                       density matrix.
        x_train: bool. Whether to train or not the y compoments of the train
                       density matrix.
        w_train: bool. Whether to train or not the weights of the compoments
                       of the train density matrix.
        n_comp: int. Number of components used to represent
                 the train density matrix
        l1_act: float. Coefficient of the regularization term penalizing the l1
                       norm of the activations.
        l1_x: float. Coefficient of the regularization term penalizing the l1
                       norm of the x components.
        l1_y: float. Coefficient of the regularization term penalizing the l1
                       norm of the y components.
    """
    def __init__(
            self,
            kernel,
            dim_x: int,
            dim_y: int,
            x_train: bool = True,
            y_train: bool = True,
            w_train: bool = True,
            n_comp: int = 0,
            l1_x: float = 0.,
            l1_y: float = 0.,
            l1_act: float = 0.,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.kernel = kernel
        self.dim_x = dim_x
        self.dim_y = dim_y
        self.x_train = x_train
        self.y_train = y_train
        self.w_train = w_train
        self.n_comp = n_comp
        self.l1_x = l1_x
        self.l1_y = l1_y
        self.l1_act = l1_act
        self.c_x = self.add_weight(
            "c_x",
            shape=(self.n_comp, self.dim_x),
            #initializer=tf.keras.initializers.orthogonal(),
            initializer=tf.keras.initializers.random_normal(),
            trainable=self.x_train)
        self.c_y = self.add_weight(
            "c_y",
            shape=(self.n_comp, self.dim_y),
            initializer=tf.keras.initializers.Constant(np.sqrt(1./self.dim_y)),
            #initializer=tf.keras.initializers.random_normal(),
            trainable=self.y_train)
        self.comp_w = self.add_weight(
            "comp_w",
            shape=(self.n_comp,),
            initializer=tf.keras.initializers.constant(1./self.n_comp),
            trainable=self.w_train)
        self.eps = 1e-10

    def call(self, inputs):
        # Weight regularizers
        if self.l1_x != 0:
            self.add_loss(self.l1_x * l1_loss(self.c_x))
        if self.l1_y != 0:
            self.add_loss(self.l1_y * l1_loss(self.c_y))
        #comp_w = tf.clip_by_value(self.comp_w, 1e-10, 1)
        comp_w = tf.abs(self.comp_w) + 1e-6
        # normalize comp_w to sum to 1
        comp_w = comp_w / tf.reduce_sum(comp_w)
        in_w = inputs[:, :, 0]  # shape (b, n_comp_in)
        in_v = inputs[:, :, 1:] # shape (b, n_comp_in, dim_x)
        out_vw = self.kernel(in_v, self.c_x)  # shape (b, n_comp_in, n_comp)
        out_w = (tf.expand_dims(tf.expand_dims(comp_w, axis=0), axis=0) *
                 tf.square(out_vw)) # shape (b, n_comp_in, n_comp)
        out_w = tf.maximum(out_w, self.eps) #########
        # out_w_sum = tf.maximum(tf.reduce_sum(out_w, axis=2), self.eps)  # shape (b, n_comp_in)
        out_w_sum = tf.reduce_sum(out_w, axis=2) # shape (b, n_comp_in)
        out_w = out_w / tf.expand_dims(out_w_sum, axis=2)
        out_w = tf.einsum('...i,...ij->...j', in_w, out_w, optimize="optimal")
                # shape (b, n_comp)
        if self.l1_act != 0:
            self.add_loss(self.l1_act * l1_loss(out_w))
        out_w = tf.expand_dims(out_w, axis=-1) # shape (b, n_comp, 1)
        out_y_shape = tf.shape(out_w) + tf.constant([0, 0, self.dim_y - 1])
        out_y = tf.broadcast_to(tf.expand_dims(self.c_y, axis=0), out_y_shape)
        out = tf.concat((out_w, out_y), 2)
        return out

    def get_config(self):
        config = {
            "dim_x": self.dim_x,
            "dim_y": self.dim_y,
            "n_comp": self.n_comp,
            "x_train": self.x_train,
            "y_train": self.y_train,
            "w_train": self.w_train,
            "l1_x": self.l1_x,
            "l1_y": self.l1_y,
            "l1_act": self.l1_act,
        }
        base_config = super().get_config()
        return {**base_config, **config}

    def compute_output_shape(self, input_shape):
        return (self.dim_y + 1, self.n_comp)

class KDMOverlap(tf.keras.layers.Layer):
    """Kernel Density Matrix Overlap Unit
    Receives as input a vector and calculates its overlap with the unit density
    matrix.
    Input shape:
        (batch_size, dim_x)
        where dim_x is the dimension of the input state
    Output shape:
        (batch_size, )
    Arguments:
        kernel: a kernel function
        dim_x: int. the dimension of the input state
        x_train: bool. Whether to train the or not the compoments of the train
                       density matrix.
        w_train: bool. Whether to train the or not the weights of the compoments
                       of the train density matrix.
        n_comp: int. Number of components used to represent
                 the train density matrix
    """

    def __init__(
            self,
            kernel,
            dim_x: int,
            x_train: bool = True,
            w_train: bool = True,
            n_comp: int = 0,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.kernel = kernel
        self.dim_x = dim_x
        self.x_train = x_train
        self.w_train = w_train
        self.n_comp = n_comp
        self.c_x = self.add_weight(
            "c_x",
            shape=(self.n_comp, self.dim_x),
            #initializer=tf.keras.initializers.orthogonal(),
            initializer=tf.keras.initializers.random_normal(),
            trainable=self.x_train)
        self.comp_w = self.add_weight(
            "comp_w",
            shape=(self.n_comp,),
            initializer=tf.keras.initializers.constant(1./self.n_comp),
            trainable=self.w_train)

    def call(self, inputs):
        #comp_w = tf.clip_by_value(self.comp_w, 1e-10, 1)
        comp_w = tf.abs(self.comp_w)
        # normalize comp_w to sum to 1
        comp_w = comp_w / tf.reduce_sum(comp_w)
        in_v = inputs[:, tf.newaxis, :]
        out_vw = self.kernel(in_v, self.c_x) ** 2 # shape (b, 1, n_comp)
        out_w = tf.einsum('...j,...ij->...', comp_w, out_vw, optimize="optimal")
        return out_w

    def get_config(self):
        config = {
            "dim_x": self.dim_x,
            "n_comp": self.n_comp,
            "x_train": self.x_train,
            "w_train": self.w_train,
        }
        base_config = super().get_config()
        return {**base_config, **config}

    def compute_output_shape(self, input_shape):
        return (1,)

class KDMClassModel(tf.keras.Model):
    def __init__(self,
                 dim_x,
                 dim_y,
                 sigma,
                 n_comp,
                 x_train=True):
        super().__init__()
        self.dim_x = dim_x
        self.dim_y = dim_y
        self.n_comp = n_comp
        self.kernel_x = RBFKernelLayer(sigma, dim=dim_x)
        self.kdmu = KDMUnit(self.kernel_x,
                            dim_x=dim_x,
                            dim_y=dim_y,
                            n_comp=n_comp,
                            x_train=x_train)

    def call(self, inputs):
        rho_x = pure2dm(inputs)
        rho_y = self.kdmu(rho_x)
        probs = dm2discrete(rho_y)
        return probs

class BagKDMClassModel(tf.keras.Model):
    def __init__(self,
                 dim_x,
                 dim_y,
                 sigma,
                 n_comp,
                 x_train=True,
                 l1_y=0.):
        super().__init__()
        self.dim_x = dim_x
        self.dim_y = dim_y
        self.n_comp = n_comp
        kernel_x = RBFKernelLayer(sigma)
        self.kdmu = KDMUnit(kernel_x,
                            dim_x=dim_x,
                            dim_y=dim_y,
                            n_comp=n_comp,
                            x_train=x_train,
                            l1_y=l1_y)

    def call(self, inputs):
        in_shape = tf.shape(inputs)
        w = tf.ones_like(inputs[:, :, 0]) / in_shape[1]
        rho_x = comp2dm(w, inputs)
        rho_y = self.kdmu(rho_x)
        probs = dm2discrete(rho_y)
        return rho_y

class KDMDenEstModel(tf.keras.Model):
    def __init__(self,
                 dim_x,
                 sigma,
                 n_comp):
        super().__init__()
        self.dim_x = dim_x
        self.n_comp = n_comp
        self.kernel = RBFKernelLayer(sigma, dim=dim_x)
        self.kdmover = KDMOverlap(self.kernel,
                                dim_x=dim_x,
                                n_comp=n_comp)

    def call(self, inputs):
        log_probs = (tf.math.log(self.kdmover(inputs) + 1e-12)
                     + self.kernel.log_weight())
        self.add_loss(-tf.reduce_mean(log_probs))
        return log_probs


class KDMDenEstModel2(tf.keras.Model):
    def __init__(self,
                 dim_x,
                 dim_y,
                 sigma,
                 n_comp,
                 trainable_sigma=True,
                 min_sigma=1e-3):
        super().__init__()
        self.dim_x = dim_x
        self.dim_y = dim_y
        self.n_comp = n_comp
        self.kernel_x = RBFKernelLayer(sigma, dim=dim_x,
                                       trainable=trainable_sigma,
                                       min_sigma=min_sigma)
        self.kernel_y = CosineKernelLayer()
        self.kernel = CrossProductKernelLayer(dim1=dim_x, kernel1=self.kernel_x, kernel2=self.kernel_y)
        self.kdmover = KDMOverlap(self.kernel,
                                dim_x=dim_x + dim_y,
                                n_comp=n_comp)

    def call(self, inputs):
        log_probs = (tf.math.log(self.kdmover(inputs) + 1e-12)
                     + self.kernel.log_weight())
        self.add_loss(-tf.reduce_mean(log_probs))
        return log_probs

### KDM Attn

In [None]:
class ProbRegression(tf.keras.layers.Layer):
    """
    Calculates the expected value and variance of a measure on a
    density matrix. The measure associates evenly distributed values
    between 0 and 1 to the different n basis states.
    Input shape:
        A tensor with shape (batch_size, n)
    Output shape:
        (batch_size, n, 2)
    Arguments:
    """

    def __init__(
            self,
            **kwargs
    ):
        super().__init__(**kwargs)


    def build(self, input_shape):
        if len(input_shape) != 2 :
            raise ValueError('A `DensityMatrix2Dist` layer should be '
                             'called with a tensor of shape '
                             '(batch_size, n)')
        self.vals = tf.constant(tf.linspace(0.0, 1.0, input_shape[1]), dtype=tf.float32)
        self.vals2 = self.vals ** 2
        self.built = True

    def call(self, inputs):
        if len(inputs.shape) != 2:
            raise ValueError('A `DensityMatrix2Dist` layer should be '
                             'called with a tensor of shape '
                             '(batch_size, n, )')
        mean = tf.einsum('...i,i->...', inputs, self.vals, optimize='optimal')
        mean2 = tf.einsum('...i,i->...', inputs, self.vals2, optimize='optimal')
        var = mean2 - mean ** 2
        return tf.stack([mean, var], axis = -1)

    def compute_output_shape(self, input_shape):
        return (input_shape[1], 2)


class KDMAttentionLayer(tf.keras.layers.Layer):
    def __init__(self,
                 dim_h,
                 dense_units_1,
                 dense_units_2):
        super().__init__()
        self.dim_h = dim_h
        self.dense_units_1 = dense_units_1
        self.dense_units_2 = dense_units_2
        self.mlp_1 = tf.keras.Sequential([
                Dense(dense_units_1, activation='relu'),
                Dense(dim_h, activation='linear')])
        self.mlp_2 = tf.keras.Sequential([
                Dense(dense_units_2, activation='relu'),
                Dense(1, activation='linear')])
    def call(self, input):
        z_local = self.mlp_1(input)
        z_global = tf.reduce_mean(z_local, axis=1)
        z_global = tf.expand_dims(z_global, axis=1)

        z_global = tf.broadcast_to(z_global, tf.shape(z_local))
        z = tf.concat([z_local, z_global], axis=-1)
        z = self.mlp_2(z)
        z = tf.squeeze(z, axis=-1) # eliminate the last dimension which should be 1
        w = tf.nn.softmax(z)
        return w

## Regression

In [None]:
class KDMPatchRegressionModel(tf.keras.Model):
    def __init__(self,
                 patch_size,
                 image_size,
                 strides,
                 encoder,
                 encoded_size,
                 dim_y,
                 n_comp,
                 sigma=0.1,
                 attention=False,
                 attention_dim_h=64,
                 attention_dense_units_1=64,
                 attention_dense_units_2=64):
        super().__init__()
        self.patch_size = patch_size
        self.image_size = image_size
        self.strides = strides
        self.patch_layer = Patches(patch_size, image_size, strides)
        self.dim_y = dim_y
        self.encoded_size = encoded_size
        self.encoder = encoder
        self.n_comp = n_comp
        self.attention = attention
        self.sigma = sigma
        self.kernel = RBFKernelLayer(sigma=sigma,
                                         dim=encoded_size,
                                         trainable=True)
        self.kdm_unit = KDMUnit(kernel=self.kernel,
                                       dim_x=encoded_size,
                                       dim_y=dim_y,
                                       n_comp=n_comp)
        if attention:
            self.attention_layer = KDMAttentionLayer(dim_h=attention_dim_h,
                                              dense_units_1=attention_dense_units_1,
                                              dense_units_2=attention_dense_units_2)
        self.regression_layer = ProbRegression()


    def call(self, input): # (bs, 1152,1152,3)
        patches = self.patch_layer(input) #(bs, n_patches, w*h*c)
        encoded = self.encoder(patches) #()
        bs = tf.shape(encoded)[0]
        if self.attention:
            w = self.attention_layer(encoded)
        else:
            w = tf.ones((bs, self.patch_layer.num_patches ** 2,)) / (self.patch_layer.num_patches ** 2)
        rho_x = comp2dm(w, encoded)
        rho_y = self.kdm_unit(rho_x)
        # distrib = dm2distrib(rho_y,self.sigma)
        # mean = distrib.mean()
        # variance = distrib.variance()
        probs = dm2discrete(rho_y)
        mean_var = self.regression_layer(probs)

        return mean_var

    def init_components(self, samples_x, samples_y, init_sigma=False, sigma_mult=1):
        patches = self.patch_layer(samples_x)
        idx = tf.random.uniform(shape=(patches.shape[0],), maxval=patches.shape[1], dtype=tf.int32) #select 1 random patch from each mosaic
        # Select the desired patches using tf.gather
        selected_patches = tf.gather(patches, idx, axis=1, batch_dims=1)
        # Encode the selected patches
        encoded_x = self.encoder(selected_patches[:, tf.newaxis, :])[:, 0, :]
        if init_sigma:
            distances = pairwise_distances(encoded_x)
            sigma = np.mean(distances) * sigma_mult
            self.kernel.sigma.assign(sigma)
        self.kdm_unit.c_x.assign(encoded_x)
        self.kdm_unit.c_y.assign(samples_y)
        self.kdm_unit.comp_w.assign(tf.ones((self.n_comp,)) / self.n_comp)

    def visualize_attention(self, input):
        patches = self.patch_layer(input)
        encoded = self.encoder(patches)
        w = self.attention_layer(encoded)
        conv2dt = tf.keras.layers.Conv2DTranspose(filters=1,
            kernel_size=self.patch_layer.patch_size,
            strides=self.patch_layer.strides,
            kernel_initializer=tf.keras.initializers.Ones(),
            bias_initializer=tf.keras.initializers.Zeros(),
            trainable=False)
        w = tf.reshape(w, [-1,
            self.patch_layer.num_patches,
            self.patch_layer.num_patches, 1])
        out = conv2dt(w)
        return out


## Encoder

#### Creating encoder and encoder classifier for warmup

In [None]:
encoded_size = 128

### Encoder that lets weights be loaded into it

In [None]:
def create_convnext_encoder(encoded_size):
  convnext = tf.keras.applications.convnext.ConvNeXtTiny(
    model_name='convnext_tiny',
    include_top=False,
    include_preprocessing=True,
    weights='imagenet',
    input_tensor=None,
    input_shape=(192,192,3),
    pooling="avg",
    classes=6,
    classifier_activation='softmax'
  )
  encoder = keras.Sequential([
      Input(shape=(192, 192, 3)),
      convnext,
      keras.layers.Dropout(0.5),
      keras.layers.Dense(encoded_size, activation="relu"), #relu, linear
  ])

  return encoder

### Encoder with reshapes and loaded weights

In [None]:
class Encoder(keras.Model):

  def __init__(self, encoder):
    super().__init__()
    self.encoded_size = encoded_size
    self.encoder = encoder

  def call(self, input): #(bs, n_patches, w*h*c)
    bs = tf.shape(input)[0] # bs
    input = tf.reshape(input,[-1,192,192,3]) # (bs * n_patches, w,h,c)
    x = self.encoder(input) #(bs * num_patches, encoded_size)
    out = tf.reshape(x, [bs, -1, self.encoded_size]) # (bs, num_patches, encoded_size)
    return out

### Instantiate the encoder that will go inside KDM and load pretrained weights to it

In [None]:
with tf.device('/device:GPU:1'):
    convnext = create_convnext_encoder(encoded_size)
    encoder_kdm = Encoder(convnext)
    convnext.load_weights('/root/data/KDM/models/best_kdm_patch_convnext_extractor.h5')

## Train the end-to-end model

#### Callbacks

In [None]:
alpha = 0.001
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    f"/root/data/KDM/models/regression_attn_KDM_alpha_{alpha}_0.5_dropout.h5",
    monitor = "val_loss",
    verbose = 1,
    save_best_only = True,
    save_weights_only = True,
    mode = "min",
    save_freq="epoch",
)

earlystop = tf.keras.callbacks.EarlyStopping(
    monitor = "val_loss",
    patience = 4,
    verbose = 1,
    restore_best_weights=True,
    mode = "min",
)
#### Compiling the model

def loss(y_true, y_pred):
  return tf.keras.losses.mean_squared_error(y_true, y_pred[:,0:1])  +  alpha * y_pred[:, 1:2]

In [None]:
n_comp = 216
# for layer in encoder_kdm.layers:
#   layer.trainable = True

In [None]:
# strategy = tf.distribute.MirroredStrategy()
# with strategy.scope():
with tf.device('/gpu:1'):
    kdm_class = KDMPatchRegressionModel(
                            patch_size=192,
                            image_size=1152,
                            strides=192,
                            encoded_size=encoded_size,
                            dim_y=6,
                            encoder=encoder_kdm,
                            n_comp=n_comp,
                            sigma=1.0,
                            attention=True,
                            attention_dim_h=64,
                            attention_dense_units_1=128,
                            attention_dense_units_2=128)

    kdm_class.compile(optimizer=optimizers.Adam(learning_rate=0.0001),
                    loss=loss, metrics=['mean_absolute_error'])



#### Init components into the model

In [None]:
kdm_class.init_components(prototypes, prototype_labels, init_sigma = True, sigma_mult = 1.)
print(kdm_class.kernel.sigma.numpy())

In [None]:
kdm_class(next(iter(train_dataset))[0])
kdm_class.load_weights("/root/data/KDM/checkpoints/regression_attn_KDM_alpha_0.001.h5")

In [None]:
kdm_class.summary()

### T-sNE visualization of encoder

In [None]:
encoder_layer = kdm_class.get_layer('encoder')

In [None]:
out = kdm_class.predict(test_dataset)
y_pred, std = out[:, 0], np.sqrt(out[:, 1])


In [None]:
test_embeddings_list = []
y_true_list = []

for x_batch, y_batch in test_dataset:
    embedding_batch = encoder_layer.predict(x_batch, verbose=0)
    test_embeddings_list.append(embedding_batch)
    y_true_list.append(y_batch)

test_embeddings = np.concatenate(test_embeddings_list, axis=0)
y_true = np.concatenate(y_true_list, axis=0)
y_true_rounded = np.round(y_true * 5)



#### Average pooling embeddings

In [None]:
import seaborn as sns

In [None]:
label_mapping = {
    0: "ISUP 0",
    1: "ISUP 1",
    2: "ISUP 2",
    3: "ISUP 3",
    4: "ISUP 4",
    5: "ISUP 5"
}

# Convert numeric labels to string labels
y_true_labels = np.vectorize(label_mapping.get)(y_true_rounded)

# Average Pooling
test_embeddings_avg = np.mean(test_embeddings, axis=1)

# Perform t-SNE dimensionality reduction
tsne_model = TSNE(n_components=2, random_state=0)
tsne_data = tsne_model.fit_transform(test_embeddings_avg)

# Define your custom color palette
palette = sns.color_palette("Pastel1", n_colors=6)

# Plot using Seaborn
plt.figure(figsize=(10, 10))
sns.scatterplot(x=tsne_data[:, 0], y=tsne_data[:, 1], hue=y_true_labels, palette=palette, alpha=1)
plt.title('t-SNE Visualization of Encoder Layer (Colored by Ground Truth)')
plt.legend(title='ISUP Grade')
plt.show()

### Continuos predictions

In [None]:
from matplotlib.colors import Normalize

test_embeddings_list = []
y_true_list = []
y_pred_list = []

for x_batch, y_batch in test_dataset:
    embedding_batch = encoder_layer.predict(x_batch, verbose=0)
    test_embeddings_list.append(embedding_batch)
    y_true_list.append(y_batch.numpy())

    # Assuming kdm_class is your trained model for prediction
    out = kdm_class.predict(x_batch, verbose=0)
    y_pred, _ = out[:, 0], np.sqrt(out[:, 1])
    y_pred_list.append(y_pred)

# Concatenate all the batches
test_embeddings = np.concatenate(test_embeddings_list, axis=0)
y_true = np.concatenate(y_true_list, axis=0)
y_pred = np.concatenate(y_pred_list, axis=0)

# Average Pooling
test_embeddings_avg = np.mean(test_embeddings, axis=1)

# Perform t-SNE dimensionality reduction
tsne_model = TSNE(n_components=2, random_state=0)
tsne_data = tsne_model.fit_transform(test_embeddings_avg)


### Avg pooling

In [None]:
sns.set(style="white")

fig, ax = plt.subplots(figsize=(12, 12), dpi=300)

# Create scatter plot using Seaborn
sns.scatterplot(x=tsne_data[:, 0], y=tsne_data[:, 1], hue=y_pred,
                palette="plasma", ax=ax, s=60, edgecolor='w', legend=False)


# Add color bar manually using Matplotlib
norm = Normalize(vmin=np.min(y_pred), vmax=np.max(y_pred))
sm = plt.cm.ScalarMappable(cmap="plasma", norm=norm)
sm.set_array([])

# Define colorbar position and dimensions [left, bottom, width, height]
cbar_ax = fig.add_axes([0.93, 0.2, 0.02, 0.6])
cbar = plt.colorbar(sm, cax=cbar_ax)
cbar.set_label('ISUP Grade', rotation=270, labelpad=20)

# Set colorbar ticks and labels
cbar.set_ticks([np.min(y_pred),0.2,0.4,0.6,0.8, np.max(y_pred)])
cbar.set_ticklabels(['ISUP 0', 'ISUP 1', 'ISUP 2', 'ISUP 3', 'ISUP 4' ,'ISUP 5'])

# Show the plot
plt.show()

### MAx pooling

In [None]:
test_embeddings_max = np.max(test_embeddings, axis=1)

# Perform t-SNE dimensionality reduction
tsne_model = TSNE(n_components=2, random_state=0)
tsne_data = tsne_model.fit_transform(test_embeddings_max)

sns.set(style="white")

fig, ax = plt.subplots(figsize=(12, 12), dpi=300)

# Create scatter plot using Seaborn
sns.scatterplot(x=tsne_data[:, 0], y=tsne_data[:, 1], hue=y_pred,
                palette="plasma", ax=ax, s=60, edgecolor='w', legend=False)


# Add color bar manually using Matplotlib
norm = Normalize(vmin=np.min(y_pred), vmax=np.max(y_pred))
sm = plt.cm.ScalarMappable(cmap="plasma", norm=norm)
sm.set_array([])

# Define colorbar position and dimensions [left, bottom, width, height]
cbar_ax = fig.add_axes([0.93, 0.2, 0.02, 0.6])
cbar = plt.colorbar(sm, cax=cbar_ax)
cbar.set_label('ISUP Grade', rotation=270, labelpad=20)

# Set colorbar ticks and labels
cbar.set_ticks([np.min(y_pred),0.2,0.4,0.6,0.8, np.max(y_pred)])
cbar.set_ticklabels(['ISUP 0', 'ISUP 1', 'ISUP 2', 'ISUP 3', 'ISUP 4' ,'ISUP 5'])

# Show the plot
plt.show()

#### All patches asigned the same label of WSI

In [None]:
test_embeddings_list = []
y_true_list = []

# Loop through test dataset to get embeddings and labels
for x_batch, y_batch in test_dataset:
    embedding_batch = encoder_layer.predict(x_batch, verbose=0)
    # Reshape each batch and append to list
    reshaped_embedding = embedding_batch.reshape(-1, 128)  # Flatten the first two dimensions
    test_embeddings_list.append(reshaped_embedding)

    # Repeat labels to match each patch and append to list
    repeated_y = np.repeat(y_batch, embedding_batch.shape[1])
    y_true_list.append(repeated_y)

# Concatenate all the batches
test_embeddings = np.vstack(test_embeddings_list)
y_true = np.concatenate(y_true_list)
y_true_rounded = np.round(y_true * 5)


In [None]:
tsne_model = TSNE(n_components=2, random_state=0)
tsne_data = tsne_model.fit_transform(test_embeddings)

sns.set(style="white")

fig, ax = plt.subplots(figsize=(12, 12), dpi=300)

# Create scatter plot using Seaborn
sns.scatterplot(x=tsne_data[:, 0], y=tsne_data[:, 1], hue=y_true_rounded,
                palette="plasma", ax=ax, s=60, edgecolor='w', legend=False)


# Add color bar manually using Matplotlib
norm = Normalize(vmin=np.min(y_pred), vmax=np.max(y_pred))
sm = plt.cm.ScalarMappable(cmap="plasma", norm=norm)
sm.set_array([])

# Define colorbar position and dimensions [left, bottom, width, height]
cbar_ax = fig.add_axes([0.93, 0.2, 0.02, 0.6])
cbar = plt.colorbar(sm, cax=cbar_ax)
cbar.set_label('ISUP Grade', rotation=270, labelpad=20)

# Set colorbar ticks and labels
cbar.set_ticks([np.min(y_pred),0.2,0.4,0.6,0.8, np.max(y_pred)])
cbar.set_ticklabels(['ISUP 0', 'ISUP 1', 'ISUP 2', 'ISUP 3', 'ISUP 4' ,'ISUP 5'])

# Show the plot
plt.show()

### Bags are classified by patch model

#### patch model

In [None]:
def dm2comp(dm):
    '''
    Extract vectors and weights from a factorized density matrix representation
    Arguments:
     dm: tensor of shape (bs, n, d + 1)
    Returns:
     w: tensor of shape (bs, n)
     v: tensor of shape (bs, n, d)
    '''
    return dm[:, :, 0], dm[:, :, 1:]

def comp2dm(w, v):
    '''
    Construct a factorized density matrix from vectors and weights
    Arguments:
     w: tensor of shape (bs, n)
     v: tensor of shape (bs, n, d)
    Returns:
     dm: tensor of shape (bs, n, d + 1)
    '''
    return tf.concat((w[:, :, tf.newaxis], v), axis=2)

def samples2dm(samples):
    '''
    Construct a factorized density matrix from a batch of samples
    each sample will have the same weight. Samples that are all
    zero will be ignored.
    Arguments:
        samples: tensor of shape (bs, n, d)
    Returns:
        dm: tensor of shape (bs, n, d + 1)
    '''
    w = tf.reduce_any(samples, axis=-1)
    w = w / tf.reduce_sum(w, axis=-1, keepdims=True)
    return comp2dm(w, samples)

def pure2dm(psi):
    '''
    Construct a factorized density matrix to represent a pure state
    Arguments:
     psi: tensor of shape (bs, d)
    Returns:
     dm: tensor of shape (bs, 1, d + 1)
    '''
    ones = tf.ones_like(psi[:, 0:1])
    dm = tf.concat((ones[:,tf.newaxis, :],
                    psi[:,tf.newaxis, :]),
                   axis=2)
    return dm

def dm2discrete(dm):
    '''
    Creates a discrete distribution from the components of a density matrix
    Arguments:
     dm: tensor of shape (bs, n, d + 1)
    Returns:
     prob: vector of probabilities (bs, d)
    '''
    w, v = dm2comp(dm)
    w = w / tf.reduce_sum(w, axis=-1, keepdims=True)
    norms_v = tf.expand_dims(tf.linalg.norm(v, axis=-1), axis=-1)
    v = v / norms_v
    probs = tf.einsum('...j,...ji->...i', w, v ** 2, optimize="optimal")
    return probs

def dm2distrib(dm, sigma):
    '''
    Creates a Gaussian mixture distribution from the components of a density
    matrix with an RBF kernel
    Arguments:
     dm: tensor of shape (bs, n, d + 1)
     sigma: sigma parameter of the RBF kernel
    Returns:
     gm: mixture of Gaussian distribution with shape (bs, )
    '''
    w, v = dm2comp(dm)
    gm = tfd.MixtureSameFamily(reparameterize=True,
            mixture_distribution=tfd.Categorical(
                                    probs=w),
            components_distribution=tfd.Independent( tfd.Normal(
                    loc=v,  # component 2
                    scale=sigma / np.sqrt(2.)),
                    reinterpreted_batch_ndims=1))
    return gm

def pure_dm_overlap(x, dm, kernel):
    '''
    Calculates the overlap of a state  \phi(x) with a density
    matrix in a RKHS defined by a kernel
    Arguments:
      x: tensor of shape (bs, d)
     dm: tensor of shape (bs, n, d + 1)
     kernel: kernel function
              k: (bs, d) x (bs, n, d) -> (bs, n)
    Returns:
     overlap: tensor with shape (bs, )
    '''
    w, v = dm2comp(dm)
    overlap = tf.einsum('...i,...i->...', w, kernel(x, v) ** 2)
    return overlap

class CompTransKernelLayer(tf.keras.layers.Layer):
    def __init__(self, transform, kernel):
        '''
        Composes a transformation and a kernel to create a new
        kernel.
        Arguments:
            transform: a function f that transform the input before feeding it to the
                    kernel
                    f:(bs, d) -> (bs, D)
            kernel: a kernel function
                    k:(bs, n, D)x(m, D) -> (bs, n, m)
        '''
        super(CompTransKernelLayer, self).__init__()
        self.transform = transform
        self.kernel = kernel

    def call(self, A, B):
        '''
        Input:
            A: tensor of shape (bs, n, d)
            B: tensor of shape (m, d)
        Result:
            K: tensor of shape (bs, n, m)
        '''
        shape = tf.shape(A) # (bs, n, d)
        A = tf.reshape(A, [shape[0] * shape[1], shape[2]])
        A = self.transform(A)
        dim_out = tf.shape(A)[1]
        A = tf.reshape(A, [shape[0], shape[1], dim_out])
        B = self.transform(B)
        return self.kernel(A, B)

    def log_weight(self):
        return self.kernel.log_weight()

class RBFKernelLayer(tf.keras.layers.Layer):
    def __init__(self, sigma, dim, trainable=True, min_sigma=1e-3):
        '''
        Builds a layer that calculates the rbf kernel between two set of vectors
        Arguments:
            sigma: RBF scale parameter. If it is a tf.Variable it will be used as is.
                     Otherwise it will create a trainable variable with the given value.
        '''
        super(RBFKernelLayer, self).__init__()
        if type(sigma) is tf.Variable:
            self.sigma = sigma
        else:
            self.sigma = tf.Variable(sigma, dtype=tf.float32, trainable=trainable)
        self.dim = dim
        self.min_sigma = min_sigma

    def call(self, A, B):
        '''
        Input:
            A: tensor of shape (bs, n, d)
            B: tensor of shape (m, d)
        Result:
            K: tensor of shape (bs, n, m)
        '''
        shape_A = tf.shape(A)
        shape_B = tf.shape(B)
        A_norm = tf.norm(A, axis=-1)[..., tf.newaxis] ** 2
        B_norm = tf.norm(B, axis=-1)[tf.newaxis, tf.newaxis, :] ** 2
        A_reshaped = tf.reshape(A, [-1, shape_A[2]])
        AB = tf.matmul(A_reshaped, B, transpose_b=True)
        AB = tf.reshape(AB, [shape_A[0], shape_A[1], shape_B[0]])
        dist2 = A_norm + B_norm - 2. * AB
        dist2 = tf.clip_by_value(dist2, 0., np.inf)
        sigma = tf.clip_by_value(self.sigma, self.min_sigma, np.inf)
        K = tf.exp(-dist2 / (2. * sigma ** 2.))
        return K

    def log_weight(self):
        sigma = tf.clip_by_value(self.sigma, self.min_sigma, np.inf)
        return - self.dim * tf.math.log(sigma + 1e-12) - self.dim * np.log(4 * np.pi)

class CosineKernelLayer(tf.keras.layers.Layer):
    def __init__(self):
        '''
        Builds a layer that calculates the cosine kernel between two set of vectors
        '''
        super(CosineKernelLayer, self).__init__()
        self.eps = 1e-6

    def call(self, A, B):
        '''
        Input:
            A: tensor of shape (bs, n, d)
            B: tensor of shape (m, d)
        Result:
            K: tensor of shape (bs, n, m)
        '''
        A = tf.math.divide_no_nan(A,
                                  tf.expand_dims(tf.norm(A, axis=-1), axis=-1))
        B = tf.math.divide_no_nan(B,
                                  tf.expand_dims(tf.norm(B, axis=-1), axis=-1))
        K = tf.einsum("...nd,md->...nm", A, B)
        return K

    def log_weight(self):
        return 0

class CrossProductKernelLayer(tf.keras.layers.Layer):

    def __init__(self, dim1, kernel1, kernel2):
        '''
        Create a layer that calculates the cross product kernel of two input
        kernels. The input vector are divided into two parts, the first of dimension
        dim1 and the second of dimension d - dim1. Each input kernel is applied to
        one of the parts of the input.
        Arguments:
            dim1: the dimension of the first part of the input vector
            kernel1: a kernel function
                    k1:(bs, n, dim1)x(m, dim1) -> (bs, n, m)
            kernel2: a kernel function
                    k2:(bs, n, d - dim1)x(m, d - dim1) -> (bs, n, m)
        '''

        super(CrossProductKernelLayer, self).__init__()
        self.dim1 = dim1
        self.kernel1 = kernel1
        self.kernel2 = kernel2

    def call(self, A, B):
        '''
        Input:
            A: tensor of shape (bs, n, d)
            B: tensor of shape (m, d)
        Result:
            K: tensor of shape (bs, n, m)
        '''
        A1 = A[:, :, :self.dim1]
        A2 = A[:, :, self.dim1:]
        B1 = B[:, :self.dim1]
        B2 = B[:, self.dim1:]
        return self.kernel1(A1, B1) * self.kernel2(A2, B2)

    def log_weight(self):
        return self.kernel1.log_weight() + self.kernel2.log_weight()

def l1_loss(vals):
    '''
    Calculate the l1 loss for a batch of vectors
    Arguments:
        vals: tensor with shape (b_size, n)
    '''
    b_size = tf.cast(tf.shape(vals)[0], dtype=tf.float32)
    vals = vals / tf.norm(vals, axis=1)[:, tf.newaxis]
    loss = tf.reduce_sum(tf.abs(vals)) / b_size
    return loss

class KDMUnit(tf.keras.layers.Layer):
    """Kernel Density Matrix Unit
    Receives as input a factored density matrix represented by a set of vectors
    and weight values.
    Returns a resulting factored density matrix.
    Input shape:
        (batch_size, n_comp_in, dim_x + 1)
        where dim_x is the dimension of the input state
        and n_comp_in is the number of components of the input factorization.
        The weights of the input factorization of sample i are [i, :, 0],
        and the vectors are [i, :, 1:dim_x + 1].
    Output shape:
        (batch_size, n_comp, dim_y)
        where dim_y is the dimension of the output state
        and n_comp is the number of components used to represent the train
        density matrix. The weights of the
        output factorization for sample i are [i, :, 0], and the vectors
        are [i, :, 1:dim_y + 1].
    Arguments:
        dim_x: int. the dimension of the input state
        dim_y: int. the dimension of the output state
        x_train: bool. Whether to train or not the x compoments of the train
                       density matrix.
        x_train: bool. Whether to train or not the y compoments of the train
                       density matrix.
        w_train: bool. Whether to train or not the weights of the compoments
                       of the train density matrix.
        n_comp: int. Number of components used to represent
                 the train density matrix
        l1_act: float. Coefficient of the regularization term penalizing the l1
                       norm of the activations.
        l1_x: float. Coefficient of the regularization term penalizing the l1
                       norm of the x components.
        l1_y: float. Coefficient of the regularization term penalizing the l1
                       norm of the y components.
    """
    def __init__(
            self,
            kernel,
            dim_x: int,
            dim_y: int,
            x_train: bool = True,
            y_train: bool = True,
            w_train: bool = True,
            n_comp: int = 0,
            l1_x: float = 0.,
            l1_y: float = 0.,
            l1_act: float = 0.,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.kernel = kernel
        self.dim_x = dim_x
        self.dim_y = dim_y
        self.x_train = x_train
        self.y_train = y_train
        self.w_train = w_train
        self.n_comp = n_comp
        self.l1_x = l1_x
        self.l1_y = l1_y
        self.l1_act = l1_act
        self.c_x = self.add_weight(
            "c_x",
            shape=(self.n_comp, self.dim_x),
            #initializer=tf.keras.initializers.orthogonal(),
            initializer=tf.keras.initializers.random_normal(),
            trainable=self.x_train)
        self.c_y = self.add_weight(
            "c_y",
            shape=(self.n_comp, self.dim_y),
            initializer=tf.keras.initializers.Constant(np.sqrt(1./self.dim_y)),
            #initializer=tf.keras.initializers.random_normal(),
            trainable=self.y_train)
        self.comp_w = self.add_weight(
            "comp_w",
            shape=(self.n_comp,),
            initializer=tf.keras.initializers.constant(1./self.n_comp),
            trainable=self.w_train)
        self.eps = 1e-10

    def call(self, inputs):
        # Weight regularizers
        if self.l1_x != 0:
            self.add_loss(self.l1_x * l1_loss(self.c_x))
        if self.l1_y != 0:
            self.add_loss(self.l1_y * l1_loss(self.c_y))
        #comp_w = tf.clip_by_value(self.comp_w, 1e-10, 1)
        comp_w = tf.abs(self.comp_w) + 1e-6
        # normalize comp_w to sum to 1
        comp_w = comp_w / tf.reduce_sum(comp_w)
        in_w = inputs[:, :, 0]  # shape (b, n_comp_in)
        in_v = inputs[:, :, 1:] # shape (b, n_comp_in, dim_x)
        out_vw = self.kernel(in_v, self.c_x)  # shape (b, n_comp_in, n_comp)
        out_w = (tf.expand_dims(tf.expand_dims(comp_w, axis=0), axis=0) *
                 tf.square(out_vw)) # shape (b, n_comp_in, n_comp)
        out_w = tf.maximum(out_w, self.eps) #########
        # out_w_sum = tf.maximum(tf.reduce_sum(out_w, axis=2), self.eps)  # shape (b, n_comp_in)
        out_w_sum = tf.reduce_sum(out_w, axis=2) # shape (b, n_comp_in)
        out_w = out_w / tf.expand_dims(out_w_sum, axis=2)
        out_w = tf.einsum('...i,...ij->...j', in_w, out_w, optimize="optimal")
                # shape (b, n_comp)
        if self.l1_act != 0:
            self.add_loss(self.l1_act * l1_loss(out_w))
        out_w = tf.expand_dims(out_w, axis=-1) # shape (b, n_comp, 1)
        out_y_shape = tf.shape(out_w) + tf.constant([0, 0, self.dim_y - 1])
        out_y = tf.broadcast_to(tf.expand_dims(self.c_y, axis=0), out_y_shape)
        out = tf.concat((out_w, out_y), 2)
        return out

    def get_config(self):
        config = {
            "dim_x": self.dim_x,
            "dim_y": self.dim_y,
            "n_comp": self.n_comp,
            "x_train": self.x_train,
            "y_train": self.y_train,
            "w_train": self.w_train,
            "l1_x": self.l1_x,
            "l1_y": self.l1_y,
            "l1_act": self.l1_act,
        }
        base_config = super().get_config()
        return {**base_config, **config}

    def compute_output_shape(self, input_shape):
        return (self.dim_y + 1, self.n_comp)

class KDMOverlap(tf.keras.layers.Layer):
    """Kernel Density Matrix Overlap Unit
    Receives as input a vector and calculates its overlap with the unit density
    matrix.
    Input shape:
        (batch_size, dim_x)
        where dim_x is the dimension of the input state
    Output shape:
        (batch_size, )
    Arguments:
        kernel: a kernel function
        dim_x: int. the dimension of the input state
        x_train: bool. Whether to train the or not the compoments of the train
                       density matrix.
        w_train: bool. Whether to train the or not the weights of the compoments
                       of the train density matrix.
        n_comp: int. Number of components used to represent
                 the train density matrix
    """

    def __init__(
            self,
            kernel,
            dim_x: int,
            x_train: bool = True,
            w_train: bool = True,
            n_comp: int = 0,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.kernel = kernel
        self.dim_x = dim_x
        self.x_train = x_train
        self.w_train = w_train
        self.n_comp = n_comp
        self.c_x = self.add_weight(
            "c_x",
            shape=(self.n_comp, self.dim_x),
            #initializer=tf.keras.initializers.orthogonal(),
            initializer=tf.keras.initializers.random_normal(),
            trainable=self.x_train)
        self.comp_w = self.add_weight(
            "comp_w",
            shape=(self.n_comp,),
            initializer=tf.keras.initializers.constant(1./self.n_comp),
            trainable=self.w_train)

    def call(self, inputs):
        #comp_w = tf.clip_by_value(self.comp_w, 1e-10, 1)
        comp_w = tf.abs(self.comp_w)
        # normalize comp_w to sum to 1
        comp_w = comp_w / tf.reduce_sum(comp_w)
        in_v = inputs[:, tf.newaxis, :]
        out_vw = self.kernel(in_v, self.c_x) ** 2 # shape (b, 1, n_comp)
        out_w = tf.einsum('...j,...ij->...', comp_w, out_vw, optimize="optimal")
        return out_w

    def get_config(self):
        config = {
            "dim_x": self.dim_x,
            "n_comp": self.n_comp,
            "x_train": self.x_train,
            "w_train": self.w_train,
        }
        base_config = super().get_config()
        return {**base_config, **config}

    def compute_output_shape(self, input_shape):
        return (1,)

class KDMClassModel(tf.keras.Model):
    def __init__(self,
                 dim_x,
                 dim_y,
                 sigma,
                 n_comp,
                 x_train=True):
        super().__init__()
        self.dim_x = dim_x
        self.dim_y = dim_y
        self.n_comp = n_comp
        self.kernel_x = RBFKernelLayer(sigma, dim=dim_x)
        self.kdmu = KDMUnit(self.kernel_x,
                            dim_x=dim_x,
                            dim_y=dim_y,
                            n_comp=n_comp,
                            x_train=x_train)

    def call(self, inputs):
        rho_x = pure2dm(inputs)
        rho_y = self.kdmu(rho_x)
        probs = dm2discrete(rho_y)
        return probs

class BagKDMClassModel(tf.keras.Model):
    def __init__(self,
                 dim_x,
                 dim_y,
                 sigma,
                 n_comp,
                 x_train=True,
                 l1_y=0.):
        super().__init__()
        self.dim_x = dim_x
        self.dim_y = dim_y
        self.n_comp = n_comp
        kernel_x = RBFKernelLayer(sigma)
        self.kdmu = KDMUnit(kernel_x,
                            dim_x=dim_x,
                            dim_y=dim_y,
                            n_comp=n_comp,
                            x_train=x_train,
                            l1_y=l1_y)

    def call(self, inputs):
        in_shape = tf.shape(inputs)
        w = tf.ones_like(inputs[:, :, 0]) / in_shape[1]
        rho_x = comp2dm(w, inputs)
        rho_y = self.kdmu(rho_x)
        probs = dm2discrete(rho_y)
        return rho_y

class KDMDenEstModel(tf.keras.Model):
    def __init__(self,
                 dim_x,
                 sigma,
                 n_comp):
        super().__init__()
        self.dim_x = dim_x
        self.n_comp = n_comp
        self.kernel = RBFKernelLayer(sigma, dim=dim_x)
        self.kdmover = KDMOverlap(self.kernel,
                                dim_x=dim_x,
                                n_comp=n_comp)

    def call(self, inputs):
        log_probs = (tf.math.log(self.kdmover(inputs) + 1e-12)
                     + self.kernel.log_weight())
        self.add_loss(-tf.reduce_mean(log_probs))
        return log_probs

class KDMDenEstModel2(tf.keras.Model):
    def __init__(self,
                 dim_x,
                 dim_y,
                 sigma,
                 n_comp,
                 trainable_sigma=True,
                 min_sigma=1e-3):
        super().__init__()
        self.dim_x = dim_x
        self.dim_y = dim_y
        self.n_comp = n_comp
        self.kernel_x = RBFKernelLayer(sigma, dim=dim_x,
                                       trainable=trainable_sigma,
                                       min_sigma=min_sigma)
        self.kernel_y = CosineKernelLayer()
        self.kernel = CrossProductKernelLayer(dim1=dim_x, kernel1=self.kernel_x, kernel2=self.kernel_y)
        self.kdmover = KDMOverlap(self.kernel,
                                dim_x=dim_x + dim_y,
                                n_comp=n_comp)

    def call(self, inputs):
        log_probs = (tf.math.log(self.kdmover(inputs) + 1e-12)
                     + self.kernel.log_weight())
        self.add_loss(-tf.reduce_mean(log_probs))
        return log_probs

class KQClassBagModel(tf.keras.Model):
    def __init__(self,
                 encoded_size,
                 dim_y,
                 encoder,
                 n_comp,
                 sigma=0.1,
                 mle_weight=0.):
        super().__init__()
        self.dim_y = dim_y
        self.encoded_size = encoded_size
        self.encoder = encoder
        self.n_comp = n_comp
        self.mle_weight = mle_weight
        self.kernel = RBFKernelLayer(sigma=sigma,
                                         dim=encoded_size,
                                         trainable=True)
        self.kdm_unit = KDMUnit(kernel=self.kernel,
                                       dim_x=encoded_size,
                                       dim_y=dim_y,
                                       n_comp=n_comp)
        self.regression_layer = ProbRegression()

    def call(self, input):
        encoded = self.encoder(input)
        rho_x = pure2dm(encoded)
        rho_y = self.kdm_unit(rho_x)
        probs = dm2discrete(rho_y)
        mean_var = self.regression_layer(probs)
        return mean_var

    def init_components(self, samples_x, samples_y, init_sigma=False, sigma_mult=1):
        encoded_x = self.encoder(samples_x)
        if init_sigma:
            distances = pairwise_distances(encoded_x)
            sigma = np.mean(distances) * sigma_mult
            self.kernel.sigma.assign(sigma)
        self.kdm_unit.c_x.assign(encoded_x)
        self.kdm_unit.c_y.assign(samples_y)
        self.kdm_unit.comp_w.assign(tf.ones((self.n_comp,)) / self.n_comp)

encoded_size = 128

def create_convnext_encoder(encoded_size):
  convnext = tf.keras.applications.convnext.ConvNeXtTiny(
    model_name='convnext_tiny',
    include_top=False,
    include_preprocessing=True,
    weights='imagenet',
    input_tensor=None,
    input_shape=(192,192,3),
    pooling="avg",
    classes=5,
    classifier_activation='softmax'
  )
  encoder = keras.Sequential([
      Input(shape=(192, 192, 3)),
      convnext,
      keras.layers.Dropout(0.2),
      keras.layers.Dense(encoded_size, activation="tanh"), #relu, linear
  ])

  encoder_cls = keras.Sequential([encoder,
                                keras.layers.Dense(5, activation="softmax")],
  )
  return encoder, encoder_cls


alpha = 0.1
def loss(y_true, y_pred):
  return tf.keras.losses.mean_squared_error(y_true, y_pred[:,0:1])  +  alpha * y_pred[:, 1:2]


encoder_kdm, _ = create_convnext_encoder(encoded_size)
kdm_cls_patch = KQClassBagModel(
                encoded_size=encoded_size,
                dim_y=5,
                encoder=encoder_kdm,
                n_comp=n_comp,
                sigma=0.1)
kdm_cls_patch(np.zeros((1,192,192,3)))

kdm_cls_patch.load_weights("/root/data/KDM/models/patch_regression_weights.h5")
kdm_cls_patch.compile()

#### TSNE

In [None]:
test_embeddings_list = []
y_pred_patch_list = []

def break_into_patches(image, patch_size=192):
    # Assuming image is of shape (1152, 1152, 3)
    patches = []
    for i in range(0, image.shape[0], patch_size):
        for j in range(0, image.shape[1], patch_size):
            patch = image[i:i+patch_size, j:j+patch_size, :]
            patches.append(patch)
    return np.array(patches)

# Loop through test dataset to get embeddings and labels
for x_batch, y_batch in test_dataset:
    
    # Get embeddings from the encoder layer
    embedding_batch = encoder_layer.predict(x_batch, verbose=0)
    
    # Reshape each batch and append to list
    reshaped_embedding = embedding_batch.reshape(-1, 128)  # Flatten the first two dimensions
    test_embeddings_list.append(reshaped_embedding)
    
    # Break each image into patches
    all_patches = []
    for image in x_batch:
        patches = break_into_patches(image)
        all_patches.extend(patches)
    
    all_patches = np.array(all_patches)
    
    # Get labels for each patch and append to list
    y_pred_from_patches = kdm_cls_patch.predict(all_patches, verbose=0)
    y_pred_patch_list.append(y_pred_from_patches)

# Concatenate all the batches
test_embeddings = np.vstack(test_embeddings_list)
y_pred_patches = np.concatenate(y_pred_patch_list)



In [None]:
np.unique(np.round(y_pred_patches * 4))

In [None]:
pred_patches = y_pred_patches[:,0]
var_patches = y_pred_patches[:,1]

In [None]:
pred_patches

In [None]:
rounded_pred_patches = np.round(pred_patches * 4)

In [None]:
min_variance = np.min(var_patches)
max_variance = np.max(var_patches)
normalized_variance = (var_patches - min_variance) / (max_variance - min_variance)

In [None]:
normalized_variance

In [None]:
rounded_pred_patches

In [None]:
np.unique(rounded_pred_patches)

In [None]:
tsne_model = TSNE(n_components=2, random_state=0)
tsne_data = tsne_model.fit_transform(test_embeddings)

sns.set(style="white")

fig, ax = plt.subplots(figsize=(12, 12), dpi=300)

# Create scatter plot using Seaborn
sns.scatterplot(x=tsne_data[:, 0], y=tsne_data[:, 1], hue=rounded_pred_patches,
                palette="plasma", ax=ax, s=60, edgecolor='w', legend=False)


# Add color bar manually using Matplotlib
norm = Normalize(vmin=np.min(y_pred), vmax=np.max(y_pred))
sm = plt.cm.ScalarMappable(cmap="plasma", norm=norm)
sm.set_array([])

# Define colorbar position and dimensions [left, bottom, width, height]
cbar_ax = fig.add_axes([0.93, 0.2, 0.02, 0.6])
cbar = plt.colorbar(sm, cax=cbar_ax)
cbar.set_label('ISUP Grade', rotation=270, labelpad=20)

# Set colorbar ticks and labels
cbar.set_ticks([np.min(y_pred),0.25,0.5,0.75, np.max(y_pred)])
cbar.set_ticklabels(['Stroma', 'Benign', 'Gleason 3', 'Gleason 4', 'Gleason 5'])

# Show the plot
plt.show()

In [None]:
low_variance_indices = np.where(var_patches < 0.05)[0]
filtered_tsne_data = tsne_data[low_variance_indices]
filtered_rounded_pred_patches = np.array(rounded_pred_patches)[low_variance_indices]

sns.set(style="white")
fig, ax = plt.subplots(figsize=(12, 12), dpi=300)

# Create scatter plot using Seaborn for the filtered points
sns.scatterplot(x=filtered_tsne_data[:, 0], y=filtered_tsne_data[:, 1], 
                hue=filtered_rounded_pred_patches, palette="plasma", ax=ax, 
                s=60, edgecolor='w', legend=False)

# Add color bar manually using Matplotlib
norm = Normalize(vmin=np.min(y_pred), vmax=np.max(y_pred))
sm = plt.cm.ScalarMappable(cmap="plasma", norm=norm)
sm.set_array([])

# Define colorbar position and dimensions [left, bottom, width, height]
cbar_ax = fig.add_axes([0.93, 0.2, 0.02, 0.6])
cbar = plt.colorbar(sm, cax=cbar_ax)
cbar.set_label('ISUP Grade', rotation=270, labelpad=20)

# Set colorbar ticks and labels
cbar.set_ticks([np.min(y_pred), 0.25, 0.5, 0.75, np.max(y_pred)])
cbar.set_ticklabels(['Stroma', 'Benign', 'Gleason 3', 'Gleason 4', 'Gleason 5'])

# Show the plot
plt.show()

In [None]:
rounded_pred_patches

In [None]:
np.unique(rounded_pred_patches)

In [None]:

tsne_model = TSNE(n_components=2, random_state=0)
tsne_data = tsne_model.fit_transform(test_embeddings)

sns.set(style="white")
fig, ax = plt.subplots(figsize=(12, 12), dpi=300)

# Create scatter plot using Matplotlib for variable alpha values
scatter = ax.scatter(tsne_data[:, 0], tsne_data[:, 1], c=rounded_pred_patches, 
                     cmap="plasma", s=60, edgecolor='w', alpha=normalized_variance)

# Add color bar manually using Matplotlib
norm = Normalize(vmin=np.min(y_pred), vmax=np.max(y_pred))
sm = plt.cm.ScalarMappable(cmap="plasma", norm=norm)
sm.set_array([])

# Define colorbar position and dimensions [left, bottom, width, height]
cbar_ax = fig.add_axes([0.93, 0.2, 0.02, 0.6])
cbar = plt.colorbar(sm, cax=cbar_ax)
cbar.set_label('Gleason pattern', rotation=270, labelpad=20)

# Set colorbar ticks and labels
cbar.set_ticks([np.min(y_pred),0.25,0.5,0.75, np.max(y_pred)])
cbar.set_ticklabels(['Stroma', 'Benign', 'Gleason 3', 'Gleason 4', 'Gleason 5'])
# Show the plot
plt.show()

### Training

In [None]:
#history = kdm_class.fit(train_dataset, validation_data=val_dataset, epochs=20, verbose=1, callbacks=[checkpoint_callback, earlystop])

### Test

In [None]:
# kdm_class(next(iter(train_dataset))[0])

# kdm_class.load_weights("/content/drive/MyDrive/data/kdm_data/regression_attn_KDM_patch_weights_no_alpha.h5")

In [None]:
out = kdm_class.predict(test_dataset)
y_pred, std = out[:, 0], np.sqrt(out[:, 1])

predictions = np.round(y_pred * 5)
y_true = np.round(np.concatenate([y for x,y in test_dataset], axis=0) * 5)

In [None]:
ConfusionMatrixDisplay.from_predictions(y_true, predictions, normalize='true', display_labels=['ISUP 0', 'ISUP 1', 'ISUP 2', 'ISUP 3', 'ISUP 4', 'ISUP 5'])

In [None]:
y_true_reg = np.concatenate([y for x,y in test_dataset], axis=0)

In [None]:
np.unique(y_true_reg)

In [None]:
y_pred

In [None]:
mean_absolute_error(y_true_reg, y_pred)

In [None]:
cohen_kappa_score(y_true, predictions, weights='quadratic')

In [None]:
print(classification_report(y_true, predictions))

In [None]:
# Your existing code
out = kdm_class.predict(test_dataset)
y_pred, var = out[:, 0], out[:, 1]
predictions = np.round(y_pred * 5)
y_true = np.round(np.concatenate([y for x,y in test_dataset], axis=0) * 5)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd


abs_errors = np.abs(predictions - y_true)

# Group predictions and variance by absolute error
data_to_plot = {'Error Group': [], 'Variance': [], 'Predictions': []}
for i, error in enumerate(abs_errors):
    if error == 0:
        error_group = '0'
    elif error == 1:
        error_group = '1'
    else:
        error_group = '2+'

    data_to_plot['Error Group'].append(error_group)
    data_to_plot['Variance'].append(var[i])
    data_to_plot['Predictions'].append(predictions[i])

# Create a DataFrame
df = pd.DataFrame(data_to_plot)

# Get the counts for each error group
counts = df['Error Group'].value_counts().loc[['0', '1', '2+']]
# Define the color palette
palette = sns.color_palette("Pastel1", n_colors=3)

# Plot the variance using seaborn
plt.figure(figsize=(10, 6))
sns.violinplot(x='Error Group', y='Variance', data=df, order=['0', '1', '2+'], palette=palette)
plt.title('Variance by Absolute Error Group', fontsize=15)
plt.xlabel('Error Group', fontsize=13)
plt.ylabel('Variance', fontsize=13)

# Create custom legend
legend_labels = [f'{key}: {value} samples' for key, value in counts.items()]
legend_handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=palette[i], markersize=10) for i in range(3)]
plt.legend(legend_handles, legend_labels, loc='center left', bbox_to_anchor=(1, 0.5))

plt.savefig('variance_plot.png', bbox_inches='tight', dpi=600)
plt.show()

# Plot the predictions using seaborn
plt.figure(figsize=(10, 6))
sns.violinplot(x='Error Group', y='Predictions', data=df, order=['0', '1', '2+'], palette=palette)
plt.title('Predictions by Absolute Error Group', fontsize=15)
plt.xlabel('Error Group', fontsize=13)
plt.ylabel('Predictions', fontsize=13)

# Create custom legend
plt.legend(legend_handles, legend_labels, loc='center left', bbox_to_anchor=(1, 0.5))

plt.savefig('predictions_plot.png', bbox_inches='tight', dpi=600)
plt.show()




In [None]:
wandb.finish()

### Inference of classification models to extract metrics

In [None]:
class KDMPatchClassModel(tf.keras.Model):
    def __init__(self,
                 patch_size,
                 image_size,
                 strides,
                 encoder,
                 encoded_size,
                 dim_y,
                 n_comp,
                 sigma=0.1,
                 attention=False,
                 attention_dim_h=64,
                 attention_dense_units_1=64,
                 attention_dense_units_2=64):
        super().__init__()
        self.patch_size = patch_size
        self.image_size = image_size
        self.strides = strides
        self.patch_layer = Patches(patch_size, image_size, strides)
        self.dim_y = dim_y
        self.encoded_size = encoded_size
        self.encoder = encoder
        self.n_comp = n_comp
        self.attention = attention
        self.kernel = RBFKernelLayer(sigma=sigma,
                                         dim=encoded_size,
                                         trainable=True)
        self.kdm_unit = KDMUnit(kernel=self.kernel,
                                       dim_x=encoded_size,
                                       dim_y=dim_y,
                                       n_comp=n_comp)
        if attention:
            self.attention_layer = KDMAttentionLayer(dim_h=attention_dim_h,
                                              dense_units_1=attention_dense_units_1,
                                              dense_units_2=attention_dense_units_2)
    def call(self, input): # (bs, 1152,1152,3)
        patches = self.patch_layer(input) #(bs, n_patches, w*h*c)
        encoded = self.encoder(patches) #()
        bs = tf.shape(encoded)[0]
        if self.attention:
            w = self.attention_layer(encoded)
        else:
            w = tf.ones((bs, self.patch_layer.num_patches ** 2,)) / (self.patch_layer.num_patches ** 2)
        rho_x = comp2dm(w, encoded)
        rho_y = self.kdm_unit(rho_x)
        probs = dm2discrete(rho_y)
        return probs

    def init_components(self, samples_x, samples_y, init_sigma=False, sigma_mult=1):
        patches = self.patch_layer(samples_x)
        idx = tf.random.uniform(shape=(patches.shape[0],), maxval=patches.shape[1], dtype=tf.int32) #select 1 random patch from each mosaic
        # Select the desired patches using tf.gather
        selected_patches = tf.gather(patches, idx, axis=1, batch_dims=1)
        # Encode the selected patches
        encoded_x = self.encoder(selected_patches[:, tf.newaxis, :])[:, 0, :]
        if init_sigma:
            distances = pairwise_distances(encoded_x)
            sigma = np.mean(distances) * sigma_mult
            self.kernel.sigma.assign(sigma)
        self.kdm_unit.c_x.assign(encoded_x)
        self.kdm_unit.c_y.assign(samples_y)
        self.kdm_unit.comp_w.assign(tf.ones((self.n_comp,)) / self.n_comp)

    def visualize_attention(self, input):
        patches = self.patch_layer(input)
        encoded = self.encoder(patches)
        w = self.attention_layer(encoded)
        conv2dt = tf.keras.layers.Conv2DTranspose(filters=1,
            kernel_size=self.patch_layer.patch_size,
            strides=self.patch_layer.strides,
            kernel_initializer=tf.keras.initializers.Ones(),
            bias_initializer=tf.keras.initializers.Zeros(),
            trainable=False)
        w = tf.reshape(w, [-1,
            self.patch_layer.num_patches,
            self.patch_layer.num_patches, 1])
        out = conv2dt(w)
        return out

In [None]:
kdm_classifier = KDMPatchClassModel(
                        patch_size=192,
                        image_size=1152,
                        strides=192,
                        encoded_size=encoded_size,
                        dim_y=6,
                        encoder=encoder_kdm,
                        n_comp=n_comp,
                        sigma=1.0,
                        attention=True,
                        attention_dim_h=64,
                        attention_dense_units_1=128,
                        attention_dense_units_2=128)

kdm_classifier.compile()

In [None]:
kdm_classifier(next(iter(train_dataset))[0])

kdm_classifier.load_weights("/data/KDM/models/attn_kdm_cls_from_kdm_patch_cls_unfrozen.h5")

In [None]:
preds = kdm_classifier.predict(test_dataset)
predictions = np.argmax(preds, axis =1)
y_true = np.round(np.concatenate([y for x,y in test_dataset], axis=0) * 5)
ConfusionMatrixDisplay.from_predictions(y_true, predictions, normalize='true', display_labels=['ISUP 0', 'ISUP 1', 'ISUP 2', 'ISUP 3', 'ISUP 4', 'ISUP 5'])

In [None]:
y_true_reg = np.concatenate([y for x,y in test_dataset], axis=0)
preds_regression = predictions / 5

In [None]:
preds_regression

In [None]:
mean_absolute_error(y_true_reg, preds_regression)