In [2]:
%load_ext autoreload
%autoreload 2

# Virtual EVE From Pretrained Embeddings

## Purpose:
This notebook provides an example of finetuning with SDOFM. In this case we create a virtual eve instrument, starting the training from a SDOFM pretrained foundation model, accomplishing a production ready model much faster than training from scratch.

## Foundation Models
The process is akin to that of transfer learning, a method typically used in computer vision, for example by freezing a feature extracting neural network pretrained on imagenet. For an extensive treatment of the method of transfer learning, as it was considered before the advent of large modern models, please see [review paper](https://arxiv.org/abs/1811.08883).

For the sake of conceptual understanding we can think of the foundation model as a feature extractor, and the head as a classifier. The foundation model is pretrained on a large dataset, and the head is trained on a smaller dataset. The foundation model is frozen during the training of the head, and the head is trained on the smaller dataset. The foundation model is then unfrozen, and the entire model is fine-tuned on the smaller dataset. This process is called transfer learning, and it is used to train models on smaller datasets, where training from scratch would not be feasible. This is our approach in this notebook.

We begin by importing the libraries we will need:

In [3]:
import os
from pathlib import Path

import lightning.pytorch as pl
import torch
import wandb
import omegaconf
from sdofm import utils
from sdofm.datasets import SDOMLDataModule, DegradedSDOMLDataModule
from sdofm.pretraining import MAE, SAMAE
from sdofm.finetuning import Autocalibration

In [7]:
cfg = omegaconf.OmegaConf.load("../../experiments/finetune_32.2M_mae_virtualeve.yaml")

In [9]:
data_module = SDOMLDataModule(
    hmi_path=None,
    aia_path=os.path.join(
        cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia
    ),
    eve_path=None,
    components=cfg.data.sdoml.components,
    wavelengths=cfg.data.sdoml.wavelengths,
    ions=cfg.data.sdoml.ions,
    frequency=cfg.data.sdoml.frequency,
    batch_size=cfg.model.opt.batch_size,
    num_workers=cfg.data.num_workers,
    val_months=cfg.data.month_splits.val,
    test_months=cfg.data.month_splits.test,
    holdout_months=cfg.data.month_splits.holdout,
    cache_dir=os.path.join(
        cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.cache
    ),
    min_date=cfg.data.min_date,
    max_date=cfg.data.max_date,
    num_frames=1,
)
data_module.setup()

[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.
[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.
[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.


In [None]:
def __init__(
        self,
        # Backbone parameters
        img_size: int = 512,
        patch_size: int = 16,
        embed_dim: int = 128,
        num_frames: int = 5,
        # Neck parameters
        num_neck_filters: int = 32,
        # Head parameters
        # d_input=None,
        cnn_model: str = "efficientnet_b3",
        lr_linear: float = 0.01,
        lr_cnn: float = 0.0001,
        cnn_dp: float = 0.75,
        epochs_linear: int = 50,
        d_output=None,
        eve_norm=None,
        # for finetuning
        backbone: object = None,
        freeze_encoder: bool = True,
        # all else
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.eve_norm = eve_norm

        self.backbone = backbone
        self.encoder = WrapEncoder(self.backbone)

        if freeze_encoder:
            self.encoder.eval()
            for param in self.encoder.parameters():
                param.requires_grad = False

        num_tokens = img_size // patch_size

        # NECK
        self.decoder = ConvTransformerTokensToEmbeddingNeck(
            embed_dim=embed_dim,
            output_embed_dim=num_neck_filters,
            Hp=num_tokens,
            Wp=num_tokens,
            drop_cls_token=True,
            num_frames=num_frames,
        )

        # HEAD
        self.head = HybridIrradianceModel(
            # virtual eve
            d_input=num_neck_filters,
            d_output=d_output,
            eve_norm=eve_norm,
            # from config
            cnn_model=cnn_model,
            lr_linear=lr_linear,
            lr_cnn=lr_cnn,
            cnn_dp=cnn_dp,
            epochs_linear=epochs_linear,
        )

    def training_step(self, batch, batch_idx):
        imgs, eve = batch
        x = self.encoder(imgs[:, :9, :, :, :])
        y_hat = self.head(self.decoder(x))
        loss = self.head.loss_func(y_hat, eve[:, :38])
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, eve = batch
        x = self.encoder(imgs[:, :9, :, :, :])
        y_hat = self.head(self.decoder(x))
        loss = self.head.loss_func(y_hat, eve[:, :38])
        self.log("val_loss", loss)
