In [None]:
from torchtune.models.llama2 import llama2_7b, lora_llama2_7b
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training import artifacts
from onnxruntime import InferenceSession
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx
import io
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM

In [None]:
# Build Llama2 without any LoRA layers
base_model = llama2_7b()

In [None]:
# The default settings for lora_llama2_7b will match those for llama2_7b
# We just need to define which layers we want LoRA applied to.
# Within each self-attention, we can choose from ["q_proj", "k_proj", "v_proj", and "output_proj"].
# We can also set apply_lora_to_mlp=True or apply_lora_to_output=True to apply LoRA to other linear
# layers outside of the self-attention.
lora_model = lora_llama2_7b(lora_attn_modules=["q_proj", "v_proj"])

In [None]:
batch = torch.tensor([[6109,3626,6100,345],[6109,1110,6622,257]])

In [None]:
#Forward graph

model_outputs = lora_model(batch)

if isinstance(model_outputs, torch.Tensor):
    model_outputs = [model_outputs]

input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}


f = io.BytesIO()
torch.onnx.export(
    lora_model,
    batch,
    "torchtune_lora_llama2.onnx",
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)



In [None]:
#Backward graph

requires_grad = [name for name, param in lora_model.named_parameters() if param.requires_grad]
frozen_params = [name for name, param in lora_model.named_parameters() if not param.requires_grad]

artifacts.generate_artifacts(
    "torchtune_lora_llama2.onnx",
    optimizer=artifacts.OptimType.AdamW,
    loss=artifacts.LossType.CrossEntropyLoss, #Specify the loss function, try with different ones
    #loss=artifacts.LossType.MSELoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory="torchtune_lora_llama2",
    additional_output_names=["output"])