In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from multilora import LoRAModel, MultiLoRALayerMaskingHom, MultiLoRALayerMaskingHomEfficient

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from multilora.benchmarking import MultiAdapterDataset, get_bitext_dataset, get_finetome_dataset, get_guanaco_dataset
N = 1000
model_id = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

dataset = MultiAdapterDataset([get_bitext_dataset(N, tokenizer), get_finetome_dataset(N, tokenizer), get_guanaco_dataset(N, tokenizer)], tokenizer)
n_adapters = 99
n_datasets = 3

In [4]:
def create_lora_hom(in_features, out_features, adapter_ids):
    return MultiLoRALayerMaskingHom(in_features, out_features, adapter_ids, n_adapters=n_adapters, rank=32)

def create_lora_hom_eff(in_features, out_features, adapter_ids):
    return MultiLoRALayerMaskingHomEfficient(in_features, out_features, adapter_ids, n_adapters=n_adapters, rank=32)

## Homogenious LoRA Adapters

In [5]:
model = GPT2LMHeadModel.from_pretrained(model_id, device_map="auto")
lora_model = LoRAModel(model, target_modules=["c_attn"], lora_factory=create_lora_hom).cuda()
lora_model.freeze_base_model()

In [6]:
from torch.optim import AdamW
from transformers import get_scheduler

dataloader = DataLoader(dataset, batch_size=8, collate_fn=dataset.collate_fn)

optimizer = AdamW(lora_model.parameters(), lr=2e-4, weight_decay=0)
num_epochs = 1
num_training_steps = num_epochs * len(dataloader)

lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

In [7]:
device = "cuda"

loss_fn = nn.CrossEntropyLoss()

def train_step(data):
    ids, masks, labels, adapter_ids = data
    adapter_ids = adapter_ids + torch.randint_like(adapter_ids, low=0, high=n_adapters // n_datasets - 1) * n_datasets
    adapter_ids %= n_adapters
    logits = lora_model(input_ids=ids.to(device), attention_mask=masks.to(device), adapter_ids=adapter_ids.to(device))[0]
    
    loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1).to('cuda'))
    optimizer.zero_grad()
    loss.backward()

    optimizer.step()  
    lr_scheduler.step()

    return loss.item()

In [None]:
from tqdm import tqdm
from time import time
running_loss = None
alpha = 0.95
start = time()

iters = 0

for epoch in range(num_epochs):
    if iters >= 100:
        break
    for i, batch in tqdm(enumerate(dataloader)):
        loss = train_step(batch)
        if not running_loss:
            running_loss = loss
        else:
            running_loss = running_loss * alpha + loss * (1 - alpha)
        if iters % 20 == 19:
            print("AVG TIME:", (time() - start) / iters)
            print("LOSS:", running_loss)
        iters += 1

20it [00:46,  2.32s/it]

AVG TIME: 2.4418084119495593
LOSS: 5.937821571017335


40it [01:32,  2.33s/it]

AVG TIME: 2.381277750699948
LOSS: 6.3602954647132925


60it [02:19,  2.32s/it]

AVG TIME: 2.3619285357200495
LOSS: 6.452357551625577


80it [03:05,  2.32s/it]

AVG TIME: 2.352299222463294
LOSS: 6.560008989069966


100it [03:52,  2.32s/it]

AVG TIME: 2.346536655618687
LOSS: 6.504811457670055


105it [04:06,  2.34s/it]


KeyboardInterrupt: 

## Homogenious Efficient

In [None]:
model = GPT2LMHeadModel.from_pretrained(model_id, device_map="auto")
lora_model = LoRAModel(model, target_modules=["c_attn"], lora_factory=create_lora_hom_eff).cuda()
lora_model.freeze_base_model()

In [6]:
from torch.optim import AdamW
from transformers import get_scheduler

dataloader = DataLoader(dataset, batch_size=8, collate_fn=dataset.collate_fn)

optimizer = AdamW(lora_model.parameters(), lr=2e-4, weight_decay=0)
num_epochs = 1
num_training_steps = num_epochs * len(dataloader)

lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

In [7]:
device = "cuda"

loss_fn = nn.CrossEntropyLoss()

def train_step(data):
    ids, masks, labels, adapter_ids = data
    adapter_ids = adapter_ids + torch.randint_like(adapter_ids, low=0, high=n_adapters // n_datasets - 1) * n_datasets
    adapter_ids %= n_adapters
    logits = lora_model(input_ids=ids.to(device), attention_mask=masks.to(device), adapter_ids=adapter_ids.to(device))[0]
    
    loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1).to('cuda'))
    optimizer.zero_grad()
    loss.backward()

    optimizer.step()  
    lr_scheduler.step()

    return loss.item()

In [8]:
from tqdm import tqdm
from time import time
running_loss = None
alpha = 0.95
start = time()

iters = 0

for epoch in range(num_epochs):
    if iters >= 100:
        break
    for i, batch in tqdm(enumerate(dataloader)):
        loss = train_step(batch)
        if not running_loss:
            running_loss = loss
        else:
            running_loss = running_loss * alpha + loss * (1 - alpha)
        if iters % 20 == 19:
            print("AVG TIME:", (time() - start) / iters)
            print("LOSS:", running_loss)
        iters += 1

20it [00:12,  1.62it/s]

AVG TIME: 0.6696739573227731
LOSS: 5.943526124257177


40it [00:25,  1.61it/s]

AVG TIME: 0.6452094897245749
LOSS: 6.3694197912188795


60it [00:37,  1.61it/s]

AVG TIME: 0.6371923664868888
LOSS: 6.4645828136119885


80it [00:50,  1.61it/s]

AVG TIME: 0.6331507525866544
LOSS: 6.584524753180639


100it [01:02,  1.61it/s]

AVG TIME: 0.6308055598326404
LOSS: 6.534201938608992


120it [01:14,  1.61it/s]

AVG TIME: 0.6292342398346973
LOSS: 6.504324700386215


140it [01:27,  1.61it/s]

AVG TIME: 0.6282946583178404
LOSS: 6.422906713234135


160it [01:39,  1.61it/s]

AVG TIME: 0.6274771075578606
LOSS: 6.527804858080267


180it [01:52,  1.61it/s]

AVG TIME: 0.6267301436909084
LOSS: 6.473699284115535


200it [02:04,  1.61it/s]

AVG TIME: 0.6262224595151354
LOSS: 6.569186793815509


220it [02:17,  1.60it/s]

AVG TIME: 0.6259730308567553
LOSS: 6.344581522823477


240it [02:29,  1.61it/s]

AVG TIME: 0.6255208198994273
LOSS: 6.403301445888076


260it [02:41,  1.61it/s]

AVG TIME: 0.6252444259908668
LOSS: 6.34719771275639


280it [02:54,  1.61it/s]

AVG TIME: 0.6248547090851705
LOSS: 6.36092244311708


300it [03:06,  1.61it/s]

AVG TIME: 0.624578987874315
LOSS: 6.346212910387776


320it [03:19,  1.61it/s]

AVG TIME: 0.6243348405652659
LOSS: 6.233039415371515


340it [03:31,  1.61it/s]

AVG TIME: 0.6240777491116594
LOSS: 6.331039570982172


360it [03:43,  1.61it/s]

AVG TIME: 0.6238408865702849
LOSS: 6.309516063228866


375it [03:53,  1.61it/s]
