In [1]:
import mlflow
import mlflow.pytorch
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.onnx
import os
import shutil

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [3]:
# Training settings
batch_size = 64
epochs = 5
lr = 0.01

In [4]:
# Data loader
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:37<00:00, 267823.79it/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 94504.07it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:06<00:00, 250060.13it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5011978.10it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw






In [9]:
def train_and_log_model():
    # Initialize model, loss function, optimizer
    model = SimpleNN()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)

    # Start MLflow run
    with mlflow.start_run() as run:
        mlflow.log_param("batch_size", batch_size)
        mlflow.log_param("epochs", epochs)
        mlflow.log_param("learning_rate", lr)
        
        for epoch in range(1, epochs + 1):
            model.train()
            epoch_loss = 0
            for batch_idx, (data, target) in enumerate(train_loader):
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
                
                if batch_idx % 100 == 0:
                    print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                          f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
            
            # Log metrics to MLflow
            avg_loss = epoch_loss / len(train_loader)
            mlflow.log_metric("avg_loss", avg_loss, step=epoch)

        # Log the model to MLflow
        mlflow.pytorch.log_model(model, "model")

        # Log example input and output
        example_input = torch.randn(1, 1, 28, 28)
        example_output = model(example_input)
        torch.save(example_input, "example_input.pt")
        torch.save(example_output, "example_output.pt")
        mlflow.log_artifact("example_input.pt")
        mlflow.log_artifact("example_output.pt")

    return model

In [10]:
def convert_to_onnx(model):
    # Create a dummy input for model export
    dummy_input = torch.randn(1, 1, 28, 28)
    # Export the model to ONNX format
    torch.onnx.export(model, dummy_input, "model.onnx")


In [11]:
def prepare_triton_repository():
    # Create necessary directories for Triton
    os.makedirs("model_repository/mnist_model/1", exist_ok=True)
    # Move the ONNX model to the correct location
    shutil.move("model.onnx", "model_repository/mnist_model/1/model.onnx")

    # Update the config.pbtxt file with correct input/output names
    input_name = "onnx::Flatten_0"
    output_name = "8"

    with open("model_repository/mnist_model/config.pbtxt", "w") as f:
        f.write(f"""
name: "mnist_model"
platform: "onnxruntime_onnx"
max_batch_size: 1
input [
  {{
    name: "{input_name}"
    data_type: TYPE_FP32
    dims: [ 1, 28, 28 ]
  }}
]
output [
  {{
    name: "{output_name}"
    data_type: TYPE_FP32
    dims: [ 10 ]
  }}
]
        """)


In [12]:
if __name__ == "__main__":
    model = train_and_log_model()
    convert_to_onnx(model)
    prepare_triton_repository()

    print("Model training, logging, and Triton deployment preparation complete.")

Model training, logging, and Triton deployment preparation complete.
