# End to end patch classifier with KDM


## Data

In [None]:
!nvidia-smi

## Imports

In [None]:
import numpy as np
import tensorflow as tf
import wandb
import sys
from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint
from tensorflow.keras.models import Model
import numpy as np
from keras import optimizers
from keras import losses
from keras import metrics
from tensorflow.keras.layers import Conv2D, Resizing, InputLayer, Flatten, Dense
from tensorflow.keras.models import Sequential
from keras.layers import Input, Dense
from keras.models import Model
import keras
import tensorflow_probability as tfp
import tensorflow as tf
import numpy as np
import numpy as np
import tensorflow_addons as tfa
import os
from PIL import Image
import pandas as pd
from collections import OrderedDict
import pathlib
from matplotlib import pyplot as plt
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel
from sklearn.preprocessing import MinMaxScaler, normalize
from matplotlib import pyplot as pl
from tensorflow.keras.callbacks import EarlyStopping
from tqdm.notebook import tqdm
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay, mean_absolute_error, mean_squared_error, r2_score, mean_squared_log_error, recall_score, f1_score, accuracy_score, cohen_kappa_score, precision_score
import glob
from sklearn.metrics import pairwise_distances
tfd = tfp.distributions
from sklearn.utils.class_weight import compute_class_weight
CUDA_VISIBLE_DEVICES="0,1"

## Data from WSI

In [None]:
# !wandb login c020039f1e51c657e83b9990a5f67a34b5b38a68
# wandb.init(project="KDM Patch classifier end-to-end")
data_dir = '/data/KDM/data'
image_folder = f'{data_dir}/train_images/'
mask_folder = f'{data_dir}/train_label_masks/'
patches_folder = f'{data_dir}/patches'
train_csv = f'{data_dir}/df_train.csv'
val_csv = f'{data_dir}/df_val.csv'
test_csv = f'{data_dir}/df_test.csv'
df_train = pd.read_csv(train_csv)
df_val = pd.read_csv(val_csv)
df_test = pd.read_csv(test_csv)
df_train['gleason_score'] -= 1
df_val['gleason_score'] -= 1
df_test['gleason_score'] -= 1

#WSI
df_train_wsi = pd.read_csv(f"{data_dir}/wsi_train.csv")
df_val_wsi = pd.read_csv(f"{data_dir}/wsi_val.csv")
df_test_wsi = pd.read_csv(f"{data_dir}/wsi_test.csv")
train_dict = df_train_wsi.set_index('image_id')['isup_grade'].to_dict(into=OrderedDict)
val_dict = df_val_wsi.set_index('image_id')['isup_grade'].to_dict(into=OrderedDict)
test_dict = df_test_wsi.set_index('image_id')['isup_grade'].to_dict(into=OrderedDict)
train_dict = {k: v for k, v in train_dict.items() if os.path.isfile(os.path.join(f'{data_dir}/tile_mosaics',k+'.jpeg'))}
val_dict = {k: v for k, v in val_dict.items() if os.path.isfile(os.path.join(f'{data_dir}/tile_mosaics',k+'.jpeg'))}
test_dict = {k: v for k, v in test_dict.items() if os.path.isfile(os.path.join(f'{data_dir}/tile_mosaics',k+'.jpeg'))}
del train_dict['1c36b3db47d83f1436bd260288c5723f']
del train_dict['50203fbd5de280144cbb16749814a3fe']
del test_dict['ecae863e7c478594aa4c84ce132b3825']
train_features = list(train_dict.keys())
train_labels = np.array(list(train_dict.values()))
val_features = list(val_dict.keys())
val_labels = np.array(list(val_dict.values()))
test_features = list(test_dict.keys())
test_labels = np.array(list(test_dict.values()))
train_paths = [os.path.join(f'{data_dir}/tile_mosaics',train_path+'.jpeg') for train_path in train_features]
val_paths = [os.path.join(f'{data_dir}/tile_mosaics',val_path+'.jpeg') for val_path in val_features]
test_paths = [os.path.join(f'{data_dir}/tile_mosaics',test_path+'.jpeg') for test_path in test_features]
encoder = OneHotEncoder()
y_train_onehot = encoder.fit_transform(train_labels.reshape(-1,1))
y_train_one_hot = y_train_onehot.toarray()
y_val_onehot = encoder.fit_transform(val_labels.reshape(-1,1))
y_val_one_hot = y_val_onehot.toarray()
y_test_onehot = encoder.fit_transform(test_labels.reshape(-1,1))
y_test_one_hot = y_test_onehot.toarray()



## Create Datasets

In [None]:
tiles_path = f'{data_dir}/tiles'

def decode_img(img):
  # Convert the compressed string to a 3D uint8 tensor
  img = tf.io.decode_jpeg(img, channels=3)
  # Resize the image to the desired size
  return img

def process_path(file_path):

  # Load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img

def create_dataset(df, data_dir, shuffle=False, batch_size=32):

  patches_path = df['image_id'].apply(lambda x: os.path.join(data_dir, x.split('_')[0], x + '.jpeg'))
  patches_labels = keras.utils.to_categorical(df['gleason_score'], num_classes = 5)
  dataset = tf.data.Dataset.from_tensor_slices((patches_path, patches_labels))
  if shuffle:
    dataset = dataset.shuffle(buffer_size=len(df))
  dataset = dataset.map(lambda x, y: (process_path(x), y), num_parallel_calls=tf.data.experimental.AUTOTUNE)
  dataset = dataset.prefetch(tf.data.AUTOTUNE)
  dataset = dataset.batch(batch_size)
  return dataset

