In [68]:
from transformers import AutoTokenizer, DistilBertForSequenceClassification
import copy
import torch.nn as nn
import torch

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
bert = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=6)

class Bert(nn.Module):
    def __init__(self):
        super(Bert, self).__init__()
        self.bert = copy.deepcopy(bert)
        
    def forward(self, input_ids, attention_mask=None, labels=None):
        return self.bert(input_ids, attention_mask=attention_mask, labels=labels)
        
    def count_trainable_parameters(self):
        return sum(p.numel() for p in self.bert.parameters() if p.requires_grad)

class LoRALayer(nn.Module):    
    def __init__(self, layer, r = 2, alpha = 1.0):
        super(LoRALayer, self).__init__()
        self.layer = layer
        self.r = r
        self.alpha = alpha
        std_dev = 1 / torch.sqrt(torch.tensor(r).float())
        self.A0 = nn.Parameter(torch.zeros(layer.in_features, self.r) * std_dev)
        self.B0 = nn.Parameter(torch.zeros(r, layer.out_features))
        self.A1 = nn.Parameter(torch.randn(layer.in_features, self.r) * std_dev)
        self.B1 = nn.Parameter(torch.zeros(r, layer.out_features))
        self.A2 = nn.Parameter(torch.randn(layer.in_features, self.r) * std_dev)
        self.B2 = nn.Parameter(torch.zeros(r, layer.out_features))
        self.A = self.A1
        self.B = self.B1
        self.A0.requires_grad = False
        self.B0.requires_grad = False
        self.A1.requires_grad = True
        self.B1.requires_grad = True
        self.A2.requires_grad = False
        self.B2.requires_grad = False
        
    def switch_task(self, task_id):
        if task_id == 0:
            self.A = self.A0
            self.B = self.B0
        elif task_id == 1:
            self.A = self.A1
            self.B = self.B1
            self.A1.requires_grad = True
            self.B1.requires_grad = True
            self.A2.requires_grad = False
            self.B2.requires_grad = False
        elif task_id == 2:  # Corrected to task_id == 2
            self.A = self.A2
            self.B = self.B2
            self.A1.requires_grad = False
            self.B1.requires_grad = False
            self.A2.requires_grad = True
            self.B2.requires_grad = True

    def forward(self, x):
        result = self.layer(x) + self.alpha * (x @ self.A @ self.B)
        return result

class BertWithSwitchableTask(nn.Module):
    def __init__(self):
        super(BertWithSwitchableTask, self).__init__()
        self.bert = copy.deepcopy(bert)
        self.lora_layers = []
        self.add_lora_layers()

    def switch_task(self, task_id):
        for layer in self.lora_layers:
            layer.switch_task(task_id)

    def forward(self, input_ids, attention_mask=None, labels=None):
        result = self.bert(input_ids, attention_mask=attention_mask, labels=labels)
        return result

    def add_lora_layers(self):
        for param in self.bert.parameters():
            param.requires_grad = False
            
        for layer in self.bert.distilbert.transformer.layer:
                layer.attention.q_lin = LoRALayer(layer.attention.q_lin)
                layer.attention.k_lin = LoRALayer(layer.attention.k_lin)
                layer.attention.v_lin = LoRALayer(layer.attention.v_lin)
                layer.attention.out_lin = LoRALayer(layer.attention.out_lin)
                layer.ffn.lin1 = LoRALayer(layer.ffn.lin1)
                layer.ffn.lin2 = LoRALayer(layer.ffn.lin2)
                self.lora_layers += [
                    layer.attention.q_lin,
                    layer.attention.k_lin,
                    layer.attention.v_lin,
                    layer.attention.out_lin,
                    layer.ffn.lin1,
                    layer.ffn.lin2,
                ]
        self.bert.pre_classifier = LoRALayer(self.bert.pre_classifier)
        self.bert.classifier = LoRALayer(self.bert.classifier)
        self.lora_layers += [
            self.bert.pre_classifier,
            self.bert.classifier,
        ]
        
    def count_trainable_parameters(self):
        return sum(p.numel() for p in self.bert.parameters() if p.requires_grad)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [69]:
print(f"TRAINABLE PARAMS FOR MODEL: {Bert().count_trainable_parameters()}")
print(f"TRAINABLE PARAMS FOR LORA MODEL: {BertWithSwitchableTask().count_trainable_parameters()}")

TRAINABLE PARAMS FOR MODEL: 66958086
TRAINABLE PARAMS FOR LORA MODEL: 170508


In [82]:
from datasets import load_dataset
from torch.utils.data import DataLoader, TensorDataset

class CustomDataloader:
    def __init__(self, name, label, prct):
        self.label = label
        
        # Load the dataset
        news_dataset = load_dataset(name, split=f"train[:{prct}%]")
        news_dataset = news_dataset.train_test_split(test_size=0.5)  # Split the dataset into train and test sets
            
        train_dataset = news_dataset['train']
        test_dataset = news_dataset['test']
        
        def tokenize_function(examples):
            return tokenizer(examples["text"], padding="max_length", truncation=True)
        
        train_dataset = train_dataset.map(tokenize_function, batched=True)
        test_dataset = test_dataset.map(tokenize_function, batched=True)
        
        # Convert to PyTorch tensors
        train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", label])
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", label])
        
        self.train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
        self.test_dataloader = DataLoader(test_dataset, batch_size=2)
        
    def train(self):
        return self.train_dataloader
        
    def test(self):
        return self.test_dataloader

