# Pretraining ViT-MAE on GOES Satellite Data

In this tutorial, we’ll guide you through the process of pretraining a masked autoencoder (MAE) built on a Vision Transformer (ViT) using 10 weeks of GOES-16 satellite data. Before diving into supervised learning tasks in climate and Earth science, it’s critical to first build a model that truly understands the structure of the data itself even without labels. Satellite imagery, especially from instruments like GOES-16, captures complex, high-dimensional patterns in the atmosphere, including cloud motion, temperature gradients, and convective activity, but making sense of these patterns requires a model that can learn these spatial representations. 

In this context, our goal is to pretrain a Vision Transformer (ViT) using a Masked Autoencoder (MAE) approach. By hiding random patches of satellite data and asking the model to reconstruct what’s missing, we force it to learn underlying spatial patterns. This self-supervised approach is particularly useful when extracting meaningful visual representations from Earth observation data, even when labeled datasets are limited. With a pretrained model, we can later fine-tune on more specific downstream tasks, such as detecting the time of day, predicting extreme weather events, classifying cloud types, and more. 

To begin, we'll start off by installing and importing the essential frameworks, including the libraries mentioned prior, the ViT-MAE Hugging Face transformer, and core PyTorch modules. 

In [1]:
# installing additional libraries into kernel 
import sys
import subprocess
subprocess.run([sys.executable, "-m", "pip", "install", "numcodecs[pcodec]==0.15.1", "xarray==2025.3.1", "zarr==3.0.6"])
subprocess.run([sys.executable, "-m", "pip", "install", "transformers==4.32", "datasets", "torch", "accelerate"])

# scientific computing
import random
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

# PyTorch and ViT transformer
import os
import torch
from torch.optim import AdamW
from torch import nn
import torch.nn.functional as F
from transformers import ViTImageProcessor, ViTMAEForPreTraining



  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


Along with the previously mentioned libraries, including `numpy` and `xarray`, we are installing and importing Hugging Face's ViT-MAE transformer and PyTorch. We will leverage the ViT-MAE model, a prebuilt architecture designed specifically for masked autoencoding, to bypass the need to manually construct ViT layers or implement custom masking logic. PyTorch allows us to train the model as a flexible and widely-used deep learning framework that supports efficient optimization and backpropagation throughout the training loop. 

# Loading Data

In [2]:
# set configs
zarr_path   = "/notebook_dir/public/mickellals-public/goes-16-2003-10-weeks.tmp.zarr"
channel    = "CMI_C13" 
lat_range   = None 
lon_range   = None 
crop_size     = 224 
mask_ratio    = 0.5
learning_rate = 5e-4 
num_epochs    = 100
checkpoint = 20 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Here, we define the key configuration parameters that will control our training process. These parameters allow us to flexibly control dataset loading, training performance, and resource usage.

`zarr_path`: Path to the GOES-16 satellite imagery dataset stored in zarr format.

`channel`: The selected band (clean longwave window aka band #13 is selected in this case).

`lat_range`, `lon_range`: Tuple formatting for latitude/longitude bounds to spatially crop the dataset. Set to `None` to disable.

`crop_size`: Size (in pixels) of the square image patches that will be fed into the Vision Transformer.

`mask_ratio`: Proportion of image patches to randomly mask during MAE pretraining.

`learning_rate`: Initial learning rate for the AdamW optimizer.  

`num_epochs`: Total number of training epochs.  

`checkpoint`: Interval (in epochs) at which we save model checkpoints and plot training loss.

`device`: Whether to use GPU ("cuda") or CPU, depending on system availability.

In [3]:
ds = xr.open_zarr(zarr_path)
# readjusting the coordinates
ds = ds.assign_coords(lon=((ds.lon + 180) % 360) - 180)
# slicing data if necessary
if lat_range is not None and lon_range is not None:
    ds = ds.sel(lat=slice(*lat_range), lon=slice(*lon_range))

cube = ds[channel]

# split up timestamps for training the transformer and for testing
num_t   = cube.sizes["t"]
all_idx = list(range(num_t))
split_pt = int(0.75 * num_t)
train_idx = all_idx[:split_pt]
test_idx  = all_idx[split_pt:]

  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)
  super().__init__(**codec_config)


