![ThinkOnward Logo](https://github.com/thinkonward/geophysical-foundation-model/blob/2848c9ae410b6bc334138f5689cb2a3b15fd02a6/Tutorial/assets/ThinkOnward.png?raw=True)

# Seismic Denoising: Fine Tune the Geophysical Foundation Model

### Jesse Pisel March 2025

Make sure you install the required packages from requirements.txt as outlined in the README file in the repository root directory. Note that this tutorial does not include complete code for training your own fine-tuned model for denoising. The ThinkOnward `denoizer` model is available on Hugging Face for you to use in this tutorial. It was the baseline model for the [Image Impeccable Challenge](https://thinkonward.com/app/c/challenges/image-impeccable/leaderboard) with a score of `0.901743`. We strongly recommend training your own model to see what you can improve.

## First Steps
First you need to import packages needed for data wrangling and inference.

In [None]:
!pip install timm

In [None]:
import cv2
import torch
import numpy as np
from glob import glob
from functools import partial
from GFM import ElasticViTMAE
import matplotlib.pyplot as plt

## Hugging Face Access

To download the `denoizer` model you will need a [HuggingFace](https://huggingface.co/) account and [request model access](https://huggingface.co/thinkonward/denoizer). 

After requesting access, you need to [set up an access token](https://huggingface.co/docs/hub/security-tokens) to use the `huggingface_hub` package. Once that is set up, log in by running the next cell and entering your token.

In [None]:
from huggingface_hub import login, hf_hub_download, snapshot_download
login()

## Fine Tuning Considerations

As mentioned above, you should train your own fine-tuned model from the GFM. To do so you should follow the following guidelines:

1. You will need to build a paired dataset with noise and denoised examples. 

Check out the [Image Impeccable: Journey to Clarity Challenge](https://thinkonward.com/app/c/challenges/image-impeccable) for some ideas or [just use the data](https://forum.thinkonward.com/t/image-impeccable-training-data-links-to-open-s3-bucket/2171)

2. You will need to build a way to evaluate your model and track its loss
3. You will need to pick which layers you want to freeze, and which ones you want to train for fine tuning e.g.

```python
# first set all layers to non trainable
    for n, p in model.named_parameters():
        p.requires_grad = False

    if args.trainable_layers == 'head':
        for n, p in model.named_parameters():
            if n.startswith('decoder_pred'):
                p.requires_grad = True
    else:
        for n, p in model.named_parameters():
            if n.startswith('decoder_pos_embed'): # pos embed should not be trainable
                pass
            elif n.startswith('decoder') or n.startswith('mask_token'):
                p.requires_grad = True

    for n, p in model.named_parameters():
        print(n, p.shape, 'trainable:', p.requires_grad)
```
4. For your `CustomDataset` dataloader we strongly suggest checking out the [Image Impeccable Starter Notebook](https://github.com/thinkonward/challenges/tree/main/geoscience/image-impeccable/image-impeccable-starter-notebook) for some ideas
5. Start your training, and check to see that the model is improving over time

For this tutorial we will use a fine-tuned version of the GFM that the ThinkOnward Challenges team trained on the Image Impeccable Challenge dataset. Lots of code will be similar to the interpolation task, but there will be some notable differences. The model used in this tutorial is an example of fine tuning and **should not be used without additional training** for your use case.

## Model Architecture Changes
Let's talk model changes, 
you need to edit three methods in `ElasticViTMAE.py` for the `ElasticViTMAE` class.

**1. `forward_encoder`**
    
* Original GFM forward encoder has random masking for training:

```python
def forward_encoder(self, x, idx_shuffle, len_keep):
    # embed patches
    x = self.patch_embed(x)

    # add pos embed w/o cls token
    x = x + self.pos_embed[:, 1:, :]

    # masking: length -> length * mask_ratio
    x, mask, ids_restore = self.random_masking(x, idx_shuffle, len_keep)

    # append cls token
    cls_token = self.cls_token + self.pos_embed[:, :1, :]
    cls_tokens = cls_token.expand(x.shape[0], -1, -1)
    x = torch.cat((cls_tokens, x), dim=1)

    # apply Transformer blocks
    for blk in self.blocks:
        x = blk(x)
    x = self.norm(x)

    return x, mask, ids_restore
```

* New denoiser forward encoder has the random masking removed so no longer needs idx_shuffle or len_keep:

```python
def forward_encoder(self, x):
    # embed patches
    x = self.patch_embed(x)

    # add pos embed w/o cls token
    x = x + self.pos_embed[:, 1:, :]

    # append cls token
    cls_token = self.cls_token + self.pos_embed[:, :1, :]
    cls_tokens = cls_token.expand(x.shape[0], -1, -1)
    x = torch.cat((cls_tokens, x), dim=1)

    # apply Transformer blocks
    for blk in self.blocks:
        x = blk(x)
    x = self.norm(x)

    return x
```

**2. `forward_decoder`**

* Original GFM forward decoder had shuffling and image restoration for masking:

```python
def forward_decoder(self, x, ids_restore):
    # embed tokens
    x = self.decoder_embed(x)

    # append mask tokens to sequence
    mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
    x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
    x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
    x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

    # add pos embed
    x = x + self.decoder_pos_embed

    # apply Transformer blocks
    for blk in self.decoder_blocks:
        x = blk(x)
    x = self.decoder_norm(x)

    # predictor projection
    x = self.decoder_pred(x)

    if not self.custom_head:
        # remove cls token
        x = x[:, 1:, :]

    return x
```

* New denoiser has the unshuffle removed and no longer takes ids_restore so we set that to image width of 160:

```python
def forward_decoder(self, x):
    # embed tokens
    x = self.decoder_embed(x)

    # append mask tokens to sequence
    mask_tokens = self.mask_token.repeat(x.shape[0], 160 + 1 - x.shape[1], 1)
    x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
    # remove the unshuffle 
    x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

    # add pos embed
    x = x + self.decoder_pos_embed

    # apply Transformer blocks
    for blk in self.decoder_blocks:
        x = blk(x)
    x = self.decoder_norm(x)

    # predictor projection
    x = self.decoder_pred(x)

    if not self.custom_head:
        # remove cls token
        x = x[:, 1:, :]

    return x
```

**3. `forward` method**

* Original GFM masked traces and tracked loss:

```python
def forward(self, imgs, idx_shuffle, len_keep):
    latent, mask, ids_restore = self.forward_encoder(imgs, idx_shuffle, len_keep)
    pred = self.forward_decoder(latent, ids_restore)
    loss = self.forward_loss(imgs, pred, mask)
    return loss, pred, mask
```

* Denoizer no longer needs to mask traces or track the loss for masking so remove those. You will need to build a new way to track loss for your classification such as `CrossEntropyLoss` when you fine tune the model:

```python
def forward(self, img):
    latent = self.forward_encoder(img)
    pred = self.forward_decoder(latent)
    return pred
```

After saving those changes in the GFM architecture code you are all set to load up the fine-tuned model and run inference for denoising.

## Loading the fine-tuned model

For an example of how to use a fine tuned version of the GFM you will need to download the model weights and create an instance of the `ElasticViTMAE` class. Once it is instantiated, you will then send it to the CUDA device and set it to eval mode so you can run inference.

In [None]:
from GFM import ElasticViTMAE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
denoiser = ElasticViTMAE.ElasticViTMAE.from_pretrained("thinkonward/denoizer")
denoiser = denoiser.float()
denoiser.to(device)

## Data Download

Next you need some data to work with. The cell below downloads some of the [Image Impeccable Challenge](https://thinkonward.com/app/c/challenges/image-impeccable) seismic data for you to work with.

In [None]:
DATA_REPO_ID = "thinkonward/image-impeccable"

NOISY_DATA = "train/42487393/seismic_w_noise_vol_42487393.parquet"
CLEAN_DATA = "train/42487393/seismicCubes_RFC_fullstack_2024.42487393.parquet"

hf_hub_download(
    repo_id=DATA_REPO_ID,
    filename=NOISY_DATA,
    repo_type="dataset",
    local_dir="./dataset",
)

hf_hub_download(
    repo_id=DATA_REPO_ID,
    filename=CLEAN_DATA,
    repo_type="dataset",
    local_dir="./dataset",
)

## Load a test dataset

Now that the model is downloaded and loaded to your CUDA device, you will need to load up a single seismic volume. It helps to visualize what the noise looks like in an inline, crossline, and timeslice

In [None]:
def parquet2array(parquet_file, original_shape=(1259,300,300)):
    df = pd.read_parquet(parquet_file)
    data_only = df.drop(columns=['Row', 'Col'])
    # Convert the DataFrame back to a 2D numpy array
    reshaped_array = data_only.values
    # Reshape the 2D array back into a 3D array
    array = reshaped_array.reshape(original_shape)
    return array

def visualize_seismic(volume, inline, crossline, timeslice):
    """
    Visualize a seismic volume by displaying three orthogonal slices.

    Parameters:
    volume (3D array): The seismic volume data.
    inline (int): The index of the inline slice to display.
    crossline (int): The index of the crossline slice to display.
    timeslice (int): The index of the timeslice to display.

    Returns:
    None

    Notes:
    The function displays three subplots:
    - The first subplot shows the inline slice with a red dashed line indicating the timeslice.
    - The second subplot shows the crossline slice with a red dashed line indicating the timeslice.
    - The third subplot shows the timeslice with red dashed lines indicating the inline and crossline indices.
    """

    iline = volume.T[:, inline, :]
    xline = volume.T[crossline, :, :]
    ts = volume.T[:, :, timeslice]

    fig, ax = plt.subplots(1, 3, figsize=(10, 6), sharex=True)
    ax[0].imshow(iline.T, cmap="gray") #view from left of timeslice
    ax[0].axhline(y=timeslice, color='r', linestyle='--')
    ax[0].set_title(f"Inline {inline}")

    ax[1].imshow(xline.T, cmap="gray") #view from bottom of timeslice
    ax[1].axhline(y=timeslice, color='r', linestyle='--')
    ax[1].set_title(f"Cross-line {crossline}")

    ax[2].imshow(ts, cmap="gray")
    ax[2].axhline(y=crossline, color='r', linestyle='--')
    ax[2].axvline(x=inline, color='r', linestyle='--')
    ax[2].set_title(f"Timeslice {timeslice}")
    plt.tight_layout()


SEISMIC_VOL = parquet2array(glob(f'./dataset/train/42487393/seismic_w_*.parquet'))

In [None]:
visualize_seismic(SEISMIC_VOL, 80, 200, 650)

## Predict

Now you will have the model work through the seismic volume slice by slice and make predictions on each slice

In [None]:
# start by rescaling the entire volume
minval = np.nanmin(SEISMIC_VOL)
maxval = np.nanmax(SEISMIC_VOL)
seismic = ((SEISMIC_VOL - minval) / (maxval - minval)) * 255

# next work through each slice and denoise it
# you might try orthogonal directions to see how it changes results
for i in range(seismic.shape[-1]):
    with torch.cuda.amp.autocast():
        # resize
        input_sample = cv2.resize(seismic[:, :, i], (160, 400))
        # standardize
        input_sample = (input_sample - 128.) / 43.

        x = torch.tensor(input_sample).unsqueeze(0).unsqueeze(0)
        x = x.float()
        x = x.to(device)

        pred = denoiser(x)
        # scale back 0-255
        pred = pred * 43. + 128.

        pred = pred.cpu().detach().numpy()[0].T
        pred = cv2.resize(pred.astype('float32'), (300, 1259))

        seismic[:, :, i] = pred

SEISMIC_PRED = np.clip(seismic, 0, 255)

## Visualise results for QC

Now you will look at the denoised seismic to see how the model did. There is still quite a bit of noise, lets see what the ground truth data looks like for a comparison.

In [None]:
visualize_seismic(SEISMIC_PRED, 80, 200, 650)

In [None]:
ORIGINAL = parquet2array(glob(f'./dataset/train/42487393/seismicCubes*.parquet'))
visualize_seismic(ORIGINAL, 80, 200, 650)

Some thoughts: you might want to try predicting the noise volume instead of the denoised volume, or [check out the Image Impeccable winning solutions on GitHub](https://github.com/thinkonward/challenges/tree/main/geoscience/image-impeccable) for more ideas on cutting edge denoising solutions