def create_regression_dataset(df, data_dir, shuffle=False, batch_size=32):

  patches_path = df['image_id'].apply(lambda x: os.path.join(data_dir, x.split('_')[0], x + '.jpeg'))
  patches_labels = keras.utils.to_categorical(df['gleason_score'], num_classes = 5)
  regression_labels = df['gleason_score'] / 4
  dataset = tf.data.Dataset.from_tensor_slices((patches_path, regression_labels))
  if shuffle:
    dataset = dataset.shuffle(buffer_size=len(df))
  dataset = dataset.map(lambda x, y: (process_path(x), y), num_parallel_calls=tf.data.experimental.AUTOTUNE)
  dataset = dataset.prefetch(tf.data.AUTOTUNE)
  dataset = dataset.batch(batch_size)
  return dataset

class WsiDataset(tf.keras.utils.Sequence):

    def __init__(self, wsi_names, stage, tiles_path, labels, shuffle=True):
      self.wsi_names = wsi_names
      self.stage = stage
      self.tiles_path = tiles_path
      self.labels = labels
      self.shuffle = shuffle

    def __len__(self):
        return len(self.wsi_names)

    def __getitem__(self, idx):
      wsi = self.wsi_names[idx]
      y = self.labels[idx]
      tile_paths = [os.path.join(tiles_path,f'{self.stage}_tiles', wsi, tile) for tile in os.listdir(os.path.join(self.tiles_path,f'{self.stage}_tiles',wsi))]
      tile_matrix = []
      for tile in tile_paths:
        tile_matrix.append(process_path(tile))
      tile_matrix = np.array(tile_matrix)
      return tile_matrix, y

      def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indexes)

class Prototypes(tf.keras.utils.Sequence):
    def __init__(self, wsi_dataset):
        self.wsi_dataset = wsi_dataset
        self.num_classes = 6
        self.num_samples_per_class = 36

    def __len__(self):
        return len(self.wsi_dataset)

    def __getitem__(self, idx): # Determine the sample index within the class

        sample, label = self.wsi_dataset[idx]
        return sample, np.tile(label, (self.num_samples_per_class, 1))

def get_prototypes():

  """
    Retrieves prototypes from a WsiDataset object based on their labels.

    Returns:
        protos (ndarray): Stacked array of prototypes with shape (N, H, W, C),
            where N is the total number of prototypes, H is the height of each prototype,
            W is the width of each prototype, and C is the number of channels.
        label_protos (ndarray): Stacked array of labels corresponding to the prototypes
            with shape (N, L), where N is the total number of prototypes and L is the number
            of label dimensions.

    Raises:
        None

    Usage:
        protos, label_protos = get_prototypes()
    """
  prototypes = Prototypes(WsiDataset(train_features, "train", tiles_path, y_train_one_hot, shuffle=True))
  protos = []
  label_protos = []
  for item in tqdm(range(6)):
    for sample, label in prototypes:
      if np.argmax(label[0]) == item:
          protos.append(sample)
          label_protos.append(label)
          break
  protos = np.concatenate(protos, axis=0)
  label_protos = np.concatenate(label_protos, axis=0)
  return protos, label_protos