Now, we load the GOES-16 Zarr dataset and get it ready for training. We use `xarray` to open the file, then adjust the longitude values from the `[0, 360]` range to `[-180, 180]` so that the coordinates match standard maps. If a specific latitude or longitude range is given, we crop the data to that region. After that, we select the desired channel band we want to use.

Next, we split the data into two parts: 75% for training the model and 25% for testing its accuracy. This helps the model learn from most of the data while letting us evaluate how well it works on data it hasn’t seen before.

# Creating a Directory for Model Training

We want to create a directory to store all outputs related to model training, including optimizer states, training logs, and loss plots. This is important because training a model can take a long time, and we don’t want to lose progress if the process is interrupted or if we want to continue training later. 

This also allows us to track how the model is improving over time by saving loss values, plots, and reconstructed images at each checkpoint. In addition, by logging key metadata, it makes the training process more reproducible, helps with debugging, and allows us to compare different experiments more easily in the future.

In [4]:
# create a directory 
directory = "./goes16mae"
os.makedirs(directory, exist_ok=True)
# setting up model and optimizer 
if os.path.exists(os.path.join(directory,"pytorch_model.bin")):
    print("✅ loading checkpoint …")
    model = ViTMAEForPreTraining.from_pretrained(directory, ignore_mismatched_sizes=True).train()
    opt   = AdamW(model.parameters(), lr=learning_rate)
    opt.load_state_dict(torch.load(os.path.join(directory,"opt.pt")))
else:
    print("⚠️ starting from facebook/vit-mae-base")
    model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").train()
    opt   = AdamW(model.parameters(), lr=learning_rate)

# save initial configs to directory     
model.config.mask_ratio = mask_ratio
model.to(device)
processor = ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
model.save_pretrained(directory)
torch.save(opt.state_dict(), os.path.join(directory, "opt.pt"))



⚠️ starting from facebook/vit-mae-base


Here, we are setting up the training environment for our ViT-MAE. First, we create a directory and ensures that it exists to store all outputs related to training, such as model weights and optimizer states. Then, we check whether a previously saved model checkpoint `pytorch_model.bin` exists in that directory. If it does, it loads the model and optimizer from those saved states so training can continue where it left off. If no checkpoint is found, it initializes a fresh model using Hugging Face's pre-trained `vit-mae-base` weights and creates a new optimizer. 

After loading or initializing the model, it sets the mask ratio, a value that controls how much of the input dataset is hidden from the model during training, and moves the model to the appropriate device (CPU/GPU). Finally, it loads a processor that helps prepare the input satellite dataset in the correct format the model expects. 

In [None]:
# function to create subfolders within the initial directory, loss_plots and output_plots
def create_folders(base):
    loss   = os.path.join(base, "loss_plots")
    output = os.path.join(base, "output_plots")
    os.makedirs(loss, exist_ok=True)
    os.makedirs(output, exist_ok=True)
    return loss, output

loss, output = create_folders(directory)

Using the previously created directory, this function creates two subfolders called loss_plots (for saving training-loss history) and output_plots (for storing the model’s reconstructed images). By doing this, we can easily obtain our training results, making it straightforward to monitor and analyze the progress of the model.

# Model Building and Training

Now that our environment is set up, we can start building and training our model. 

