In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from multilora import LoRAModel, MultiLoRALayerMaskingHom, MultiLoRALayerMaskingHomEfficient, MultiLoRALayerMasking, MultiLoRALayerSTK

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from multilora.benchmarking import MultiAdapterDataset, get_bitext_dataset, get_finetome_dataset, get_guanaco_dataset
N = 1000
model_id = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

dataset = MultiAdapterDataset([get_bitext_dataset(N, tokenizer), get_finetome_dataset(N, tokenizer), get_guanaco_dataset(N, tokenizer)], tokenizer)
n_adapters = 99
n_datasets = 3

In [4]:
def create_lora_hom(in_features, out_features, adapter_ids):
    return MultiLoRALayerMaskingHom(in_features, out_features, adapter_ids, n_adapters=n_adapters, rank=32)

def create_lora_hom_eff(in_features, out_features, adapter_ids):
    return MultiLoRALayerMaskingHomEfficient(in_features, out_features, adapter_ids, n_adapters=n_adapters, rank=32)

def create_lora_het(in_features, out_features, adapter_ids):
    return MultiLoRALayerMasking(in_features, out_features, adapter_ids, ranks=[32] * n_adapters)

def create_lora_het_stk(in_features, out_features, adapter_ids):
    return MultiLoRALayerSTK(in_features, out_features, adapter_ids, ranks=[32] * n_adapters)

## Homogenious LoRA Adapters

In [5]:
model = GPT2LMHeadModel.from_pretrained(model_id, device_map="auto").to(dtype=torch.bfloat16)
lora_model = LoRAModel(model, target_modules=["c_attn"], lora_factory=create_lora_hom).cuda().to(torch.bfloat16)
lora_model.freeze_base_model()

In [6]:
from torch.optim import AdamW
from transformers import get_scheduler

dataloader = DataLoader(dataset, batch_size=8, collate_fn=dataset.collate_fn)

optimizer = AdamW(lora_model.parameters(), lr=2e-4, weight_decay=0)
num_epochs = 1
num_training_steps = num_epochs * len(dataloader)

lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

In [7]:
device = "cuda"

loss_fn = nn.CrossEntropyLoss()

def train_step(data):
    ids, masks, labels, adapter_ids = data
    adapter_ids = adapter_ids + torch.randint_like(adapter_ids, low=0, high=n_adapters // n_datasets - 1) * n_datasets
    adapter_ids %= n_adapters
    logits = lora_model(input_ids=ids.to(device), attention_mask=masks.to(device), adapter_ids=adapter_ids.to(device))[0]
    
    loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1).to('cuda'))
    optimizer.zero_grad()
    loss.backward()

    optimizer.step()  
    lr_scheduler.step()

    return loss.item()

In [8]:
from tqdm import tqdm
from time import time
running_loss = None
alpha = 0.95
start = time()

iters = 0

for epoch in range(num_epochs):
    if iters >= 100:
        break
    for i, batch in tqdm(enumerate(dataloader)):
        loss = train_step(batch)
        if not running_loss:
            running_loss = loss
        else:
            running_loss = running_loss * alpha + loss * (1 - alpha)
        if iters % 20 == 19:
            print("AVG TIME:", (time() - start) / iters)
            print("LOSS:", running_loss)
        iters += 1
        if iters >= 100:
            break

0it [00:00, ?it/s]

20it [00:10,  1.97it/s]

AVG TIME: 0.5509742561139559
LOSS: 5.942682955053021


40it [00:20,  1.98it/s]

AVG TIME: 0.5269916424384484
LOSS: 6.373347408739958


60it [00:30,  1.98it/s]

AVG TIME: 0.5192565513869464
LOSS: 6.472410628159011


80it [00:40,  1.98it/s]

AVG TIME: 0.515574219860608
LOSS: 6.601495394466152


99it [00:50,  1.95it/s]

AVG TIME: 0.5131114010859017
LOSS: 6.56423440711488





## Homogenious Efficient

In [5]:
model = GPT2LMHeadModel.from_pretrained(model_id, device_map="auto").to(dtype=torch.bfloat16)
lora_model = LoRAModel(model, target_modules=["c_attn"], lora_factory=create_lora_hom_eff).cuda().to(torch.bfloat16)
lora_model.freeze_base_model()

In [6]:
from torch.optim import AdamW
from transformers import get_scheduler

dataloader = DataLoader(dataset, batch_size=8, collate_fn=dataset.collate_fn)

optimizer = AdamW(lora_model.parameters(), lr=2e-4, weight_decay=0)
num_epochs = 1
num_training_steps = num_epochs * len(dataloader)

lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

In [7]:
device = "cuda"

loss_fn = nn.CrossEntropyLoss()

def train_step(data):
    ids, masks, labels, adapter_ids = data
    adapter_ids = adapter_ids + torch.randint_like(adapter_ids, low=0, high=n_adapters // n_datasets - 1) * n_datasets
    adapter_ids %= n_adapters
    logits = lora_model(input_ids=ids.to(device), attention_mask=masks.to(device), adapter_ids=adapter_ids.to(device))[0]
    
    loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1).to('cuda'))
    optimizer.zero_grad()
    loss.backward()

    optimizer.step()  
    lr_scheduler.step()

    return loss.item()

In [8]:
from tqdm import tqdm
from time import time
running_loss = None
alpha = 0.95
start = time()

iters = 0

for epoch in range(num_epochs):
    if iters >= 100:
        break
    for i, batch in tqdm(enumerate(dataloader)):
        loss = train_step(batch)
        if not running_loss:
            running_loss = loss
        else:
            running_loss = running_loss * alpha + loss * (1 - alpha)
        if iters % 20 == 19:
            print("AVG TIME:", (time() - start) / iters)
            print("LOSS:", running_loss)
        iters += 1
        if iters >= 100:
            break

