# Introduction

This notebook use the Longformer with the [CLS] + entity-token.  

# Libraries

In [1]:
%%capture
!pip install transformers
!pip install tdqm
#!pip install tensorflow
#!pip install torch

In [2]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt

# import torch
# print(torch.__version__)

import tensorflow as tf
print(tf.__version__)

2.5.0


# Pretrained Model

In [3]:
# Tokenizer and Models
from transformers import LongformerConfig, LongformerTokenizerFast, TFLongformerModel

tokenizer = LongformerTokenizerFast.from_pretrained('allenai/longformer-base-4096')
transformer_model = TFLongformerModel.from_pretrained('allenai/longformer-base-4096',
                                                      output_attentions=True,
                                                      attention_window=[32, 64, 128, 256, 512, 512, 
                                                                        512, 512, 512, 512, 512, 512])
transformer_model.config

# from transformers import AutoTokenizer, TFAutoModel

# tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
# transformer_model = TFAutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", from_pt=True,
#                                                output_attentions=True)
# transformer_model.config

Some layers from the model checkpoint at allenai/longformer-base-4096 were not used when initializing TFLongformerModel: ['lm_head']
- This IS expected if you are initializing TFLongformerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFLongformerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFLongformerModel were initialized from the model checkpoint at allenai/longformer-base-4096.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFLongformerModel for predictions without further training.


