In [10]:
def generate_pos_labels(n_pos_sample):
    return tf.fill(dims=[n_pos_sample, ], value=1)

def generate_neg_labels(n_neg_sample):
    return tf.fill(dims=[n_neg_sample, ], value=0)

def ins_in_call(ins_classifier, h, A_I, n_ins, n_class):
    pos_label = generate_pos_labels(n_ins)
    neg_label = generate_neg_labels(n_ins)
    ins_label_in = tf.concat(values=[pos_label, neg_label], axis=0)
    A_I = tf.reshape(tf.convert_to_tensor(A_I), (1, len(A_I))) 

    top_pos_ids = tf.math.top_k(A_I, n_ins)[1][-1]  
    pos_index = list()
    for i in top_pos_ids:
        pos_index.append(i)

    pos_index = tf.convert_to_tensor(pos_index)
    top_pos = list()
    for i in pos_index:
        top_pos.append(h[i])

    top_neg_ids = tf.math.top_k(-A_I, n_ins)[1][-1]
    neg_index = list()
    for i in top_neg_ids:
         neg_index.append(i)

    neg_index = tf.convert_to_tensor(neg_index)
    top_neg = list()
    for i in neg_index:
        top_neg.append(h[i])

    ins_in = tf.concat(values=[top_pos, top_neg], axis=0)
    logits_unnorm_in = list()
    logits_in = list()

    for i in range(n_class * n_ins):
        ins_score_unnorm_in = ins_classifier(ins_in[i])
        logit_in = tf.math.softmax(ins_score_unnorm_in)
        logits_unnorm_in.append(ins_score_unnorm_in)
        logits_in.append(logit_in)

    return ins_label_in, logits_unnorm_in, logits_in

def ins_out_call(ins_classifier, h, A_O, n_ins):
    # get compressed 512-dimensional instance-level feature vectors for following use, denoted by h
    A_O = tf.reshape(tf.convert_to_tensor(A_O), (1, len(A_O)))
    top_pos_ids = tf.math.top_k(A_O, n_ins)[1][-1]
    pos_index = list()
    for i in top_pos_ids:
        pos_index.append(i)

    pos_index = tf.convert_to_tensor(pos_index)
    top_pos = list()
    for i in pos_index:
        top_pos.append(h[i])

    # mutually-exclusive -> top k instances w/ highest attention scores ==> false pos = neg
    pos_ins_labels_out = generate_neg_labels(n_ins)
    ins_label_out = pos_ins_labels_out

    logits_unnorm_out = list()
    logits_out = list()

    for i in range(n_ins):
        ins_score_unnorm_out = ins_classifier(top_pos[i])
        logit_out = tf.math.softmax(ins_score_unnorm_out)
        logits_unnorm_out.append(ins_score_unnorm_out)
        logits_out.append(logit_out)

    return ins_label_out, logits_unnorm_out, logits_out

def ins_call(ins_classifiers, bag_label, h, A, n_class, n_ins, mut_ex):
    for i in range(n_class):
        ins_classifier = ins_classifiers[i]
        if i == bag_label:
            A_I = list()
            for j in range(len(A)):
                a_i = A[j][0][i]
                A_I.append(a_i)
            ins_label_in, logits_unnorm_in, logits_in = ins_in_call(ins_classifier, h, A_I, n_ins, n_class)
        else:
            if mut_ex:
                A_O = list()
                for j in range(len(A)):
                    a_o = A[j][0][i]
                    A_O.append(a_o)
                ins_label_out, logits_unnorm_out, logits_out = ins_out_call(ins_classifier, h, A_O, n_ins)
            else:
                continue

    if mut_ex:
        ins_labels = tf.concat(values=[ins_label_in, ins_label_out], axis=0)
        ins_logits_unnorm = logits_unnorm_in + logits_unnorm_out
        ins_logits = logits_in + logits_out
    else:
        ins_labels = ins_label_in
        ins_logits_unnorm = logits_unnorm_in
        ins_logits = logits_in

    return ins_labels, ins_logits_unnorm, ins_logits

In [11]:
def s_bag_h_slide(A, h):
        # compute the slide-level representation aggregated per the attention score distribution for the mth class
        SAR = list()
        for i in range(len(A)):
            sar = tf.linalg.matmul(tf.transpose(A[i]), h[i])  # shape be (2,512)
            SAR.append(sar)
        slide_agg_rep = tf.math.add_n(SAR)   # return h_[slide,m], shape be (2,512)
        
        return slide_agg_rep
    
def s_bag_call(bag_classifier, bag_label, A, h, n_class):
    slide_agg_rep = s_bag_h_slide(A, h)

    slide_score_unnorm = bag_classifier(slide_agg_rep)
    slide_score_unnorm = tf.reshape(slide_score_unnorm, (1, n_class))
    Y_hat = tf.math.top_k(slide_score_unnorm ,1)[1][-1]
    Y_prob = tf.math.softmax(tf.reshape(slide_score_unnorm, (1, n_class)))   #shape be (1,2), predictions for each of the classes
    predict_slide_label = np.argmax(Y_prob.numpy())

    Y_true = tf.one_hot([bag_label], 2)

    return slide_score_unnorm, Y_hat, Y_prob, predict_slide_label, Y_true

In [9]:
def g_att_call(g_att_net, img_features):
    h = list()
    A = list()

    for i in img_features:
        c_imf = g_att_net[0](i)
        h.append(c_imf)

    for i in img_features:
        layer1_output = g_att_net[1](i)
        layer2_output = g_att_net[2](i)
        a = tf.math.multiply(layer1_output, layer2_output)
        a = g_att_net[3](a)
        A.append(a)
        
    return h, A

In [5]:
def s_clam_call(att_net, ins_net, bag_net, img_features, slide_label, n_class, n_ins, att_only, mil_ins, mut_ex):
    h, A = g_att_call(att_net, img_features)
    att_score = A  # output from attention network
    A = tf.math.softmax(A)   # softmax onattention scores 

    if att_only:
        return att_score

    if mil_ins:
        ins_labels, ins_logits_unnorm, ins_logits = ins_call(ins_net, slide_label, h, A, n_class, n_ins, False)

    slide_score_unnorm, Y_hat, Y_prob, predict_slide_label, Y_true = bag_call(bag_net, slide_label, A, h, n_class)

    return att_score, A, h, ins_labels, ins_logits_unnorm, ins_logits, slide_score_unnorm, Y_prob, Y_hat, Y_true, predict_slide_label