In [31]:
import tensorflow as tf

import keras

import numpy as np
import pandas as pd
import wandb
import keras_nlp
import os



os.environ['KERAS_BACKEND'] = 'tensorflow'



### Splitting data

In [32]:
full_train_df = pd.read_csv("data/cleaned_train.csv").dropna()

shuffled_train_df = full_train_df.sample(frac=1,
                                    random_state=42,
                                    replace=False)



train_data_split_df = shuffled_train_df[ : 6000]
val_data_split_df = shuffled_train_df[6000 : 6800]
test_data_split_df = shuffled_train_df[6800: ]

train_split_ids = train_data_split_df["id"]
val_split_ids = val_data_split_df["id"]
test_split_ids = val_data_split_df["id"]


print("Train  split shape: ", train_data_split_df.shape)
print("Validation split shape: ", val_data_split_df.shape)
print("Test  split shape: ", test_data_split_df.shape)

Train  split shape:  (6000, 3)
Validation split shape:  (800, 3)
Test  split shape:  (810, 3)


In [18]:
X_train, y_train = train_data_split_df["text_cleaned"].values, train_data_split_df["target"].values
X_val, y_val = val_data_split_df["text_cleaned"].values, val_data_split_df["target"].values
X_test, y_test = test_data_split_df["text_cleaned"].values, test_data_split_df["target"].values

### TODO: read kaggle how to finetune

In [27]:
PRETRAINED_MODEL = "bert_tiny_en_uncased"


preprocessor = keras_nlp.models.BertPreprocessor.from_preset(PRETRAINED_MODEL,
                                                             sequence_length=120,
                                                             name="bert_preprocessor")



# #output - logits
bert_classifier = keras_nlp.models.BertClassifier.from_preset(PRETRAINED_MODEL,
                                                          num_classes=2)


  return id(getattr(self, attr)) not in self._functional_layer_ids
  return id(getattr(self, attr)) not in self._functional_layer_ids


In [28]:
bert_classifier.summary()

In [29]:

LOSS = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
OPTIMIZER = tf.keras.optimizers.Adam(1e-2)
METRICS = ["accuracy"]


bert_classifier.compile(
    loss=LOSS,
    optimizer=OPTIMIZER,
    metrics=METRICS
    
    
    
)

### Preparing callbacks for Training

In [22]:
class WandbLoggerCallback(tf.keras.callbacks.Callback):
   
    def on_batch_end(self, epoch, logs=None):
        wandb.log({"batch_loss" : logs["loss"],
                   "batch_accuracy" : logs["accuracy"]})
        

    def on_epoch_end(self, epoch, logs=None):
        
        wandb.log({"epoch_loss" : logs["loss"],
                   "epoch_accuracy" : logs["accuracy"],
                   "val_loss" : logs["val_loss"],
                   "val_accuracy" : logs["val_accuracy"]})
        print("\nMETRIC LOGGED")


backup_restore_callback = keras.callbacks.BackupAndRestore(backup_dir="train_backups/")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath="train_checkpoints/bert_checkpoint.ckpt",
                                                         save_weights_only=False,
                                                         save_best_only=True,
                                                         monitor='val_loss', 
                                                         verbose=1)
wandb_logger = WandbLoggerCallback()



        

In [23]:

TRIAL_CONFIG = {"pipeline" :  ["BERT_CLASSIFIER",
                                       "LOGITS -> SIGMOID"],
                "pretrained_model" : "bert_tiny_en_uncased",
                        
                "train_params" : {
                                    "optimizer" : OPTIMIZER,
                                    "metrics" : METRICS,
                                    "loss" : LOSS,
                                },

                "framework" : "keras_nlp",

                "data_split" : [train_data_split_df.shape,
                                      val_data_split_df.shape,
                                      test_data_split_df.shape]
          }


wandb.init(
            project="disaster_tweets",
            name="bert_finetuning",
            config=BERT_CONFIG,
        )

In [30]:

bert_classifier.fit(X_train[:60], y_train[:60], 
               batch_size=32, epochs=10,
              validation_data=(X_val[:20], y_val[:20]))




Epoch 1/10
Epoch 2/10

KeyboardInterrupt: 