Task: Movie review classification

Dataset: IMDB https://www.kaggle.com/datasets/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews

Model: BERT

Libraries: Pytorch, HuggingFace

Reference: https://medium.com/@pyroswolf200/fine-tuning-bert-on-imdb-review-dataset-309e90b6dac0

Kaggle notebook: https://www.kaggle.com/code/soumyaprabhamaiti/finetune-bert-on-imdb-reviews-pytorch-lightning/edit

# Libraries

In [None]:
%%capture
!pip install wget
!pip install transformers
!pip install lightning

In [None]:
import pandas as pd
import re
from torch.utils.data import TensorDataset, random_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from sklearn.model_selection import train_test_split
import torch
import matplotlib.pyplot as plt
import numpy as np


# Config

In [None]:
INPUT_CSV_PATH = "/kaggle/input/imdb-dataset-of-50k-movie-reviews/IMDB Dataset.csv"
EPOCHS = 10
BATCH_SIZE = 32
MAX_SEQ_LEN = 64
FAST_DEV_RUN = False
TOTAL_SAMPLES= 50000

In [None]:
if torch.cuda.is_available():       
    device = torch.device("cuda")
    print(f'{torch.cuda.device_count()} GPU(s) available. Using the GPU: {torch.cuda.get_device_name(0)}')
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Mac ARM64 GPU")
else:
    device = torch.device("cpu")
    print('No GPU available, using CPU')

# Dataset

In [None]:
df = pd.read_csv(INPUT_CSV_PATH)
df = df.head(TOTAL_SAMPLES)
df.head()

In [None]:
df.sentiment = [1 if s == 'positive' else 0 for s in df.sentiment]
def process(x):
    x = re.sub('[,\.!?:()"]', '', x)
    x = re.sub('<.*?>', ' ', x)
    x = re.sub('http\S+', ' ', x)
    x = re.sub('[^a-zA-Z0-9]', ' ', x)
    x = re.sub('\s+', ' ', x)
    return x.lower().strip()

df['review'] = df['review'].apply(lambda x: process(x))

In [None]:
df['sentiment'].value_counts().plot(kind='pie', autopct='%1.1f%%')

In [None]:
[int(.8*len(df)), int(.9*len(df))]

In [None]:
train, val, test = np.split(df.sample(frac=1, random_state=42), 
                       [int(.8*len(df)), int(.9*len(df))])
print(len(train), len(val), len(test))

In [None]:
train['sentiment'].value_counts().plot(kind='pie', autopct='%1.1f%%')

In [None]:
val['sentiment'].value_counts().plot(kind='pie', autopct='%1.1f%%')

In [None]:
test['sentiment'].value_counts().plot(kind='pie', autopct='%1.1f%%')

In [None]:
# Get the lists of sentences and their labels.
train_sentences = train.review.values
train_labels = train.sentiment.values
val_sentences = val.review.values
val_labels = val.sentiment.values
test_sentences = test.review.values
test_labels = test.sentiment.values

In [None]:
train_sentences

In [None]:
train_labels

# Model

In [None]:
from transformers import BertTokenizer

# Load the BERT tokenizer.
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [None]:
# Tokenize all of the sentences and map the tokens to thier word IDs.
def generate_data(data,labels):
  input_ids = []
  attention_masks = []

  for sent in data:
      # `encode_plus` will:
      #   (1) Tokenize the sentence.
      #   (2) Prepend the `[CLS]` token to the start.
      #   (3) Append the `[SEP]` token to the end.
      #   (4) Map tokens to their IDs.
      #   (5) Pad or truncate the sentence to `max_length`
      #   (6) Create attention masks for [PAD] tokens.
      encoded_dict = tokenizer.encode_plus(
                          sent,                      # Sentence to encode.
                          add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                          max_length = MAX_SEQ_LEN,           # Pad & truncate all sentences.
                          pad_to_max_length = True,
                          return_attention_mask = True,   # Construct attn. masks.
                          return_tensors = 'pt',     # Return pytorch tensors.
                    )
      
      # Add the encoded sentence to the list.    
      input_ids.append(encoded_dict['input_ids'])
      
      # And its attention mask (simply differentiates padding from non-padding).
      attention_masks.append(encoded_dict['attention_mask'])

  # Convert the lists into tensors.
  input_ids = torch.cat(input_ids, dim=0)
  attention_masks = torch.cat(attention_masks, dim=0)
  labels = torch.tensor(labels)

  return input_ids, attention_masks, labels