In [None]:
# training for one epoch
def train_one_epoch(model, cube, processor, train_idx, opt, crop_size, device):
    model.train()
    total_loss = 0.0
    last_inputs = None

    # iterate through every index in training timestamps
    for idx in train_idx:
        # load and normalize frame
        frame = cube.isel(t=idx).values.astype(np.float32)
        frame = (frame - frame.min()) / (frame.max() - frame.min() + 1e-6)
        # replace NaNs or infinities with 0
        frame = np.nan_to_num(frame, nan=0.0, posinf=0.0, neginf=0.0)

        H, W = frame.shape
        # randomly crop the dataset to fit optimial crop_size
        top  = random.randint(0, H - crop_size)
        left = random.randint(0, W - crop_size)
        patch = frame[top:top+crop_size, left:left+crop_size]

        # stack the patch three times to mimick rgb and run processor
        rgb = np.stack([patch]*3, axis=-1)
        inputs = processor(images=rgb, return_tensors='pt').to(device)
        # save for later checkpoint
        last_inputs = inputs  

        # forward/backward
        opt.zero_grad()
        loss = model(pixel_values=inputs.pixel_values).loss
        loss.backward()
        opt.step()

        total_loss += loss.item()
        # free memory and clear cache
        del frame, patch, rgb, inputs, loss
        torch.cuda.empty_cache()

    avg_loss = total_loss / len(train_idx)
    return avg_loss, last_inputs

The `train_one_epoch` function trains one epoch, iterating through every timestamp in the training set, and returns the average loss and last batch of processed input data used during the epoch. It sets the model to training mode and initializing variables to track total loss and store the last input batch. 

For each time index in the training set, it extracts the corresponding image frame, normalizes its values to the [0, 1] range, and replaces any invalid entries (NaNs or infinities) with zero to ensure the data is numerically stable. A random crop of the specified size is taken from the frame to increase input variety and help regularize the model. Since the model expects three-channel input, the single-channel image is stacked three times to mimic an RGB image.

The image patch is then passed through a processor that formats it into PyTorch tensors compatible with the model’s input structure. Then, gradients are cleared, loss is computed from the forward pass, gradients are backpropagated, and the model weights are updated. The loss for each sample is accumulated to compute the average loss over the entire epoch. 

The function concludes by clearing memory and returning both the average training loss and the last set of processed inputs.


In [None]:
# plot reconstruction at checkpoint
def plot_reconstruction(model, last_inputs, output_dir, epoch):
    # put the model into evaluation mode
    model.eval()
    
    # disable gradient tracking to speed up inference and reduce memory usage
    with torch.no_grad():
        out = model(**last_inputs).pixel_values.squeeze().permute(1, 2, 0).cpu().numpy()
    plt.imshow(out)
    plt.axis('off')
    
    # save the figure to output folder with the epoch number as the filename
    plt.savefig(os.path.join(output_dir, f"recon_epoch_{epoch}.png"),
                bbox_inches='tight')
    plt.close()
    
    # return model back to training mode
    model.train()

The `plot_reconstruction` visualizes and saves a reconstruction result from the masked autoencoder model at a specific training checkpoint. First, it switches the model to evaluation mode to ensure layers like dropout or batch normalization behave consistently during inference. Then, it disables gradient tracking with `torch.no_grad()` to reduce memory usage and speed up computation, since no backpropagation is needed for this forward pass.

The model generates an output reconstruction from the `last_inputs`, which is a batch of image tensors. The output is squeezed to remove the batch dimension, its axes are permuted to match the height × width × channels format expected by `matplotlib`. The reconstructed image is then plotted and saved to the specified directory using a filename that reflects the current epoch (e.g., `recon_epoch_20.png`). Finally, the model is returned to training mode so training can resume in the next epoch. This function is useful for tracking how well the model is learning to reconstruct masked inputs over time, providing a visual checkpoint alongside quantitative metrics like loss.

In [None]:
# save loss and model checkpoint
def save_checkpoint(epoch, loss_history, model, opt, loss_dir, base_dir):
    plt.figure()
    
    # plot loss values against epoch numbers
    plt.plot(range(1, len(loss_history) + 1), loss_history)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Loss up to epoch {epoch}')
    
    # save the loss plot to the loss folder
    plt.savefig(os.path.join(loss_dir, f'loss_epoch_{epoch}.png'))
    plt.close()
    
    # retrieve the most recent loss value
    latest_loss = loss_history[-1]
    
    # save the latest model's config for metadata
    model.config.last_training_loss    = latest_loss
    model.config.last_checkpoint_epoch = epoch
    model.save_pretrained(base_dir)
    torch.save(opt.state_dict(), os.path.join(base_dir, 'opt.pt'))
    print(f"— checkpointed at epoch {epoch}, loss={latest_loss:.4f}")

