In [None]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

# Pretrained Weights

_Written by: Caleb Robinson_

In this tutorial, we will demonstrate how to use pretrained models in TorchGeo to extract fixed-length embeddings from remote sensing imagery. This is useful for quickly exploring datasets, visualizing feature spaces, and establishing baseline performance without the need for extensive training.

Specifically, we will:

- Load the **EuroSAT** dataset using TorchGeo's Lightning `DataModule`.
- Load two pretrained encoders:
  1) **DOFA** (a ViT-B/16-style encoder fine-tuned for Earth observation) - outputs 768-D features.
  2) **ResNet-18** (using weights from SSL4EO) - outputs 512-D features.
- Extract fixed-length embeddings for every image without labels (no gradient / just forward passes).
- Train a simple k-Nearest Neighbors classifier on the embeddings to quantify linear separability.
- Visualize the feature space with a 2-D PCA plot to see class separation.

## Setup

First, we install TorchGeo.

In [None]:
# On Colab, this ensures the latest TorchGeo is available.

%pip install torchgeo scikit-learn tqdm

## Imports

Next, we import TorchGeo and any other libraries we need.

In [None]:
import os
import tempfile

import kornia.augmentation as K
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.decomposition import PCA
from sklearn.metrics import classification_report
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm

from torchgeo.datamodules import EuroSAT100DataModule
from torchgeo.models import DOFABase16_Weights, ResNet18_Weights, get_model

## Datamodule

We will utilize TorchGeo's [Lightning](https://lightning.ai/docs/pytorch/stable/) datamodules to organize the dataloader setup.

In [None]:
# Build EuroSAT DataModule (train/val loaders) using TorchGeo Lightning utilities.

root = os.path.join(tempfile.gettempdir(), 'eurosat100')
datamodule = EuroSAT100DataModule(
    root=root, batch_size=10, num_workers=2, download=True, bands=('B02', 'B03', 'B04')
)
datamodule.setup('fit')
datamodule.setup('validate')
datamodule.setup('test')

train_dl = datamodule.train_dataloader()
val_dl = datamodule.val_dataloader()
test_dl = datamodule.test_dataloader()

## Embedding

We will embed the entirety of EuroSAT train and validation splits using a pretrained model by extracting features from the final layer before the classification head.

With the DOFA model this will give us 768-dimensional feature vectors from each image. We use these vectors with the labels to train and evaluate a simple k-nearest neighbors classifier.

In [None]:
accelerator = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Load the DOFA pretrained encoder (ViT-style) and move it to device.

model = get_model('dofa_base_patch16_224', weights=DOFABase16_Weights.DOFA_MAE)
model = model.eval().to(accelerator)

resize = K.Resize((224, 224))  # DOFA model expects 224x224 inputs

In [None]:
def embed_dofa(model, dataloader, accelerator, transforms=None):
    """Helper function to embed the samples from a dataloader using a DOFA model."""
    x_all = []
    y_all = []
    for batch in tqdm(dataloader):
        x = batch['image'].to(accelerator)
        y = batch['label']
        x = x.to(accelerator)
        if transforms is not None:
            x = transforms(x)

        with torch.inference_mode():
            # DOFA requires us to forward the central wavelengths of each bands
            embeddings = model.forward_features(
                x, wavelengths=[0.49, 0.56, 0.66]
            )  # these are B02, B03, B04 wavelengths for Sentinel-2
        x_all.append(embeddings.cpu().numpy())
        y_all.append(y.numpy())
    x_all = np.concatenate(x_all, axis=0)
    y_all = np.concatenate(y_all, axis=0)
    return x_all, y_all

In [None]:
x_train, y_train = embed_dofa(model, train_dl, accelerator, transforms=resize)
x_val, y_val = embed_dofa(model, val_dl, accelerator, transforms=resize)

In [None]:
# Fit a k-NN classifier on DOFA train embeddings and evaluate on validation embeddings.
# This gives a quick, label-efficient baseline without fine-tuning.

knn_model = KNeighborsClassifier(n_neighbors=5)
knn_model.fit(x_train, y_train)
y_pred = knn_model.predict(x_val)

In [None]:
class_names = [
    'Annual Crop',
    'Forest',
    'Herbaceous Vegetation',
    'Highway',
    'Industrial Buildings',
    'Pasture',
    'Permanent Crop',
    'Residential Buildings',
    'River',
    'Sea & Lake',
]

print(
    classification_report(
        y_val, y_pred, digits=2, target_names=class_names, zero_division=0
    )
)

Now let's do the same thing with a ResNet18 model pretrained on Sentinel-2 RGB imagery (from the SSL4EO paper).

We reuse the same dataloaders and evaluation code to isolate the effect of the encoder.

In [None]:
# We can instantiate a ResNet18 model in a similar way to the DOFA model above.

model = get_model('resnet18', weights=ResNet18_Weights.SENTINEL2_RGB_MOCO)
model = model.eval().to(accelerator)

In [None]:
def embed_standard(model, dataloader, accelerator, transforms=None):
    x_all = []
    y_all = []
    for batch in tqdm(dataloader):
        x = batch['image'].to(accelerator)
        y = batch['label']
        x = x.to(accelerator)
        if transforms is not None:
            x = transforms(x)

        with torch.inference_mode():
            embeddings = model.forward_features(x)
            embeddings = torch.mean(
                embeddings, dim=(-2, -1)
            )  # global average pooling over the spatial dims
        x_all.append(embeddings.cpu().numpy())
        y_all.append(y.numpy())
    x_all = np.concatenate(x_all, axis=0)
    y_all = np.concatenate(y_all, axis=0)
    return x_all, y_all

In [None]:
x_train, y_train = embed_standard(model, train_dl, accelerator, transforms=resize)
x_val, y_val = embed_standard(model, val_dl, accelerator, transforms=resize)

In [None]:
# Fit/evaluate k-NN on ResNet-18 embeddings for a side-by-side comparison.

knn_model = KNeighborsClassifier(n_neighbors=5)
knn_model.fit(x_train, y_train)
y_pred = knn_model.predict(x_val)

In [None]:
print(
    classification_report(
        y_val, y_pred, digits=2, target_names=class_names, zero_division=0
    )
)

In [None]:
pca = PCA(n_components=2, whiten=True)
x_reduced = pca.fit_transform(x_train)
print(f'Explained variance ratio: {pca.explained_variance_ratio_.sum()}')

In [None]:
fig = plt.figure(figsize=(7, 7))
ax = plt.gca()
scatter = plt.scatter(
    x_reduced[:, 0], x_reduced[:, 1], c=y_train, cmap='tab10', s=4, alpha=0.5
)
handles, _ = scatter.legend_elements()
labels = [class_names[i] for i in range(10)]
plt.legend(handles, labels, title='Classes', loc='best')
plt.title('PCA of EuroSAT training set embeddings')
plt.xlim([-1.5, 2])
plt.ylim([-1, 2])
plt.axis('off')
plt.tight_layout()
plt.show()
plt.close()