In [53]:
from transformers import AutoTokenizer, DistilBertForSequenceClassification
import copy
import torch.nn as nn
import torch

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
bert = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=6)

class Bert(nn.Module):
    def __init__(self):
        super(Bert, self).__init__()
        self.bert = copy.deepcopy(bert)
        
    def forward(self, input_ids, attention_mask=None, labels=None):
        return self.bert(input_ids, attention_mask=attention_mask, labels=labels)
                
    def print_grad_list(self): 
        grad_list = [] 
        for p in self.bert.parameters():
            grad_list.append(p.requires_grad)
        print(grad_list)
        
    def count_trainable_parameters(self):
        return sum(p.numel() for p in self.bert.parameters() if p.requires_grad)

class LoRALayer(nn.Module):    
    def __init__(self, layer, r = 2, alpha = 1.0):
        super(LoRALayer, self).__init__()
        self.layer = layer
        self.r = r
        self.alpha = alpha
        std_dev = 1 / torch.sqrt(torch.tensor(r).float())
        self.A0 = nn.Parameter(torch.zeros(layer.in_features, self.r) * std_dev)
        self.B0 = nn.Parameter(torch.zeros(r, layer.out_features))
        self.A1 = nn.Parameter(torch.randn(layer.in_features, self.r) * std_dev)
        self.B1 = nn.Parameter(torch.zeros(r, layer.out_features))
        self.A2 = nn.Parameter(torch.randn(layer.in_features, self.r) * std_dev)
        self.B2 = nn.Parameter(torch.zeros(r, layer.out_features))
        self.A = self.A1
        self.B = self.B1
        self.A0.requires_grad = False
        self.B0.requires_grad = False
        self.A1.requires_grad = True
        self.B1.requires_grad = True
        self.A2.requires_grad = False
        self.B2.requires_grad = False
        
    def switch_task(self, task_id):
        if task_id == 0:
            self.A = self.A0
            self.B = self.B0
        elif task_id == 1:
            self.A = self.A1
            self.B = self.B1
            self.A1.requires_grad = True
            self.B1.requires_grad = True
            self.A2.requires_grad = False
            self.B2.requires_grad = False
        elif task_id == 2:  # Corrected to task_id == 2
            self.A = self.A2
            self.B = self.B2
            self.A1.requires_grad = False
            self.B1.requires_grad = False
            self.A2.requires_grad = True
            self.B2.requires_grad = True

    def forward(self, x):
        result = self.layer(x) + self.alpha * (x @ self.A @ self.B)
        return result

class BertWithSwitchableTask(nn.Module):
    def __init__(self):
        super(BertWithSwitchableTask, self).__init__()
        self.bert = copy.deepcopy(bert)
        self.lora_layers = []
        self.add_lora_layers()

    def switch_task(self, task_id):
        for layer in self.lora_layers:
            layer.switch_task(task_id)

    def forward(self, input_ids, attention_mask=None, labels=None):
        result = self.bert(input_ids, attention_mask=attention_mask, labels=labels)
        return result

    def add_lora_layers(self):
        for param in self.bert.parameters():
            param.requires_grad = False
            
        for layer in self.bert.distilbert.transformer.layer:
                layer.attention.q_lin = LoRALayer(layer.attention.q_lin)
                layer.attention.k_lin = LoRALayer(layer.attention.k_lin)
                layer.attention.v_lin = LoRALayer(layer.attention.v_lin)
                layer.attention.out_lin = LoRALayer(layer.attention.out_lin)
                layer.ffn.lin1 = LoRALayer(layer.ffn.lin1)
                layer.ffn.lin2 = LoRALayer(layer.ffn.lin2)
                self.lora_layers += [
                    layer.attention.q_lin,
                    layer.attention.k_lin,
                    layer.attention.v_lin,
                    layer.attention.out_lin,
                    layer.ffn.lin1,
                    layer.ffn.lin2,
                ]
        self.bert.pre_classifier = LoRALayer(self.bert.pre_classifier)
        self.bert.classifier = LoRALayer(self.bert.classifier)
        self.lora_layers += [
            self.bert.pre_classifier,
            self.bert.classifier,
        ]
                
    def print_grad_list(self): 
        grad_list = [] 
        for p in self.bert.parameters():
            grad_list.append(p.requires_grad)
        print(grad_list)
        
    def count_trainable_parameters(self):
        return sum(p.numel() for p in self.bert.parameters() if p.requires_grad)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [54]:
BertWithSwitchableTask().print_grad_list()
print(f"TRAINABLE PARAMS FOR MODEL: {Bert().count_trainable_parameters()}")
print(f"TRAINABLE PARAMS FOR LORA MODEL: {BertWithSwitchableTask().count_trainable_parameters()}")

[False, False, False, False, False, False, True, True, False, False, False, False, False, False, True, True, False, False, False, False, False, False, True, True, False, False, False, False, False, False, True, True, False, False, False, False, False, False, False, False, True, True, False, False, False, False, False, False, True, True, False, False, False, False, False, False, False, False, True, True, False, False, False, False, False, False, True, True, False, False, False, False, False, False, True, True, False, False, False, False, False, False, True, True, False, False, False, False, False, False, False, False, True, True, False, False, False, False, False, False, True, True, False, False, False, False, False, False, False, False, True, True, False, False, False, False, False, False, True, True, False, False, False, False, False, False, True, True, False, False, False, False, False, False, True, True, False, False, False, False, False, False, False, False, True, True, False, Fals

In [55]:
from datasets import load_dataset
from torch.utils.data import DataLoader, TensorDataset

class CustomDataloader:
    def __init__(self, name, label, prct):
        self.label = label
        
        # Load the dataset
        news_dataset = load_dataset(name, split=f"train[:{prct}%]")  # Using 1% of the dataset for a quick demonstration
        news_dataset = news_dataset.train_test_split(test_size=0.2)  # Split the dataset into train and test sets
            
        train_dataset = news_dataset['train']
        test_dataset = news_dataset['test']
        
        def tokenize_function(examples):
            return tokenizer(examples["text"], padding="max_length", truncation=True)
        
        train_dataset = train_dataset.map(tokenize_function, batched=True)
        test_dataset = test_dataset.map(tokenize_function, batched=True)
        
        # Convert to PyTorch tensors
        train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", label])
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", label])
        
        self.train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
        self.test_dataloader = DataLoader(test_dataset, batch_size=2)
        
    def train(self):
        return self.train_dataloader
        
    def test(self):
        return self.test_dataloader

In [56]:
import torch
from torch.optim import AdamW
import tqdm
import os

os.environ["TOKENIZERS_PARALLELISM"] = "true"
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
"ag_news"

class Trainer:
    def __init__(self, dataloader: CustomDataloader):
        self.dataloader = dataloader
    
    def train_model(self, model, name):
        model.to(device)
        
        # Define optimizer
        optimizer = AdamW(model.parameters(), lr=0.0001)
        criterion = nn.CrossEntropyLoss()
        
        # Training loop
        model.train()
        for epoch in range(5):
            model.train()
            train_loss = 0
            progress_bar = tqdm.tqdm(self.dataloader.train(), desc=f"Training")
            for batch in progress_bar:
                optimizer.zero_grad()
                inputs = {"input_ids": batch["input_ids"].to(device), "attention_mask": batch["attention_mask"].to(device), "labels": batch[self.dataloader.label].to(device)}
                outputs = model(**inputs)
                loss = criterion(outputs.logits, batch[self.dataloader.label].to(device))
                
                train_loss += loss.item()
                
                loss.backward()
                optimizer.step()
                progress_bar.set_postfix(epoch=epoch, loss=loss.item())
            print(f"Eval accuracy = {self.eval_model(model)}")
            
        
        torch.save(model, f"./models/{name}.pt")
        return model
        
    def eval_model(self, model):
        # Evaluation
        model.eval()
        eval_loss = 0
        correct_predictions = 0
        total_predictions = 0
        with torch.no_grad():
            for batch in self.dataloader.test():
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch[self.dataloader.label].to(device)
                
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                eval_loss += loss.item()
                
                predictions = outputs.logits.argmax(dim=-1)
                correct_predictions += (predictions == labels).sum().item()
                total_predictions += len(labels)
        
        eval_accuracy = correct_predictions / total_predictions
        return eval_accuracy

In [57]:
dataloader_news = CustomDataloader("ag_news", "label", 1)

Map:   0%|          | 0/960 [00:00<?, ? examples/s]

Map:   0%|          | 0/240 [00:00<?, ? examples/s]

In [58]:
dataloader_trec = CustomDataloader("trec", "coarse_label", 10)

Map:   0%|          | 0/436 [00:00<?, ? examples/s]

Map:   0%|          | 0/109 [00:00<?, ? examples/s]

In [59]:
bert_switchable = BertWithSwitchableTask().to(device)

In [60]:
bert_switchable.switch_task(1)
bert_switchable = Trainer(dataloader_trec).train_model(bert_switchable, "distilbert_lora")

Training: 100%|██████████| 218/218 [00:13<00:00, 15.91it/s, epoch=0, loss=1.34] 


Eval accuracy = 0.6238532110091743


Training: 100%|██████████| 218/218 [00:13<00:00, 15.87it/s, epoch=1, loss=1.21] 


Eval accuracy = 0.7706422018348624


Training: 100%|██████████| 218/218 [00:13<00:00, 15.70it/s, epoch=2, loss=0.961] 


Eval accuracy = 0.8348623853211009


Training: 100%|██████████| 218/218 [00:13<00:00, 15.71it/s, epoch=3, loss=0.0254] 


Eval accuracy = 0.8165137614678899


Training: 100%|██████████| 218/218 [00:14<00:00, 15.52it/s, epoch=4, loss=0.589]  


Eval accuracy = 0.7798165137614679


In [61]:
bert_switchable.switch_task(2)
bert_switchable = Trainer(dataloader_news).train_model(bert_switchable, "distilbert_lora")

Training: 100%|██████████| 480/480 [00:29<00:00, 16.22it/s, epoch=0, loss=0.135] 


Eval accuracy = 0.8875


Training: 100%|██████████| 480/480 [00:29<00:00, 16.24it/s, epoch=1, loss=0.251] 


Eval accuracy = 0.9083333333333333


Training: 100%|██████████| 480/480 [00:29<00:00, 16.35it/s, epoch=2, loss=0.101]  


Eval accuracy = 0.9


Training: 100%|██████████| 480/480 [00:29<00:00, 16.12it/s, epoch=3, loss=1.4]    


Eval accuracy = 0.8791666666666667


Training: 100%|██████████| 480/480 [00:29<00:00, 16.19it/s, epoch=4, loss=0.12]   


Eval accuracy = 0.8708333333333333


In [62]:
bert_switchable.switch_task(0)
print(f"Accuracy on news dataset without switching: {Trainer(dataloader_news).eval_model(bert_switchable)}")
print(f"Accuracy on trec dataset without switching: {Trainer(dataloader_trec).eval_model(bert_switchable)}")

bert_switchable.switch_task(1)
print(f"Accuracy on ag_news dataset when switched to trec: {Trainer(dataloader_news).eval_model(bert_switchable)}")
bert_switchable.switch_task(2)
print(f"Accuracy on trec dataset when switched to ag_news: {Trainer(dataloader_trec).eval_model(bert_switchable)}")

bert_switchable.switch_task(1)
print(f"Accuracy on trec dataset when switched to trec: {Trainer(dataloader_trec).eval_model(bert_switchable)}")
bert_switchable.switch_task(2)
print(f"Accuracy on news dataset when switched to news: {Trainer(dataloader_news).eval_model(bert_switchable)}")

Accuracy on news dataset without switching: 0.22083333333333333
Accuracy on trec dataset without switching: 0.1926605504587156
Accuracy on ag_news dataset when switched to trec: 0.1
Accuracy on trec dataset when switched to ag_news: 0.10091743119266056
Accuracy on trec dataset when switched to trec: 0.7798165137614679
Accuracy on news dataset when switched to news: 0.8708333333333333


In [66]:
import time

start = time.process_time()
bert_switchable.switch_task(1)
end = time.process_time()
print(f"Time it takes to switch tasks on lora model: {end - start}")

start = time.process_time()
test = torch.load("models/distilbert_lora.pt").to(device)
end = time.process_time()
print(f"Time it takes to load another model from drive: {end - start}")

Time it takes to switch tasks on lora model: 0.0004307069999640589
Time it takes to load another model from drive: 0.1443533159999788
