In [1]:
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training import artifacts
from onnxruntime import InferenceSession
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx
import io

## Create the Transformer (Attention + FFN) block : declare Q,K,V as linear layers

In [2]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads,dropout=0.1):
        super(TransformerBlock, self).__init__()
        d_in = 32
        d_out =32
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out,bias=False)
        self.W_key   = nn.Linear(d_in, d_out,bias=False)
        self.W_value = nn.Linear(d_in, d_out,bias=False)
        #self.W_out = nn.Linear(d_in,d_out,bias=False)
        self.dropout = nn.Dropout(dropout) # New
        #self.attention = nn.MultiheadAttention(embed_size, heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, 4 * embed_size),
            nn.ReLU(),
            nn.Linear(4 * embed_size, embed_size)
        )
        self.layer_norm1 = nn.LayerNorm(embed_size)
        self.layer_norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Multi-head self-attention
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        #out = self.W_out(x)

        attn_scores = queries @ keys.T # Changed transpose
        #attn_scores.masked_fill_(  # New, _ ops are in-place
        #    self.mask.bool()[:n_tokens, :n_tokens], -torch.inf) 
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights) # New

        attn_output = attn_weights @ values
        #return context_vec

        # Add and norm
        out1 = self.layer_norm1(x + attn_output)
        
        # Feed forward network
        ff_output = self.feed_forward(out1)
        ff_output = self.dropout(ff_output)
        # Add and norm
        out2 = self.layer_norm2(out1 + ff_output)
        
        return out2


In [3]:
# Example usage of TransformerBlock
#input_tensor = torch.randn(10, 32, 512)  # (sequence_length, batch_size, embed_size).
#input_tensor = torch.randn(10, 32, 32) 
inputs = torch.rand(6,32)
transformer_block = TransformerBlock(embed_size=32, heads=8)
output_tensor = transformer_block(inputs)
print(output_tensor.shape)  # should print: torch.Size([10, 32, 512])
print(transformer_block)