LongformerConfig {
  "_name_or_path": "allenai/longformer-base-4096",
  "attention_mode": "longformer",
  "attention_probs_dropout_prob": 0.1,
  "attention_window": [
    32,
    64,
    128,
    256,
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "bos_token_id": 0,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "ignore_attention_mask": false,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 4098,
  "model_type": "longformer",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_attentions": true,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "sep_token_id": 2,
  "transformers_version": "4.8.2",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

# Data

In [6]:
#train_path = '/home/ubuntu/Longformer/train_test_data/mini_df.csv'
test_path = '/home/ubuntu/Longformer/train_test_data/mini_test_df.csv'
resample_path = '/home/ubuntu/Longformer/train_test_data/resample_df.csv'

# Get all the data
test_df = pd.read_csv(test_path)
resample_df = pd.read_csv(resample_path)


display(resample_df.head())

print('resample shape: ', resample_df.shape)
print('test shape: ', test_df.shape)

Unnamed: 0.1,Unnamed: 0,ID,TLINK,Target_A,Target_B,Text,End,Words,left_0,left_1,right_0,right_1,distance
0,0,612_SECTIME1,AFTER,further evaluation,2014-08-07,Admission Date : [R] 2014-08-07 [R] Discharge ...,end,['Admission' 'Date' ':' '[R]' '2014-08-07' '[R...,188,191,3,5,188
1,1,123_TL12,OVERLAP,NPO,IVF,Admission Date : 2013-11-21 Discharge Date : 2...,end,['Admission' 'Date' ':' '2013-11-21' 'Discharg...,65,67,69,71,2
2,2,413_TL22,AFTER,anti-Candida regimen,consultation,ADMISSION DATE : 5-3-93 DISCHARGE DATE : 5-12-...,end,['ADMISSION' 'DATE' ':' '5-3-93' 'DISCHARGE' '...,173,176,81,83,95
3,3,492_TL0,AFTER,elective coronary artery bypass grafting,Admission,[R] Admission [R] Date : 2016-02-14 Discharge ...,end,['[R]' 'Admission' '[R]' 'Date' ':' '2016-02-1...,91,97,0,2,97
4,4,337_TL19,AFTER,monitoring,referred,Admission Date : 2017-06-12 Discharge Date : 2...,end,['Admission' 'Date' ':' '2017-06-12' 'Discharg...,79,81,74,76,7


resample shape:  (9000, 13)
test shape:  (11200, 13)


In [8]:
# Map labels to numbers
TLINK_map = {'AFTER':0,
             'OVERLAP': 1,
             'BEFORE': 2}

y_train = np.asarray(resample_df['TLINK'].apply(lambda x: TLINK_map[x])).reshape(-1,1)
y_test = np.asarray(test_df['TLINK'].apply(lambda x: TLINK_map[x])).reshape(-1,1)

# Distribution
print('train')
dist = np.unique(y_train, return_counts=True)[1]/sum(np.unique(y_train, return_counts=True)[1])
print(dist)

print()
print('test')
dist = np.unique(y_test, return_counts=True)[1]/sum(np.unique(y_test, return_counts=True)[1])
print(dist)

train
[0.33333333 0.33333333 0.33333333]

test
[0.10017857 0.36410714 0.53571429]


# Tokenize

In [9]:
def tokenize(max_length, sentences, tokenizer, 
             left_token=None, right_token=None, entity_mask=False, 
             global_attention=False):
    '''
    Tokenize each sentence one at a time and then batch together
    '''

    # Initialize tokens with the first example
    tokens = tokenizer.encode_plus(sentences[0], add_special_tokens=True, max_length=max_length, 
                                   padding='max_length', truncation=True, 
                                   return_tensors='tf')
    
    # Initialize entity mask with first example
    if entity_mask == True:
        left_mask = [0]*max_length
        right_mask = [0]*max_length

        # First: Index for [L] entity [L]
        idx = np.where(tokens.input_ids[0] == left_token)[0]
        for i in range(idx[0]+1, idx[1]):
            left_mask[i] = 1
        left_mask = tf.constant(left_mask, shape=(1, max_length), dtype='int32')
        tokens['left_mask'] = left_mask

        # First: Index for [R] entity [R]
        idx = np.where(tokens.input_ids[0] == right_token)[0]
        for i in range(idx[0]+1, idx[1]):
            right_mask[i] = 1
        right_mask = tf.constant(right_mask, shape=(1, max_length), dtype='int32')
        tokens['right_mask'] = right_mask
        
    # Initialize global attention mask with first example
    if global_attention == True:
        global_attention_mask = [0]*max_length
        
        # Find positions of tokens
        left_idx = np.where(tokens.input_ids[0] == left_token)[0]
        right_idx = np.where(tokens.input_ids[0] == right_token)[0]
        
        # Left and Right
        for i in range(left_idx[0]+1, left_idx[1]):
            global_attention_mask[i] = 1
         
        for i in range(right_idx[0]+1, right_idx[1]):
            global_attention_mask[i] = 1
            
        # Convert to tf.tensor  
        global_attention_mask = tf.constant(global_attention_mask, shape=(1, max_length), dtype='int32')
        tokens['global_attention_mask'] = global_attention_mask

    # Rest of the examples
    for s, sentence in tqdm(enumerate(sentences[1:])):
        inputs = tokenizer.encode_plus(sentence, add_special_tokens=True, max_length=max_length,
                                       padding='max_length', truncation=True, 
                                       return_tensors='tf')
        
        tokens.input_ids = tf.concat([tokens.input_ids, inputs.input_ids], 0)
        tokens.attention_mask = tf.concat([tokens.attention_mask, inputs.attention_mask], 0)

        # Rest of entity mask
        if entity_mask == True:
            left_mask = [0]*max_length
            right_mask = [0]*max_length

            # Later: Index for [L] entity [L]
            idx = np.where(tokens.input_ids[s+1] == left_token)[0]
            for i in range(idx[0]+1, idx[1]):
                left_mask[i] = 1
            left_mask = tf.constant(left_mask, shape=(1, max_length), dtype='int32')
            tokens['left_mask'] = tf.concat([tokens.left_mask, left_mask], 0)  

            # Later: Index for [R] entity [R]
            idx = np.where(tokens.input_ids[s+1] == right_token)[0]
            for i in range(idx[0]+1, idx[1]):
                right_mask[i] = 1
            right_mask = tf.constant(right_mask, shape=(1, max_length), dtype='int32')
            tokens['right_mask'] = tf.concat([tokens.right_mask, right_mask], 0) 
        
        # Rest of global attentino mask
        if global_attention == True:
            global_attention_mask = [0]*max_length
            
            # Later: left and right positions
            left_idx = np.where(tokens.input_ids[s+1] == left_token)[0]
            right_idx = np.where(tokens.input_ids[s+1] == right_token)[0]
            
            # Left and Right
            for i in range(left_idx[0]+1, left_idx[1]):
                global_attention_mask[i] = 1

            for i in range(right_idx[0]+1, right_idx[1]):
                global_attention_mask[i] = 1
                
            global_attention_mask = tf.constant(global_attention_mask, shape=(1, max_length), dtype='int32')
            tokens['global_attention_mask'] = tf.concat([tokens.global_attention_mask, global_attention_mask], 0) 

    return tokens

In [10]:
# Add special tokens
extra_tokens = ['[L]', '[R]']
tokenizer.add_tokens(extra_tokens, special_tokens=True)

left_token = tokenizer.convert_tokens_to_ids('[L]')
right_token = tokenizer.convert_tokens_to_ids('[R]')

# Resize model vocabulary
transformer_model.resize_token_embeddings(len(tokenizer)) 

print(left_token, right_token)

50265 50266


In [11]:
# Batch tokenization - 1000 examples (it) take 6 seconds 
max_length=1500
global_attention=True
# X_train = tokenize(max_length, train_df['Text'], 
#                   tokenizer,
#                   left_token=left_token,
#                   right_token=right_token,
#                   entity_mask=True,
#                   global_attention=False)
# X_test = tokenize(max_length, test_df['Text'], 
#                   tokenizer,
#                   left_token=left_token,
#                   right_token=right_token,
#                   entity_mask=True,
#                   global_attention=False)
X_train = tokenize(max_length, resample_df['Text'], 
                  tokenizer,
                  left_token=left_token,
                  right_token=right_token,
                  entity_mask=True,
                  global_attention=global_attention)
X_test = tokenize(max_length, test_df['Text'], 
                  tokenizer,
                  left_token=left_token,
                  right_token=right_token,
                  entity_mask=True,
                  global_attention=global_attention)

print('input ids')
print(X_train.input_ids[0])
print()
print('attention mask')
print(X_train.attention_mask[0])
print()
print('left mask')
print(X_train.left_mask[0])
print()
print('right mask')
print(X_train.right_mask[0])
print()
if global_attention==True:
    print('global attention mask')
    print(X_train.global_attention_mask[0])
    print()

8999it [03:27, 43.31it/s] 
11199it [06:59, 26.72it/s]

input ids
tf.Tensor([    0  9167 12478 ...     1     1     1], shape=(1500,), dtype=int32)

attention mask
tf.Tensor([1 1 1 ... 0 0 0], shape=(1500,), dtype=int32)

left mask
tf.Tensor([0 0 0 ... 0 0 0], shape=(1500,), dtype=int32)

right mask
tf.Tensor([0 0 0 ... 0 0 0], shape=(1500,), dtype=int32)

global attention mask
tf.Tensor([0 0 0 ... 0 0 0], shape=(1500,), dtype=int32)






# Save numpy after tokenization

In [37]:
# np.save('X_train.input_ids', X_train.input_ids)
# np.save('X_train.attention_mask', X_train.attention_mask)
# np.save('X_train.left_mask', X_train.left_mask)
# np.save('X_train.right_mask', X_train.right_mask)

# Build Model

In [12]:
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input, Dense, Dropout,\
                                    Concatenate, GlobalAveragePooling1D, \
                                    BatchNormalization

def create_model(freeze_layers=True,
                 train_size=None,
                 batch_size=1,
                 epochs=2,
                 initial_learning_rate=1e-5, 
                 final_learning_rate=1e-6, 
                 global_attention=False):
    
    # Input Layers
    input_ids = Input(shape=(max_length,), name='input_ids', dtype='int32')
    attention_mask = Input(shape=(max_length,), name='attention_mask', dtype='int32') 
    left_mask = Input(shape=(max_length,), name='left_mask', dtype='float32') 
    right_mask = Input(shape=(max_length,), name='right_mask', dtype='float32') 
    
    if global_attention == False:
        # Transformer Layer
        bert_out = transformer_model(input_ids=input_ids,
                                    attention_mask=attention_mask) 
    
    elif global_attention == True:
        print('----Performing Global Attention----')
        global_attention_mask = Input(shape=(max_length,), name='global_attention_mask', dtype='int32') 
        
        # Transformer Layer
        bert_out = transformer_model(input_ids=input_ids,
                              attention_mask=attention_mask,
                              global_attention_mask=global_attention_mask)

    # Get embeddings    
    X = bert_out[0]
    CLS = bert_out[1]
        
    # Get embeddings of entity pairs only
    # Approach one - dot product, only sums across token
    # left = tf.expand_dims(left_mask, -1, name='expand_left_mask')
    # right = tf.expand_dims(right_mask, -1, name='expand_right_mask')
    # left = tf.matmul(left, X, transpose_a=True, name='matmul_left_mask')
    # right = tf.matmul(right, X, transpose_a=True, name='matmul_right_mask')
    # Approach two - gather using index
    left_idx = tf.squeeze(tf.where(left_mask[0]==1), axis=1)
    right_idx = tf.squeeze(tf.where(right_mask[0]==1), axis=1)
    left = tf.gather(X, left_idx, axis=1)
    right = tf.gather(X, right_idx, axis=1)
    left = GlobalAveragePooling1D()(left)
    right = GlobalAveragePooling1D()(right)

    X = Concatenate()([CLS, left, right])

    # Softmax
    X = Dropout(0.5)(X)
    Y = Dense(3, activation='softmax')(X)
    
    # Instantiate model
    if global_attention == False:
            model = Model(inputs=[input_ids, attention_mask, left_mask, right_mask], 
                          outputs = Y)
            
    elif global_attention == True:
            model = Model(inputs=[input_ids, attention_mask, left_mask, right_mask, global_attention_mask], 
                          outputs = Y)

    # Training the embeddings
    if freeze_layers == True:
        for layer in model.layers[:6]:
            layer.trainable = False
    elif freeze_layers == False:
        for layer in model.layers[:6]:
            layer.trainable = True

#     for idx, layer in enumerate(transformer_model.layers[0].encoder.layer):
#         if freeze_layers != None:
#             if idx in freeze_layers:
#                 layer.trainable = False
#             else: 
#                 layer.trainable = True
#         else:
#             layer.trainable = True

#         print(layer, layer.trainable)
            
    # Optimizer and Learning Rate Decay
    # learning_rate_decay_factor = (final_learning_rate / initial_learning_rate)**(1/epochs)
    # steps_per_epoch = int(train_size/batch_size)
    
    # lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    #     initial_learning_rate=initial_learning_rate,
    #     decay_steps=steps_per_epoch,
    #     decay_rate=learning_rate_decay_factor)

    opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)

    # Compile the model
    model.compile(optimizer=opt, 
                  loss='sparse_categorical_crossentropy')
    
    # Print model summary
    model.summary()
    
    return model

# Train 

In [47]:
del model
tf.keras.backend.clear_session()

In [13]:
from datetime import datetime
from tensorflow.keras.callbacks import History, ReduceLROnPlateau, EarlyStopping

# Decreases learning rate as loss plateaus
rlr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=0.000001, verbose=1, min_delta=1e-5)
early_stopping_monitor = EarlyStopping(
    monitor='val_loss',
    min_delta=0,
    patience=2,
    verbose=2,
    mode='auto',
    baseline=None,
    restore_best_weights=True)

# Store loss on tensorboard
logdir = "/home/ubuntu/Longformer/logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir, update_freq=20)

# Model params:
data = X_train
labels = y_train
# class_weight = resample_class_weights

freeze_layers=False
epochs = 20
train_size = data.input_ids.shape[0]
batch_size = 2
initial_learning_rate = 1e-05
# final_learning_rate = 1e-06

# Create model and fit
model = create_model(freeze_layers=freeze_layers, 
                     epochs=epochs,
                     train_size=train_size,
                     batch_size=batch_size,
                     initial_learning_rate=initial_learning_rate,
#                      final_learning_rate=final_learning_rate,
                     global_attention=global_attention)

if global_attention == False:
    output = model.fit(x=[data.input_ids, data.attention_mask, 
                          data.left_mask, data.right_mask], 
                      y=labels, 
                      batch_size=batch_size,
#                       class_weight=class_weight,
                      epochs=epochs,
                      validation_split=0.2,
                      callbacks=[tensorboard_callback, rlr, early_stopping_monitor])
elif global_attention == True:
    output = model.fit(x=[data.input_ids[0:-1], data.attention_mask[0:-1], 
                          data.left_mask[0:-1], data.right_mask[0:-1], 
                          data.global_attention_mask[0:-1]], 
                      y=labels[0:-1], 
#                       class_weight=class_weight,
                      batch_size=batch_size,
                      epochs=epochs,
                      validation_split=0.2,
                      callbacks=[tensorboard_callback, rlr, early_stopping_monitor])

----Performing Global Attention----
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: <cyfunction Socket.send at 0x7fa7ca648f20> is not a module, class, method, function, traceback, frame, or code object
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: <cyfunction Socket.send at 0x7fa7ca648f20> is not a module, class, method, function, traceback, frame, or code object

Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
left_mask (InputLayer)          [(None, 15

# Evaluation

In [14]:
from sklearn.metrics import classification_report

pred = model.predict(x=[X_train.input_ids, X_train.attention_mask, 
                        X_train.left_mask, X_train.right_mask,
                        X_train.global_attention_mask],
                     batch_size=4)
pred_labels = tf.argmax(pred, axis=1)

pred_df = pd.DataFrame({'Actual':y_train.ravel(), 'Predicted':pred_labels.numpy().tolist()})
print()
print(np.unique(pred_df['Predicted'], return_counts=True))
print(classification_report(pred_df['Actual'], pred_df['Predicted'], digits=3))


(array([0, 1, 2]), array([3101, 2999, 2900]))
              precision    recall  f1-score   support

           0      0.958     0.991     0.974      3000
           1      0.964     0.964     0.964      3000
           2      0.988     0.955     0.971      3000

    accuracy                          0.970      9000
   macro avg      0.970     0.970     0.970      9000
weighted avg      0.970     0.970     0.970      9000



In [15]:
from sklearn.metrics import classification_report

pred = model.predict(x=[X_test.input_ids, X_test.attention_mask, 
                        X_test.left_mask, X_test.right_mask,
                        X_test.global_attention_mask],
                     batch_size=4)
pred_labels = tf.argmax(pred, axis=1)

pred_df = pd.DataFrame({'Actual':y_test.ravel(), 'Predicted':pred_labels.numpy().tolist()})
print()
print(np.unique(pred_df['Predicted'], return_counts=True))
print(classification_report(pred_df['Actual'], pred_df['Predicted'], digits=3))


(array([0, 1, 2]), array([1719, 4729, 4752]))
              precision    recall  f1-score   support

           0      0.326     0.500     0.395      1122
           1      0.662     0.768     0.711      4078
           2      0.927     0.734     0.819      6000

    accuracy                          0.723     11200
   macro avg      0.638     0.667     0.642     11200
weighted avg      0.770     0.723     0.737     11200



In [16]:
# Save predictions
results_df = test_df.copy()
results_df['labels'] = y_test.ravel()
results_df['longformer_labels'] = pred_labels.numpy().tolist()
results_df.to_csv('longformer_1500_window_pred_df.csv')

In [28]:
# Save trained model
model.save('longformer_768_20210729-034052')





INFO:tensorflow:Assets written to: longformer_20210726-033317/assets


INFO:tensorflow:Assets written to: longformer_20210726-033317/assets


# Prediction on Handcrafted Examples

In [96]:
# Calculate distance between entity pairs
def get_distances(results):
    results['Token'] = results['Text'].apply(lambda x: tokenizer.encode(x, return_tensors='np'))
    results['left_0'] = results['Token'].apply(lambda x: np.where(x[0] == left_token)[0][0])
    results['left_1'] = results['Token'].apply(lambda x: np.where(x[0] == left_token)[0][1])
    results['right_0'] = results['Token'].apply(lambda x: np.where(x[0] == right_token)[0][0])
    results['right_1'] = results['Token'].apply(lambda x: np.where(x[0] == right_token)[0][1])

    results = results.drop('Token', axis=1)
#     results = results.drop('Unnamed: 0', axis=1)

    # Since the order b/w [L] and [R] maybe flipped, use shortest distance
    results['left_right'] = abs(results['left_1'] - results['right_0']) 
    results['right_left'] = abs(results['right_1'] - results['left_0']) 

    results['distance'] = results[['left_right', 'right_left']].min(axis=1)
    
    return results

In [161]:
example1 = "Admission Date : [R] 2012-03-23 [R] Discharge Date : 2012-03-26 Service : MEDICINE History of Present \
        Illness : 39 year old male w/ h/o low back pain on chronic narcotics presents after being found \
        [L] unresponsive [L] at home . His daughter awoke him at 7 a.m. , reports he said he felt cold and shivery ,\
        vomited several times , then drove her to school ."

example2 = "Admission Date : [L] 2012-03-23 [L] Discharge Date : 2012-03-26 Service : MEDICINE History of Present \
        Illness : 39 year old male w/ h/o low back pain on chronic narcotics presents after being found \
        [R] unresponsive [R] at home . His daughter awoke him at 7 a.m. , reports he said he felt cold and shivery ,\
        vomited several times , then drove her to school ."

example3 = "Admission Date : [L] 2012-03-23 [L] Discharge Date : 2012-03-26 Service : MEDICINE History of Present \
        Illness : 39 year old male w/ h/o low back pain, low back pain, cold and shivery, vomited\
        low back pain, low back pain, cold and shivery, vomited\
        low back pain, low back pain, cold and shivery, vomited\
        low back pain, low back pain, cold and shivery, vomited\
        low back pain,low back pain, cold and shivery, vomited, on chronic narcotics presents after being found \
        [R] unresponsive [R] at home . His daughter awoke him at 7 a.m. , reports he said he felt cold and shivery ,\
        vomited several times , then drove her to school ."

hc_df = pd.DataFrame({'Example': [1, 2, 3], 'Text': [example1, example2, example3]})

hc_df = get_distances(hc_df)
hc_df


Unnamed: 0,Example,Text,left_0,left_1,right_0,right_1,left_right,right_left,distance
0,1,Admission Date : [R] 2012-03-23 [R] Discharge ...,70,74,6,13,68,57,57
1,2,Admission Date : [L] 2012-03-23 [L] Discharge ...,6,13,70,74,57,68,57
2,3,Admission Date : [L] 2012-03-23 [L] Discharge ...,6,13,171,175,158,169,158


In [162]:
X_hc = tokenize(max_length, hc_df['Text'], 
                  tokenizer,
                  left_token=left_token,
                  right_token=right_token,
                  entity_mask=True,
                  global_attention=global_attention)

2it [00:00, 346.61it/s]


In [163]:
from sklearn.metrics import classification_report

pred = model.predict(x=[X_hc.input_ids, X_hc.attention_mask, 
                        X_hc.left_mask, X_hc.right_mask,
                        X_hc.global_attention_mask],
                     batch_size=1)
pred_labels = tf.argmax(pred, axis=1)

print(pred)
pred_labels

[[1.4002675e-04 6.5801464e-02 9.3405855e-01]
 [1.7425803e-03 6.9971627e-01 2.9854113e-01]
 [2.3077088e-03 8.9609867e-01 1.0159367e-01]]


<tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 1, 1])>