In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from multilora import LoRAModel, MultiLoRALayerMaskingHom, MultiLoRALayerMaskingHomEfficient, MultiLoRALayerMasking

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from multilora.benchmarking import MultiAdapterDataset, get_bitext_dataset, get_finetome_dataset, get_guanaco_dataset
N = 1000
model_id = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

dataset = MultiAdapterDataset([get_bitext_dataset(N, tokenizer), get_finetome_dataset(N, tokenizer), get_guanaco_dataset(N, tokenizer)], tokenizer)
n_adapters = 99
n_datasets = 3

In [4]:
def create_lora_hom(in_features, out_features, adapter_ids):
    return MultiLoRALayerMaskingHom(in_features, out_features, adapter_ids, n_adapters=n_adapters, rank=32)

def create_lora_hom_eff(in_features, out_features, adapter_ids):
    return MultiLoRALayerMaskingHomEfficient(in_features, out_features, adapter_ids, n_adapters=n_adapters, rank=32)

def create_lora_het(in_features, out_features, adapter_ids):
    return MultiLoRALayerMasking(in_features, out_features, adapter_ids, ranks=[32] * n_adapters)


## Homogenious LoRA Adapters

In [5]:
model = GPT2LMHeadModel.from_pretrained(model_id, device_map="auto")
lora_model = LoRAModel(model, target_modules=["c_attn"], lora_factory=create_lora_hom).cuda()
lora_model.freeze_base_model()

In [6]:
from torch.optim import AdamW
from transformers import get_scheduler

dataloader = DataLoader(dataset, batch_size=8, collate_fn=dataset.collate_fn)

optimizer = AdamW(lora_model.parameters(), lr=2e-4, weight_decay=0)
num_epochs = 1
num_training_steps = num_epochs * len(dataloader)

lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

In [7]:
device = "cuda"

loss_fn = nn.CrossEntropyLoss()

def train_step(data):
    ids, masks, labels, adapter_ids = data
    adapter_ids = adapter_ids + torch.randint_like(adapter_ids, low=0, high=n_adapters // n_datasets - 1) * n_datasets
    adapter_ids %= n_adapters
    logits = lora_model(input_ids=ids.to(device), attention_mask=masks.to(device), adapter_ids=adapter_ids.to(device))[0]
    
    loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1).to('cuda'))
    optimizer.zero_grad()
    loss.backward()

    optimizer.step()  
    lr_scheduler.step()

    return loss.item()

In [8]:
from tqdm import tqdm
from time import time
running_loss = None
alpha = 0.95
start = time()

iters = 0

for epoch in range(num_epochs):
    if iters >= 100:
        break
    for i, batch in tqdm(enumerate(dataloader)):
        loss = train_step(batch)
        if not running_loss:
            running_loss = loss
        else:
            running_loss = running_loss * alpha + loss * (1 - alpha)
        if iters % 20 == 19:
            print("AVG TIME:", (time() - start) / iters)
            print("LOSS:", running_loss)
        iters += 1
        if iters >= 100:
            break

20it [00:32,  1.59s/it]

AVG TIME: 1.6881590767910606
LOSS: 5.943533623020265


40it [01:03,  1.59s/it]

AVG TIME: 1.6371148977524195
LOSS: 6.369998167898361


60it [01:35,  1.59s/it]

AVG TIME: 1.620033813735186
LOSS: 6.4659664555656375


80it [02:07,  1.59s/it]

AVG TIME: 1.611857906172547
LOSS: 6.582099420359769


99it [02:39,  1.61s/it]

AVG TIME: 1.6071466580785887
LOSS: 6.538887024955727





## Homogenious Efficient

In [5]:
model = GPT2LMHeadModel.from_pretrained(model_id, device_map="auto")
lora_model = LoRAModel(model, target_modules=["c_attn"], lora_factory=create_lora_hom_eff).cuda()
lora_model.freeze_base_model()

In [6]:
from torch.optim import AdamW
from transformers import get_scheduler

dataloader = DataLoader(dataset, batch_size=8, collate_fn=dataset.collate_fn)

optimizer = AdamW(lora_model.parameters(), lr=2e-4, weight_decay=0)
num_epochs = 1
num_training_steps = num_epochs * len(dataloader)

lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

In [7]:
device = "cuda"

