In [None]:
import tensorflow as tf
from transformers import TFAutoModel, AutoTokenizer, BertTokenizer
import pandas as pd
import numpy as np
from tensorflow.keras import backend as K

In [None]:
train = pd.read_csv("../input/commonlitreadabilityprize/train.csv")
test = pd.read_csv("../input/commonlitreadabilityprize/test.csv")

In [None]:
SEQ_LEN = 250
MODEL_PATH = "../input/huggingface-bert-variants/bert-base-cased/bert-base-cased"
# initialize model and tokenizer
bert = TFAutoModel.from_pretrained(MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

In [None]:
Xids = np.zeros((len(train), SEQ_LEN))
Xmask = np.zeros((len(train), SEQ_LEN))

In [None]:
def tokenize(sentence):
    tokens = tokenizer.encode_plus(sentence, max_length=SEQ_LEN,
                                   truncation=True, padding='max_length',
                                   add_special_tokens=True, return_attention_mask=True,
                                   return_token_type_ids=False, return_tensors='tf')
    return tokens['input_ids'], tokens['attention_mask']

In [None]:
for i, sentence in enumerate(train['excerpt']):
    Xids[i, :], Xmask[i, :] = tokenize(sentence)

def root_mean_squared_std_loss():
    def loss(y_true, y_pred):
        main_loss = y_true - y_pred
        mean_loss = tf.reduce_mean(tf.square(main_loss))
        root_loss = tf.sqrt(mean_loss)
        std_loss = tf.pow(root_loss, tf.convert_to_tensor(std, dtype=tf.float32))
        return std_loss
    return loss        

In [None]:
def mean_root_loss(y_true, y_pred):
    main_loss = y_true - y_pred
    mean_loss = K.mean(K.square(main_loss))
    root_loss = K.sqrt(mean_loss)
    return root_loss

In [None]:
error_len = len(train["standard_error"])

In [None]:
input_ids = tf.keras.layers.Input(shape=(250,), name='input_ids', dtype='int32')
mask = tf.keras.layers.Input(shape=(250,), name='attention_mask', dtype='int32')
#standard_error=tf.keras.layers.Input(shape=(1,))

embeddings = bert(input_ids, attention_mask=mask)[0]
X = tf.keras.layers.GRU(256, return_sequences=True)(embeddings)
X = tf.keras.layers.BatchNormalization()(X)
X = tf.keras.layers.Dropout(0.2)(X)
X = tf.keras.layers.GRU(512,return_sequences=True)(X)
X = tf.keras.layers.BatchNormalization()(X)
X = tf.keras.layers.Dropout(0.2)(X)
X = tf.keras.layers.GRU(1024,return_sequences=True)(X)
X = tf.keras.layers.BatchNormalization()(X)
X = tf.keras.layers.Dropout(0.2)(X)
#X = tf.keras.layers.GlobalAveragePooling1D()(X)
X = tf.keras.layers.Flatten()(X)
#X = tf.keras.layers.GlobalMaxPool1D()(X)
#X = tf.keras.layers.BatchNormalization()(X)
#X = tf.keras.layers.Conv1D(512)(X)
X = tf.keras.layers.Dense(512, activation='tanh')(X)
X = tf.keras.layers.BatchNormalization()(X)
X = tf.keras.layers.Dropout(0.2)(X)
X = tf.keras.layers.Dense(256)(X)
X = tf.keras.layers.BatchNormalization()(X)
X = tf.keras.layers.Dropout(0.2)(X)
X = tf.keras.layers.Dense(128)(X)
X = tf.keras.layers.BatchNormalization()(X)
X = tf.keras.layers.Dropout(0.2)(X)
y = tf.keras.layers.Dense(1)(X)

model = tf.keras.Model(inputs=[input_ids, mask], outputs=y)

# freeze the BERT layer
model.layers[2].trainable = False

# compile the model
optimizer = tf.keras.optimizers.Adam(1e-3)
loss = tf.keras.losses.MeanSquaredError()
acc = tf.keras.metrics.RootMeanSquaredError('accuracy')

model.compile(optimizer=optimizer, loss=mean_root_loss, metrics=[acc])

In [None]:
model.summary()

checkpoint_filepath = '/kaggle/working'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

In [None]:
history = model.fit([Xids, Xmask],train["target"],epochs=50)

model.load_weights(checkpoint_filepath)

In [None]:
Tids = np.zeros((len(test), SEQ_LEN))
Tmask = np.zeros((len(test), SEQ_LEN))

In [None]:
for i, sentence in enumerate(test['excerpt']):
    Tids[i, :], Tmask[i, :] = tokenize(sentence)

In [None]:
prediction = model.predict([Tids, Tmask])
submission_mine = pd.DataFrame({ "id": test.id, "target": prediction.reshape(len(prediction),)})
submission_mine.to_csv("submission.csv",index=False)