In [1]:
import torch
import torchvision

model = torchvision.models.mobilenet_v2(
    weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2)

In [2]:
model.classifier[1] = torch.nn.Linear(1280, 4)

In [3]:
model_name = "mobilenetv2"
torch.onnx.export(model, torch.randn(1, 3, 224, 224),
                  f"training_artifacts/{model_name}.onnx",
                  input_names=["input"], output_names=["output"],
                  dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})

In [4]:
import onnx
from onnxruntime.training import artifacts

In [5]:
onnx_model = onnx.load(f"training_artifacts/{model_name}.onnx")

requires_grad = ["classifier.1.weight", "classifier.1.bias"]
frozen_params = [
    param.name
    for param in onnx_model.graph.initializer
    if param.name not in requires_grad
]

artifacts.generate_artifacts(
    onnx_model,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    loss=artifacts.LossType.CrossEntropyLoss,
    optimizer=artifacts.OptimType.AdamW,
    artifact_directory="training_artifacts"
)

2024-09-18 17:37:23.981495875 [I:onnxruntime:Default, constant_sharing.cc:248 ApplyImpl] Total shared scalar initializer count: 68
2024-09-18 17:37:23.981863375 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer ConstantSharing modified: 1 with status: OK
2024-09-18 17:37:23.982323675 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer LayerNormFusion modified: 0 with status: OK
2024-09-18 17:37:23.982421975 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer CommonSubexpressionElimination modified: 0 with status: OK
2024-09-18 17:37:23.982446675 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer GeluFusion modified: 0 with status: OK
2024-09-18 17:37:23.982462075 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer SimplifiedLayerNormFusion modified: 0 with status: OK
2024-09-18 17:37:23.982476775 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer FastGeluFusion mo