In [2]:
import tensorflow as tf
import numpy as np
import pandas as pd
import hydra
import os
from omegaconf import DictConfig, OmegaConf

In [3]:
!export CUDA_VISIBLE_DEVICES=3

### Load Configurations

In [15]:
def load_config(config_path):
    import yaml
    with open(config_path, "r") as f:
        cfg=yaml.safe_load(f)
    # cfg = pd.DataFrame.from_dict(cfg.items())
    return cfg

In [16]:
args=load_config(config_path="/home/quincy/code/DP_BACH/clam_tf_module/configs/train.yaml")
args

{'is_training': True,
 'gpu': True,
 'all_tfrecords_path': '/home/quincy/data/bach_tfrecords',
 'train_data_dir': '/home/quincy/data/bach_kf/fold_1/bach_fold_1_train.csv',
 'val_data_dir': '/home/quincy/data/bach_kf/fold_1/bach_fold_1_val.csv',
 'checkpoints_dir': '/home/quincy/exps/clam/h1/cv1',
 'i_optimizer_name': 'Adam',
 'b_optimizer_name': 'Adam',
 'a_optimizer_name': 'Adam',
 'i_loss_name': 'binary_crossentropy',
 'b_loss_name': 'binary_crossentropy',
 'c1': 0.95,
 'c2': 0.05,
 'att_gate': True,
 'mut_ex': False,
 'i_learn_rate': 0.0002,
 'b_learn_rate': 0.0002,
 'a_learn_rate': 0.0002,
 'imf_norm_op': True,
 'n_class': 2,
 'top_k_percent': 0.2,
 'm_clam_op': False,
 'batch_size': 0,
 'epochs': 50,
 'att_only': False,
 'net_size': 'small',
 'dropout_rate': 0.5,
 'dim_compress_features': 512}

### Attention Network

In [25]:
class G_Att_Net(tf.keras.Model):
    """_summary_

    Args:
        tf (_type_): _description_
    """

    def __init__(
        self,
        args,
        dim_features=1024,
        n_hidden_units=256,
    ):
        """_summary_

        Args:
            args (_type_): _description_
            dim_features (int, optional): _description_. Defaults to 1024.
            n_hidden_units (int, optional): _description_. Defaults to 256.
        """
        super(G_Att_Net, self).__init__()
        self.args = args
        self.dim_features = dim_features
        self.n_hidden_units = n_hidden_units

        self.compression_model = tf.keras.models.Sequential()
        self.model_v = tf.keras.models.Sequential()
        self.model_u = tf.keras.models.Sequential()
        self.model = tf.keras.models.Sequential()

        self.fc_compress_layer = tf.keras.layers.Dense(
            units=self.args["dim_compress_features"],
            activation="relu",
            input_shape=(self.dim_features,),
            kernel_initializer="glorot_normal",
            bias_initializer="zeros",
            name="Fully_Connected_Layer",
        )

        self.compression_model.add(self.fc_compress_layer)

        self.att_v_layer1 = tf.keras.layers.Dense(
            units=n_hidden_units,
            activation="linear",
            input_shape=(self.args["dim_compress_features"],),
            kernel_initializer="glorot_normal",
            bias_initializer="zeros",
            name="Attention_V_Layer1",
        )

        self.att_v_layer2 = tf.keras.layers.Dense(
            units=n_hidden_units,
            activation="tanh",
            input_shape=(self.args["dim_compress_features"],),
            kernel_initializer="glorot_normal",
            bias_initializer="zeros",
            name="Attention_V_Layer2",
        )

        self.att_u_layer1 = tf.keras.layers.Dense(
            units=n_hidden_units,
            activation="linear",
            input_shape=(self.args["dim_compress_features"],),
            kernel_initializer="glorot_normal",
            bias_initializer="zeros",
            name="Attention_U_Layer1",
        )

        self.att_u_layer2 = tf.keras.layers.Dense(
            units=n_hidden_units,
            activation="sigmoid",
            input_shape=(self.args["dim_compress_features"],),
            kernel_initializer="glorot_normal",
            bias_initializer="zeros",
            name="Attention_U_Layer2",
        )

        self.att_layer_f = tf.keras.layers.Dense(
            units=self.args["n_class"],
            activation="linear",
            input_shape=(n_hidden_units,),
            kernel_initializer="glorot_normal",
            bias_initializer="zeros",
            name="Attention_Gated_Final_Layer",
        )

        self.model_v.add(self.att_v_layer1)
        self.model_v.add(self.att_v_layer2)

        self.model_u.add(self.att_u_layer1)
        self.model_u.add(self.att_u_layer2)

        if self.args["dropout_rate"] > 0.0:
            self.model_v.add(
                tf.keras.layers.Dropout(self.args["dropout_rate"], name="Dropout_V_Layer")
            )
            self.model_u.add(
                tf.keras.layers.Dropout(self.args["dropout_rate"], name="Dropout_U_Layer")
            )

        self.model.add(self.att_layer_f)

    def att_model(self):
        """_summary_

        Returns:
            _type_: _description_
        """
        attention_model = [
            self.compression_model,
            self.model_v,
            self.model_u,
            self.model,
        ]
        return attention_model

    def call(self, img_features):
        """_summary_

        Args:
            img_features (_type_): _description_

        Returns:
            _type_: _description_
        """
        h = list()
        A = list()

        for i in img_features:
            c_imf = self.att_model()[0](i)
            h.append(c_imf)

        for j in h:
            att_v_output = self.att_model()[1](j)
            att_u_output = self.att_model()[2](j)
            att_input = tf.math.multiply(att_v_output, att_u_output)
            a = self.att_model()[3](att_input)
            A.append(a)

        return {"h": h, "A": A}