loss_fn = nn.CrossEntropyLoss()

def train_step(data):
    ids, masks, labels, adapter_ids = data
    adapter_ids = adapter_ids + torch.randint_like(adapter_ids, low=0, high=n_adapters // n_datasets - 1) * n_datasets
    adapter_ids %= n_adapters
    logits = lora_model(input_ids=ids.to(device), attention_mask=masks.to(device), adapter_ids=adapter_ids.to(device))[0]
    
    loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1).to('cuda'))
    optimizer.zero_grad()
    loss.backward()

    optimizer.step()  
    lr_scheduler.step()

    return loss.item()

In [8]:
from tqdm import tqdm
from time import time
running_loss = None
alpha = 0.95
start = time()

iters = 0

for epoch in range(num_epochs):
    if iters >= 100:
        break
    for i, batch in tqdm(enumerate(dataloader)):
        loss = train_step(batch)
        if not running_loss:
            running_loss = loss
        else:
            running_loss = running_loss * alpha + loss * (1 - alpha)
        if iters % 20 == 19:
            print("AVG TIME:", (time() - start) / iters)
            print("LOSS:", running_loss)
        iters += 1
        if iters >= 100:
            break

0it [00:00, ?it/s]

20it [00:07,  2.75it/s]

AVG TIME: 0.3960070735529849
LOSS: 5.943649301706349


40it [00:14,  2.76it/s]

AVG TIME: 0.378990894708878
LOSS: 6.369598071105254


60it [00:22,  2.76it/s]

AVG TIME: 0.3734110129081597
LOSS: 6.465682817730355


80it [00:29,  2.76it/s]

AVG TIME: 0.3705313718771633
LOSS: 6.583523474863273


99it [00:36,  2.71it/s]

AVG TIME: 0.36882089364408244
LOSS: 6.536543183439562





## Heterogenious Naive

In [5]:
model = GPT2LMHeadModel.from_pretrained(model_id, device_map="auto")
lora_model = LoRAModel(model, target_modules=["c_attn"], lora_factory=create_lora_het).cuda()
lora_model.freeze_base_model()

In [6]:
from torch.optim import AdamW
from transformers import get_scheduler

dataloader = DataLoader(dataset, batch_size=8, collate_fn=dataset.collate_fn)

optimizer = AdamW(lora_model.parameters(), lr=2e-4, weight_decay=0)
num_epochs = 1
num_training_steps = num_epochs * len(dataloader)

lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

In [7]:
device = "cuda"

loss_fn = nn.CrossEntropyLoss()

def train_step(data):
    ids, masks, labels, adapter_ids = data
    adapter_ids = adapter_ids + torch.randint_like(adapter_ids, low=0, high=n_adapters // n_datasets - 1) * n_datasets
    adapter_ids %= n_adapters
    logits = lora_model(input_ids=ids.to(device), attention_mask=masks.to(device), adapter_ids=adapter_ids.to(device))[0]
    
    loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1).to('cuda'))
    optimizer.zero_grad()
    loss.backward()

    optimizer.step()  
    lr_scheduler.step()

    return loss.item()

In [8]:
from tqdm import tqdm
from time import time
running_loss = None
alpha = 0.95
start = time()

iters = 0

for epoch in range(num_epochs):
    if iters >= 100:
        break
    for i, batch in tqdm(enumerate(dataloader)):
        loss = train_step(batch)
        if not running_loss:
            running_loss = loss
        else:
            running_loss = running_loss * alpha + loss * (1 - alpha)
        if iters % 20 == 19:
            print("AVG TIME:", (time() - start) / iters)
            print("LOSS:", running_loss)
        iters += 1
        if iters >= 100:
            break

20it [00:12,  1.69it/s]

AVG TIME: 0.6422461835961593
LOSS: 5.944175354099627


40it [00:23,  1.70it/s]

AVG TIME: 0.615334425217066
LOSS: 6.373000084942059


60it [00:35,  1.70it/s]

AVG TIME: 0.6061642695281465
LOSS: 6.475421756737563


80it [00:47,  1.71it/s]

AVG TIME: 0.601677846304978
LOSS: 6.605075527242298


99it [00:59,  1.67it/s]

AVG TIME: 0.5996107573461051
LOSS: 6.576296012283517