In [71]:
import torch
from torch.optim import AdamW
import tqdm
import os
from torch.optim.lr_scheduler import CosineAnnealingLR

os.environ["TOKENIZERS_PARALLELISM"] = "true"
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
"ag_news"

class Trainer:
    def __init__(self, dataloader: CustomDataloader):
        self.dataloader = dataloader
    
    def train_model(self, model, name):
        model.to(device)
        
        # Define optimizer
        num_epochs = 5
        optimizer = AdamW(model.parameters(), lr=0.0001)
        criterion = nn.CrossEntropyLoss()
        scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
        
        # Training loop
        model.train()
        for epoch in range(num_epochs):
            model.train()
            train_loss = 0
            i = 0
            progress_bar = tqdm.tqdm(self.dataloader.train(), desc=f"Training")
            for batch in progress_bar:
                i += 1
                optimizer.zero_grad()
                inputs = {"input_ids": batch["input_ids"].to(device), "attention_mask": batch["attention_mask"].to(device), "labels": batch[self.dataloader.label].to(device)}
                outputs = model(**inputs)
                loss = criterion(outputs.logits, batch[self.dataloader.label].to(device))
                
                train_loss += loss.item()
                
                loss.backward()
                optimizer.step()
                progress_bar.set_postfix(epoch=epoch, loss=train_loss/(i+1))
            scheduler.step()
            print(f"Eval accuracy = {self.eval_model(model)}")
            
        
        torch.save(model, f"./models/{name}.pt")
        return model
        
    def eval_model(self, model):
        # Evaluation
        model.eval()
        eval_loss = 0
        correct_predictions = 0
        total_predictions = 0
        with torch.no_grad():
            for batch in self.dataloader.test():
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch[self.dataloader.label].to(device)
                
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                eval_loss += loss.item()
                
                predictions = outputs.logits.argmax(dim=-1)
                correct_predictions += (predictions == labels).sum().item()
                total_predictions += len(labels)
        
        eval_accuracy = correct_predictions / total_predictions
        return eval_accuracy

In [72]:
dataloader_news = CustomDataloader("ag_news", "label", 8)

Map:   0%|          | 0/4800 [00:00<?, ? examples/s]

Map:   0%|          | 0/4800 [00:00<?, ? examples/s]

In [84]:
dataloader_trec = CustomDataloader("trec", "coarse_label", 100)

Map:   0%|          | 0/2726 [00:00<?, ? examples/s]

Map:   0%|          | 0/2726 [00:00<?, ? examples/s]

In [74]:
bert_switchable = BertWithSwitchableTask().to(device)

In [75]:
bert_switchable.switch_task(1)
bert_switchable = Trainer(dataloader_trec).train_model(bert_switchable, "distilbert_lora")

Training: 100%|██████████| 1363/1363 [01:12<00:00, 18.85it/s, epoch=0, loss=0.735]


Eval accuracy = 0.8950843727072634


Training: 100%|██████████| 1363/1363 [01:22<00:00, 16.56it/s, epoch=1, loss=0.304]


Eval accuracy = 0.8980190755685987


Training: 100%|██████████| 1363/1363 [01:21<00:00, 16.65it/s, epoch=2, loss=0.194]


Eval accuracy = 0.921863536316948


Training: 100%|██████████| 1363/1363 [01:18<00:00, 17.43it/s, epoch=3, loss=0.114]


Eval accuracy = 0.9310344827586207


Training: 100%|██████████| 1363/1363 [01:15<00:00, 18.09it/s, epoch=4, loss=0.0766]


Eval accuracy = 0.9328686720469552


In [76]:
bert_switchable.switch_task(2)
bert_switchable = Trainer(dataloader_news).train_model(bert_switchable, "distilbert_lora")

Training: 100%|██████████| 2400/2400 [02:03<00:00, 19.41it/s, epoch=0, loss=0.488]


Eval accuracy = 0.8727083333333333


Training: 100%|██████████| 2400/2400 [02:04<00:00, 19.24it/s, epoch=1, loss=0.306]


Eval accuracy = 0.8945833333333333


Training: 100%|██████████| 2400/2400 [02:16<00:00, 17.64it/s, epoch=2, loss=0.225]


Eval accuracy = 0.8929166666666667


Training: 100%|██████████| 2400/2400 [02:05<00:00, 19.12it/s, epoch=3, loss=0.151]


Eval accuracy = 0.90125


Training: 100%|██████████| 2400/2400 [02:05<00:00, 19.09it/s, epoch=4, loss=0.111]


Eval accuracy = 0.908125


In [77]:
bert_original_trec = Bert()
bert_original_trec = Trainer(dataloader_trec).train_model(bert_original_trec, "distilbert_original_trec")