In [26]:
att = G_Att_Net(args,dim_features=1024,
        n_hidden_units=256,)

### Instance Classifier

In [36]:
class Ins(tf.keras.Model):
    """_summary_

    Args:
        tf (_type_): _description_
    """

    def __init__(
        self,
        args,
    ):
        """_summary_

        Args:
            args (_type_): _description_
        """
        super(Ins, self).__init__()
        self.args = args

        self.ins_model = list()
        self.m_ins_model = tf.keras.models.Sequential()
        self.m_ins_layer = tf.keras.layers.Dense(
            units=self.args["n_class"],
            activation="linear",
            input_shape=(self.args["dim_compress_features"],),
            name="Instance_Classifier_Layer",
        )
        self.m_ins_model.add(self.m_ins_layer)

        for i in range(self.args["n_class"]):
            self.ins_model.append(self.m_ins_model)

    def ins_classifier(self):
        """_summary_

        Returns:
            _type_: _description_
        """
        return self.ins_model

    @staticmethod
    def generate_pos_labels(n_pos_sample):
        """_summary_

        Args:
            n_pos_sample (_type_): _description_

        Returns:
            _type_: _description_
        """
        return tf.fill(
            dims=[
                n_pos_sample,
            ],
            value=1,
        )

    @staticmethod
    def generate_neg_labels(n_neg_sample):
        """_summary_

        Args:
            n_neg_sample (_type_): _description_

        Returns:
            _type_: _description_
        """
        return tf.fill(
            dims=[
                n_neg_sample,
            ],
            value=0,
        )

    def in_call(self, n_ins, ins_classifier, h, A_I):
        """_summary_

        Args:
            n_ins (_type_): _description_
            ins_classifier (_type_): _description_
            h (_type_): _description_
            A_I (_type_): _description_

        Returns:
            _type_: _description_
        """
        pos_label = self.generate_pos_labels(n_ins)
        neg_label = self.generate_neg_labels(n_ins)
        ins_label_in = tf.concat(values=[pos_label, neg_label], axis=0)
        A_I = tf.reshape(tf.convert_to_tensor(A_I), (1, len(A_I)))

        top_pos_ids = tf.math.top_k(A_I, n_ins)[1][-1]
        pos_index = list()
        for i in top_pos_ids:
            pos_index.append(i)

        pos_index = tf.convert_to_tensor(pos_index)
        top_pos = list()
        for i in pos_index:
            top_pos.append(h[i])

        top_neg_ids = tf.math.top_k(-A_I, n_ins)[1][-1]
        neg_index = list()
        for i in top_neg_ids:
            neg_index.append(i)

        neg_index = tf.convert_to_tensor(neg_index)
        top_neg = list()
        for i in neg_index:
            top_neg.append(h[i])

        ins_in = tf.concat(values=[top_pos, top_neg], axis=0)
        logits_unnorm_in = list()
        logits_in = list()

        for i in range(self.args["n_class"] * n_ins):
            ins_score_unnorm_in = ins_classifier(ins_in[i])
            logit_in = tf.math.softmax(ins_score_unnorm_in)
            logits_unnorm_in.append(ins_score_unnorm_in)
            logits_in.append(logit_in)

        return ins_label_in, logits_unnorm_in, logits_in

    def out_call(self, n_ins, ins_classifier, h, A_O):
        """_summary_

        Args:
            n_ins (_type_): _description_
            ins_classifier (_type_): _description_
            h (_type_): _description_
            A_O (_type_): _description_

        Returns:
            _type_: _description_
        """
        # get compressed 512-dimensional instance-level feature vectors for following use, denoted by h
        A_O = tf.reshape(tf.convert_to_tensor(A_O), (1, len(A_O)))
        top_pos_ids = tf.math.top_k(A_O, n_ins)[1][-1]
        pos_index = list()
        for i in top_pos_ids:
            pos_index.append(i)

        pos_index = tf.convert_to_tensor(pos_index)
        top_pos = list()
        for i in pos_index:
            top_pos.append(h[i])

        # mutually-exclusive -> top k instances w/ highest attention scores ==> false pos = neg
        pos_ins_labels_out = self.generate_neg_labels(n_ins)
        ins_label_out = pos_ins_labels_out

        logits_unnorm_out = list()
        logits_out = list()

        for i in range(n_ins):
            ins_score_unnorm_out = ins_classifier(top_pos[i])
            logit_out = tf.math.softmax(ins_score_unnorm_out)
            logits_unnorm_out.append(ins_score_unnorm_out)
            logits_out.append(logit_out)

        return ins_label_out, logits_unnorm_out, logits_out

    def call(self, bag_label, h, A):
        """_summary_

        Args:
            bag_label (_type_): _description_
            h (_type_): _description_
            A (_type_): _description_

        Returns:
            _type_: _description_
        """
        n_ins = self.args["top_k_percent"] * len(h)
        n_ins = int(n_ins)
        # if n_ins computed above is less than 0, make n_ins be default be 8
        if n_ins == 0:
            n_ins += 8

        for i in range(self.args["n_class"]):
            ins_classifier = self.ins_classifier()[i]
            if i == bag_label:
                A_I = list()
                for j in range(len(A)):
                    a_i = A[j][0][i]
                    A_I.append(a_i)
                ins_label_in, logits_unnorm_in, logits_in = self.in_call(
                    n_ins, ins_classifier, h, A_I
                )
            else:
                if self.args["mut_ex"]:
                    A_O = list()
                    for j in range(len(A)):
                        a_o = A[j][0][i]
                        A_O.append(a_o)
                    ins_label_out, logits_unnorm_out, logits_out = self.out_call(
                        n_ins, ins_classifier, h, A_O
                    )
                else:
                    continue

        if self.args["mut_ex"]:
            ins_labels = tf.concat(values=[ins_label_in, ins_label_out], axis=0)
            ins_logits_unnorm = logits_unnorm_in + logits_unnorm_out
            ins_logits = logits_in + logits_out
        else:
            ins_labels = ins_label_in
            ins_logits_unnorm = logits_unnorm_in
            ins_logits = logits_in

        return {
            "ins_labels": ins_labels,
            "ins_logits_unnorm": ins_logits_unnorm,
            "ins_logits": ins_logits,
        }

