# Bringing Your Own PyTorch Model to RasterFlow

This guide shows how to export a PyTorch model, store it in S3, and run it in RasterFlow.

We'll use a simple toy model with minimal dependencies to focus on the essential steps: **model export** and **RasterFlow integration**.

---

### Resources

For exported model examples and scripts, see our [Hugging Face collection](https://huggingface.co/collections/wherobots/wherobotsai-models).

**PyTorch PT2 export documentation:**
- [Common Challenges and Solutions](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html) — Beginner
- [Export Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html) — Advanced (complex models, accelerator optimization)
- [PT2 Archive Format](https://docs.pytorch.org/docs/stable/export/pt2_archive.html)

## Create a Toy Model

We'll create an example model matching the [Meta/WRI Canopy Height Model](https://huggingface.co/wherobots/meta-chm-v1-pt2) signature: the input is a tensor of image data, output is a tensor of continuous values (canopy height in meters). The shape of the output is the same as the input.

In [None]:
!uv pip install torch==2.8 torchvision --extra-index-url https://download.pytorch.org/whl/cu126

In [None]:
# Create model and load checkpoint
import torch.nn as nn
import torch

class ExampleModel(nn.Module):
    def forward(self, x):
        predictions = torch.randn(x.shape)
        return predictions

In [None]:
x = torch.randn(1,3,256,256)
print(x.shape)

In [None]:
model = ExampleModel()
result = model(x)
print(result.shape)

With our model defined, let's now export it.

## PyTorch 2 Export Formats

PyTorch offers multiple export formats for different use cases: storing weights, training, edge inference, and server inference.

### Why not export with the `.pth` format?

You may be familiar with the checkpoint format saved as `.pth`:

```python
torch.save(model.state_dict(), "model.pth")
```

This only stores model weights, not the model structure or execution logic. Loading requires all original dependencies:

```python
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load("model.pth", weights_only=True))
```

Additionally, `state_dict` exports can't be optimized for specific accelerators, making deployment difficult.

### The `.pt2` Format — A Better Alternative

`torch.export` produces a single artifact for both training and inference with some key benefits:

1. **Device flexibility** — Export on CPU, load parameters to GPU at runtime
2. **Accelerator optimization** — Compile for faster execution on CUDA, AMD, or Intel GPUs
3. **Standardized metadata** — Store hyperparameters, configuration, and transforms alongside the model

---

We'll export our `ExampleModel` in PT2 format. For input preprocessing, we'll export transforms as an `nn.Module` in the same archive. Using `torch.nn.Sequential` simplifies export since its `forward` method takes a single `input` argument.

In [None]:
from pathlib import Path
import torchvision.transforms.v2 as T

In [None]:
per_band_mean= [.5]*len(x)
per_band_std = [.1]*len(x)
norm_transform = torch.nn.Sequential(T.Normalize(mean=per_band_mean, std=per_band_std))
print(isinstance(norm_transform, nn.Module))

## Defining Input Shape Constraints

`torch.export` needs to know the expected input shape. Use `Dim` objects to specify:
- **Dynamic** dimensions — can be any value ≥1
- **Static** dimensions — must be a fixed size

#### FAQ

| Question | Answer |
|----------|--------|
| How do I know what shape to use? | Input shape affects runtime performance and accuracy. Check the model creator's recommendations. |
| Should I use dynamic for all dimensions? | Typically no. Some models have data dependent control flow logic that requires fixed dimensions (e.g., object detection models). |

**Rule of thumb:** Keep batch size dynamic; fix channel, height, and width. Test other dynamic configurations through trial and error to enable flexible channels, height, or width.

We'll export with: **dynamic batch size** (Denoted by `Dim.Auto`), **3 channels**, **256×256 height/width**.

In [None]:
from torch.export import Dim
input_shape_constraints = [-1, 3, 256, 256]
example_input_shape = [2, 3, 256, 256]
example_tensor = torch.randn(*example_input_shape, requires_grad=False)
dims = tuple(Dim.AUTO if dim == -1 else dim for dim in input_shape_constraints)
print(dims)

Set the model to `eval` mode and configure device/dtype before export:

In [None]:
import inspect
device = torch.device("cuda")
dtype = torch.float32
model.eval()
model = model.to(device).to(dtype)

`torch.export` needs the forward function's argument names. Parse them using Python's `inspect` module:

In [None]:
model_arg = next(iter(inspect.signature(model.forward).parameters))
print(model_arg)

## Export to ExportedProgram

Now that we have the following inputs we can export to an `ExportedProgram`, an in-memory model object that can be saved to a `.pt2` archive.

* `model`, we defined this toy model earlier as an nn.Module
* `args`, a tuple of the arguments to the model's forward pass function
* `dynamic_shapes`, a dict mapping the argument name of the input to the tuple of dimension constraints we created earlier: `dims`

The `ExportedProgram` includes the `state_dict` (weights) and `example_inputs` (useful for testing).

In [None]:
model_program = torch.export.export(mod=model, args=(example_tensor,), dynamic_shapes={model_arg: dims})
print(model_program.example_inputs[0][0].shape)

Follow the same export steps for the transforms module:

In [None]:
norm_transform.eval()
norm_transform = norm_transform.to(device).to(dtype)
transform_arg = next(iter(inspect.signature(norm_transform.forward).parameters))

In [None]:
transforms_program = torch.export.export(
    mod=norm_transform, args=(example_tensor,), dynamic_shapes={transform_arg: dims}
)

## Save to `.pt2` Archive

> **Note:** `torch.export.save` saves a single ExportedProgram. We'll use `torch.export.pt2_archive._package.package_pt2` to bundle both the model and transforms into one `.pt2` file.

In [None]:
from torch.export.pt2_archive._package import package_pt2
exported_programs = {}
local_model_path = "example.pt2"
exported_programs["model"] = model_program
exported_programs["transforms"] = transforms_program

package_pt2(
    f=local_model_path,
    exported_programs=exported_programs
)


## Running the Custom Model with RasterFlow

RasterFlow supports both Wherobots Hosted Models and custom models like the one we just exported.

**Steps:**
1. Upload the `.pt2` file to S3 (we'll use Wherobots Managed Storage)
2. Define an `InferenceConfig` to tell RasterFlow how to run the model

Note: You can also load open models stored in pt2 format directly from HuggingFace.

In [None]:
import os
import s3fs

fs = s3fs.S3FileSystem(profile="default")

# Define the destination path on S3
# We use the USER_S3_PATH environment variable to ensure it goes to your personal bucket space
s3_model_path = os.getenv("USER_S3_PATH") + local_model_path
fs.put(local_model_path, s3_model_path)

## Build the Model Input

To run our model, we need some input imagery. We'll test our model on Sentinel-2 4 band imagery - red, blue, greee, near infrared. We'll select an AOI over Nashua, New Jersey that has some forest canopy for our toy canopy height model.

In [None]:
import wkls
import geopandas as gpd
import os

# Generate a geometry for Nashua, NH using WKLS (https://github.com/wherobots/wkls)
gdf = gpd.read_file(wkls.us.nh.nashua.geojson())

# Save the geometry to a parquet file in the user's S3 path
aoi_path = os.getenv("USER_S3_PATH") + "nashua.parquet"
gdf.to_parquet(aoi_path)

To prepare this imagery, we'll use RasterFlow to create a mosaic. Mosaics are backed by a cloud native Zarr store that enables accessing spatial subsets, individual bands, and computing on the mosaic with RasterFlow.

This workflow takes a few minutes to complete, so you can skip ahead to the next cell where we load the prepared output from the workflow.

In [None]:
from rasterflow_remote import RasterflowClient
client = RasterflowClient()

```python
mosaic_path = client.build_gti_mosaic(
        gti = "s3://wherobots-examples/rasterflow/indexes/naip_index.parquet",
        aoi = aoi_path,
        bands = ["red", "green", "blue", "nir"],
        location_field = "url",
        crs_epsg = 3857,
        xy_chunksize = 1024,
        query = "res == .6",
        requester_pays = True,
        sort_field = "time",
        resampling = ResamplingMethod.NEAREST,
        nodata= 0.0,
)
```

## Visualize a subset of the model outputs
We will use hvplot and datashader to visualize a small subset of the mosaic's red band.

In [None]:
# Import libraries for visualization and coordinate transformation
import hvplot.xarray
import xarray as xr
import s3fs 
import zarr
from pyproj import Transformer
from holoviews.element.tiles import EsriImagery 

# Open the Zarr store
mosaic_path = "s3://wherobots-examples/rasterflow/mosaics/nashua.zarr"
fs = s3fs.S3FileSystem(profile="default", asynchronous=True, anon=True)
zstore = zarr.storage.FsspecStore(fs, path=mosaic_path)
ds = xr.open_zarr(zstore)

In [None]:
# Create a transformer to convert from lat/lon to meters
transformer = Transformer.from_crs("EPSG:4326", "EPSG:3857", always_xy=True)

# Transform bounding box coordinates from lat/lon to meters
min_lon, min_lat, max_lon, max_lat = gdf.total_bounds
(min_x, max_x), (min_y, max_y) = transformer.transform(
    [min_lon, max_lon], 
    [min_lat, max_lat]
)

# Select the red band and slice the dataset to the bounding box
# y=slice(max_y, min_y) handles the standard "North-to-South" image orientation
ds_subset = ds.sel(band="red",
    x=slice(min_x, max_x), 
    y=slice(max_y, min_y) 
)

# Select the first time step and extract the variables array
arr_subset = ds_subset.isel(time=0)["variables"]

# Create a base map layer using Esri satellite imagery
base_map = EsriImagery()

# Create an overlay layer from the model outputs with hvplot
output_layer = arr_subset.hvplot(
    x = "x",
    y = "y",
    geo = True,           # Enable geographic plotting
    dynamic = True,       # Enable dynamic rendering for interactivity
    rasterize = True,     # Use datashader for efficient rendering of large datasets
    cmap = "viridis",     # Color map for visualization
    aspect = "equal",     # Maintain equal aspect ratio
    title = "Nashua, NJ Sentinel-2 Red Band" 
).opts(
    width = 600, 
    height = 600,
    alpha = 0.7           # Set transparency to see the base map underneath
)

# Combine the base map and output layer
final_plot = base_map * output_layer
final_plot

With our mosaic, we are now ready to run model prediction on the mosaic with RasterFlow.

We'll use the [`predict_mosaic`](https://docs.wherobots.com/reference/rasterflow/client#predict_mosaic) method to run our model. `predict_mosaic` leverages RasterFlow's powerful inference engine that scales from small to global scale areas of interest.

The inputs to this method are our input store we want to run prediction on, and our InferenceConfig object we created earlier.

## Defining the InferenceConfig

With the model on S3, define the inference job configuration.

> **Note:** Wherobots Hosted Models come with preconfigured `ModelRecipes`— this step is only needed for custom models.

See the [InferenceConfig documentation](https://docs.wherobots.com/reference/RasterFlow/data-models#inferenceconfig) for parameter details.

In [None]:
from dataclasses import asdict
from rasterflow_remote.data_models import InferenceConfig, InferenceActorEnum, MergeModeEnum, ResamplingMethod

custom_inference_config = InferenceConfig(
    model_path = s3_model_path,
    actor = InferenceActorEnum.REGRESSION_PYTORCH,
    patch_size = 224,
    clip_size = 28,
    device = "cuda",
    features = ["red", "green", "blue"],
    labels = ["canopy_height"],
    max_batch_size=64,
    merge_mode = MergeModeEnum.WEIGHTED_AVERAGE
    
)

In [None]:
client.predict_mosaic(
        store=mosaic_path,
        **asdict(custom_inference_config)
)