## Bringing your own Pytorch Model to Rasterflow

This guide will show you how you can export a Pytorch model, store in s3, and then run that model in rasterflow.

We'll cover these steps using minimal dependencies and a simple toy Pytorch model. This will let us focus on the steps for model export, storing models, and referencing models in rasterflow.


## Resources

To see other examples of exported models and their export scripts, check out our [Huggingface page](https://huggingface.co/collections/wherobots/wherobotsai-models).

For a deeper dive on Pytorch's PT2 export format and export methods check out the following tutorials.

1. [Torch Export Common Challenges and Solutions](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html) (Beginner)
2. [Torch Export Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html) (Advanced, if you want to export complex models or optimize models for different accelerators)
3. [PT2 Archive Format Walkthrough](https://docs.pytorch.org/docs/stable/export/pt2_archive.html)

We'll first make an example toy model that matches the input and output of the [Meta and WRI Canopy Height Model](https://huggingface.co/wherobots/meta-chm-v1-pt2).

For this regression model, the input is a Pytorch Tensor of continuous values (image data) and the output is Pytorch Tensor of continuous values (canopy height in meters).

In [None]:
!uv pip install torch==2.8 torchvision --extra-index-url https://download.pytorch.org/whl/cu126

In [None]:
# Create model and load checkpoint
import torch.nn as nn
import torch

class ExampleModel(nn.Module):
    def forward(self, x):
        predictions = torch.randn(x.shape)
        return predictions

In [None]:
x = torch.randn(1,3,256,256)
print(x.shape)

Now that we have our model defined, we can run it on some test data and then export it.

In [None]:
model = ExampleModel()
result = model(x)
print(result.shape)

## Pytorch 2's export formats

Pytorch has multiple export formats, each suited for different use case, like storing model weights, model training, inference on edge devices, and inference on servers.

If you've been using Pytorch for a while, you may be familiar with it's ubiquitous model checkpoint format, often saved with the extension `.pth`.

```
PATH = "model.pth"
torch.save(model.state_dict(), PATH)
```

This format only stores a simple dictionary of a model's weights, not the structure that loads the model weights or any description of how the model is executed. In order to load the model, one needs to have all the dependencies used to expor tthe model, a burdensome requirement.

```
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, weights_only=True))
```

What's more is that it is not possibly to optimize models for specific accelerators using the `state_dict` export method. This makes it difficult to deploy models exported using `torch.save`. 

## .pt2, an export format for all use cases

An improved alternative export method is to use `torch.export`. Torch's export module allows you to export a single model artifact that can be used for both training and inference. It has a number of useful features

1. Supports exporting the model on CPU, and then at runtime, moving the model's paramters to GPU.
2. Exported models can be compiled to run faster and consume less memory on particular acclerators (CUDA, AMD, or Intel GPUs).
3. Standardizes how inference and tryining hyperparamters (model metadata) and model input or output transforms are stored alongside the model.

We'll now export our `ExampleModel` in PT2 format. We'll also handle the common case of input preprocessing by exporting our input transforms as a torch nn.Module int he same PT2 archive as our model. `torch.nn.Sequential` can be used to chain together multiple transforms. `torch.nn.Sequential` also helps us use `torch.export` since it has a simple `forward` method that takes a single argument called `input`. This let's us define fewer arguments later when we call `torch.export.export`.

In [None]:
from pathlib import Path
import torchvision.transforms.v2 as T

In [None]:
per_band_mean= [.5]*len(x)
per_band_std = [.1]*len(x)
norm_transform = torch.nn.Sequential(T.Normalize(mean=per_band_mean, std=per_band_std))
print(isinstance(norm_transform, nn.Module))

Next, for the PT2 torch.export method, we need to tell torch about the expected shape of our model input. This helps torch export static file that captures the structure of the model, which can be dependent on the input data. `torch.export` introduces a `Dim` object, which we can use to denote if a dimension is meant to by dynamic, meaning it can be any value greater than 1, or static, meaning that the dimension must be a certain size.

Frequently asked questions the topic of data dependent export include:
1. How do I know what shape my model should take?
   * Answer: The input shape to a model can change the runtime performance and accuracy of the model result. It's best to check the recommendations from the model creators.
3. Should I always just use dynamic for all dimensions?
   * Answer: Usually no. Many models have logic that requires dimensions to be fixed. An exmple of this is some detection models that have lots of control flow that depends on the shape of the input. This control flow can sometimes restrict the model to being exported on the same patch height and width that a model was trained on.

A good rule of thumb is that it is usually safe to leave the batch size dimension dynamic, while fixing the channel, height, and width shape dimensions. If you'd like to support the ability to export with dynamic shapes besides the batch dimension, it's quick trial and error to try exporting with these settings.

We'll export our toy model with a dynamic batch size, fixed channel dimension of 3, and fixed height and width of 256.

In [None]:
from torch.export import Dim
input_shape_constraints = [-1, 3, 256, 256]
example_input_shape = [2, 3, 256, 256]
example_tensor = torch.randn(*example_input_shape, requires_grad=False)
dims = tuple(Dim.AUTO if dim == -1 else dim for dim in input_shape_constraints)
print(dims)

Next we will set our model to "eval" mode before we export and set some expected parameters for our model, including the device type it will run on and the expected data type of the model.

In [None]:
import inspect
device = torch.device("cuda")
dtype = torch.float32
model.eval()
model = model.to(device).to(dtype)

torch.export also must know the expected arguments for the forward pass function for our torch model. We can parse this from the nn.Module with python's built-in `inspect` module.

In [None]:
model_arg = next(iter(inspect.signature(model.forward).parameters))
print(model_arg)

And now we are ready to export our model to an ExportedProgram! An Exported Program is an in-memory model object that can be stored within a `.pt2` archive file on disk.

This ExportedProgram stores many things related to our model, including the `state_dict` for containing model weights and `example_inputs` which can be handy for testing a model.

In [None]:
model_program = torch.export.export(mod=model, args=(example_tensor,), dynamic_shapes={model_arg: dims})
print(model_program.example_inputs[0][0].shape)

We'll follow the same export steps for our transforms module.

In [None]:
norm_transform.eval()
norm_transform = norm_transform.to(device).to(dtype)
transform_arg = next(iter(inspect.signature(norm_transform.forward).parameters))

In [None]:
transforms_program = torch.export.export(
    mod=norm_transform, args=(example_tensor,), dynamic_shapes={transform_arg: dims}
)

Now that we have ExportedPrograms for our model and normalize transform, we are ready to save the model to a `.pt2` archive.

Note: There are a few ways to save `.pt2` model archives with torch. `torch.export.save` can be used to save a single ExportedPgrogram. We will instead use `torch.export.pt2_archive._package.package_pt2` to save the transforms and model ExportedProgram into one `.pt2`. 

In [None]:
from torch.export.pt2_archive._package import package_pt2
exported_programs = {}
local_model_path = "example.pt2"
exported_programs["model"] = model_program
exported_programs["transforms"] = transforms_program

package_pt2(
    f=output_file,
    exported_programs=exported_programs
)


## Running the Custom Model with RasterFlow

In addition to using Wheorbots Hosted Models, RasterFlow allows you to run inference using your own custom models like the one we have just exported.

To do this, we first need to upload the model (in this case, a `.pt2` file) to a location on S3 that RasterFlow can access. We will use the Wherobots Managed Storage bucket that comes with each account. We will then define an `InferenceConfig` that tells the RasterFlow how to run the model.

In [None]:
import os
import s3fs

fs = s3fs.S3FileSystem(profile="default")

# Define the destination path on S3
# We use the USER_S3_PATH environment variable to ensure it goes to your personal bucket space
s3_model_path = os.getenv("USER_S3_PATH") + local_model_path
fs.put(local_model_path, s3_model_path)

## Defining the RasterFlow InferenceConfig

Now that the model weights are on S3, we need to define the configuration for the inference job.

Note: If you are using a Wherobots Hosted Model, there's no need for this step, as all of these models are available as preconfigured ModelRecipes.

See the [InferenceConfig](https://docs.wherobots.com/reference/rasterflow/data-models#inferenceconfig) documentation for details on these parameters.

In [None]:
from rasterflow_remote.data_models import InferenceConfig, InferenceActorEnum, MergeModeEnum, ResamplingMethod

custom_inference_config = InferenceConfig(
    model_path = s3_model_path,
    actor = InferenceActorEnum.REGRESSION_PYTORCH,
    patch_size = 224,
    clip_size = 28,
    device = "cuda",
    features = ["red", "green", "blue"],
    labels = ["canopy_height"],
    max_batch_size=64,
    merge_mode = MergeModeEnum.WEIGHTED_AVERAGE
    
)


We can build the input for our model like so.

In [None]:
import wkls
import geopandas as gpd
import os

# Generate a geometry for Nashua, NH using WKLS (https://github.com/wherobots/wkls)
gdf = gpd.read_file(wkls.us.nh.nashua.geojson())

# Save the geometry to a parquet file in the user's S3 path
aoi_path = os.getenv("USER_S3_PATH") + "nashua.parquet"
gdf.to_parquet(aoi_path)

from rasterflow_remote import RasterflowClient
from rasterflow_remote.data_models import InferenceConfig, InferenceActorEnum, MergeModeEnum, ResamplingMethod
client = RasterflowClient(mosaics_version= "v0.18.0", rasterflow_version="v1.40.2")


client.build_gti_mosaic(
        gti = "s3://wherobots-examples/rasterflow/indexes/naip_index.parquet",
        aoi = aoi_path,
        bands = ["red", "green", "blue", "nir"],
        location_field = "url",
        crs_epsg = 3857,
        xy_chunksize = 1024,
        query = "res == .6",
        requester_pays = True,
        sort_field = "time",
        resampling = ResamplingMethod.NEAREST,
        nodata= 0.0,
)