In [37]:
ins = Ins(args)

### Bag Classifier

In [38]:
class S_Bag(tf.keras.Model):
    """_summary_

    Args:
        tf (_type_): _description_
    """

    def __init__(
        self,
        args,
    ):
        """_summary_

        Args:
            args (_type_): _description_
        """
        super(S_Bag, self).__init__()
        self.args = args

        self.s_bag_model = tf.keras.models.Sequential()
        self.s_bag_layer = tf.keras.layers.Dense(
            units=1,
            activation="linear",
            input_shape=(self.args["n_class"], self.args["dim_compress_features"]),
            name="Bag_Classifier_Layer",
        )
        self.s_bag_model.add(self.s_bag_layer)

    def bag_classifier(self):
        """_summary_

        Returns:
            _type_: _description_
        """
        return self.s_bag_model

    def h_slide(self, A, h):
        """_summary_

        Args:
            A (_type_): _description_
            h (_type_): _description_

        Returns:
            _type_: _description_
        """
        # compute the slide-level representation aggregated per the attention score distribution for the mth class
        SAR = list()
        for i in range(len(A)):
            sar = tf.linalg.matmul(tf.transpose(A[i]), h[i])  # shape be (2,512)
            SAR.append(sar)
        slide_agg_rep = tf.math.add_n(SAR)  # return h_[slide,m], shape be (2,512)
        ## need to reshape slide_agg_rep be (1,2,512), which will be compatible with input layer dimension
        if len(slide_agg_rep.shape) == 2:
            slide_agg_rep = tf.reshape(
                slide_agg_rep, (1, slide_agg_rep.shape[0], slide_agg_rep.shape[1])
            )

        return slide_agg_rep

    def call(self, bag_label, A, h):
        """_summary_

        Args:
            bag_label (_type_): _description_
            A (_type_): _description_
            h (_type_): _description_

        Returns:
            _type_: _description_
        """
        slide_agg_rep = self.h_slide(A, h)
        bag_classifier = self.bag_classifier()
        slide_score_unnorm = bag_classifier(slide_agg_rep)
        slide_score_unnorm = tf.reshape(slide_score_unnorm, (1, self.args["n_class"]))
        Y_hat = tf.math.top_k(slide_score_unnorm, 1)[1][-1]
        Y_prob = tf.math.softmax(
            tf.reshape(slide_score_unnorm, (1, self.args["n_class"]))
        )  # shape be (1,2), predictions for each of the classes
        predict_slide_label = np.argmax(Y_prob.numpy())

        Y_true = tf.one_hot([bag_label], 2)

        return {
            "slide_score_unnorm": slide_score_unnorm,
            "Y_hat": Y_hat,
            "Y_prob": Y_prob,
            "predict_slide_label": predict_slide_label,
            "Y_true": Y_true,
        }

