![](https://wherobots.com/wp-content/uploads/2023/12/Inline-Blue_Black_onWhite@3x.png)

## Wherobots Inference - Scene Classification 

This example demonstrates query inference using a classification model with Wherobots Inference to identify land cover in satellite imagery. We will use a machine-learning model from [torchgeo](torchgeo)<sup>1</sup> trained using imagery from the European Space Agency’s Sentinel-2 satellites.

**Note: This notebook requires the Wherobots Inference functionality to be enabled and a GPU runtime selected in Wherobots Cloud. Please [contact us](https://wherobots.com/contact/) to enable these features.**


### 1: Set up the Wherobots Context

In [None]:
import warnings
warnings.filterwarnings('ignore')

from wherobots.inference.data.io import read_raster_table
from sedona.spark import SedonaContext
from pyspark.sql.functions import expr

config = SedonaContext.builder().appName('classification-batch-inference')\
    .getOrCreate()

sedona = SedonaContext.create(config)

### 2: Load satellite imagery

Next, we load the satellite imagery that we will be running inference over. These GeoTiff images are loaded as *out-db* rasters in WherobotsDB, where each row represents a different scene.

In [None]:
tif_folder_path = 's3a://wherobots-examples/data/eurosat_small'
files_df = read_raster_table(tif_folder_path, sedona)
df_raster_input = files_df.withColumn(
        "outdb_raster", expr("RS_FromPath(path)")
    )
df_raster_input.cache().show(truncate=False)
print(df_raster_input.count())

### 3: Run prediction with sedona.sql apis

To run predictions we will specify the model we wish to use. Some models are pre-loaded and made available in Wherobots Cloud. We can also load our own models. Predictions can be run using Wherobot's Spatial SQL functions, in this case `RS_CLASSIFY`.

Here we generate 200 predictions using `RS_CLASSIFY`.

In [None]:
#%%time
df_raster_input.createOrReplaceTempView("df_raster_input")
model_id = 'landcover-eurosat-sentinel2'
predictions_df = sedona.sql(f"SELECT name, outdb_raster, RS_CLASSIFY('{model_id}', outdb_raster) AS preds FROM df_raster_input")
predictions_df.cache().show(truncate=False)
predictions_df.createOrReplaceTempView("predictions_df")

From the prediction result, we can retrieve the most confidence classification label and it's probability score.

In [None]:
max_predictions_df = sedona.sql(f"SELECT name, outdb_raster, RS_MAX_CONFIDENCE(preds).max_confidence_label, RS_MAX_CONFIDENCE(preds).max_confidence_score FROM predictions_df")
max_predictions_df.show(20, truncate=False)

### wherobots.inference Python API

If you prefer python, wherobots.inference offers a module for registering the SQL inference functions as python functions. Below we run the same inference as before with `RS_CLASSIFY`.

In [None]:
from wherobots.inference.engine.register import create_single_label_classification_udfs
rs_classify, rs_max_confidence = create_single_label_classification_udfs(batch_size = 10, sedona=sedona)
df_predictions = df_raster_input.withColumn("preds", rs_classify(model_id, 'outdb_raster'))
df_predictions.show(1)

In [None]:
from pyspark.sql.functions import col

df_max_predictions = df_predictions.withColumn("max_confidence_temp", rs_max_confidence(col("preds"))) \
                            .withColumn("max_confidence_label", col("max_confidence_temp.max_confidence_label")) \
                            .withColumn("max_confidence_score", col("max_confidence_temp.max_confidence_score")) \
                            .drop("max_confidence_temp", "preds")
df_max_predictions.show(2, truncate=False)

We can write the label and score results to parquet to refer to them later.

In [None]:
output_path = "./results.parquet"
df_predictions.select(["preds"]).write.parquet(output_path, mode="overwrite")

### Visualize the model predictions and source imagery

In [None]:
df_rast = sedona.read.format("binaryFile").option("pathGlobFilter", "*.tif").option("recursiveFileLookup", "true").load(tif_folder_path).selectExpr("RS_FromGeoTiff(content) as raster")

In [None]:
htmlDF = df_max_predictions.selectExpr("RS_Band(outdb_raster, Array(4, 3, 2)) as image_raster", "name", "max_confidence_label")\
    .selectExpr("RS_NormalizeAll(image_raster, 1, 65535, True) as image_raster", "name", "max_confidence_label")\
    .selectExpr("RS_AsImage(image_raster, 500) as image_raster", "name", "max_confidence_label")

In [None]:
from sedona.raster_utils.SedonaUtils import SedonaUtils
from pyspark.sql.functions import rand
SedonaUtils.display_image(htmlDF.orderBy(rand()).limit(3))

### References

1. Stewart, A. J., Robinson, C., Corley, I. A., Ortiz, A., Lavista Ferres, J. M., & Banerjee, A. (2022). [TorchGeo: Deep Learning With Geospatial Data](https://dl.acm.org/doi/10.1145/3557915.3560953). In *Proceedings of the 30th International Conference on Advances in Geographic Information Systems* (pp. 1-12). Association for Computing Machinery. https://doi.org/10.1145/3557915.3560953
