# Step 1: Define trainer without LoRA to train the Base model

In [1]:
import marveltoolbox as mt 
import torch
import torch.nn as nn
import torch.nn.functional as F

class Confs(mt.BaseConfs):
    def __init__(self):
        super().__init__()
    
    def get_dataset(self):
        self.dataset = 'mnist'
        self.nc = 1
        self.nz = 10
        self.batch_size = 128
        self.epochs = 50
        self.seed = 0

    def get_flag(self):
        self.flag = 'demo-{}-clf'.format(self.dataset)

    def get_device(self):
        self.device_ids = [0]
        self.ngpu = len(self.device_ids)
        self.device = torch.device(
            "cuda:{}".format(self.device_ids[0]) if \
            (torch.cuda.is_available() and self.ngpu > 0) else "mps")


class Trainer(mt.BaseTrainer, Confs):
    def __init__(self):
        Confs.__init__(self)
        mt.BaseTrainer.__init__(self, self)
        
        self.models['C'] = mt.nn.dcgan.Enet32(self.nc, self.nz).to(self.device) 

        self.optims['C'] = torch.optim.Adam(
            self.models['C'].parameters(), lr=1e-4, betas=(0.5, 0.99))
        
        self.train_loader, self.val_loader, self.test_loader, _ = \
            mt.datasets.load_data(self.dataset, 1.0, 0.8, self.batch_size, 32, None, False)

        self.records['acc'] = 0.0

    def train(self, epoch):
        self.models['C'].train()
        for i, (x, y) in enumerate(self.train_loader):
            x, y = x.to(self.device), y.to(self.device)
            scores = self.models['C'](x)
            loss = F.cross_entropy(scores, y)
            self.optims['C'].zero_grad()
            loss.backward()
            self.optims['C'].step()
            if i % 100 == 0:
                self.logs['Train Loss'] = loss.item()
                self.print_logs(epoch, i)

        return loss.item()
                
    def eval(self, epoch):
        self.models['C'].eval()
        correct = 0.0
        with torch.no_grad():
            for x, y in self.val_loader:
                x, y = x.to(self.device), y.to(self.device)
                N = len(x)
                scores = self.models['C'](x)
                pred_y = torch.argmax(scores, dim=1)
                correct += torch.sum(pred_y == y).item()
        N = len(self.val_loader.dataset)
        acc = correct / N
        is_best = False
        if acc >= self.records['acc']:
            is_best = True
            self.records['acc'] = acc
        print('acc: {}'.format(acc))
        return is_best


if __name__ == '__main__':
    trainer_base = Trainer()
    trainer_base.run(load_best=False, retrain=False)

Configs:
Flag:       demo-mnist-clf
Batch size: 128
Epochs:     50
device:     cuda:0

image range [0, 1]
Set random seed to: 0
Log file save at:  ./logs/demo-mnist-clf.log
=> loading checkpoint './chkpts/checkpoint_demo-mnist-clf.pth.tar'
=> loaded checkpoint (epoch 50)


# Step2: Define a LoRA trainer that takes a base model as input and injects LoRA weights.

In [2]:
import marveltoolbox as mt 
import torch
import torch.nn as nn
import torch.nn.functional as F


class LoRAConfs(mt.BaseConfs):
    def __init__(self):
        super().__init__()
    
    def get_dataset(self):
        self.dataset = 'mnist'
        self.nc = 1
        self.nz = 10
        self.batch_size = 128
        self.epochs = 50
        self.seed = 0

    def get_flag(self):
        self.flag = 'demo-{}-clf-lora'.format(self.dataset)

    def get_device(self):
        self.device_ids = [0]
        self.ngpu = len(self.device_ids)
        self.device = torch.device(
            "cuda:{}".format(self.device_ids[0]) if \
            (torch.cuda.is_available() and self.ngpu > 0) else "mps")