The `save_checkpoint` function manages the periodic saving of both model progress and training diagnostics during the training process. Its purpose is to preserve the state of the model and optimizer, along with visual documentation of how the training loss evolves over time.

First, it creates a plot of the loss values stored in `loss_history`, mapping them against epoch numbers to provide a clear visualization of the model’s learning trajectory. This plot is saved in the designated `loss_plots` folder with the epoch number embedded in the filename, helping track training performance over time.

Next, it extracts the most recent loss value to annotate the model configuration. It then updates the model’s configuration file by recording the last observed training loss and the epoch number at which the checkpoint was saved. This metadata can be useful for future analysis or resuming training.

In [None]:
loss_history = []
# main training loop
for epoch in range(1, num_epochs+1):
    avg_loss, last_inputs = train_one_epoch(
        model, cube, processor, train_idx, opt, crop_size, device
    )
    
    # record the epoch's average loss
    loss_history.append(avg_loss)
    print(f"Epoch {epoch}/{num_epochs}  loss={avg_loss:.4f}")
    
    # for every checkpoint epoch, visualize a reconstruction and save the loss in the corresponding folders
    if epoch % checkpoint == 0:
        plot_reconstruction(model, last_inputs, output, epoch)
        save_checkpoint(epoch, loss_history, model, opt, loss, directory)

This is the main training loop that orchestrates the full training process for our model over multiple epochs. It begins by initializing an empty list, `loss_history`, which will store the average loss value for each epoch. The loop then runs from epoch 1 to the specified `num_epochs`, calling the `train_one_epoch` function on each iteration. After each epoch, the average loss is appended to `loss_history` and printed to the console, providing real-time feedback on the model’s learning progress. 

Every time it reaches a `checkpoint` epoch, the code performs will call `plot_reconstruction` to generate and save a reconstructed image using the latest `last_inputs` and then `save_checkpoint` to store the current model state, optimizer state, and a loss curve image. 

The loop ensures that loss tracking and visual reconstruction evaluations are performed regularly, and that model state is periodically saved to enable robust experimentation and reproducibility.

# Model Evaluation and Residual Analysis 

To evaluate model performance over time, we saved loss plots and the reconstruction visualizations at regular checkpoint intervals. The loss plots, generated and saved by the `save_checkpoint` function, chart the average training loss across epochs. These plots provide a quantitative view of how the model’s reconstruction error changes over time, helping us monitor convergence and detect potential training instabilities. A consistently decreasing loss curve indicates effective learning, while plateaus or spikes may signal overfitting or learning bottlenecks.

Complementing the loss plots are the qualitative reconstruction outputs generated by `plot_reconstruction`. These visualize the model’s predictions on masked satellite image patches using the final batch from each epoch. By comparing reconstructed images across checkpoints, we can observe how the model improves in restoring spatial features and structure. These outputs act as an intuitive form of residual analysis, highlighting areas where the model performs well or struggles, including blurring, missing edges, or inconsistent textures.


# Limitations and Final Remarks

One limitation of the current training setup is that it only uses random spatial crops from individual frames as input, which may limit the model’s understanding of global spatial context. The training loop also does not include explicit validation loss tracking or early stopping, which would be valuable for detecting overfitting. Moreover, due to memory constraints, only a small number of frames are processed per epoch, and larger batch sizes or full-frame context learning are not currently feasible. 

Despite these limitations, the training pipeline is modular, efficient, and effective in visualizing model progression. The clear separation between training, reconstruction plotting, and checkpoint saving makes the codebase easy to extend and scale. With further tuning, this framework can serve as a strong foundation for future satellite image reconstruction or other self-supervised learning tasks.