## This notebook was run on Google Colab due to constraints on a local GPU 

In [None]:
!pip install pandas
!pip install tensorflow
!pip install transformers
!pip install scikit-learn

In [None]:
import pandas as pd
import tensorflow as tf
from transformers import TFAutoModel, AutoTokenizer
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import numpy as np
from keras.callbacks import ModelCheckpoint

## Get Data

In [None]:
!wget https://raw.githubusercontent.com/vennietweek/aita-analysis-tool/main/data/balanced/train.csv
!wget https://raw.githubusercontent.com/vennietweek/aita-analysis-tool/main/data/balanced/test.csv

!wget https://raw.githubusercontent.com/vennietweek/aita-analysis-tool/main/data/summarised/test_summarised_gpt2.csv
!wget https://raw.githubusercontent.com/vennietweek/aita-analysis-tool/main/data/summarised/train_summarised_gpt2.csv

!wget https://raw.githubusercontent.com/vennietweek/aita-analysis-tool/main/data/test_with_pagerank.csv
!wget https://raw.githubusercontent.com/vennietweek/aita-analysis-tool/main/data/train_with_pagerank_part1.csv
!wget https://raw.githubusercontent.com/vennietweek/aita-analysis-tool/main/data/train_with_pagerank_part2.csv

## Config

In [None]:
config = "ORIG" # "GPT2-S" or "PR-S"

if config == "ORIG":
    col_name = "content"
    train_df = pd.read_csv('train.csv')
    test_df = pd.read_csv('test.csv')
elif config == "GPT2-R":
    col_name = "summarised"
    train_df = pd.read_csv('train_summarised_gpt2.csv')
    test_df = pd.read_csv('test_summarised_gpt2.csv')
else:
    col_name = "pagerank"
    d1 = pd.read_csv('train_with_pagerank_part1.csv')
    d2 = pd.read_csv('train_with_pagerank_part2.csv')
    train_df = pd.concat([d1, d2])
    test_df = pd.read_csv('test_with_pagerank.csv')

## Fine-tuning

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def aita_tokenize(c):
    return tokenizer(c, padding='max_length', truncation=True, max_length=512) 

train_encodings = [aita_tokenize(text) for text in train_df[col_name]]
test_encodings = [aita_tokenize(text) for text in test_df[col_name]]

def create_tf_dataset(encodings, labels):
    return tf.data.Dataset.from_tensor_slices((
        {
            'input_ids': [encoding['input_ids'] for encoding in encodings], 
            'attention_mask': [encoding['attention_mask'] for encoding in encodings],
            'token_type_ids': [encoding['token_type_ids'] for encoding in encodings] if 'token_type_ids' in encodings[0] else None,
        },
        labels
    ))

train_labels = train_df['flag'].values
test_labels = test_df['flag'].values

train_dataset = create_tf_dataset(train_encodings, train_labels)
test_dataset = create_tf_dataset(test_encodings, test_labels)

BATCH_SIZE = 16

train_dataset = train_dataset.shuffle(len(train_dataset)).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

model = TFAutoModel.from_pretrained("bert-base-uncased")


class BERTForClassification(tf.keras.Model):

    def __init__(self, bert_model, num_classes):
        super().__init__()
        self.bert = bert_model
        self.fc = tf.keras.layers.Dense(num_classes, activation='softmax')

    def call(self, inputs):
        x = self.bert(inputs)[1]
        return self.fc(x)

classifier = BERTForClassification(model, num_classes=2)

classifier.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy']
)

## Training 

In [None]:
checkpoint_callback = ModelCheckpoint(
    '/content/drive/MyDrive/MComp/CS5246 - Text Mining/project/model',
    monitor='val_loss',
    verbose=1,
    save_best_only=True,
    mode='min'
)

history = classifier.fit(
    train_dataset,
    epochs=10,
    validation_data=test_dataset,
    callbacks=[checkpoint_callback] 
)

# load best state
classifier = tf.keras.models.load_model('/content/drive/MyDrive/MComp/CS5246 - Text Mining/project/model')
classifier.evaluate(test_dataset)

## Evaluation Metrics

In [None]:
predictions = classifier.predict(test_dataset)
class_labels = np.argmax(predictions, axis=-1)

# Calculate the metrics
precision = precision_score(test_labels, class_labels)
recall = recall_score(test_labels, class_labels)
f1 = f1_score(test_labels, class_labels)
computed_accuracy = accuracy_score(test_labels, class_labels)

# Print out the metrics
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')
print(f'Computed Accuracy: {computed_accuracy:.4f}')

## Inference 

In [None]:
def inference_analysis(model, text):
  encoding = aita_tokenize(text)
  inputs = {
      'input_ids': encoding['input_ids'],
      'attention_mask': encoding['attention_mask']
  }
  if 'token_type_ids' in encoding:
      inputs['token_type_ids'] = encoding['token_type_ids']
  test_dataset = tf.data.Dataset.from_tensor_slices((inputs))
  test_dataset = test_dataset.batch(1) 
  predictions = model.predict(test_dataset)
  print("Probabilities for 0 and 1 :")
  print(predictions)


text = "We love text mining :D"
inference_analysis(classifier, text)