In [40]:
bag=S_Bag(args)

### CLAM

In [41]:
class S_CLAM(tf.keras.Model):
    """_summary_

    Args:
        tf (_type_): _description_
    """

    def __init__(
        self,
        args,
    ):
        """_summary_

        Args:
            args (_type_): _description_
        """
        super(S_CLAM, self).__init__()
        self.args = args

        self.net_shape_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
        self.net_shape = self.net_shape_dict[self.args["net_size"]]

        if self.args["att_gate"]:
            self.att_net = G_Att_Net(
                args=self.args,
                dim_features=self.net_shape[0],
                n_hidden_units=self.net_shape[2],
            )
        else:
            self.att_net = NG_Att_Net(
                args=self.args,
                dim_features=self.net_shape[0],
                n_hidden_units=self.net_shape[2],
            )

        self.ins_net = Ins(
            args=self.args,
        )
        self.bag_net = S_Bag(
            args=self.args,
        )

    def networks(self):
        """_summary_

        Returns:
            _type_: _description_
        """
        c_nets = {
            "a_net": self.att_net,
            "i_net": self.ins_net,
            "b_net": self.bag_net,
        }

        return c_nets

    def clam_model(self):
        """_summary_

        Returns:
            _type_: _description_
        """
        att_model = self.att_net.att_model()
        ins_classifier = self.ins_net.ins_classifier()
        bag_classifier = self.bag_net.bag_classifier()

        clam_model = {
            "att_model": att_model,
            "ins_classifier": ins_classifier,
            "bag_classifier": bag_classifier,
        }

        return clam_model

    def call(self, img_features, slide_label):
        """_summary_

        Args:
            img_features (_type_): original 1024-dimensional instance-level feature vectors
            slide_label (_type_): ground-truth slide label, could be 0 or 1 for binary classification

        Returns:
            _type_: _description_
        """
        att_net_dict = self.att_net.call(img_features)
        (h, att_score) = (att_net_dict["h"], att_net_dict["A"])
        A = tf.math.softmax(att_score)  # softmax on attention scores

        if self.args["att_only"]:
            return att_score

        ins_net_dict = self.ins_net.call(slide_label, h, A)

        bag_net_dict = self.bag_net.call(slide_label, A, h)

        return {
            "att_score": att_score,
            "A": A,
            "h": h,
            "ins_labels": ins_net_dict["ins_labels"],
            "ins_logits_unnorm": ins_net_dict["ins_logits_unnorm"],
            "ins_logits": ins_net_dict["ins_logits"],
            "slide_score_unnorm": bag_net_dict["slide_score_unnorm"],
            "Y_prob": bag_net_dict["Y_prob"],
            "Y_hat": bag_net_dict["Y_hat"],
            "Y_true": bag_net_dict["Y_true"],
            "predict_slide_label": bag_net_dict["predict_slide_label"],
        }