class LoRATrainer(mt.BaseTrainer, LoRAConfs):
    def __init__(self, BaseModel):
        LoRAConfs.__init__(self)
        mt.BaseTrainer.__init__(self, self)
        
        # To inject LoRA (Low-Rank Adaptation) into a regular ⁠nn.Module using the ⁠inject_lora function
        self.models['C'] = mt.utils.inject_lora(BaseModel, r=2, alpha=1).to(self.device) 

        # Extract the trainable parameters of the LoRA module, then pass it to the optimizer.
        self.optims['C'] = torch.optim.Adam(
            mt.utils.get_lora_parameters(self.models['C']), lr=1e-4, betas=(0.5, 0.99))
        
        self.train_loader, self.val_loader, self.test_loader, _ = \
            mt.datasets.load_data(self.dataset, 1.0, 0.8, self.batch_size, 32, None, False)

        self.records['acc'] = 0.0

    def train(self, epoch):
        self.models['C'].train()
        for i, (x, y) in enumerate(self.train_loader):
            x, y = x.to(self.device), y.to(self.device)
            scores = self.models['C'](x)
            loss = F.cross_entropy(scores, y)
            self.optims['C'].zero_grad()
            loss.backward()
            self.optims['C'].step()
            if i % 100 == 0:
                self.logs['Train Loss'] = loss.item()
                self.print_logs(epoch, i)

        return loss.item()
                
    def eval(self, epoch):
        self.models['C'].eval()
        correct = 0.0
        with torch.no_grad():
            for x, y in self.val_loader:
                x, y = x.to(self.device), y.to(self.device)
                N = len(x)
                scores = self.models['C'](x)
                pred_y = torch.argmax(scores, dim=1)
                correct += torch.sum(pred_y == y).item()
        N = len(self.val_loader.dataset)
        acc = correct / N
        is_best = False
        if acc >= self.records['acc']:
            is_best = True
            self.records['acc'] = acc
        print('acc: {}'.format(acc))
        return is_best


if __name__ == '__main__':
    trainer_lora = LoRATrainer(trainer_base.models['C']) # Feed the base model to LoRATrainer
    trainer_lora.run(load_best=False, retrain=False)

Configs:
Flag:       demo-mnist-clf-lora
Batch size: 128
Epochs:     50
device:     cuda:0

image range [0, 1]
Set random seed to: 0
Log file save at:  ./logs/demo-mnist-clf-lora.log
=> loading checkpoint './chkpts/checkpoint_demo-mnist-clf-lora.pth.tar'
=> loaded checkpoint (epoch 50)


# Step3 Extract，Inject or Merge existing LoRA

In [4]:
# Load and extract existing LoRA
lora = mt.utils.extract_lora_state_dict(trainer_lora.models['C']) 

# Inject LoRA into the new model.
new_trainer_base = Trainer()
new_trainer_base.models['C'] = mt.utils.inject_lora(new_trainer_base.models['C'], r=2, alpha=1, lora_state_dict=lora)
print(new_trainer_base.models['C'])

Configs:
Flag:       demo-mnist-clf
Batch size: 128
Epochs:     50
device:     cuda:0

image range [0, 1]
Enet32(
  (main): Sequential(
    (0): NormalizedModel()
    (1): LoRAConv2d(1, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), dilation=(1, 1), groups=1, r=2, alpha=1)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): LoRAConv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), dilation=(1, 1), groups=1, r=2, alpha=1)
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): LoRAConv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), dilation=(1, 1), groups=1, r=2, alpha=1)
    (7): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (output): Sequential(
    (0): LoRAConv2d(1024, 128, kernel_size=(4, 4), stride=(1, 1), padding=(0, 0), dilation=(1, 1), group

In [5]:
# Merge LoRA to origial weight.
new_trainer_base.models['C'] = mt.utils.merge_lora_weights(new_trainer_base.models['C'])
print(new_trainer_base.models['C'])

Enet32(
  (main): Sequential(
    (0): NormalizedModel()
    (1): Conv2d(1, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (output): Sequential(
    (0): Conv2d(1024, 128, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Conv2d(128, 20, kernel_size=(1, 1), stride=(1, 1))
  )
)
