In [1]:
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training import artifacts
from onnxruntime import InferenceSession
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx
import io

## LoRA Layer

In [11]:
class LoRALayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        rank = 2
        alpha = 4 
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x

## LoRA combined with the linear layer

In [12]:
class LinearWithLoRA(torch.nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features
        )

    def forward(self, x):
        return self.linear(x) + self.lora(x)

### Single Layer Neural Network

In [13]:
# Hyperparameters
random_seed=123

torch.manual_seed(random_seed)
layer=nn.Linear(10,2)
x=torch.randn((1, 10))

print(x)
print(layer)
print('Original output:', layer(x))

tensor([[ 0.5490,  0.3671,  0.1219,  0.6466, -1.4168,  0.8429, -0.6307,  1.2340,
          0.3127,  0.6972]])
Linear(in_features=10, out_features=2, bias=True)
Original output: tensor([[0.6639, 0.4487]], grad_fn=<AddmmBackward0>)


### Applying LoRA to linear layer

In [14]:
layer_lora=LinearWithLoRA(layer)
print(layer_lora(x))

tensor([[0.6639, 0.4487]], grad_fn=<AddBackward0>)


In [17]:
layer_lora=LinearWithLoRA(layer)
model_outputs = layer_lora(x)
if isinstance(model_outputs, torch.Tensor):
    model_outputs = [model_outputs]
    
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}

In [18]:
f = io.BytesIO()
torch.onnx.export(
    layer_lora,
    x,
    f,
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)
onnx_model = onnx.load_model_from_string(f.getvalue())

In [20]:
requires_grad = [name for name, param in layer_lora.named_parameters() if param.requires_grad]

frozen_params = [name for name, param in layer_lora.named_parameters() if not param.requires_grad]

artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    #loss=artifacts.LossType.CrossEntropyLoss, #Specify the loss function, try with different ones
    loss=artifacts.LossType.MSELoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory="Lora",
    additional_output_names=["output"])

In [21]:
model = onnx.load("Lora/training_model.onnx")
print('Model :\n\n{}'.format(onnx.helper.printable_graph(model.graph)))

Model :