Training: 100%|██████████| 1363/1363 [01:32<00:00, 14.73it/s, epoch=0, loss=0.83] 


Eval accuracy = 0.8345561261922231


Training: 100%|██████████| 1363/1363 [01:32<00:00, 14.71it/s, epoch=1, loss=0.461]


Eval accuracy = 0.8910491562729274


Training: 100%|██████████| 1363/1363 [01:33<00:00, 14.53it/s, epoch=2, loss=0.271]


Eval accuracy = 0.8782098312545855


Training: 100%|██████████| 1363/1363 [01:31<00:00, 14.87it/s, epoch=3, loss=0.178]


Eval accuracy = 0.8961848862802642


Training: 100%|██████████| 1363/1363 [01:30<00:00, 15.01it/s, epoch=4, loss=0.104]


Eval accuracy = 0.9104915627292737


In [78]:
bert_original_news = Bert()
bert_original_news = Trainer(dataloader_news).train_model(bert_original_news, "distilbert_original_news")

Training: 100%|██████████| 2400/2400 [02:38<00:00, 15.11it/s, epoch=0, loss=0.585]


Eval accuracy = 0.6625


Training: 100%|██████████| 2400/2400 [02:32<00:00, 15.78it/s, epoch=1, loss=0.713]


Eval accuracy = 0.6404166666666666


Training: 100%|██████████| 2400/2400 [02:38<00:00, 15.12it/s, epoch=2, loss=0.567]


Eval accuracy = 0.8585416666666666


Training: 100%|██████████| 2400/2400 [02:33<00:00, 15.66it/s, epoch=3, loss=0.372]


Eval accuracy = 0.8654166666666666


Training: 100%|██████████| 2400/2400 [02:49<00:00, 14.12it/s, epoch=4, loss=0.28] 


Eval accuracy = 0.8635416666666667


In [79]:
torch.save(bert, f"./models/distilbert_original.pt")

In [80]:
bert_switchable.switch_task(0)
print(f"Accuracy on news dataset without switching: {Trainer(dataloader_news).eval_model(bert_switchable)}")
print(f"Accuracy on trec dataset without switching: {Trainer(dataloader_trec).eval_model(bert_switchable)}")

bert_switchable.switch_task(1)
print(f"Accuracy on ag_news dataset when switched to trec: {Trainer(dataloader_news).eval_model(bert_switchable)}")
bert_switchable.switch_task(2)
print(f"Accuracy on trec dataset when switched to ag_news: {Trainer(dataloader_trec).eval_model(bert_switchable)}")

bert_switchable.switch_task(1)
print(f"Accuracy on trec dataset when switched to trec: {Trainer(dataloader_trec).eval_model(bert_switchable)}")
bert_switchable.switch_task(2)
print(f"Accuracy on news dataset when switched to news: {Trainer(dataloader_news).eval_model(bert_switchable)}")

print(f"Accuracy of fine-tuned bert on news dataset: {Trainer(dataloader_news).eval_model(bert_original_news)}")
print(f"Accuracy of fine-tuned bert on trec dataset: {Trainer(dataloader_trec).eval_model(bert_original_trec)}")

Accuracy on news dataset without switching: 0.211875
Accuracy on trec dataset without switching: 0.05062362435803375
Accuracy on ag_news dataset when switched to trec: 0.15375
Accuracy on trec dataset when switched to ag_news: 0.10051357300073367
Accuracy on trec dataset when switched to trec: 0.9328686720469552
Accuracy on news dataset when switched to news: 0.908125
Accuracy of fine-tuned bert on news dataset: 0.8635416666666667
Accuracy of fine-tuned bert on trec dataset: 0.9104915627292737


In [85]:
import time

start = time.process_time()
bert_switchable.switch_task(1)
end = time.process_time()
print(f"Time it takes to switch task: {end - start}")

start = time.process_time()
test = torch.load("models/distilbert_original.pt").to(device)
end = time.process_time()
print(f"Time it takes to load model: {end - start}")

start = time.process_time()
test = torch.load("models/distilbert_lora.pt").to(device)
bert_switchable.switch_task(1)
bert_switchable.switch_task(0)
bert_switchable.switch_task(1)
bert_switchable.switch_task(2)
bert_switchable.switch_task(1)
end = time.process_time()
print(f"Time it takes to load lora model from drive and then switch tasks 5 times: {end - start}")

start = time.process_time()
test = torch.load("models/distilbert_original.pt").to(device)
test = torch.load("models/distilbert_original_trec.pt").to(device)
test = torch.load("models/distilbert_original.pt").to(device)
test = torch.load("models/distilbert_original_trec.pt").to(device)
test = torch.load("models/distilbert_original_news.pt").to(device)
test = torch.load("models/distilbert_original_trec.pt").to(device)
end = time.process_time()
print(f"Time it takes to load 6 different models from drive: {end - start}")

Time it takes to switch task: 0.000921797999581031
Time it takes to load model: 0.14060739699925762
Time it takes to load lora model from drive and then switch tasks 5 times: 0.1611400900001172
Time it takes to load 6 different models from drive: 0.7930469319999247