21it [00:02, 10.69it/s]

AVG TIME: 0.11871823511625591
LOSS: 5.942794941708233


41it [00:04, 10.94it/s]

AVG TIME: 0.10463229203835511
LOSS: 6.374149546534218


61it [00:06, 10.84it/s]

AVG TIME: 0.10051029415453895
LOSS: 6.475503611293823


81it [00:07, 10.90it/s]

AVG TIME: 0.09805604777758635
LOSS: 6.603067205769785


99it [00:09, 10.31it/s]

AVG TIME: 0.09706232764504173
LOSS: 6.566872642360667





## Heterogenious Naive

In [5]:
model = GPT2LMHeadModel.from_pretrained(model_id, device_map="auto").to(dtype=torch.bfloat16)
lora_model = LoRAModel(model, target_modules=["c_attn"], lora_factory=create_lora_het).cuda().to(torch.bfloat16)
lora_model.freeze_base_model()

In [6]:
from torch.optim import AdamW
from transformers import get_scheduler

dataloader = DataLoader(dataset, batch_size=8, collate_fn=dataset.collate_fn)

optimizer = AdamW(lora_model.parameters(), lr=2e-4, weight_decay=0)
num_epochs = 1
num_training_steps = num_epochs * len(dataloader)

lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

In [7]:
device = "cuda"

loss_fn = nn.CrossEntropyLoss()

def train_step(data):
    ids, masks, labels, adapter_ids = data
    adapter_ids = adapter_ids + torch.randint_like(adapter_ids, low=0, high=n_adapters // n_datasets - 1) * n_datasets
    adapter_ids %= n_adapters
    logits = lora_model(input_ids=ids.to(device), attention_mask=masks.to(device), adapter_ids=adapter_ids.to(device))[0]
    
    loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1).to('cuda'))
    optimizer.zero_grad()
    loss.backward()

    optimizer.step()  
    lr_scheduler.step()

    return loss.item()

In [8]:
from tqdm import tqdm
from time import time
running_loss = None
alpha = 0.95
start = time()

iters = 0

for epoch in range(num_epochs):
    if iters >= 100:
        break
    for i, batch in tqdm(enumerate(dataloader)):
        loss = train_step(batch)
        if not running_loss:
            running_loss = loss
        else:
            running_loss = running_loss * alpha + loss * (1 - alpha)
        if iters % 20 == 19:
            print("AVG TIME:", (time() - start) / iters)
            print("LOSS:", running_loss)
        iters += 1
        if iters >= 100:
            break

20it [00:06,  3.12it/s]

AVG TIME: 0.35765408214769867
LOSS: 5.941646360629132


40it [00:13,  3.12it/s]

AVG TIME: 0.3376512771997696
LOSS: 6.375575699811821


60it [00:19,  3.19it/s]

AVG TIME: 0.33037236989554714
LOSS: 6.478838890604542


80it [00:25,  3.13it/s]

AVG TIME: 0.3272677886335156
LOSS: 6.610806591793325


99it [00:32,  3.07it/s]

AVG TIME: 0.32538832317699085
LOSS: 6.583812531106318





## Heterogenious MegaBlocks

In [5]:
model = GPT2LMHeadModel.from_pretrained(model_id, device_map="auto").to(dtype=torch.bfloat16)
lora_model = LoRAModel(model, target_modules=["c_attn"], lora_factory=create_lora_het_stk).cuda()
lora_model.freeze_base_model()

In [6]:
from torch.optim import AdamW
from transformers import get_scheduler

dataloader = DataLoader(dataset, batch_size=8, collate_fn=dataset.collate_fn)

optimizer = AdamW(lora_model.parameters(), lr=2e-4, weight_decay=0)
num_epochs = 1
num_training_steps = num_epochs * len(dataloader)

lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

In [7]:
device = "cuda"

loss_fn = nn.CrossEntropyLoss()

def train_step(data):
    ids, masks, labels, adapter_ids = data
    adapter_ids = adapter_ids + torch.randint_like(adapter_ids, low=0, high=n_adapters // n_datasets - 1) * n_datasets
    adapter_ids %= n_adapters
    logits = lora_model(input_ids=ids.to(device), attention_mask=masks.to(device), adapter_ids=adapter_ids.to(device))[0]
    
    loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1).to('cuda'))
    optimizer.zero_grad()
    loss.backward()

    optimizer.step()  
    lr_scheduler.step()

    return loss.item()

In [8]:
from tqdm import tqdm
from time import time
running_loss = None
alpha = 0.95
start = time()

iters = 0

for epoch in range(num_epochs):
    if iters >= 100:
        break
    for i, batch in tqdm(enumerate(dataloader)):
        loss = train_step(batch)
        if not running_loss:
            running_loss = loss
        else:
            running_loss = running_loss * alpha + loss * (1 - alpha)
        if iters % 20 == 19:
            print("AVG TIME:", (time() - start) / iters)
            print("LOSS:", running_loss)
        iters += 1
        if iters >= 100:
            break

21it [00:03,  7.65it/s]

AVG TIME: 0.18297824106718363
LOSS: 5.942448457634257


41it [00:06,  7.76it/s]

AVG TIME: 0.1562349184965476
LOSS: 6.372899227386387


61it [00:08,  7.69it/s]

AVG TIME: 0.1468204360897258
LOSS: 6.472101177471724


81it [00:11,  7.90it/s]

AVG TIME: 0.14184650288352482
LOSS: 6.599937714595329


99it [00:13,  7.18it/s]

AVG TIME: 0.13927348695620143
LOSS: 6.564399818292757