In [None]:
train_input_ids, train_attention_masks,train_labels = generate_data(train_sentences,train_labels)
val_input_ids, val_attention_masks,val_labels = generate_data(val_sentences,val_labels)
test_input_ids, test_attention_masks,test_labels = generate_data(test_sentences,test_labels)

print('Original: ', train_sentences[0])
print('Token IDs:', train_input_ids[0])

In [None]:
train_dataset = TensorDataset(train_input_ids, train_attention_masks, train_labels)
val_dataset = TensorDataset(val_input_ids, val_attention_masks, val_labels)
test_dataset = TensorDataset(test_input_ids, test_attention_masks, test_labels)

train_dataloader = DataLoader(
            train_dataset,  # The training samples.
            sampler = RandomSampler(train_dataset), # Select batches randomly
            batch_size = BATCH_SIZE, # Trains with this batch size.
            num_workers=3,
            persistent_workers=True,
        )

# For validation the order doesn't matter, so we'll just read them sequentially.
val_dataloader = DataLoader(
            val_dataset, # The validation samples.
            sampler = SequentialSampler(val_dataset), # Pull out batches sequentially.
            batch_size = BATCH_SIZE, # Evaluate with this batch size.
            num_workers=3,
            persistent_workers=True,
        )

# For validation the order doesn't matter, so we'll just read them sequentially.
test_dataloader = DataLoader(
            test_dataset, # The validation samples.
            sampler = SequentialSampler(test_dataset), # Pull out batches sequentially.
            batch_size = BATCH_SIZE, # Evaluate with this batch size.
            num_workers=3,
            persistent_workers=True,
        )

In [None]:
import logging 
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
for logger in loggers:
    if "transformers" in logger.name.lower():
        logger.setLevel(logging.ERROR)
#https://stackoverflow.com/a/78844884

In [None]:
import lightning as L
from torch import nn
import torchmetrics
import torch.nn.functional as F
import torch.optim as optim
import torch
import torch.optim as optim
from transformers import AdamW
from transformers import BertForSequenceClassification
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
import pandas as pd
from transformers import get_linear_schedule_with_warmup, PretrainedConfig


class Bert(L.LightningModule):
    def __init__(self, num_classes=None, training_steps=None, from_checkpoint=False, model_config_json_filepath=None):
        super().__init__()
        if from_checkpoint:
            model_config = PretrainedConfig.from_json_file(model_config_json_filepath)
            self._model = BertForSequenceClassification(config=model_config)
        else:
            self._model = BertForSequenceClassification.from_pretrained(
                "bert-base-uncased",  # Use the 12-layer BERT model, with an uncased vocab.
                num_labels=num_classes,  # The number of output labels--2 for binary classification.
                # You can increase this for multi-class tasks.
                output_attentions=False,  # Whether the model returns attentions weights.
                output_hidden_states=False,  # Whether the model returns all hidden-states.
            )
        self.total_training_steps = training_steps
        self.f1 = torchmetrics.F1Score(task="binary")
        self.confmat = torchmetrics.ConfusionMatrix(task="binary", num_classes=2)
        
        #TODO remove
        self.preds = []
        self.labels = []

    
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, return_dict=False):
        outputs = self._model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels, return_dict=return_dict)
        return outputs
    
    def _common_step(self, batch, batch_idx, prefix):
        input_ids = batch[0]
        input_mask = batch[1]
        labels = batch[2]
        result = self(input_ids, 
                       token_type_ids=None, 
                       attention_mask=input_mask, 
                       labels=labels,
                       return_dict=True)
        loss = result.loss
        self.log(f"{prefix}_loss", loss.item(), prog_bar=True)
        
        logits = result.logits
        y_hat = torch.argmax(logits, dim=1)
        self.f1.update(y_hat, labels)
        self.confmat.update(y_hat, labels)
        
        return loss, logits, y_hat, labels
    
    def training_step(self, batch, batch_idx):
        loss, logits, y_hat, y = self._common_step(batch, batch_idx, "train")
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, logits, y_hat, y = self._common_step(batch, batch_idx, "val")
        
        #TODO remove
        self.preds.extend(y_hat.cpu().numpy())
        self.labels.extend(y.cpu().numpy())
    
    def test_step(self, batch, batch_idx):
        loss, logits, y_hat, y = self._common_step(batch, batch_idx, "test")

    def _common_on_epoch_end(self, prefix):
        f1_score = self.f1.compute()
        self.log(f'{prefix}_f1', f1_score, prog_bar=True)
        self.f1.reset()

        confmat = self.confmat.compute()
        self.log(f'{prefix}_TN', confmat[0,0], prog_bar=True)
        self.log(f'{prefix}_FP', confmat[0,1], prog_bar=True)
        self.log(f'{prefix}_FN', confmat[1,0], prog_bar=True)
        self.log(f'{prefix}_TP', confmat[1,1], prog_bar=True)
        self.confmat.reset()

    def on_train_epoch_end(self):
        self._common_on_epoch_end("train")
        
    def on_validation_epoch_end(self):
        self._common_on_epoch_end("val")

        # Optionally log other metrics or average loss
        # avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        # self.log('val_loss', avg_loss, prog_bar=True)
        
        #TODO remove
        epoch = self.current_epoch
        df = pd.DataFrame({
            'actual_label': self.labels,
            'predicted_label': self.preds
        })
        df.to_csv(f'validation_predictions_epoch_{epoch}.csv', index=False)
        self.preds = []
        self.labels = []

    def on_test_epoch_end(self):
        self._common_on_epoch_end("test")

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr = 2e-5, eps = 1e-8)
        lr_scheduler = get_linear_schedule_with_warmup(optimizer, 
                                                    num_warmup_steps = 0,
                                                    num_training_steps = self.total_training_steps)
        
        return {"optimizer":optimizer, "lr_scheduler":lr_scheduler}

