In [1]:
import os
import torch
import random
import requests
import pandas as pd
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizer
import os
from six.moves.urllib.request import urlretrieve
from sklearn import preprocessing
import matplotlib.pyplot as plt
plt.style.use('ggplot')
def seed_all(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
seed_all(seed=1234)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def seed_all(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
seed_all(seed=1234)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class DataManager:
    """
    This class manages and preprocesses a simple text dataset for a sentence classification task.

    Attributes:
        verbose (bool): Controls verbosity for printing information during data processing.
        max_sentence_len (int): The maximum length of a sentence in the dataset.
        str_questions (list): A list to store the string representations of the questions in the dataset.
        str_labels (list): A list to store the string representations of the labels in the dataset.
        numeral_labels (list): A list to store the numerical representations of the labels in the dataset.
        maxlen (int): Maximum length for padding sequences. Sequences longer than this length will be truncated,
            and sequences shorter than this length will be padded with zeros. Defaults to 50.
        numeral_data (list): A list to store the numerical representations of the questions in the dataset.
        random_state (int): Seed value for random number generation to ensure reproducibility.
            Set this value to a specific integer to reproduce the same random sequence every time. Defaults to 6789.
        random (np.random.RandomState): Random number generator object initialized with the given random_state.
            It is used for various random operations in the class.

    Methods:
        maybe_download(dir_name, file_name, url, verbose=True):
            Downloads a file from a given URL if it does not exist in the specified directory.
            The directory and file are created if they do not exist.

        read_data(dir_name, file_names):
            Reads data from files in a directory, preprocesses it, and computes the maximum sentence length.
            Each file is expected to contain rows in the format "<label>:<question>".
            The labels and questions are stored as string representations.

        manipulate_data():
            Performs data manipulation by tokenizing, numericalizing, and padding the text data.
            The questions are tokenized and converted into numerical sequences using a tokenizer.
            The sequences are padded or truncated to the maximum sequence length.

        train_valid_test_split(train_ratio=0.9):
            Splits the data into training, validation, and test sets based on a given ratio.
            The data is randomly shuffled, and the specified ratio is used to determine the size of the training set.
            The string questions, numerical data, and numerical labels are split accordingly.
            TensorFlow `Dataset` objects are created for the training and validation sets.


    """

    def __init__(self, verbose=True, random_state=6789):
        self.verbose = verbose
        self.max_sentence_len = 0
        self.str_questions = list()
        self.str_labels = list()
        self.numeral_labels = list()
        self.numeral_data = list()
        self.random_state = random_state
        self.random = np.random.RandomState(random_state)

    @staticmethod
    def maybe_download(dir_name, file_name, url, verbose=True):
        if not os.path.exists(dir_name):
            os.mkdir(dir_name)
        if not os.path.exists(os.path.join(dir_name, file_name)):
            urlretrieve(url + file_name, os.path.join(dir_name, file_name))
        if verbose:
            print("Downloaded successfully {}".format(file_name))

    def read_data(self, dir_name, file_names):
        self.str_questions = list()
        self.str_labels = list()
        for file_name in file_names:
            file_path= os.path.join(dir_name, file_name)
            with open(file_path, "r", encoding="latin-1") as f:
                for row in f:
                    row_str = row.split(":")
                    label, question = row_str[0], row_str[1]
                    question = question.lower()
                    self.str_labels.append(label)
                    self.str_questions.append(question[0:-1])
                    if self.max_sentence_len < len(self.str_questions[-1]):
                        self.max_sentence_len = len(self.str_questions[-1])

        # turns labels into numbers
        le = preprocessing.LabelEncoder()
        le.fit(self.str_labels)
        self.numeral_labels = np.array(le.transform(self.str_labels))
        self.str_classes = le.classes_
        self.num_classes = len(self.str_classes)
        if self.verbose:
            print("\nSample questions and corresponding labels... \n")
            print(self.str_questions[0:5])
            print(self.str_labels[0:5])

    def manipulate_data(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        vocab = self.tokenizer.get_vocab()
        self.word2idx = {w: i for i, w in enumerate(vocab)}
        self.idx2word = {i:w for w,i in self.word2idx.items()}
        self.vocab_size = len(self.word2idx)

        token_ids = []
        num_seqs = []
        for text in self.str_questions:  # iterate over the list of text
          text_seqs = self.tokenizer.tokenize(str(text))  # tokenize each text individually
          # Convert tokens to IDs
          token_ids = self.tokenizer.convert_tokens_to_ids(text_seqs)
          # Convert token IDs to a tensor of indices using your word2idx mapping
          seq_tensor = torch.LongTensor(token_ids)
          num_seqs.append(seq_tensor)  # append the tensor for each sequence

        # Pad the sequences and create a tensor
        if num_seqs:
          self.numeral_data = pad_sequence(num_seqs, batch_first=True)  # Pads to max length of the sequences
          self.num_sentences, self.max_seq_len = self.numeral_data.shape

    def train_valid_test_split(self, train_ratio=0.8, test_ratio = 0.1):
        train_size = int(self.num_sentences*train_ratio) +1
        test_size = int(self.num_sentences*test_ratio) +1
        valid_size = self.num_sentences - (train_size + test_size)
        data_indices = list(range(self.num_sentences))
        random.shuffle(data_indices)
        self.train_str_questions = [self.str_questions[i] for i in data_indices[:train_size]]
        self.train_numeral_labels = self.numeral_labels[data_indices[:train_size]]
        train_set_data = self.numeral_data[data_indices[:train_size]]
        train_set_labels = self.numeral_labels[data_indices[:train_size]]
        train_set_labels = torch.from_numpy(train_set_labels)
        train_set = torch.utils.data.TensorDataset(train_set_data, train_set_labels)
        self.test_str_questions = [self.str_questions[i] for i in data_indices[-test_size:]]
        self.test_numeral_labels = self.numeral_labels[data_indices[-test_size:]]
        test_set_data = self.numeral_data[data_indices[-test_size:]]
        test_set_labels = self.numeral_labels[data_indices[-test_size:]]
        test_set_labels = torch.from_numpy(test_set_labels)
        test_set = torch.utils.data.TensorDataset(test_set_data, test_set_labels)
        self.valid_str_questions = [self.str_questions[i] for i in data_indices[train_size:-test_size]]
        self.valid_numeral_labels = self.numeral_labels[data_indices[train_size:-test_size]]
        valid_set_data = self.numeral_data[data_indices[train_size:-test_size]]
        valid_set_labels = self.numeral_labels[data_indices[train_size:-test_size]]
        valid_set_labels = torch.from_numpy(valid_set_labels)
        valid_set = torch.utils.data.TensorDataset(valid_set_data, valid_set_labels)
        self.train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
        self.test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
        self.valid_loader = DataLoader(valid_set, batch_size=64, shuffle=False)

In [None]:
print('Loading data...')
DataManager.maybe_download("data", "train_2000.label", "http://cogcomp.org/Data/QA/QC/")

dm = DataManager()
dm.read_data("data/", ["train_2000.label"])
dm.manipulate_data()
dm.train_valid_test_split(train_ratio=0.8, test_ratio = 0.1)
for x, y in dm.train_loader:
    print(x.shape, y.shape)
    break

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class BaseTrainer:
    def __init__(self, model, criterion, optimizer, train_loader, val_loader):
        self.model = model
        self.criterion = criterion  #the loss function
        self.optimizer = optimizer  #the optimizer
        self.train_loader = train_loader  #the train loader
        self.val_loader = val_loader  #the valid loader

    #the function to train the model in many epochs
    def fit(self, num_epochs):
        self.num_batches = len(self.train_loader)

        for epoch in range(num_epochs):
            print(f'Epoch {epoch + 1}/{num_epochs}')
            train_loss, train_accuracy = self.train_one_epoch()
            val_loss, val_accuracy = self.validate_one_epoch()
            print(
                f'{self.num_batches}/{self.num_batches} - train_loss: {train_loss:.4f} - train_accuracy: {train_accuracy*100:.4f}% \
                - val_loss: {val_loss:.4f} - val_accuracy: {val_accuracy*100:.4f}%')

    #train in one epoch, return the train_acc, train_loss
    def train_one_epoch(self):
        self.model.train()
        running_loss, correct, total = 0.0, 0, 0
        for i, data in enumerate(self.train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device).long()  # Cast labels to LongTensor
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        train_accuracy = correct / total
        train_loss = running_loss / self.num_batches
        return train_loss, train_accuracy


    #evaluate on a loader and return the loss and accuracy
    def evaluate(self, loader):
        self.model.eval()
        loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for data in loader:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device).long()  # Cast labels to LongTensor
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = correct / total
        loss = loss / len(self.val_loader)
        return loss, accuracy

    #return the val_acc, val_loss, be called at the end of each epoch
    def validate_one_epoch(self):
      val_loss, val_accuracy = self.evaluate(self.val_loader)
      return val_loss, val_accuracy

In [None]:
from transformers import AutoModel, AutoTokenizer, AdamW
from datasets import Dataset

model_name = "bert-base-uncased"  # BERT or any similar model

# Tokenize input and prepare model inputs
tokenizer = AutoTokenizer.from_pretrained(model_name)

dataset = Dataset.from_dict({"text": dm.str_questions, "label": dm.numeral_labels})

# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length= 36)

dataset = dataset.map(tokenize_function, batched=True)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
print(dataset)

Map: 100%|██████████| 2000/2000 [00:00<00:00, 9154.52 examples/s]

Dataset({
    features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 2000
})





The following function splits the BERT dataset `dataset` into three BERT datasets for training, valid, and testing.

In [None]:
def train_valid_test_split(dataset, train_ratio=0.8, test_ratio = 0.1):
    num_sentences = len(dataset)
    train_size = int(num_sentences*train_ratio) +1
    test_size = int(num_sentences*test_ratio) +1
    valid_size = num_sentences - (train_size + test_size)
    train_set = dataset[:train_size]
    train_set = Dataset.from_dict(train_set)
    train_set.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
    test_set = dataset[-test_size:]
    test_set = Dataset.from_dict(test_set)
    test_set.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
    valid_set = dataset[train_size:-test_size]
    valid_set = Dataset.from_dict(valid_set)
    valid_set.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
    train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
    valid_loader = DataLoader(valid_set, batch_size=64, shuffle=False)
    return train_loader, test_loader, valid_loader


In [None]:
train_loader, test_loader, valid_loader = train_valid_test_split(dataset)

You need to implement the class `PrefixTuningForClassification` for the prefix prompt fine-tuning. We first load a pre-trained BERT model specified by `model_name`. The parameter `prefix_length` specifies the length of the prefix prompts we add to the pre-trained BERT model. Specifically, given the input batch `[batch_size, seq_len]`, we input to the embedding layer of the pre-trained BERT model to obtain `[batch_size, seq_len, embed_size]`. We create the prefix prompts $P$ of the size `[prefix_length, embed_size]` and concatenate to the embeddings from the pre-trained BERT to obtain `[batch_size, seq_len + prefix_length, embed_size]`. This concatenation tensor will then be fed to the encoder layers of the pre-trained BERT layer to obtain the last `[batch_size, seq_len + prefix_length, embed_size]`.

We then take mean across the seq_len to obtain `[batch_size, embed_size]` on which we can build up a linear layer for making predictions. Please note that **the parameters to tune include the prefix prompts $P$** and **the output linear layer**, and you should freeze the parameters of the BERT pre-trained model. Moreover, your code should cover the edge case when `prefix_length=None`. In this case, we do not insert any prefix prompts and we only do fine-tuning for the output linear layer on top.  

In [None]:
class PrefixTuningForClassification(nn.Module):
    def __init__(self, model_name, prefix_length=None, data_manager=None):
        super(PrefixTuningForClassification, self).__init__()

        # Load the pretrained transformer model (BERT-like model)
        self.model = AutoModel.from_pretrained(model_name)
        self.hidden_size = self.model.config.hidden_size
        self.prefix_length = prefix_length
        self.num_classes = data_manager.num_classes

        # Freeze BERT parameters (only fine-tuning the prompt and classifier layers)
        for param in self.model.parameters():
            param.requires_grad = False

        # If prefix_length is specified, create prefix embeddings
        if self.prefix_length is not None:
            self.prefix_embeddings = nn.Parameter(
                torch.randn(self.prefix_length, self.hidden_size)
            )
        
        # Linear layer for classification
        self.classifier = nn.Linear(self.hidden_size, self.num_classes)

    def forward(self, input_ids, attention_mask):
        # Get the embeddings from BERT
        embedding_output = self.model.embeddings(input_ids)

        # If prefix_length is specified, prepend the prefix to the embedding
        if self.prefix_length is not None:
            batch_size = embedding_output.size(0)
            # Repeat the prefix for each sample in the batch
            prefix = self.prefix_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
            # Concatenate prefix with embeddings along sequence length dimension
            embedding_output = torch.cat((prefix, embedding_output), dim=1)

            # Extend the attention mask for the prefix tokens (with value 1, since they should not be masked)
            prefix_mask = torch.ones((batch_size, self.prefix_length), dtype=attention_mask.dtype).to(device)
            attention_mask = torch.cat((prefix_mask, attention_mask), dim=1)

        # Cast attention_mask to float (if not already) and expand dimensions
        attention_mask = attention_mask.to(dtype=torch.float)
        attention_mask = attention_mask[:, None, None, :]  # Expand mask for multi-head attention

        # Pass through the BERT encoder
        encoder_outputs = self.model.encoder(
            embedding_output,
            attention_mask=attention_mask
        )
        
        # Get the last hidden state from the encoder
        sequence_output = encoder_outputs.last_hidden_state  # [batch_size, seq_len + prefix_length, hidden_size]

        # Take the mean across sequence length (global average pooling)
        pooled_output = torch.mean(sequence_output, dim=1)  # [batch_size, hidden_size]

        # Pass through the classifier for prediction
        logits = self.classifier(pooled_output)  # [batch_size, num_classes]

        return logits

You can use the following `FineTunedBaseTrainer` to train the prompt fine-tuning models.

In [None]:
class FineTunedBaseTrainer:
    def __init__(self, model, criterion, optimizer, train_loader, val_loader):
        self.model = model
        self.criterion = criterion  #the loss function
        self.optimizer = optimizer  #the optimizer
        self.train_loader = train_loader  #the train loader
        self.val_loader = val_loader  #the valid loader

    #the function to train the model in many epochs
    def fit(self, num_epochs):
        self.num_batches = len(self.train_loader)

        for epoch in range(num_epochs):
            print(f'Epoch {epoch + 1}/{num_epochs}')
            train_loss, train_accuracy = self.train_one_epoch()
            val_loss, val_accuracy = self.validate_one_epoch()
            print(
                f'{self.num_batches}/{self.num_batches} - train_loss: {train_loss:.4f} - train_accuracy: {train_accuracy*100:.4f}% \
                - val_loss: {val_loss:.4f} - val_accuracy: {val_accuracy*100:.4f}%')

    #train in one epoch, return the train_acc, train_loss
    def train_one_epoch(self):
        self.model.train()
        running_loss, correct, total = 0.0, 0, 0
        for batch in self.train_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)
            self.optimizer.zero_grad()
            outputs = self.model(input_ids= input_ids, attention_mask= attention_mask)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        train_accuracy = correct / total
        train_loss = running_loss / self.num_batches
        return train_loss, train_accuracy

    #evaluate on a loader and return the loss and accuracy
    def evaluate(self, loader):
        self.model.eval()
        loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for batch in loader:
                input_ids = batch["input_ids"].to(device)
                labels = batch["label"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                outputs = self.model(input_ids= input_ids, attention_mask= attention_mask)
                loss = self.criterion(outputs, labels)
                loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = correct / total
        loss = loss / len(self.val_loader)
        return loss, accuracy

    #return the val_acc, val_loss, be called at the end of each epoch
    def validate_one_epoch(self):
      val_loss, val_accuracy = self.evaluate(self.val_loader)
      return val_loss, val_accuracy

We declare and train the prefix-prompt tuning model. In addition, you need to be patient with this model because it might converge slowly with many epochs.

In [None]:
prefix_tuning_model = PrefixTuningForClassification(model_name="bert-base-uncased", prefix_length=20, data_manager=dm).to(device)

In [None]:
# Optimizer: tune the classifier and prefix embeddings if prefix_length is provided
if prefix_tuning_model.prefix_length is not None:
    optimizer = torch.optim.Adam(
        list(prefix_tuning_model.classifier.parameters()) + [prefix_tuning_model.prefix_embeddings], lr=1e-3
    )
else:
    optimizer = torch.optim.Adam(prefix_tuning_model.classifier.parameters(), lr=1e-3)

# Loss function
criterion = nn.CrossEntropyLoss()

# Trainer setup
trainer = FineTunedBaseTrainer(model=prefix_tuning_model, criterion=criterion, optimizer=optimizer, train_loader=train_loader, val_loader=valid_loader)

# Train the model
trainer.fit(num_epochs=100)

Epoch 1/100
26/26 - train_loss: 1.6873 - train_accuracy: 25.4841%                 - val_loss: 0.8231 - val_accuracy: 28.7879%
Epoch 2/100
26/26 - train_loss: 1.5871 - train_accuracy: 38.6633%                 - val_loss: 0.7774 - val_accuracy: 32.3232%
Epoch 3/100
26/26 - train_loss: 1.6114 - train_accuracy: 46.4085%                 - val_loss: 0.8169 - val_accuracy: 32.8283%
Epoch 4/100
26/26 - train_loss: 1.4183 - train_accuracy: 48.2823%                 - val_loss: 0.7836 - val_accuracy: 53.0303%
Epoch 5/100
26/26 - train_loss: 1.3544 - train_accuracy: 51.3429%                 - val_loss: 0.8449 - val_accuracy: 55.0505%
Epoch 6/100
26/26 - train_loss: 1.2763 - train_accuracy: 56.9644%                 - val_loss: 0.7435 - val_accuracy: 60.1010%
Epoch 7/100
26/26 - train_loss: 1.1752 - train_accuracy: 59.7751%                 - val_loss: 0.7205 - val_accuracy: 60.6061%
Epoch 8/100
26/26 - train_loss: 1.1207 - train_accuracy: 62.2736%                 - val_loss: 0.6775 - val_accuracy: 6