[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dbamman/anlp24/blob/main/5.classification/BertClassification_TODO.ipynb)

Thie notebook explores using BERT for text classification.  Before starting, change the runtime to GPU: Runtime > Change runtime type > Hardware accelerator: GPU.

First, let's download some classification data (feel free to use other data we've worked with this semester).

In [None]:
!wget https://raw.githubusercontent.com/dbamman/anlp24/main/data/convote/train.tsv
!wget https://raw.githubusercontent.com/dbamman/anlp24/main/data/convote/dev.tsv

In [None]:
from transformers import BertModel, BertTokenizer
import torch
from tqdm import tqdm
import torch.nn as nn
import numpy as np
import random
import time

Double-check that this notebook is running on the GPU (this should "Running on cuda").

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on {}".format(device))

In [None]:
def read_labels(filename):
    labels={}
    with open(filename) as file:
        for line in file:
            cols = line.split("\t")
            label = cols[0]
            if label not in labels:
                labels[label]=len(labels)
    return labels

In [None]:
def read_data(filename, labels, max_data_points=None):
    """
    :param filename: the name of the file
    :return: list of tuple ([word index list], label)
    as input for the forward and backward function
    """
    data = []
    data_labels = []
    with open(filename) as file:
        for line in file:
            cols = line.split("\t")
            label = cols[0]
            text = cols[1]

            data.append(text)
            data_labels.append(labels[label])


    # shuffle the data
    tmp = list(zip(data, data_labels))
    random.shuffle(tmp)
    data, data_labels = zip(*tmp)

    if max_data_points is None:
        return data, data_labels

    return data[:max_data_points], data_labels[:max_data_points]

In [None]:
labels=read_labels("train.tsv")
print(labels)

We'll limit the training and dev data to 1,000 data points for this exercise.

In [None]:
train_x, train_y=read_data("train.tsv", labels, max_data_points=1000)

In [None]:
dev_x, dev_y=read_data("dev.tsv", labels, max_data_points=1000)

In [None]:
# Calculates accuracy of input model in test set
def evaluate(model, all_x, all_y):
    model.eval()
    corr = 0.
    total = 0.
    with torch.no_grad():
        idx=0
        for x, y in zip(all_x, all_y):

            idx+=1
            y_preds=model.forward(x)
            for idx, y_pred in enumerate(y_preds):
                prediction=torch.argmax(y_pred)
                if prediction == y[idx]:
                    corr += 1.
                total+=1
    return corr/total

In [None]:
class BERTClassifier(nn.Module):
    """
    BERTClassifier is a PyTorch module for text classification using a pre-trained BERT model.
    Attributes:
        model_name (str): The name of the pre-trained BERT model to use.
        tokenizer (BertTokenizer): The tokenizer associated with the pre-trained BERT model.
        bert (BertModel): The pre-trained BERT model.
        num_labels (int): The number of labels for classification.
        fc (nn.Linear): A fully connected layer for classification.
    Methods:
        get_batches(all_x, all_y, batch_size=32, max_toks=256):
            Generates batches of tokenized input data and corresponding labels.
        forward(batch_x):
            Performs a forward pass through the BERT model and the fully connected layer.
    """

    def __init__(self, params):
        super().__init__()

        # Initialize model name and tokenizer
        self.model_name = params["model_name"]
        self.tokenizer = BertTokenizer.from_pretrained(self.model_name, do_lower_case=params["doLowerCase"], do_basic_tokenize=False)
        
        # Load pre-trained BERT model
        self.bert = BertModel.from_pretrained(self.model_name)

        # Number of labels for classification
        self.num_labels = params["label_length"]

        # Fully connected layer for classification
        self.fc = nn.Linear(params["embedding_size"], self.num_labels)

    def get_batches(self, all_x, all_y, batch_size=32, max_toks=256):
        """ Get batches for input x, y data, with data tokenized according to the BERT tokenizer
        (and limited to a maximum number of WordPiece tokens) """

        batches_x = []
        batches_y = []

        # Iterate over data in batches
        for i in range(0, len(all_x), batch_size):
            current_batch = []

            # Get current batch of data
            x = all_x[i:i+batch_size]

            # Tokenize the batch of data
            batch_x = self.tokenizer(x, padding=True, truncation=True, return_tensors="pt", max_length=max_toks)
            batch_y = all_y[i:i+batch_size]

            # Append tokenized data and labels to batches
            batches_x.append(batch_x.to(device))
            batches_y.append(torch.LongTensor(batch_y).to(device))

        return batches_x, batches_y

    def forward(self, batch_x):
        # Forward pass through BERT model
        bert_output = self.bert(input_ids=batch_x["input_ids"],
                                attention_mask=batch_x["attention_mask"],
                                token_type_ids=batch_x["token_type_ids"],
                                output_hidden_states=True)

        # Get hidden states from BERT output
        bert_hidden_states = bert_output['hidden_states']

        # Represent the document by its [CLS] embedding (at position 0)
        out = bert_hidden_states[-1][:,0,:]

        # Pass through fully connected layer for classification
        out = self.fc(out)

        return out  # Return the output