In [None]:
lightning_model = Bert(num_classes=2,training_steps=len(train_dataloader)*EPOCHS)
logger = CSVLogger("logs")
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints",
    save_top_k=-1,
)
trainer = L.Trainer(max_epochs=EPOCHS, fast_dev_run=FAST_DEV_RUN, accelerator="auto", logger=logger, callbacks=[checkpoint_callback], gradient_clip_val=1.0)
trainer.fit(lightning_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

In [None]:
trainer.test(lightning_model, dataloaders=test_dataloader)

In [None]:
lightning_model

In [None]:
!ls logs/lightning_logs/version_0

In [None]:
!ls

In [None]:
# Steps to enable google drive file upload/download:
# 1. Create a project in google cloud console
# 2. Enable google drive api for the project at https://console.cloud.google.com/apis/library/drive.googleapis.com
# 3. Create a service account for the project at https://console.cloud.google.com/iam-admin/serviceaccounts
# 4. Download the json file containing the service account credentials
# 5. Share the google drive folder with the service account email with Editor permissions
# 6. pip install google-api-python-client==2.142.0
# 7. Set the environment variable GOOGLE_SERVICE_ACC_CREDS to the stringified json creds, or pass it as an argument to the functions
# 8. Run the functions

import datetime
import io
import json
import os

from dotenv import load_dotenv
from google.oauth2 import service_account
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseDownload


class GDriveUtils:
    LOG_EVENTS = True

    @staticmethod
    def get_gdrive_service(creds_stringified: str | None = None):
        SCOPES = ["https://www.googleapis.com/auth/drive"]
        if not creds_stringified:
            print(
                "Attempting to use google drive creds from environment variable"
            ) if GDriveUtils.LOG_EVENTS else None
            creds_stringified = os.getenv("GOOGLE_SERVICE_ACC_CREDS")
        creds_dict = json.loads(creds_stringified)
        creds = service_account.Credentials.from_service_account_info(
            creds_dict, scopes=SCOPES
        )
        return build("drive", "v3", credentials=creds)

    @staticmethod
    def upload_file_to_gdrive(
        local_file_path,
        drive_parent_folder_id: str,
        drive_filename: str | None = None,
        creds_stringified: str | None = None,
    ) -> str:
        service = GDriveUtils.get_gdrive_service(creds_stringified)

        if not drive_filename:
            drive_filename = os.path.basename(local_file_path)

        file_metadata = {
            "name": drive_filename,
            "parents": [drive_parent_folder_id],
        }
        file = (
            service.files()
            .create(body=file_metadata, media_body=local_file_path)
            .execute()
        )
        print(
            "File uploaded, drive file id: ", file.get("id")
        ) if GDriveUtils.LOG_EVENTS else None
        return file.get("id")

    @staticmethod
    def upload_file_to_gdrive_sanity_check(
        drive_parent_folder_id: str,
        creds_stringified: str | None = None,
    ):
        try:
            curr_time_utc = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
            file_name = f"gdrive_upload_test_{curr_time_utc}_UTC.txt"
            print(
                "Creating local file to upload: ", file_name
            ) if GDriveUtils.LOG_EVENTS else None
            with open(file_name, "w") as f:
                f.write(f"gdrive_upload_test_{curr_time_utc}_UTC")
            return GDriveUtils.upload_file_to_gdrive(
                file_name, drive_parent_folder_id, creds_stringified=creds_stringified
            )
        except Exception as e:
            raise e
        finally:
            if os.path.exists(file_name):
                print(
                    "Deleting local file: ", file_name
                ) if GDriveUtils.LOG_EVENTS else None
                os.remove(file_name)

    @staticmethod
    def download_file_from_gdrive(
        drive_file_id: str,
        local_file_path: str | None = None,
        creds_stringified: str | None = None,
    ):
        service = GDriveUtils.get_gdrive_service(creds_stringified)

        drive_filename = service.files().get(fileId=drive_file_id, fields="name").execute().get('name')

        if not local_file_path:
            local_file_path = f"{drive_file_id}_{drive_filename}"

        request = service.files().get_media(fileId=drive_file_id)
        file = io.BytesIO()
        downloader = MediaIoBaseDownload(file, request, chunksize= 25 * 1024 * 1024)
        done = False
        while done is False:
            status, done = downloader.next_chunk()
            print(f"Downloading gdrive file {drive_filename} to local file {local_file_path}: {int(status.progress() * 100)}%.") if GDriveUtils.LOG_EVENTS else None

        if os.path.dirname(local_file_path):
            os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
        with open(local_file_path, "wb") as f:
            f.write(file.getvalue())
        print(
            "Downloaded file locally to: ", local_file_path
        ) if GDriveUtils.LOG_EVENTS else None

    @staticmethod
    def download_file_from_gdrive_sanity_check(
        drive_parent_folder_id: str,
        creds_stringified: str | None = None,
    ):
        file_id = GDriveUtils.upload_file_to_gdrive_sanity_check(
            drive_parent_folder_id, creds_stringified
        )
        GDriveUtils.download_file_from_gdrive(
            file_id, creds_stringified=creds_stringified
        )

    @staticmethod
    def stringify_json_creds(json_file: str, txt_file: str) -> str:
        with open(json_file, "r") as f:
            creds_dict = json.load(f)
        with open(txt_file, "w") as f:
            f.write(json.dumps(creds_dict))


# if __name__ == "__main__":
#     creds_stringified = input("Enter stringified creds: ")
#     print(
#         GDriveUtils.upload_file_to_gdrive(
#             "validation_predictions_epoch_0.csv", "16QGpGwyIbA29BJa8uLMD1lMCrnOb1_Gj", creds_stringified=creds_stringified
#         )
#     )
#     print(
#         GDriveUtils.upload_file_to_gdrive(
#             "validation_predictions_epoch_1.csv", "16QGpGwyIbA29BJa8uLMD1lMCrnOb1_Gj", creds_stringified=creds_stringified
#         )
#     )
#     print(
#         GDriveUtils.upload_file_to_gdrive(
#             "validation_predictions_epoch_2.csv", "16QGpGwyIbA29BJa8uLMD1lMCrnOb1_Gj", creds_stringified=creds_stringified
#         )
#     )
#     print(
#         GDriveUtils.upload_file_to_gdrive(
#             "validation_predictions_epoch_3.csv", "16QGpGwyIbA29BJa8uLMD1lMCrnOb1_Gj", creds_stringified=creds_stringified
#         )
#     )
#     print(
#         GDriveUtils.upload_file_to_gdrive(
#             "logs/lightning_logs/version_0/metrics.csv", "16QGpGwyIbA29BJa8uLMD1lMCrnOb1_Gj", creds_stringified=creds_stringified
#         )
#     )
    