In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer

In [5]:
datasets_dict = {'MRPC': {'path': "SetFit/mrpc", 'data': ('text1', 'text2')}, 'QQP': {'path': "SetFit/qqp", 'data': ('text1', 'text2')}, \
            'SST': {'path': "sst", 'data': ('sentence')}, 'ANLI': {'path': "facebook/anli", 'data': ('premise', 'hypothesis')}}

models_dict = {'BERT-Base': "google-bert/bert-base-uncased", 'BERT-Large': "google-bert/bert-large-uncased", \
          'RoBERTa-Base': "FacebookAI/roberta-base", 'RoBERTa-Large': "FacebookAI/roberta-large"}

In [7]:
class FineTune():
    def __init__(self, model_path, dataset_dict):
        # Load the dataset
        print(f"Loading the dataset from {dataset_dict['path']}")
        self.dataset = load_dataset(dataset_dict['path'], split = 'train')
        self.data = dataset_dict['data']
        # Task is paraphrasing if we have more than one data column
        self.paraphrase = (len(self.data) == 1)

        print('Generating tokens for the dataset')
        self.tokenizer =  AutoTokenizer.from_pretrained(model_path)
        self.tokenized_dataset = self.tokenize()
        self.num_labels = len(set(self.dataset['label']))

        print(f'Loading the model from {model_path}')
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=self.num_labels)

    def tokenize(self):
        if(self.paraphrase):
            prompt = "{text1}\nPARAPHRASE:\n{text2}"
            col1, col2 = self.data
            text_data = [prompt.format(text1=sentence1, text2=sentence2) for sentence1, sentence2 in zip(self.dataset[col1], self.dataset[col2])]
            self.dataset = self.dataset.add_column('text', text_data)
            self.data = 'text'
        else:
            self.data = self.data[0]

        def tokenize_function(examples):
            return self.tokenizer(examples[self.data], padding="max_length", truncation=True)

        tokenized_dataset = self.dataset.map(tokenize_function, batched=True)

        return tokenized_dataset

    def train(self, num_epochs=6):
        training_args = TrainingArguments(output_dir="test_trainer", num_train_epochs=num_epochs)

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=self.tokenized_dataset
        )

        trainer.train()


In [12]:
from peft import LoraConfig, TaskType
from peft import get_peft_model

class PEFT(FineTune):
    def __init__(self, model_path, dataset_dict, lora_rank=4):
        super().__init__(model_path, dataset_dict)
        self.rank = lora_rank
        self.config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=self.rank, lora_alpha=32)
        self.model = get_peft_model(self.model, self.config)

In [13]:
model = FineTune(models_dict['BERT-Base'], datasets_dict['MRPC'])

Loading the dataset from SetFit/mrpc




Generating tokens for the dataset
Loading the model from google-bert/bert-base-uncased


Some weights of the model checkpoint at google-bert/bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base