# Objective:
To build an inference pipeline in PT and TF and make sure that the extracted features are the same

# Steps:
1. Start with a pair of sentences
2. Pass the sentences to the model and get back the Pytorch tensors 
3. Pass the sentences to the model and get back the Tensorflow tensors 

In [None]:
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from transformers import BertTokenizer, BertModel, TFBertModel

In [None]:
DATA_PATH = Path('../input/commonlitreadabilityprize')
TRAIN_DATA_PATH = DATA_PATH / 'train.csv'
TEST_DATA_PATH = DATA_PATH / 'test.csv'

In [None]:
train_data = pd.read_csv(TRAIN_DATA_PATH)
test_data = pd.read_csv(TEST_DATA_PATH)

In [None]:
class Inference:
    def __init__(self, model_name):
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.pt_model = BertModel.from_pretrained(model_name)
        self.pt_model = self.pt_model.eval()
        
        self.tf_model = TFBertModel.from_pretrained(model_name)
        for layer in self.tf_model.layers:
            layer.trainable = False
    
    def run(self, sentences, platform):
        if platform == 'pt':
            with torch.no_grad():
                inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
                outputs = self.pt_model(**inputs)
            return inputs, outputs
        elif platform == 'tf':
            inputs = self.tokenizer(sentences, return_tensors="tf", padding=True, truncation=True)
            outputs = self.tf_model(inputs)
            return inputs, outputs
        else:
            raise "platform must be tf or pt"
            
inference = Inference(model_name="bert-base-uncased")

In [None]:
sentences = train_data['excerpt'].tolist()[:10]

# Pytorch

In [None]:
pt_inputs, pt_outputs = inference.run(sentences, platform='pt')
pt_pooler_outputs = pt_outputs.pooler_output.detach().numpy()

# Tensorflow

In [None]:
tf_inputs, tf_outputs = inference.run(sentences, platform='tf')
tf_pooler_outputs = tf_outputs.pooler_output.numpy()

# Compare TF and PT inputs

In [None]:
np.allclose(tf_inputs.input_ids.numpy(), pt_inputs.input_ids.detach().numpy())

# Compare PT vs TF outputs

In [None]:
tolerance = 1e-4
for i in range(len(sentences)):
    is_it_close = np.allclose(pt_pooler_outputs[i], tf_pooler_outputs[i], rtol=tolerance)
    print(f'Tolerance: {tolerance}: is PT and TF close: {is_it_close}')

In [None]:
is_dtype_same = pt_pooler_outputs.dtype == tf_pooler_outputs.dtype
print(f'Is dtype same: {is_dtype_same}')

In [None]:
pt_pooler_outputs.shape, tf_pooler_outputs.shape