In [42]:
clam = S_CLAM(args)

### Load TFRecords

In [43]:
def get_data_from_tf(
    tf_path,
    args,
):
    feature = {
        "height": tf.io.FixedLenFeature([], tf.int64),
        "width": tf.io.FixedLenFeature([], tf.int64),
        "depth": tf.io.FixedLenFeature([], tf.int64),
        "label": tf.io.FixedLenFeature([], tf.int64),
        "image/format": tf.io.FixedLenFeature([], tf.string),
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image/encoded": tf.io.FixedLenFeature([], tf.string),
        "image_feature": tf.io.FixedLenFeature([], tf.string),
    }

    tfrecord_dataset = tf.data.TFRecordDataset(tf_path)

    def _parse_image_function(key):
        return tf.io.parse_single_example(key, feature)

    CLAM_dataset = tfrecord_dataset.map(_parse_image_function)

    image_features = list()

    for tfrecord_value in CLAM_dataset:
        img_feature = tf.io.parse_tensor(tfrecord_value["image_feature"], "float32")

        if args["imf_norm_op"]:
            img_feature = tf.math.l2_normalize(img_feature)

        slide_labels = tfrecord_value["label"]
        slide_label = int(slide_labels)

        image_features.append(img_feature)
    image_features = tf.convert_to_tensor(image_features)
    return image_features, slide_label

In [61]:
image_features1, slide_label1 = get_data_from_tf(tf_path=os.path.join(args["all_tfrecords_path"], "Benign_b001.tif.tfrecords"), args=args)
image_features2, slide_label2 = get_data_from_tf(tf_path=os.path.join(args["all_tfrecords_path"], "Benign_b002.tif.tfrecords"), args=args)

In [92]:
image_features = [image_features1, image_features2]
# image_features = tf.convert_to_tensor(image_features)
# image_features.shape
image_labels = [slide_label1, slide_label2]

In [97]:
d = {
"f": image_features,
"l": image_labels
}

In [134]:
f,l = d["f"], d["l"]

In [137]:
for ff, ll in f,l:
    print(ff, ll)