Now let's train BERT on this data.  A few practicalities of this environment: if you encounter an out of memory error:

* Reset the notebook (Runtime > Factory reset runtime) and execute all cells from the beginning.
* If your `max_length` is high, try reducing the `batch_size` in `get_batches` above.

Even on a GPU, BERT can take a long time to train, so you might try experimenting first with smaller `max_data_points` above. before running it on the full training data.

In [None]:
def train_and_evaluate(bert_model_name, model_filename, train_x, train_y, dev_x, dev_y, labels, embedding_size=768, doLowerCase=None):
    """
    Trains and evaluates a BERT model for classification.

    Args:
        bert_model_name (str): The name of the pre-trained BERT model to use.
        model_filename (str): The filename to save the best model.
        train_x (list): Training data features.
        train_y (list): Training data labels.
        dev_x (list): Development/validation data features.
        dev_y (list): Development/validation data labels.
        labels (list): List of unique labels in the dataset.
        embedding_size (int, optional): Size of the BERT embeddings. Default is 768.
        doLowerCase (bool, optional): Whether to lowercase the input text. Default is None.

    Returns:
        None
    """
    # Record the start time
    start_time = time.time()

    # Initialize the BERT classifier with the given parameters
    bert_model = BERTClassifier(params={"doLowerCase": doLowerCase, "model_name": bert_model_name, "embedding_size": embedding_size, "label_length": len(labels)})
    
    # Move the model to the appropriate device (GPU or CPU)
    bert_model.to(device)

    # Get batches of training and development data
    batch_x, batch_y = bert_model.get_batches(train_x, train_y)
    dev_batch_x, dev_batch_y = bert_model.get_batches(dev_x, dev_y)

    # Initialize the optimizer and loss function
    optimizer = torch.optim.Adam(bert_model.parameters(), lr=1e-5)
    cross_entropy = nn.CrossEntropyLoss()

    # Set the number of epochs and initialize the best development accuracy
    num_epochs = 5
    best_dev_acc = 0.

    # Training loop
    for epoch in range(num_epochs):

        # Set PyTorch model to training mode (activates things like dropout and batch normalization)
        bert_model.train()

        # Train on each batch
        for x, y in tqdm(list(zip(batch_x, batch_y))):
            y_pred = bert_model.forward(x)
            loss = cross_entropy(y_pred.view(-1, bert_model.num_labels), y.view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Evaluate the model on the development set
        dev_accuracy = evaluate(bert_model, dev_batch_x, dev_batch_y)
        if epoch % 1 == 0:
            print("Epoch %s, dev accuracy: %.3f" % (epoch, dev_accuracy))
            if dev_accuracy > best_dev_acc:
                torch.save(bert_model.state_dict(), model_filename)
                best_dev_acc = dev_accuracy

    # Load the best model and print the final results
    bert_model.load_state_dict(torch.load(model_filename))
    print("\nBest Performing Model achieves dev accuracy of : %.3f" % (best_dev_acc))
    print("Time: %.3f seconds ---" % (time.time() - start_time))


In [None]:
train_and_evaluate("bert-base-cased", "convote-bert-base-cased", train_x, train_y, dev_x, dev_y, labels, embedding_size=768, doLowerCase=False)

As you can see, training `bert-base` can be expensive.  Google has released a number of [smaller BERT models](https://github.com/google-research/bert) with fewer layers (2, 4, 6, 8, 10) and smaller dimensions (128, 256, 512) that effectively trade off accuracy for speed.  Select a few of these models and train them.  To use these models in the huggingface library that we have been using, the huggingface name of the model can be derived from the URL linking to it:

https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-2_H-128_A-2.zip -> `google/bert_uncased_L-2_H-128_A-2`

All of these smaller models are uncased (so all text is lowercase), so be sure to set `doLowerCase` to be true.  You'll also need to change the `embedding_size` parameter to this function based on the H value from the model (listed both on the BERT Github page and in the model's URL).  One sample model is provided below.

In [None]:
train_and_evaluate("google/bert_uncased_L-2_H-128_A-2", "lmrd-uncased_L-2_H-128_A-2", train_x, train_y, dev_x, dev_y, labels, embedding_size=128, doLowerCase=True)