In [None]:
%autosave 300
%reload_ext autoreload
%config Completer.use_jedi = False

In [None]:
import os

os.chdir(
    r"/home/azureuser/cloudfiles/code/Users/soutrik.chowdhury/abi_genai_bert_classifier"
)

In [None]:
from src.utils.model_helpers import (
    set_seed,
    plot_loss_accuracy,
    get_device,
)
from src.preprocess import data_preprocess
from src.settings import (
    DataSettings,
    env_settings,
    ModelSettings,
    TokenizerSettings,
    AzureblobSettings,
    LoggerSettings
)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from src.pretrained_model import tokenizer, pretrained_model
from src.dataloader import create_data_loader
from src.model import BertSentimentClassifier, BertSentimentClassifierAdvanced
from src.trainer import train_module, test_module, training_drivers, get_predictions
import json
from sklearn.metrics import (
    classification_report,
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
)
from src.utils.azure_connector import AzureBlobConnection
from src.utils.logger import setup_logging # type: ignore


sns.set(style="whitegrid", palette="muted", font_scale=1.2)
HAPPY_COLORS_PALETTE = [
    "#01BEFE",
    "#FFDD00",
    "#FF7D00",
    "#FF006D",
    "#ADFF02",
    "#8F00FF",
]
sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))

In [None]:
logger = setup_logging(
            logger_name=LoggerSettings().logger_name,
            log_file='ModelTrainEval.log',
            log_level=LoggerSettings().log_level,
        )

In [None]:
logger.info("Consolidating all the syntesized files")
folder_path = os.path.join(os.getcwd(), DataSettings().data_path)
csv_files = [file for file in os.listdir(folder_path) if file.endswith(".csv")]
concatenated_df = pd.concat(
    (pd.read_csv(os.path.join(folder_path, file)) for file in csv_files),
    ignore_index=True,
)

In [None]:
# check the distribution of the labels
plt.figure(figsize=(13, 7))
sns.countplot(data=concatenated_df, x='Domain', hue='FinalLabel')
plt.xticks(rotation=45)
plt.title('Domain Distribution')
plt.show()
plt.close()

In [None]:
train_df, test_df = data_preprocess(
    concatenated_df, DataSettings().evaluation_size, ModelSettings().seed
)

In [None]:
# train test dataloader
logger.info("Creating train and test dataloaders")
train_loader = create_data_loader(
    question=train_df["Question"].values,
    targets=train_df["FinalLabel"].values,
    max_len=TokenizerSettings().max_length,
    batch_size=TokenizerSettings().batch_size,
    shuffle=True,
    tokenizer=tokenizer,
)
test_loader = create_data_loader(
    question=test_df["Question"].values,
    targets=test_df["FinalLabel"].values,
    max_len=TokenizerSettings().max_length,
    batch_size=TokenizerSettings().batch_size,
    shuffle=False,
    tokenizer=tokenizer,
)

In [None]:
# batched of 8 with 128 token size
for train_data in train_loader:
    logger.info(train_data.keys())
    logger.info(f"Shape of the val input ids: {train_data['input_ids'].shape}")
    logger.info(f"Shape of val attention heads: {train_data['attention_mask'].shape}")
    logger.info(f"Shape of val targets:: {train_data['targets'].shape}")
    logger.info("\n")
    break

for test_data in test_loader:
    logger.info(test_data.keys())
    logger.info(f"Shape of the val input ids: {test_data['input_ids'].shape}")
    logger.info(f"Shape of val attention heads: {test_data['attention_mask'].shape}")
    logger.info(f"Shape of val targets:: {test_data['targets'].shape}")
    logger.info("\n")
    break

In [None]:
set_seed(ModelSettings().seed)
device = get_device()

In [None]:
logger.info("Base Classifier")
bert_base_classifier = BertSentimentClassifier(
    pretrained_model, ModelSettings().num_classes,ModelSettings().drop_out
).to(device)

In [None]:
criterion, optimizer, scheduler, metric, early_stopping = training_drivers(
    bert_base_classifier,
    learning_rate=ModelSettings().learning_rate,
    train_loader=train_loader,
    epochs=ModelSettings().epochs,
    device=device,
    model_name="base_bert_model.pt",
)

In [None]:
train_losses = []
train_metrics = []
test_losses = []
test_metrics = []

for epoch in range(ModelSettings().epochs):

    logger.info(f"Epoch {epoch + 1}/{ModelSettings().epochs}")
    logger.info("-" * 10)

    train_losses, train_metrics = train_module(
        model=bert_base_classifier,
        device=device,
        train_dataloader=train_loader,
        optimizer=optimizer,
        criterion=criterion,
        metric=metric,
        scheduler=scheduler,
        train_losses=train_losses,
        train_metrics=train_metrics,
    )

    test_losses, test_metrics = test_module(
        model=bert_base_classifier,
        device=device,
        test_dataloader=test_loader,
        criterion=criterion,
        metric=metric,
        test_losses=test_losses,
        test_metrics=test_metrics,
    )
    scheduler.step()

    logger.info(f"The learing rate is going to be next::{scheduler.get_last_lr()}")

    early_stopping(
        test_losses[-1], bert_base_classifier, epoch
    )  # last recorded test loss to measure the improvement against the prior one
    if early_stopping.early_stop:
        logger.info("Early stopping")
        break