tf.Tensor(
[[[-0.04770432 -0.02885905  0.03580757 ... -0.03625662  0.02170119
   -0.01594341]]

 [[-0.04828378 -0.03103301  0.03278226 ... -0.03719165  0.01908513
   -0.01416116]]

 [[-0.04782384 -0.03494673  0.034795   ... -0.03757655  0.02054382
   -0.01873459]]

 ...

 [[-0.05088389 -0.02966737  0.03198273 ... -0.04303918  0.01688325
   -0.01034377]]

 [[-0.04771412 -0.03413296  0.03635693 ... -0.03302318  0.02247962
   -0.01806119]]

 [[-0.04698806 -0.03595782  0.03680196 ... -0.03351949  0.02090824
   -0.01440851]]], shape=(48, 1, 1024), dtype=float32) tf.Tensor(
[[[-0.04684934 -0.03560995  0.03652195 ... -0.03152008  0.02395286
   -0.02166679]]

 [[-0.04859848 -0.03001648  0.03324676 ... -0.04270793  0.01947537
   -0.02428757]]

 [[-0.04545477 -0.02825014  0.03524903 ... -0.03711646  0.02121446
   -0.01648372]]

 ...

 [[-0.04651819 -0.0304567   0.0355213  ... -0.03727447  0.02321582
   -0.01662292]]

 [[-0.04652981 -0.03055941  0.03568071 ... -0.03823921  0.02248095
   -0.017066

In [45]:
att, ins, bag, clam

(<__main__.G_Att_Net at 0x7f6528fd0b70>,
 <__main__.Ins at 0x7f6528ec8f60>,
 <__main__.S_Bag at 0x7f6528709cf8>,
 <__main__.S_CLAM at 0x7f652871db38>)

In [65]:
att_dict = att.call(image_features)
h, A = att_dict["h"], att_dict["A"]

ValueError: Input 0 of layer "sequential" is incompatible with the layer: expected shape=(None, 1024), found shape=(48, 1, 1024)

In [50]:
h1 = tf.convert_to_tensor(h)
A1 = tf.convert_to_tensor(A)

In [56]:
ins_dict = ins.call(slide_label, h1, A1)
ins_labels, ins_unnorm_logits, ins_logits = ins_dict['ins_labels'], ins_dict['ins_logits_unnorm'], ins_dict['ins_logits']

In [73]:
ins_l11 = tf.one_hot(ins_l1, 2)
ins_l11

<tf.Tensor: shape=(18, 2), dtype=float32, numpy=
array([[0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [0., 1.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.]], dtype=float32)>

In [57]:
ins_l1 = tf.convert_to_tensor(ins_labels)
ins_ulo1 = tf.convert_to_tensor(ins_unnorm_logits)
ins_lo1 = tf.convert_to_tensor(ins_logits)

In [78]:
bag_dict = bag.call(slide_label, A1, h1)
y_true, y_prob = bag_dict["Y_true"], bag_dict["Y_prob"]

In [60]:
# clam.call(image_features, slide_label)

In [75]:
ins_loss = tf.keras.losses.binary_crossentropy(ins_l11, ins_lo1)

In [76]:
ins_loss1 = tf.math.reduce_mean(ins_loss)

In [77]:
ins_loss1

<tf.Tensor: shape=(), dtype=float32, numpy=0.6931622>

In [79]:
bag_loss = tf.keras.losses.binary_crossentropy(y_true, y_prob)

In [80]:
bag_loss

<tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.6577803], dtype=float32)>

In [82]:
ssu, sp, yhat = bag_dict['slide_score_unnorm'], bag_dict['predict_slide_label'], bag_dict["Y_hat"]

In [83]:
ssu

<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[ 0.0416797 , -0.03035028]], dtype=float32)>

In [84]:
yhat

<tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>

In [85]:
sp

0

In [86]:
float(bag_loss)

0.6577802896499634

In [89]:
tf.math.l2_normalize(image_features).shape

TensorShape([2, 48, 1, 1024])

In [90]:
image_features.shape

TensorShape([2, 48, 1, 1024])

In [None]:
dict{"features": image_features, 