# Ocean colour geospatial foundation model (GFM) demo notebook

The following notebook demonstrates fine-tuning the ocean colour geospatial foundation model for primary production quantification. The notebook makes use of [TerraTorch](https://github.com/IBM/terratorch) for fine-tuning and prediction.

The primary production data used in fine-tuning were collected from various sources [Mattei, Francesco; Scardi, Michele (2021)](https://doi.pangaea.de/10.1594/PANGAEA.932417), simons_cmap_2023_10019979, Marra2021,Buitenhuis2013,Goericke2021} and include ship-bourne observations and buoy data. The accompanying Sentinel-3 Ocean Land Colour Instrument (OLCI) and Sea and Land Surface Temperature Radiometer (SLSTR) images were created from a 6 day median of all cloud free measurements. Full details on the creation of this dataset can be found [here]().

It's best to run this notebook on a machine with one or more GPUs. If this is not possible, you can reduce the amount of training data to shorten the training time, at a cost of reduced performance. You can also try reducing the batch size.

## 0.1 Setup for running on Google colab

You may want to take this opportunity to double check you're using GPUs on Google Colab before proceeding any further. We have tested this notebook using T4 GPU on the free colab account.

### 0.1.1 Check python version

It's recommended that you run this notebook using python 3.10. Let's check the python version by executing the cell below.

In [None]:
!python --version

### 0.1.2 Setup environment 

To install the necessary packages on Colab, execute the cell below. This will take a few minutes. Once the installation process is done, a window will pop up to ask you to restart the session. This is normal and you should proceed to restart using the interface in the pop up window. Once the session has restarted, it's important that you ignore the cell below, and go straight to section 0.1.3.

In [None]:
import sys

# if running on colab
if "google.colab" in sys.modules:
    # Clone the ibm-granite GitHub repo
    !git clone https://github.com/ibm-granite/geospatial.git
    # Install the package
    !pip install -e ./geospatial/granite-geospatial-ocean[colab]

### 0.1.3 Set up working directory

This is the first thing you should run after restarting your Colab session.

In [None]:
# Only run this cell if running on Colab.
if "google.colab" in sys.modules:
    # Change to the notebooks directory
    %cd geospatial/granite-geospatial-ocean/notebooks
    %pwd

Now your environment is set up for Google Colab. Please proceed to section 0.3.

## 0.2 set-up for running on your local machine

Before running through this notebook it's best to create a virtual environment and install the necessary packages there before running this notebook. The instructions can be found in README.md.

Once that's done, come back to this notebook and make sure it's using the newly made virtual environment.

Please proceed to section 0.3.

In [None]:
# Imports
import glob
import os
import matplotlib.pyplot as plt
import numpy as np
import random
import re
import rioxarray
import tarfile
import zipfile

from huggingface_hub import hf_hub_download
from pathlib import Path

from granite_geo_ocean_colour.helper import (
    get_rgb,
    plot_inference_data,
    plot_training_data,
    crop_image,
)

In [None]:
# some basic set-up
%matplotlib inline

project_root = Path("../")
hf_repo_name = "ibm-granite/"

## 1. Fine Tuning

### 1.1 Data prep

Let's place this in the `granite-geospatial-ocean/data` directory.

In [None]:
# Download the data set
dataset_name = "granite-geospatial-ocean-processed-sentinel-3-primary-production.tar.gz"
data_url = f"https://zenodo.org/records/17093560/files/{dataset_name}"
download_cmd = f"wget {data_url} -O {project_root}/{dataset_name}"
os.system(download_cmd)

# unzip
with tarfile.open(f"{project_root}/{dataset_name}", "r:gz") as tar:
    tar.extractall(project_root)

In [None]:
# specify where the training and inference data are stored
data_path = project_root / "data"

#### Plot some samples.

Here we randomly plot some of the samnple in the fine-tuning dataset. As Sentinel-3 OLCI data does not have seperate red green blue bands we have plotted a log scales natural colour broad band, log scaled. This code was taken from [EUMETlab](https://gitlab.eumetsat.int/eumetlab/oceans/ocean-training/sensors/learn-olci) repository wihich is a great resource for learning more about the data.

Note that the label location is shown in purple in the centre of each image


In [None]:
ft_data_path = glob.glob(f"{data_path}/*/*_img.tif")
selected_images = random.sample(ft_data_path, 4)

plot_training_data(selected_images)

### 1.2 Model prep - checkpoints

Download the pre-trained model weights from HuggingFace.

In [None]:
# checkpoint-specific
checkpoint_folder = project_root / "data" / "checkpoints"
os.makedirs(checkpoint_folder, exist_ok=True)

hf_repo_name = "ibm-granite/granite-geospatial-ocean"

inference_checkpoint = Path(
    hf_hub_download(
        repo_id=hf_repo_name,
        filename="checkpoint.pt",
        local_dir=checkpoint_folder,
    )
)

config_name = "config.yaml"
config_folder = project_root / "configs"

# download model config
model_config = Path(
    hf_hub_download(
        repo_id=hf_repo_name,
        filename=config_name,
        local_dir=config_folder,
    )
)

### 1.3 Model prep - configs

As this model uses different bands to the prithvi model included in terratorch, we allow for the weights for these additional bands to be read in by terratorch by defining a custom module in [./custom_modules/prithvi_vi_S3.py](../custom_modules/prithvi_vit_S3.py).

We make sure to point to this as a backbone in our config file.

We have also set the maximum epochs to 30 for this demonstration, but you may want to run the model for more epochs (e.g. 100).

In [None]:
config_file = config_folder / "config-fine-tuning.yaml"

### 1.4 Carry out fine-tuning

Execute the below cell to print out a command. Check the command and the config location to make sure that the config file exists in the expected folder.

In [None]:
fine_tuning_command = f"terratorch fit --config ./{config_file}"
print(fine_tuning_command)

If everything looks ok, we'll execute the below cell to fine-tune the model. This command will place the fine-tuning output in `model_run/version_0` if it is the first time it has been run and subsequent versions for additional runs. The fine-tuning output includes the model checkpoints used in the next section.

In [None]:
os.system(fine_tuning_command)


## 2. Checking the results - inference prep

Let's gather and specify the relevant files for carrying out inference in a new folder. Look for your .ckpt file produced during the fine-tuning process and list it in the cell in section 2.1. We are perfoming inference for a region off the coast of Spain and Portugal between 7th and 13th July 2020.


### 2.1 Inference checkpoint specification
Identify the model_run directory containing the correct checkpoint file you wish to use. This will be in `model_runs/version_0` the first time fine-tuning is performed.

In [None]:
# Find the checkpoint produced from the fine-tuning process, and overwrite below
inference_checkpoint_loc = project_root / "data" / "model_runs" / "version_0"

inference_checkpoint = glob.glob(f"{inference_checkpoint_loc}/checkpoints/*.ckpt")
inference_checkpoint = min(inference_checkpoint, key=len)

### 2.2 Preparing paths for inference results
Create a new directory for the inference results to be placed.

In [None]:
inference_output = project_root / "data" / "inference_results"
inference_input =  project_root / "data" / "inference"

os.makedirs(inference_output, exist_ok=True)

The region we are running inference on is quite large, so may take a long time to run. If you want inference to run faster run this cell below as it will crop the image to a smaller area. Otherwise skip the cell below and go directly to running inference.

In [None]:
bbox = [-9.8,42.5,-8.7,43.3]
inference_image = inference_input / '2020-07-07_00_00_00+2020-07-13_00_00_00_img.tif'
crop_image(inference_image, bbox, inference_image)

### 2.3 Run 
Let's carry out inference on the test images. Execute the cell below to print out a command. Make sure the paths look correct.

In [None]:
inference_command = f"terratorch predict -c {config_file} --ckpt_path {inference_checkpoint} --predict_output_dir ./{inference_output} --data.init_args.predict_data_root ./{inference_input}"
print(inference_command)

If everything looks good, execute the cell below.

In [None]:
os.system(inference_command)

## 3. Checking and visualizing results

We can then plot the results:

In [None]:
plot_inference_data(
    f"{inference_input}/2020-07-07_00_00_00+2020-07-13_00_00_00_img.tif",
    f"{inference_output}/2020-07-07_00_00_00+2020-07-13_00_00_00_img_pred.tif",
)

## Next steps

Check out the other granite-geospatial models for [Above Ground Biomass](https://huggingface.co/ibm-granite/granite-geospatial-biomass), [Canopy Height](https://huggingface.co/ibm-granite/granite-geospatial-canopyheight), [Land Surface Temperature](https://huggingface.co/ibm-granite/granite-geospatial-land-surface-temperature) and [Weather and Climate Downscaling](https://huggingface.co/ibm-granite/granite-geospatial-wxc-downscaling).