graph main_graph (
  %input[FLOAT, batch_sizex10]
  %target[FLOAT, batch_sizex2]
  %linear.weight[FLOAT, 2x10]
  %linear.bias[FLOAT, 2]
  %lora.A[FLOAT, 10x2]
  %lora.B[FLOAT, 2x2]
  %linear.weight_grad.accumulation.buffer[FLOAT, 2x10]
  %linear.bias_grad.accumulation.buffer[FLOAT, 2]
  %lora.A_grad.accumulation.buffer[FLOAT, 10x2]
  %lora.B_grad.accumulation.buffer[FLOAT, 2x2]
  %lazy_reset_grad[BOOL, 1]
) initializers (
  %onnx::pow_exponent::3[FLOAT, 1]
  %/lora/Constant_output_0[FLOAT, scalar]
  %onnx::reducemean_output::6_grad[FLOAT, scalar]
  %/linear/Gemm_Grad/ReduceAxes_for_/linear/Gemm_Grad/dC_reduced[INT64, 1]
  %OneConstant_Type1[FLOAT, 1]
) {
  %/lora/MatMul_output_0 = MatMul(%input, %lora.A)
  %/lora/MatMul_1_output_0 = MatMul(%/lora/MatMul_output_0, %lora.B)
  %/lora/Mul_output_0 = Mul(%/lora/MatMul_1_output_0, %/lora/Constant_output_0)
  %/linear/Gemm_output_0 = Gemm[alpha = 1, beta = 1, transA = 0, transB = 1](%input, %linear.weight, %linear.bias)
  %output = A

In [22]:
print(requires_grad)

['linear.weight', 'linear.bias', 'lora.A', 'lora.B']


In [23]:
print(frozen_params)

[]


## SimpleNet with LoRA

In [29]:
# Pytorch class that we will use to generate the graphs.
class Net(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(Net, self).__init__()

        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_size, num_classes)

    def forward(self, model_input):
        out = self.fc1(model_input)
        out = self.relu(out)
        out = self.fc2(out)
        return out

In [30]:
# Create an instance.
device = "cpu"
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
pt_model = Net(input_size, hidden_size, output_size).to(device)

In [38]:
# This LoRA code is equivalent to LinearWithLoRA
class LinearWithLoRAMerged(nn.Module):
    def __init__(self, linear):
        super().__init__()
        rank = 2
        alpha = 4 
        self.linear=linear
        self.lora=LoRALayer(
            linear.in_features, linear.out_features
        )
    
    def forward(self, x):
        lora=self.lora.A @ self.lora.B # combine LoRA metrices
        # then combine LoRA original weights
        combined_weight=self.linear.weight+self.lora.alpha*lora.T
        return F.linear(x, combined_weight, self.linear.bias)

In [34]:
import copy
from functools import partial

In [33]:
model_lora=copy.deepcopy(pt_model)
print(model_lora)

Net(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)


In [39]:
model_lora.fc1=LinearWithLoRAMerged(model_lora.fc1)
model_lora.fc2=LinearWithLoRAMerged(model_lora.fc2)

In [40]:
print(model_lora)

Net(
  (fc1): LinearWithLoRAMerged(
    (linear): Linear(in_features=784, out_features=500, bias=True)
    (lora): LoRALayer()
  )
  (relu): ReLU()
  (fc2): LinearWithLoRAMerged(
    (linear): Linear(in_features=500, out_features=10, bias=True)
    (lora): LoRALayer()
  )
)


In [50]:
# Generate a random input.
model_inputs = (torch.randn(batch_size, input_size, device=device),)

model_outputs = model_lora(*model_inputs)
if isinstance(model_outputs, torch.Tensor):
    model_outputs = [model_outputs]
    
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}

f = io.BytesIO()
torch.onnx.export(
    model_lora,
    model_inputs,
    f,
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)
onnx_model = onnx.load_model_from_string(f.getvalue())

In [51]:
requires_grad = [name for name, param in model_lora.named_parameters() if param.requires_grad]

frozen_params = [name for name, param in model_lora.named_parameters() if not param.requires_grad]

artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    #loss=artifacts.LossType.CrossEntropyLoss, #Specify the loss function, try with different ones
    loss=artifacts.LossType.MSELoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory="SimpleNet_Lora",
    additional_output_names=["output"])

In [52]:
model = onnx.load("SimpleNET_Lora/training_model.onnx")
print('Model :\n\n{}'.format(onnx.helper.printable_graph(model.graph)))

Model :