torch.Size([6, 32])
TransformerBlock(
  (W_query): Linear(in_features=32, out_features=32, bias=False)
  (W_key): Linear(in_features=32, out_features=32, bias=False)
  (W_value): Linear(in_features=32, out_features=32, bias=False)
  (dropout): Dropout(p=0.1, inplace=False)
  (feed_forward): Sequential(
    (0): Linear(in_features=32, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=32, bias=True)
  )
  (layer_norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (layer_norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)


## Lora Layers

In [4]:
class LoRALayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        rank = 2
        alpha = 4 
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x

In [5]:
class LinearWithLoRA(torch.nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features
        )

    def forward(self, x):
        return self.linear(x) + self.lora(x)

In [6]:
# This LoRA code is equivalent to LinearWithLoRA
class LinearWithLoRAMerged(nn.Module):
    def __init__(self, linear):
        super().__init__()
        rank = 2
        alpha = 4 
        self.linear=linear
        self.lora=LoRALayer(
            linear.in_features, linear.out_features
        )
    
    def forward(self, x):
        lora=self.lora.A @ self.lora.B # combine LoRA metrices
        # then combine LoRA original weights
        combined_weight=self.linear.weight+self.lora.alpha*lora.T
        return F.linear(x, combined_weight, self.linear.bias)

## Add Lora layers to the transformer : Q & V

In [7]:
import copy
from functools import partial

In [8]:
transformer_lora = copy.deepcopy(transformer_block)
print(transformer_lora)

TransformerBlock(
  (W_query): Linear(in_features=32, out_features=32, bias=False)
  (W_key): Linear(in_features=32, out_features=32, bias=False)
  (W_value): Linear(in_features=32, out_features=32, bias=False)
  (dropout): Dropout(p=0.1, inplace=False)
  (feed_forward): Sequential(
    (0): Linear(in_features=32, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=32, bias=True)
  )
  (layer_norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (layer_norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)


In [9]:
transformer_lora.W_query=LinearWithLoRAMerged(transformer_lora.W_query)
transformer_lora.W_value=LinearWithLoRAMerged(transformer_lora.W_value)

In [10]:
print(transformer_lora)

TransformerBlock(
  (W_query): LinearWithLoRAMerged(
    (linear): Linear(in_features=32, out_features=32, bias=False)
    (lora): LoRALayer()
  )
  (W_key): Linear(in_features=32, out_features=32, bias=False)
  (W_value): LinearWithLoRAMerged(
    (linear): Linear(in_features=32, out_features=32, bias=False)
    (lora): LoRALayer()
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (feed_forward): Sequential(
    (0): Linear(in_features=32, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=32, bias=True)
  )
  (layer_norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (layer_norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)


## Create ONNX graph

In [11]:
model_outputs = transformer_lora(inputs)
if isinstance(model_outputs, torch.Tensor):
    model_outputs = [model_outputs]
    
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}

In [12]:
f = io.BytesIO()
torch.onnx.export(
    transformer_lora,
    inputs,
    f,
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)
onnx_model = onnx.load_model_from_string(f.getvalue())

In [13]:
requires_grad = [name for name, param in transformer_lora.named_parameters() if param.requires_grad]

frozen_params = [name for name, param in transformer_lora.named_parameters() if not param.requires_grad]

In [14]:
print(requires_grad)

['W_query.linear.weight', 'W_query.lora.A', 'W_query.lora.B', 'W_key.weight', 'W_value.linear.weight', 'W_value.lora.A', 'W_value.lora.B', 'feed_forward.0.weight', 'feed_forward.0.bias', 'feed_forward.2.weight', 'feed_forward.2.bias', 'layer_norm1.weight', 'layer_norm1.bias', 'layer_norm2.weight', 'layer_norm2.bias']


In [15]:
def freeze_linear_layers(model):
    for child in model.children():
        if isinstance(child, nn.Linear):
            for param in child.parameters():
                param.requires_grad=False
        else:
            # recursively freeze linear layers in children modules
            freeze_linear_layers(child)

freeze_linear_layers(transformer_lora)
for name, param in transformer_lora.named_parameters():
    print(f'{name}:{param.requires_grad}')

W_query.linear.weight:False
W_query.lora.A:True
W_query.lora.B:True
W_key.weight:False
W_value.linear.weight:False
W_value.lora.A:True
W_value.lora.B:True
feed_forward.0.weight:False
feed_forward.0.bias:False
feed_forward.2.weight:False
feed_forward.2.bias:False
layer_norm1.weight:True
layer_norm1.bias:True
layer_norm2.weight:True
layer_norm2.bias:True


In [16]:
artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    loss=artifacts.LossType.CrossEntropyLoss, #Specify the loss function, try with different ones
    #loss=artifacts.LossType.MSELoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory="Transformer_Lora",
    additional_output_names=["output"])

In [17]:
model = onnx.load("Transformer_Lora/training_model.onnx")
print('Model :\n\n{}'.format(onnx.helper.printable_graph(model.graph)))

Model :

graph main_graph (
  %input[FLOAT, batch_sizex32]
  %labels[INT64, batch_size]
  %W_query.linear.weight[FLOAT, 32x32]
  %W_query.lora.A[FLOAT, 32x2]
  %W_query.lora.B[FLOAT, 2x32]
  %W_key.weight[FLOAT, 32x32]
  %W_value.linear.weight[FLOAT, 32x32]
  %W_value.lora.A[FLOAT, 32x2]
  %W_value.lora.B[FLOAT, 2x32]
  %feed_forward.0.weight[FLOAT, 128x32]
  %feed_forward.0.bias[FLOAT, 128]
  %feed_forward.2.weight[FLOAT, 32x128]
  %feed_forward.2.bias[FLOAT, 32]
  %layer_norm1.weight[FLOAT, 32]
  %layer_norm1.bias[FLOAT, 32]
  %layer_norm2.weight[FLOAT, 32]
  %layer_norm2.bias[FLOAT, 32]
  %W_query.linear.weight_grad.accumulation.buffer[FLOAT, 32x32]
  %W_query.lora.A_grad.accumulation.buffer[FLOAT, 32x2]
  %W_query.lora.B_grad.accumulation.buffer[FLOAT, 2x32]
  %W_key.weight_grad.accumulation.buffer[FLOAT, 32x32]
  %W_value.linear.weight_grad.accumulation.buffer[FLOAT, 32x32]
  %W_value.lora.A_grad.accumulation.buffer[FLOAT, 32x2]
  %W_value.lora.B_grad.accumulation.buffer[FLOAT, 2x

## To add LoRA to all layers 
#### Accessing the FFN

In [18]:
print(transformer_lora.feed_forward)

Sequential(
  (0): Linear(in_features=32, out_features=128, bias=True)
  (1): ReLU()
  (2): Linear(in_features=128, out_features=32, bias=True)
)


In [19]:
print(transformer_lora.feed_forward[0])
print(transformer_lora.feed_forward[1])
print(transformer_lora.feed_forward[2])

Linear(in_features=32, out_features=128, bias=True)
ReLU()
Linear(in_features=128, out_features=32, bias=True)


In [20]:
transformer_lora.W_key=LinearWithLoRAMerged(transformer_lora.W_key)
transformer_lora.feed_forward[0]=LinearWithLoRAMerged(transformer_lora.feed_forward[0])
transformer_lora.feed_forward[2]=LinearWithLoRAMerged(transformer_lora.feed_forward[2])
print(transformer_lora)

TransformerBlock(
  (W_query): LinearWithLoRAMerged(
    (linear): Linear(in_features=32, out_features=32, bias=False)
    (lora): LoRALayer()
  )
  (W_key): LinearWithLoRAMerged(
    (linear): Linear(in_features=32, out_features=32, bias=False)
    (lora): LoRALayer()
  )
  (W_value): LinearWithLoRAMerged(
    (linear): Linear(in_features=32, out_features=32, bias=False)
    (lora): LoRALayer()
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (feed_forward): Sequential(
    (0): LinearWithLoRAMerged(
      (linear): Linear(in_features=32, out_features=128, bias=True)
      (lora): LoRALayer()
    )
    (1): ReLU()
    (2): LinearWithLoRAMerged(
      (linear): Linear(in_features=128, out_features=32, bias=True)
      (lora): LoRALayer()
    )
  )
  (layer_norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (layer_norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)


In [21]:
def freeze_linear_layers(model):
    for child in model.children():
        if isinstance(child, nn.Linear):
            for param in child.parameters():
                param.requires_grad=False
        else:
            # recursively freeze linear layers in children modules
            freeze_linear_layers(child)

freeze_linear_layers(transformer_lora)
for name, param in transformer_lora.named_parameters():
    print(f'{name}:{param.requires_grad}')

W_query.linear.weight:False
W_query.lora.A:True
W_query.lora.B:True
W_key.linear.weight:False
W_key.lora.A:True
W_key.lora.B:True
W_value.linear.weight:False
W_value.lora.A:True
W_value.lora.B:True
feed_forward.0.linear.weight:False
feed_forward.0.linear.bias:False
feed_forward.0.lora.A:True
feed_forward.0.lora.B:True
feed_forward.2.linear.weight:False
feed_forward.2.linear.bias:False
feed_forward.2.lora.A:True
feed_forward.2.lora.B:True
layer_norm1.weight:True
layer_norm1.bias:True
layer_norm2.weight:True
layer_norm2.bias:True


In [22]:
model_outputs = transformer_lora(inputs)
if isinstance(model_outputs, torch.Tensor):
    model_outputs = [model_outputs]
    
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}

In [23]:
f = io.BytesIO()
torch.onnx.export(
    transformer_lora,
    inputs,
    f,
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)
onnx_model = onnx.load_model_from_string(f.getvalue())

In [24]:
requires_grad = [name for name, param in transformer_lora.named_parameters() if param.requires_grad]

frozen_params = [name for name, param in transformer_lora.named_parameters() if not param.requires_grad]
print(requires_grad)

['W_query.lora.A', 'W_query.lora.B', 'W_key.lora.A', 'W_key.lora.B', 'W_value.lora.A', 'W_value.lora.B', 'feed_forward.0.lora.A', 'feed_forward.0.lora.B', 'feed_forward.2.lora.A', 'feed_forward.2.lora.B', 'layer_norm1.weight', 'layer_norm1.bias', 'layer_norm2.weight', 'layer_norm2.bias']


In [25]:
artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    loss=artifacts.LossType.CrossEntropyLoss, #Specify the loss function, try with different ones
    #loss=artifacts.LossType.MSELoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory="Transformer_Lora/All_layers",
    additional_output_names=["output"])

In [26]:
model = onnx.load("Transformer_Lora/All_layers/training_model_all.onnx")
print('Model :\n\n{}'.format(onnx.helper.printable_graph(model.graph)))

Model :

graph main_graph (
  %input[FLOAT, batch_sizex32]
  %labels[INT64, batch_size]
  %W_query.linear.weight[FLOAT, 32x32]
  %W_query.lora.A[FLOAT, 32x2]
  %W_query.lora.B[FLOAT, 2x32]
  %W_key.linear.weight[FLOAT, 32x32]
  %W_key.lora.A[FLOAT, 32x2]
  %W_key.lora.B[FLOAT, 2x32]
  %W_value.linear.weight[FLOAT, 32x32]
  %W_value.lora.A[FLOAT, 32x2]
  %W_value.lora.B[FLOAT, 2x32]
  %feed_forward.0.linear.weight[FLOAT, 128x32]
  %feed_forward.0.linear.bias[FLOAT, 128]
  %feed_forward.0.lora.A[FLOAT, 32x2]
  %feed_forward.0.lora.B[FLOAT, 2x128]
  %feed_forward.2.linear.weight[FLOAT, 32x128]
  %feed_forward.2.linear.bias[FLOAT, 32]
  %feed_forward.2.lora.A[FLOAT, 128x2]
  %feed_forward.2.lora.B[FLOAT, 2x32]
  %layer_norm1.weight[FLOAT, 32]
  %layer_norm1.bias[FLOAT, 32]
  %layer_norm2.weight[FLOAT, 32]
  %layer_norm2.bias[FLOAT, 32]
  %W_query.lora.A_grad.accumulation.buffer[FLOAT, 32x2]
  %W_query.lora.B_grad.accumulation.buffer[FLOAT, 2x32]
  %W_key.lora.A_grad.accumulation.buffer[FLO

In [28]:
print(transformer_lora)

TransformerBlock(
  (W_query): LinearWithLoRAMerged(
    (linear): Linear(in_features=32, out_features=32, bias=False)
    (lora): LoRALayer()
  )
  (W_key): Linear(in_features=32, out_features=32, bias=False)
  (W_value): LinearWithLoRAMerged(
    (linear): Linear(in_features=32, out_features=32, bias=False)
    (lora): LoRALayer()
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (feed_forward): Sequential(
    (0): LinearWithLoRAMerged(
      (linear): Linear(in_features=32, out_features=128, bias=True)
      (lora): LoRALayer()
    )
    (1): ReLU()
    (2): LinearWithLoRAMerged(
      (linear): Linear(in_features=128, out_features=32, bias=True)
      (lora): LoRALayer()
    )
  )
  (layer_norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (layer_norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)