def get_patch_prototypes():
  """
    Retrieves prototypes from a dataset object based on their labels.

    Returns:
        protos (ndarray): Stacked array of prototypes with shape (N, H, W, C),
            where N is the total number of prototypes, H is the height of each prototype,
            W is the width of each prototype, and C is the number of channels.
        label_protos (ndarray): Stacked array of labels corresponding to the prototypes
            with shape (N, L), where N is the total number of prototypes and L is the number
            of label dimensions.

    Raises:
        None

    Usage:
        protos, label_protos = get_prototypes()
    """
  prototypes = create_dataset(df_train, patches_folder, shuffle=True, batch_size=1)
  #prototypes = Prototypes(WsiDataset(train_features, "train", tiles_path, y_train_one_hot, shuffle=True))
  protos = []
  label_protos = []
  for proto in tqdm(range(216//5)):
    for item in range(5):
      for sample, label in prototypes:
        if np.argmax(label[0]) == item:
            protos.append(sample)
            label_protos.append(label)
            break
  for sample, label in prototypes:
    protos.append(sample)
    label_protos.append(label)
    break
  protos = np.concatenate(protos, axis=0)
  label_protos = np.concatenate(label_protos, axis=0)
  return protos, label_protos

### Train, Val, Test patch datasets

In [None]:
train_dataset = create_regression_dataset(df_train, patches_folder, shuffle=True, batch_size=64)
val_dataset = create_regression_dataset(df_val, patches_folder, shuffle=False, batch_size=64)
test_dataset = create_regression_dataset(df_test, patches_folder, shuffle=False, batch_size=64)

### Compute class weights

In [None]:
class_sample_count = np.unique(df_train['gleason_score'].to_numpy())
class_weights = compute_class_weight(class_weight='balanced',classes=class_sample_count,y=df_train['gleason_score'].to_numpy())
print(class_weights)

## KDM

## Regression layer

In [None]:
class ProbRegression(tf.keras.layers.Layer):
    """
    Calculates the expected value and variance of a measure on a
    density matrix. The measure associates evenly distributed values
    between 0 and 1 to the different n basis states.
    Input shape:
        A tensor with shape (batch_size, n)
    Output shape:
        (batch_size, n, 2)
    Arguments:
    """

    def __init__(
            self,
            **kwargs
    ):
        super().__init__(**kwargs)


    def build(self, input_shape):
        if len(input_shape) != 2 :
            raise ValueError('A `DensityMatrix2Dist` layer should be '
                             'called with a tensor of shape '
                             '(batch_size, n)')
        self.vals = tf.constant(tf.linspace(0.0, 1.0, input_shape[1]), dtype=tf.float32)
        self.vals2 = self.vals ** 2
        self.built = True

    def call(self, inputs):
        if len(inputs.shape) != 2:
            raise ValueError('A `DensityMatrix2Dist` layer should be '
                             'called with a tensor of shape '
                             '(batch_size, n, )')
        mean = tf.einsum('...i,i->...', inputs, self.vals, optimize='optimal')
        mean2 = tf.einsum('...i,i->...', inputs, self.vals2, optimize='optimal')
        var = mean2 - mean ** 2
        return tf.stack([mean, var], axis = -1)
        #return mean

    def compute_output_shape(self, input_shape):
        return (input_shape[1], 2)

In [None]:
def dm2comp(dm):
    '''
    Extract vectors and weights from a factorized density matrix representation
    Arguments:
     dm: tensor of shape (bs, n, d + 1)
    Returns:
     w: tensor of shape (bs, n)
     v: tensor of shape (bs, n, d)
    '''
    return dm[:, :, 0], dm[:, :, 1:]

def comp2dm(w, v):
    '''
    Construct a factorized density matrix from vectors and weights
    Arguments:
     w: tensor of shape (bs, n)
     v: tensor of shape (bs, n, d)
    Returns:
     dm: tensor of shape (bs, n, d + 1)
    '''
    return tf.concat((w[:, :, tf.newaxis], v), axis=2)

def samples2dm(samples):
    '''
    Construct a factorized density matrix from a batch of samples
    each sample will have the same weight. Samples that are all
    zero will be ignored.
    Arguments:
        samples: tensor of shape (bs, n, d)
    Returns:
        dm: tensor of shape (bs, n, d + 1)
    '''
    w = tf.reduce_any(samples, axis=-1)
    w = w / tf.reduce_sum(w, axis=-1, keepdims=True)
    return comp2dm(w, samples)

def pure2dm(psi):
    '''
    Construct a factorized density matrix to represent a pure state
    Arguments:
     psi: tensor of shape (bs, d)
    Returns:
     dm: tensor of shape (bs, 1, d + 1)
    '''
    ones = tf.ones_like(psi[:, 0:1])
    dm = tf.concat((ones[:,tf.newaxis, :],
                    psi[:,tf.newaxis, :]),
                   axis=2)
    return dm

def dm2discrete(dm):
    '''
    Creates a discrete distribution from the components of a density matrix
    Arguments:
     dm: tensor of shape (bs, n, d + 1)
    Returns:
     prob: vector of probabilities (bs, d)
    '''
    w, v = dm2comp(dm)
    w = w / tf.reduce_sum(w, axis=-1, keepdims=True)
    norms_v = tf.expand_dims(tf.linalg.norm(v, axis=-1), axis=-1)
    v = v / norms_v
    probs = tf.einsum('...j,...ji->...i', w, v ** 2, optimize="optimal")
    return probs

def dm2distrib(dm, sigma):
    '''
    Creates a Gaussian mixture distribution from the components of a density
    matrix with an RBF kernel
    Arguments:
     dm: tensor of shape (bs, n, d + 1)
     sigma: sigma parameter of the RBF kernel
    Returns:
     gm: mixture of Gaussian distribution with shape (bs, )
    '''
    w, v = dm2comp(dm)
    gm = tfd.MixtureSameFamily(reparameterize=True,
            mixture_distribution=tfd.Categorical(
                                    probs=w),
            components_distribution=tfd.Independent( tfd.Normal(
                    loc=v,  # component 2
                    scale=sigma / np.sqrt(2.)),
                    reinterpreted_batch_ndims=1))
    return gm

def pure_dm_overlap(x, dm, kernel):
    '''
    Calculates the overlap of a state  \phi(x) with a density
    matrix in a RKHS defined by a kernel
    Arguments:
      x: tensor of shape (bs, d)
     dm: tensor of shape (bs, n, d + 1)
     kernel: kernel function
              k: (bs, d) x (bs, n, d) -> (bs, n)
    Returns:
     overlap: tensor with shape (bs, )
    '''
    w, v = dm2comp(dm)
    overlap = tf.einsum('...i,...i->...', w, kernel(x, v) ** 2)
    return overlap

class CompTransKernelLayer(tf.keras.layers.Layer):
    def __init__(self, transform, kernel):
        '''
        Composes a transformation and a kernel to create a new
        kernel.
        Arguments:
            transform: a function f that transform the input before feeding it to the
                    kernel
                    f:(bs, d) -> (bs, D)
            kernel: a kernel function
                    k:(bs, n, D)x(m, D) -> (bs, n, m)
        '''
        super(CompTransKernelLayer, self).__init__()
        self.transform = transform
        self.kernel = kernel

    def call(self, A, B):
        '''
        Input:
            A: tensor of shape (bs, n, d)
            B: tensor of shape (m, d)
        Result:
            K: tensor of shape (bs, n, m)
        '''
        shape = tf.shape(A) # (bs, n, d)
        A = tf.reshape(A, [shape[0] * shape[1], shape[2]])
        A = self.transform(A)
        dim_out = tf.shape(A)[1]
        A = tf.reshape(A, [shape[0], shape[1], dim_out])
        B = self.transform(B)
        return self.kernel(A, B)

    def log_weight(self):
        return self.kernel.log_weight()

class RBFKernelLayer(tf.keras.layers.Layer):
    def __init__(self, sigma, dim, trainable=True, min_sigma=1e-3):
        '''
        Builds a layer that calculates the rbf kernel between two set of vectors
        Arguments:
            sigma: RBF scale parameter. If it is a tf.Variable it will be used as is.
                     Otherwise it will create a trainable variable with the given value.
        '''
        super(RBFKernelLayer, self).__init__()
        if type(sigma) is tf.Variable:
            self.sigma = sigma
        else:
            self.sigma = tf.Variable(sigma, dtype=tf.float32, trainable=trainable)
        self.dim = dim
        self.min_sigma = min_sigma

    def call(self, A, B):
        '''
        Input:
            A: tensor of shape (bs, n, d)
            B: tensor of shape (m, d)
        Result:
            K: tensor of shape (bs, n, m)
        '''
        shape_A = tf.shape(A)
        shape_B = tf.shape(B)
        A_norm = tf.norm(A, axis=-1)[..., tf.newaxis] ** 2
        B_norm = tf.norm(B, axis=-1)[tf.newaxis, tf.newaxis, :] ** 2
        A_reshaped = tf.reshape(A, [-1, shape_A[2]])
        AB = tf.matmul(A_reshaped, B, transpose_b=True)
        AB = tf.reshape(AB, [shape_A[0], shape_A[1], shape_B[0]])
        dist2 = A_norm + B_norm - 2. * AB
        dist2 = tf.clip_by_value(dist2, 0., np.inf)
        sigma = tf.clip_by_value(self.sigma, self.min_sigma, np.inf)
        K = tf.exp(-dist2 / (2. * sigma ** 2.))
        return K

    def log_weight(self):
        sigma = tf.clip_by_value(self.sigma, self.min_sigma, np.inf)
        return - self.dim * tf.math.log(sigma + 1e-12) - self.dim * np.log(4 * np.pi)

class CosineKernelLayer(tf.keras.layers.Layer):
    def __init__(self):
        '''
        Builds a layer that calculates the cosine kernel between two set of vectors
        '''
        super(CosineKernelLayer, self).__init__()
        self.eps = 1e-6

    def call(self, A, B):
        '''
        Input:
            A: tensor of shape (bs, n, d)
            B: tensor of shape (m, d)
        Result:
            K: tensor of shape (bs, n, m)
        '''
        A = tf.math.divide_no_nan(A,
                                  tf.expand_dims(tf.norm(A, axis=-1), axis=-1))
        B = tf.math.divide_no_nan(B,
                                  tf.expand_dims(tf.norm(B, axis=-1), axis=-1))
        K = tf.einsum("...nd,md->...nm", A, B)
        return K

    def log_weight(self):
        return 0

class CrossProductKernelLayer(tf.keras.layers.Layer):

    def __init__(self, dim1, kernel1, kernel2):
        '''
        Create a layer that calculates the cross product kernel of two input
        kernels. The input vector are divided into two parts, the first of dimension
        dim1 and the second of dimension d - dim1. Each input kernel is applied to
        one of the parts of the input.
        Arguments:
            dim1: the dimension of the first part of the input vector
            kernel1: a kernel function
                    k1:(bs, n, dim1)x(m, dim1) -> (bs, n, m)
            kernel2: a kernel function
                    k2:(bs, n, d - dim1)x(m, d - dim1) -> (bs, n, m)
        '''

        super(CrossProductKernelLayer, self).__init__()
        self.dim1 = dim1
        self.kernel1 = kernel1
        self.kernel2 = kernel2

    def call(self, A, B):
        '''
        Input:
            A: tensor of shape (bs, n, d)
            B: tensor of shape (m, d)
        Result:
            K: tensor of shape (bs, n, m)
        '''
        A1 = A[:, :, :self.dim1]
        A2 = A[:, :, self.dim1:]
        B1 = B[:, :self.dim1]
        B2 = B[:, self.dim1:]
        return self.kernel1(A1, B1) * self.kernel2(A2, B2)

    def log_weight(self):
        return self.kernel1.log_weight() + self.kernel2.log_weight()

def l1_loss(vals):
    '''
    Calculate the l1 loss for a batch of vectors
    Arguments:
        vals: tensor with shape (b_size, n)
    '''
    b_size = tf.cast(tf.shape(vals)[0], dtype=tf.float32)
    vals = vals / tf.norm(vals, axis=1)[:, tf.newaxis]
    loss = tf.reduce_sum(tf.abs(vals)) / b_size
    return loss

class KDMUnit(tf.keras.layers.Layer):
    """Kernel Quantum Measurement Unit
    Receives as input a factored density matrix represented by a set of vectors
    and weight values.
    Returns a resulting factored density matrix.
    Input shape:
        (batch_size, n_comp_in, dim_x + 1)
        where dim_x is the dimension of the input state
        and n_comp_in is the number of components of the input factorization.
        The weights of the input factorization of sample i are [i, :, 0],
        and the vectors are [i, :, 1:dim_x + 1].
    Output shape:
        (batch_size, n_comp, dim_y)
        where dim_y is the dimension of the output state
        and n_comp is the number of components used to represent the train
        density matrix. The weights of the
        output factorization for sample i are [i, :, 0], and the vectors
        are [i, :, 1:dim_y + 1].
    Arguments:
        dim_x: int. the dimension of the input state
        dim_y: int. the dimension of the output state
        x_train: bool. Whether to train or not the x compoments of the train
                       density matrix.
        x_train: bool. Whether to train or not the y compoments of the train
                       density matrix.
        w_train: bool. Whether to train or not the weights of the compoments
                       of the train density matrix.
        n_comp: int. Number of components used to represent
                 the train density matrix
        l1_act: float. Coefficient of the regularization term penalizing the l1
                       norm of the activations.
        l1_x: float. Coefficient of the regularization term penalizing the l1
                       norm of the x components.
        l1_y: float. Coefficient of the regularization term penalizing the l1
                       norm of the y components.
    """
    def __init__(
            self,
            kernel,
            dim_x: int,
            dim_y: int,
            x_train: bool = True,
            y_train: bool = True,
            w_train: bool = True,
            n_comp: int = 0,
            l1_x: float = 0.,
            l1_y: float = 0.,
            l1_act: float = 0.,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.kernel = kernel
        self.dim_x = dim_x
        self.dim_y = dim_y
        self.x_train = x_train
        self.y_train = y_train
        self.w_train = w_train
        self.n_comp = n_comp
        self.l1_x = l1_x
        self.l1_y = l1_y
        self.l1_act = l1_act
        self.c_x = self.add_weight(
            "c_x",
            shape=(self.n_comp, self.dim_x),
            #initializer=tf.keras.initializers.orthogonal(),
            initializer=tf.keras.initializers.random_normal(),
            trainable=self.x_train)
        self.c_y = self.add_weight(
            "c_y",
            shape=(self.n_comp, self.dim_y),
            initializer=tf.keras.initializers.Constant(np.sqrt(1./self.dim_y)),
            #initializer=tf.keras.initializers.random_normal(),
            trainable=self.y_train)
        self.comp_w = self.add_weight(
            "comp_w",
            shape=(self.n_comp,),
            initializer=tf.keras.initializers.constant(1./self.n_comp),
            trainable=self.w_train)
        self.eps = 1e-10

    def call(self, inputs):
        # Weight regularizers
        if self.l1_x != 0:
            self.add_loss(self.l1_x * l1_loss(self.c_x))
        if self.l1_y != 0:
            self.add_loss(self.l1_y * l1_loss(self.c_y))
        #comp_w = tf.clip_by_value(self.comp_w, 1e-10, 1)
        comp_w = tf.abs(self.comp_w) + 1e-6
        # normalize comp_w to sum to 1
        comp_w = comp_w / tf.reduce_sum(comp_w)
        in_w = inputs[:, :, 0]  # shape (b, n_comp_in)
        in_v = inputs[:, :, 1:] # shape (b, n_comp_in, dim_x)
        out_vw = self.kernel(in_v, self.c_x)  # shape (b, n_comp_in, n_comp)
        out_w = (tf.expand_dims(tf.expand_dims(comp_w, axis=0), axis=0) *
                 tf.square(out_vw)) # shape (b, n_comp_in, n_comp)
        out_w = tf.maximum(out_w, self.eps) #########
        # out_w_sum = tf.maximum(tf.reduce_sum(out_w, axis=2), self.eps)  # shape (b, n_comp_in)
        out_w_sum = tf.reduce_sum(out_w, axis=2) # shape (b, n_comp_in)
        out_w = out_w / tf.expand_dims(out_w_sum, axis=2)
        out_w = tf.einsum('...i,...ij->...j', in_w, out_w, optimize="optimal")
                # shape (b, n_comp)
        if self.l1_act != 0:
            self.add_loss(self.l1_act * l1_loss(out_w))
        out_w = tf.expand_dims(out_w, axis=-1) # shape (b, n_comp, 1)
        out_y_shape = tf.shape(out_w) + tf.constant([0, 0, self.dim_y - 1])
        out_y = tf.broadcast_to(tf.expand_dims(self.c_y, axis=0), out_y_shape)
        out = tf.concat((out_w, out_y), 2)
        return out

    def get_config(self):
        config = {
            "dim_x": self.dim_x,
            "dim_y": self.dim_y,
            "n_comp": self.n_comp,
            "x_train": self.x_train,
            "y_train": self.y_train,
            "w_train": self.w_train,
            "l1_x": self.l1_x,
            "l1_y": self.l1_y,
            "l1_act": self.l1_act,
        }
        base_config = super().get_config()
        return {**base_config, **config}

    def compute_output_shape(self, input_shape):
        return (self.dim_y + 1, self.n_comp)

class KDMOverlap(tf.keras.layers.Layer):
    """Kernel Quantum Measurement Overlap Unit
    Receives as input a vector and calculates its overlap with the unit density
    matrix.
    Input shape:
        (batch_size, dim_x)
        where dim_x is the dimension of the input state
    Output shape:
        (batch_size, )
    Arguments:
        kernel: a kernel function
        dim_x: int. the dimension of the input state
        x_train: bool. Whether to train the or not the compoments of the train
                       density matrix.
        w_train: bool. Whether to train the or not the weights of the compoments
                       of the train density matrix.
        n_comp: int. Number of components used to represent
                 the train density matrix
    """

    def __init__(
            self,
            kernel,
            dim_x: int,
            x_train: bool = True,
            w_train: bool = True,
            n_comp: int = 0,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.kernel = kernel
        self.dim_x = dim_x
        self.x_train = x_train
        self.w_train = w_train
        self.n_comp = n_comp
        self.c_x = self.add_weight(
            "c_x",
            shape=(self.n_comp, self.dim_x),
            #initializer=tf.keras.initializers.orthogonal(),
            initializer=tf.keras.initializers.random_normal(),
            trainable=self.x_train)
        self.comp_w = self.add_weight(
            "comp_w",
            shape=(self.n_comp,),
            initializer=tf.keras.initializers.constant(1./self.n_comp),
            trainable=self.w_train)

    def call(self, inputs):
        #comp_w = tf.clip_by_value(self.comp_w, 1e-10, 1)
        comp_w = tf.abs(self.comp_w)
        # normalize comp_w to sum to 1
        comp_w = comp_w / tf.reduce_sum(comp_w)
        in_v = inputs[:, tf.newaxis, :]
        out_vw = self.kernel(in_v, self.c_x) ** 2 # shape (b, 1, n_comp)
        out_w = tf.einsum('...j,...ij->...', comp_w, out_vw, optimize="optimal")
        return out_w

    def get_config(self):
        config = {
            "dim_x": self.dim_x,
            "n_comp": self.n_comp,
            "x_train": self.x_train,
            "w_train": self.w_train,
        }
        base_config = super().get_config()
        return {**base_config, **config}

    def compute_output_shape(self, input_shape):
        return (1,)

class KDMClassModel(tf.keras.Model):
    def __init__(self,
                 dim_x,
                 dim_y,
                 sigma,
                 n_comp,
                 x_train=True):
        super().__init__()
        self.dim_x = dim_x
        self.dim_y = dim_y
        self.n_comp = n_comp
        self.kernel_x = RBFKernelLayer(sigma, dim=dim_x)
        self.kdmu = KDMUnit(self.kernel_x,
                            dim_x=dim_x,
                            dim_y=dim_y,
                            n_comp=n_comp,
                            x_train=x_train)

    def call(self, inputs):
        rho_x = pure2dm(inputs)
        rho_y = self.kdmu(rho_x)
        probs = dm2discrete(rho_y)
        return probs

class BagKDMClassModel(tf.keras.Model):
    def __init__(self,
                 dim_x,
                 dim_y,
                 sigma,
                 n_comp,
                 x_train=True,
                 l1_y=0.):
        super().__init__()
        self.dim_x = dim_x
        self.dim_y = dim_y
        self.n_comp = n_comp
        kernel_x = RBFKernelLayer(sigma)
        self.kdmu = KDMUnit(kernel_x,
                            dim_x=dim_x,
                            dim_y=dim_y,
                            n_comp=n_comp,
                            x_train=x_train,
                            l1_y=l1_y)

    def call(self, inputs):
        in_shape = tf.shape(inputs)
        w = tf.ones_like(inputs[:, :, 0]) / in_shape[1]
        rho_x = comp2dm(w, inputs)
        rho_y = self.kdmu(rho_x)
        probs = dm2discrete(rho_y)
        return rho_y

class KDMDenEstModel(tf.keras.Model):
    def __init__(self,
                 dim_x,
                 sigma,
                 n_comp):
        super().__init__()
        self.dim_x = dim_x
        self.n_comp = n_comp
        self.kernel = RBFKernelLayer(sigma, dim=dim_x)
        self.kdmover = KDMOverlap(self.kernel,
                                dim_x=dim_x,
                                n_comp=n_comp)

    def call(self, inputs):
        log_probs = (tf.math.log(self.kdmover(inputs) + 1e-12)
                     + self.kernel.log_weight())
        self.add_loss(-tf.reduce_mean(log_probs))
        return log_probs

class KDMDenEstModel2(tf.keras.Model):
    def __init__(self,
                 dim_x,
                 dim_y,
                 sigma,
                 n_comp,
                 trainable_sigma=True,
                 min_sigma=1e-3):
        super().__init__()
        self.dim_x = dim_x
        self.dim_y = dim_y
        self.n_comp = n_comp
        self.kernel_x = RBFKernelLayer(sigma, dim=dim_x,
                                       trainable=trainable_sigma,
                                       min_sigma=min_sigma)
        self.kernel_y = CosineKernelLayer()
        self.kernel = CrossProductKernelLayer(dim1=dim_x, kernel1=self.kernel_x, kernel2=self.kernel_y)
        self.kdmover = KDMOverlap(self.kernel,
                                dim_x=dim_x + dim_y,
                                n_comp=n_comp)

    def call(self, inputs):
        log_probs = (tf.math.log(self.kdmover(inputs) + 1e-12)
                     + self.kernel.log_weight())
        self.add_loss(-tf.reduce_mean(log_probs))
        return log_probs

class KQClassBagModel(tf.keras.Model):
    def __init__(self,
                 encoded_size,
                 dim_y,
                 encoder,
                 n_comp,
                 sigma=0.1,
                 mle_weight=0.):
        super().__init__()
        self.dim_y = dim_y
        self.encoded_size = encoded_size
        self.encoder = encoder
        self.n_comp = n_comp
        self.mle_weight = mle_weight
        self.kernel = RBFKernelLayer(sigma=sigma,
                                         dim=encoded_size,
                                         trainable=True)
        self.kdm_unit = KDMUnit(kernel=self.kernel,
                                       dim_x=encoded_size,
                                       dim_y=dim_y,
                                       n_comp=n_comp)
        self.regression_layer = ProbRegression()

    def call(self, input):
        encoded = self.encoder(input)
        rho_x = pure2dm(encoded)
        rho_y = self.kdm_unit(rho_x)
        probs = dm2discrete(rho_y)
        mean_var = self.regression_layer(probs)
        return mean_var

    def init_components(self, samples_x, samples_y, init_sigma=False, sigma_mult=1):
        encoded_x = self.encoder(samples_x)
        if init_sigma:
            distances = pairwise_distances(encoded_x)
            sigma = np.mean(distances) * sigma_mult
            self.kernel.sigma.assign(sigma)
        self.kdm_unit.c_x.assign(encoded_x)
        self.kdm_unit.c_y.assign(samples_y)
        self.kdm_unit.comp_w.assign(tf.ones((self.n_comp,)) / self.n_comp)

##  Create encoder

In [None]:
encoded_size = 128

def create_convnext_encoder(encoded_size):
  convnext = tf.keras.applications.convnext.ConvNeXtTiny(
    model_name='convnext_tiny',
    include_top=False,
    include_preprocessing=True,
    weights='imagenet',
    input_tensor=None,
    input_shape=(256,256,3),
    pooling="avg",
    classes=5,
    classifier_activation='softmax'
  )
  encoder = keras.Sequential([
      Input(shape=(256, 256, 3)),
      convnext,
      keras.layers.Dropout(0.2),
      keras.layers.Dense(encoded_size, activation="tanh"), #relu, linear
  ])

  encoder_cls = keras.Sequential([encoder,
                                keras.layers.Dense(5, activation="softmax")],
  )
  return encoder, encoder_cls

In [None]:
#encoder, encoder_cls = create_convnext_encoder(encoded_size)

## ConvNeXT Training / Warmup

In [None]:
alpha = 0.1
def loss(y_true, y_pred):
  return tf.keras.losses.mean_squared_error(y_true, y_pred[:,0:1])  +  alpha * y_pred[:, 1:2]

In [None]:
# checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
#     "/home/smedin7/data/KDM/models/patch_regression_weights.h5",
#     monitor = "val_cohen_kappa",
#     verbose = 1,
#     save_best_only = True,
#     save_weights_only = True,
#     mode = "max",
#     save_freq="epoch",
# )
# early_stopping = EarlyStopping(monitor='val_loss', patience=10, verbose=1, mode="min")

In [None]:
# encoder_cls.compile(optimizer=optimizers.Adam(learning_rate=1e-4),
#                 loss=losses.categorical_crossentropy,
#                 metrics=[metrics.categorical_accuracy, tfa.metrics.CohenKappa(num_classes = 5, weightage='quadratic')], loss_weights=class_weights)

In [None]:
# encoder_cls.load_weights('/content/drive/MyDrive/data/kdm_data/best_convnext_patch_classifier.h5')
# encoder.save_weights('/content/drive/MyDrive/data/kdm_data/best_convnext_patch_encoder.h5')

In [None]:
# history = encoder_cls.fit(train_dataset, validation_data=val_dataset, epochs=1, verbose=1, callbacks=[checkpoint_callback, early_stopping, WandbMetricsLogger(log_freq='epoch')])

In [None]:
# encoder_cls.evaluate(test_dataset)

Confusion matrix of convnext alone to classify patches

In [None]:
# preds = encoder_cls.predict(test_dataset)
# predictions = np.argmax(preds, axis =1)
# y_true = np.argmax(np.concatenate([y for x,y in test_dataset], axis=0), axis = 1)
# ConfusionMatrixDisplay.from_predictions(y_true, predictions, normalize='true', display_labels=['Stroma', 'Healthy', 'G 3', 'G 4', 'G 5'])

## KDM Training

#### Create encoder for KDM

### Callbacks

In [None]:
kdm_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    "/data/KDM/models/patch_regression_weights.h5",
    monitor = "val_mean_absolute_error",
    verbose = 1,
    save_best_only = True,
    save_weights_only = True,
    mode = "min",
    save_freq="epoch",
)

early_stopping_callback = EarlyStopping(
    monitor='val_mean_absolute_error',
    patience=10,
    verbose=1,
    mode='min',
    restore_best_weights=True,
)

### Define number of components

In [None]:
n_comp = 216

### Instantiate model

### Initialize KDM with prototypes

## Train

In [None]:
mirrored_strategy = tf.distribute.MirroredStrategy()

with mirrored_strategy.scope():
    encoder_kdm, _ = create_convnext_encoder(encoded_size)
    encoder_kdm.load_weights('/data/KDM/models/best_convnext_patch_encoder.h5')
    kdm_cls = KQClassBagModel(
                    encoded_size=encoded_size,
                    dim_y=5,
                    encoder=encoder_kdm,
                    n_comp=n_comp,
                    sigma=0.1)
    kdm_cls(next(iter(train_dataset))[0])
    # patch_protos, patch_label_protos = get_patch_prototypes()
    # kdm_cls.init_components(patch_protos, patch_label_protos, init_sigma = True, sigma_mult = 1.)
    kdm_cls.load_weights("/data/KDM/models/patch_regression_weights.h5")
    kdm_cls.compile(optimizer=optimizers.Adam(learning_rate=1e-4),
                    loss=loss,
                    metrics=[metrics.mean_absolute_error,
                            tfa.metrics.CohenKappa(num_classes = 5, weightage='quadratic')])
print(kdm_cls.kernel.sigma.numpy())


In [None]:
# _ = kdm_cls(tf.zeros((1, 256,256,3)))
# kdm_cls.load_weights('/content/drive/MyDrive/data/kdm_data/best_kdm_patch_convnext.h5')
#encoder_kdm.save_weights('/content/drive/MyDrive/data/kdm_data/best_kdm_patch_convnext_extractor.h5')

In [None]:
history = kdm_cls.fit(train_dataset, validation_data=val_dataset, epochs=30, verbose=1, callbacks=[kdm_checkpoint_callback, early_stopping_callback])

In [None]:
kdm_cls.evaluate(test_dataset)

In [None]:
out = kdm_cls.predict(test_dataset)
y_pred, std = out[:, 0], np.sqrt(out[:, 1])

predictions = np.round(y_pred * 4)
y_true = np.round(np.concatenate([y for x,y in test_dataset], axis=0) * 4)

print(ConfusionMatrixDisplay.from_predictions(y_true, predictions, normalize='true', display_labels=['Stroma', 'Healthy', 'G3', 'G 4', 'G 5']))



In [None]:
y_true_reg = np.round(np.concatenate([y for x,y in test_dataset], axis=0))

print(mean_absolute_error(y_true_reg, y_pred))

print(cohen_kappa_score(y_true, predictions, weights='quadratic'))

print(classification_report(y_true, predictions))

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

var = std ** 2
abs_errors = np.abs(predictions - y_true)

# Group predictions and variance by absolute error
data_to_plot = {'Error Group': [], 'Variance': [], 'Predictions': []}
for i, error in enumerate(abs_errors):
    if error == 0:
        error_group = '0'
    elif error == 1:
        error_group = '1'
    else:
        error_group = '2+'

    data_to_plot['Error Group'].append(error_group)
    data_to_plot['Variance'].append(var[i])
    data_to_plot['Predictions'].append(predictions[i])

# Create a DataFrame
df = pd.DataFrame(data_to_plot)

# Get the counts for each error group
counts = df['Error Group'].value_counts().loc[['0', '1', '2+']]
# Define the color palette
palette = sns.color_palette("Pastel1", n_colors=3)
sample_fraction = 0.005  # 10% of data
df_sampled = df.groupby('Error Group').apply(lambda x: x.sample(frac=sample_fraction)).reset_index(drop=True)

# Plot the variance using seaborn
plt.figure(figsize=(10, 6))
sns.violinplot(x='Error Group', y='Variance', data=df, order=['0', '1', '2+'], palette=palette)
sns.stripplot(x='Error Group', y='Variance', data=df_sampled, jitter=True, marker="o", alpha=0.2, color='gray')

plt.title('Variance by Absolute Error Group', fontsize=15)
plt.xlabel('Error Group', fontsize=13)
plt.ylabel('Variance', fontsize=13)

# Create custom legend
legend_labels = [f'{key}: {value} samples' for key, value in counts.items()]
legend_handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=palette[i], markersize=10) for i in range(3)]
plt.legend(legend_handles, legend_labels, loc='center left', bbox_to_anchor=(1, 0.5))

plt.savefig('variance_plot.png', bbox_inches='tight', dpi=600)
plt.show()

# Plot the predictions using seaborn
plt.figure(figsize=(10, 6))
sns.violinplot(x='Error Group', y='Predictions', data=df, order=['0', '1', '2+'], palette=palette)


# Create the strip plot using the sampled DataFrame
sns.stripplot(x='Error Group', y='Predictions', data=df_sampled, jitter=True, marker="o", alpha=0.2, color='gray')
#sns.stripplot(x='Error Group', y='Predictions', data=df, jitter=True, marker="o", alpha=0.4, color='black')

plt.title('Predictions by Absolute Error Group', fontsize=15)
plt.xlabel('Error Group', fontsize=13)
plt.ylabel('Predictions', fontsize=13)

# Create custom legend
plt.legend(legend_handles, legend_labels, loc='center left', bbox_to_anchor=(1, 0.5))

plt.savefig('predictions_plot.png', bbox_inches='tight', dpi=600)
plt.show()




## testing model when taking into account variance

In [None]:
threshold = 0.05
indices = np.where(var < threshold)[0]
confident_predictions = predictions[indices]
ground_truth_subset = y_true[indices]

In [None]:
print(ConfusionMatrixDisplay.from_predictions(ground_truth_subset, confident_predictions, normalize='true', display_labels=['Stroma', 'Healthy', 'G3', 'G 4', 'G 5']))


In [None]:
y_true_reg_subset =  y_true_reg[indices]
y_pred_subset = y_pred[indices]

print("MAE: ",mean_absolute_error(y_true_reg_subset, y_pred_subset))

print("KAPPA: ",cohen_kappa_score(ground_truth_subset, confident_predictions, weights='quadratic'))

print(classification_report(ground_truth_subset, confident_predictions))

## extract features and compare with prototypes

### Load pretrained KDM

In [None]:
kdm_cls = KQClassBagModel(
                 encoded_size=encoded_size,
                 dim_y=5,
                 encoder=encoder_kdm,
                 n_comp=n_comp,
                 sigma=0.1)
kdm_cls.compile(optimizer=optimizers.Adam(learning_rate=1e-4),
                loss=losses.categorical_crossentropy,
                metrics=[metrics.categorical_accuracy,
                         tfa.metrics.CohenKappa(num_classes = 5, weightage='quadratic')])
_ = kdm_cls(tf.zeros((1, 256,256,3)))
kdm_cls.load_weights('/content/drive/MyDrive/data/kdm_data/best_kdm_patch_convnext.h5')

In [None]:
#features = encoder_kdm.predict(train_dataset)
np.save("/content/drive/MyDrive/data/kdm_data/train_patch_features_kdm.npy", features)
c_x = kdm_cls.kdm_unit.c_x
c_y = kdm_cls.kdm_unit.c_y
np.save("/content/drive/MyDrive/data/kdm_data/c_x_features_kdm.npy", c_x)
np.save("/content/drive/MyDrive/data/kdm_data/c_y_features_kdm.npy", c_y)
indices = np.argmin(dist, axis = 0)
comp_w = kdm_cls.kdm_unit.comp_w.numpy()

In [None]:
plt.plot(comp_w)

In [None]:
c_y[comp_w > 0.001]


In [None]:
weighted_c_x = c_x[comp_w > 0.001]

In [None]:
dist = pairwise_distances(features, weighted_c_x)

In [None]:
indices = np.argmin(dist, axis = 0)

In [None]:
indices

In [None]:
prototype_patches = df_train.iloc[indices]

In [None]:
learned_prototype_labels = np.argmax(c_y[comp_w > 0.001], axis=1)

In [None]:
learned_prototype_labels

In [None]:
img_ids = prototype_patches['image_id']

In [None]:
def plot_patches(prototype_patches, gleason_score):
  gleason_patches = prototype_patches[prototype_patches['gleason_score'] == gleason_score]
  patches_path = gleason_patches['image_id'].apply(lambda x: os.path.join("/content/patches", x.split('_')[0], x + '.jpeg'))
  for i, path in enumerate(patches_path):
    #plt.subplot(1, len(patches_path), i + 1)
    patch = Image.open(path)
    plt.imshow(patch)
    plt.show()





In [None]:
plot_patches(prototype_patches,4)

In [None]:
os.path.join(data_dir, x.split('_')[0], x + '.jpeg')

In [None]:
features.shape

In [None]:
preds = kdm_cls.predict(test_dataset)
predictions = np.argmax(preds, axis =1)
y_true = np.argmax(np.concatenate([y for x,y in test_dataset], axis=0), axis = 1)
ConfusionMatrixDisplay.from_predictions(y_true, predictions, normalize='true', display_labels=['Stroma', 'Healthy', 'G 3', 'G 4', 'G 5'])

In [None]:
print(classification_report(y_true, predictions))

In [None]:
wandb.finish()

In [None]:
from google.colab import runtime
runtime.unassign()