In [None]:
plot_loss_accuracy(train_loss= [train_losses],
                   val_loss= [test_losses],
                   train_acc= [train_metrics],
                   val_acc= [test_metrics],
                   labels=['baseline_Bert'],
                   colors=['blue'],
                   loss_legend_loc='upper left',
                   acc_legend_loc='upper left',
                   legend_font=7)

In [None]:
logger.info("Advanced Classifier")
bert_advanced_classifier = BertSentimentClassifierAdvanced(
    bert=pretrained_model,
    n_classes=ModelSettings().num_classes,
    dropout=ModelSettings().drop_out,
).to(device)

In [None]:
criterion, optimizer, scheduler, metric, early_stopping = training_drivers(
    bert_advanced_classifier,
    learning_rate=ModelSettings().learning_rate,
    train_loader=train_loader,
    epochs=ModelSettings().epochs,
    device=device,
    model_name="advanced_bert_model.pt",
)

In [None]:
train_losses = []
train_metrics = []
test_losses = []
test_metrics = []

for epoch in range(ModelSettings().epochs):

    logger.info(f"Epoch {epoch + 1}/{ModelSettings().epochs}")
    logger.info("-" * 10)

    train_losses, train_metrics = train_module(
        model=bert_advanced_classifier,
        device=device,
        train_dataloader=train_loader,
        optimizer=optimizer,
        criterion=criterion,
        metric=metric,
        scheduler=scheduler,
        train_losses=train_losses,
        train_metrics=train_metrics,
    )

    test_losses, test_metrics = test_module(
        model=bert_advanced_classifier,
        device=device,
        test_dataloader=test_loader,
        criterion=criterion,
        metric=metric,
        test_losses=test_losses,
        test_metrics=test_metrics,
    )
    scheduler.step()

    logger.info(f"The learing rate is going to be next::{scheduler.get_last_lr()}")

    early_stopping(
        test_losses[-1], bert_advanced_classifier, epoch
    )  # last recorded test loss to measure the improvement against the prior one
    if early_stopping.early_stop:
        logger.info("Early stopping")
        break

In [None]:
plot_loss_accuracy(train_loss= [train_losses],
                   val_loss= [test_losses],
                   train_acc= [train_metrics],
                   val_acc= [test_metrics],
                   labels=['advanced_Bert'],
                   colors=['green'],
                   loss_legend_loc='upper left',
                   acc_legend_loc='upper left',
                   legend_font=7)

In [None]:
with open(os.path.join(os.getcwd(), "data/testing/eval_questions.json")) as file:
   eval_data = json.load(file)

In [None]:
eval_loader = create_data_loader(
    question=eval_data["questions"],
    targets=eval_data["targets"],
    max_len=TokenizerSettings().max_length,
    batch_size=1,
    shuffle=False,
    tokenizer=tokenizer,
)

In [None]:
review_texts, predictions, prediction_probs, real_values = get_predictions(
    bert_base_classifier, eval_loader, device, ModelSettings().binary_thresh
)

In [None]:
review_texts

In [None]:
predictions,real_values

In [None]:
def classific_metrics(real_values, predictions, class_names):
    """Returns the classification metrics"""
    logger.info("Test Accuracy : {}".format(accuracy_score(real_values, predictions)))
    logger.info(f"Test Recall : {recall_score(real_values, predictions)}")
    logger.info(f"Test Precision : {precision_score(real_values, predictions)}")
    logger.info(f"Test F1 Score : {f1_score(real_values, predictions)}")
    logger.info("\nClassification Report : ")
    logger.info(classification_report(real_values, predictions, target_names=class_names))

In [None]:
classific_metrics(real_values, predictions, DataSettings().class_names)

In [None]:
az_connection = AzureBlobConnection(
    storage_account=env_settings.STORAGE_ACCOUNT,
    client_id=env_settings.CLIENT_ID,
    tenant_id=env_settings.TENANT_ID,
    client_secret=env_settings.SECRET_ID,
)

In [None]:
logger.info("Uploading to Azure Blob Storage")
az_connection.azblob_upload(
    container_name=env_settings.CONTAINER_NAME,
    root_path=os.getcwd(),
    local_input_path=AzureblobSettings().input_path,
    blob_path=AzureblobSettings().blob_path,
    file_names=[]

)

In [None]:
logger.info("Downloading from Azure Blob Storage")
az_connection.azblob_download(
    container_name=env_settings.CONTAINER_NAME,
    root_path=os.getcwd(),
    local_output_path=AzureblobSettings().input_path,
    blob_path=AzureblobSettings().blob_path,
    file_names=[]
)