graph main_graph (
  %input[FLOAT, batch_sizex784]
  %target[FLOAT, batch_sizex10]
  %fc1.linear.weight[FLOAT, 500x784]
  %fc1.linear.bias[FLOAT, 500]
  %fc1.lora.A[FLOAT, 784x2]
  %fc1.lora.B[FLOAT, 2x500]
  %fc2.linear.weight[FLOAT, 10x500]
  %fc2.linear.bias[FLOAT, 10]
  %fc2.lora.A[FLOAT, 500x2]
  %fc2.lora.B[FLOAT, 2x10]
  %fc1.linear.weight_grad.accumulation.buffer[FLOAT, 500x784]
  %fc1.linear.bias_grad.accumulation.buffer[FLOAT, 500]
  %fc1.lora.A_grad.accumulation.buffer[FLOAT, 784x2]
  %fc1.lora.B_grad.accumulation.buffer[FLOAT, 2x500]
  %fc2.linear.weight_grad.accumulation.buffer[FLOAT, 10x500]
  %fc2.linear.bias_grad.accumulation.buffer[FLOAT, 10]
  %fc2.lora.A_grad.accumulation.buffer[FLOAT, 500x2]
  %fc2.lora.B_grad.accumulation.buffer[FLOAT, 2x10]
  %lazy_reset_grad[BOOL, 1]
) initializers (
  %onnx::pow_exponent::37[FLOAT, 1]
  %/fc1/Constant_output_0[FLOAT, scalar]
  %onnx::reducemean_output::40_grad[FLOAT, scalar]
  %/fc1/Gemm_Grad/ReduceAxes_for_/fc1/Gemm_Gr

In [53]:
print(requires_grad)

['fc1.linear.weight', 'fc1.linear.bias', 'fc1.lora.A', 'fc1.lora.B', 'fc2.linear.weight', 'fc2.linear.bias', 'fc2.lora.A', 'fc2.lora.B']


In [48]:
print(frozen_params)

[]


## Activating only the LoRA layers

In [54]:
def freeze_linear_layers(model):
    for child in model.children():
        if isinstance(child, nn.Linear):
            for param in child.parameters():
                param.requires_grad=False
        else:
            # recursively freeze linear layers in children modules
            freeze_linear_layers(child)

freeze_linear_layers(model_lora)
for name, param in model_lora.named_parameters():
    print(f'{name}:{param.requires_grad}')

fc1.linear.weight:False
fc1.linear.bias:False
fc1.lora.A:True
fc1.lora.B:True
fc2.linear.weight:False
fc2.linear.bias:False
fc2.lora.A:True
fc2.lora.B:True


In [55]:
requires_grad = [name for name, param in model_lora.named_parameters() if param.requires_grad]

frozen_params = [name for name, param in model_lora.named_parameters() if not param.requires_grad]

artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    #loss=artifacts.LossType.CrossEntropyLoss, #Specify the loss function, try with different ones
    loss=artifacts.LossType.MSELoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory="SimpleNet_Lora",
    additional_output_names=["output"])

In [56]:
print(requires_grad)

['fc1.lora.A', 'fc1.lora.B', 'fc2.lora.A', 'fc2.lora.B']


In [57]:
model = onnx.load("SimpleNET_Lora/training_model.onnx")
print('Model :\n\n{}'.format(onnx.helper.printable_graph(model.graph)))

Model :

graph main_graph (
  %input[FLOAT, batch_sizex784]
  %target[FLOAT, batch_sizex10]
  %fc1.linear.weight[FLOAT, 500x784]
  %fc1.linear.bias[FLOAT, 500]
  %fc1.lora.A[FLOAT, 784x2]
  %fc1.lora.B[FLOAT, 2x500]
  %fc2.linear.weight[FLOAT, 10x500]
  %fc2.linear.bias[FLOAT, 10]
  %fc2.lora.A[FLOAT, 500x2]
  %fc2.lora.B[FLOAT, 2x10]
  %fc1.lora.A_grad.accumulation.buffer[FLOAT, 784x2]
  %fc1.lora.B_grad.accumulation.buffer[FLOAT, 2x500]
  %fc2.lora.A_grad.accumulation.buffer[FLOAT, 500x2]
  %fc2.lora.B_grad.accumulation.buffer[FLOAT, 2x10]
  %lazy_reset_grad[BOOL, 1]
) initializers (
  %onnx::pow_exponent::46[FLOAT, 1]
  %/fc1/Constant_output_0[FLOAT, scalar]
  %onnx::reducemean_output::49_grad[FLOAT, scalar]
  %OneConstant_Type1[FLOAT, 1]
) {
  %/fc1/MatMul_output_0 = MatMul(%fc1.lora.A, %fc1.lora.B)
  %/fc1/Transpose_output_0 = Transpose[perm = [1, 0]](%/fc1/MatMul_output_0)
  %/fc1/Mul_output_0 = Mul(%/fc1/Transpose_output_0, %/fc1/Constant_output_0)
  %/fc1/Add_output_0 = Add(%fc