# Instrument Degradation Prediction from Pretrained Embeddings

![Figure 1: Instrument Degrad with latents](assets/architecture_diags_degrad.svg)

In [1]:
import os
import omegaconf
import numpy as np

In [2]:
cfg = omegaconf.OmegaConf.load("finetune_degrad_config.yml")

In [4]:
from sdofm.datasets import DegradedSDOMLDataModule
data_module = DegradedSDOMLDataModule(
    hmi_path=None,
    aia_path=os.path.join(
        cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia
    ),
    eve_path=None,
    components=cfg.data.sdoml.components,
    wavelengths=cfg.data.sdoml.wavelengths,
    ions=cfg.data.sdoml.ions,
    frequency=cfg.data.sdoml.frequency,
    batch_size=cfg.model.opt.batch_size,
    num_workers=cfg.data.num_workers,
    val_months=cfg.data.month_splits.val,
    test_months=cfg.data.month_splits.test,
    holdout_months=cfg.data.month_splits.holdout,
    cache_dir=os.path.join(
        cfg.data.sdoml.base_directory,
        cfg.data.sdoml.sub_directory.cache,
    ),
    min_date=cfg.data.min_date,
    max_date=cfg.data.max_date,
    num_frames=cfg.data.num_frames,
)
data_module.setup()

[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.
[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.
[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.


In [5]:
from typing import Optional

import lightning.pytorch as pl
import torch
import torch.nn as nn

from sdofm import BaseModule
from sdofm.models import (
    Autocalibration13Head,
    ConvTransformerTokensToEmbeddingNeck,
    MaskedAutoencoderViT3D,
    WrapEncoder,
    SolarAwareMaskedAutoencoderViT3D,
)


def heteroscedastic_loss(output, gt_output, reduction):
    """
    Args:
        output: NN output values, tensor of shape 2, batch_size, n_channels.
        where the first dimension contains the mean values and the second
        dimension contains the log_var
        gt_output: groundtruth values. tensor of shape batch_size, n_channels
        reduction: if mean, the loss is averaged across the third dimension,
        if summ the loss is summed across the third dimension, if None any
        aggregation is performed

    Returns:
        tensor of size n_channels if reduction is None or tensor of size 0
        if reduction is mean or sum

    """
    precision = torch.exp(-output[1])
    batch_size = output[0].shape[0]
    loss = (
        torch.sum(precision * (gt_output - output[0]) ** 2.0 + output[1], 0)
        / batch_size
    )
    if reduction == "mean":
        return loss.mean()
    elif reduction == "sum":
        return loss.sum()
    elif reduction is None:
        return loss
    else:
        raise ValueError("Aggregation can only be None, mean or sum.")


class HeteroscedasticLoss(nn.Module):
    """
    Heteroscedastic loss
    """

    def __init__(self, reduction="mean"):
        super(HeteroscedasticLoss, self).__init__()
        self.reduction = reduction

    def forward(self, output, target):
        return heteroscedastic_loss(output, target, reduction=self.reduction)


class Autocalibration(BaseModule):
    def __init__(
        self,
        # Backbone parameters
        img_size=512,
        patch_size=16,
        embed_dim=128,
        num_frames=5,
        # Neck parameters
        num_neck_filters: int = 32,
        # Head parameters
        output_dim: int = 1,
        loss: str = "mse",
        freeze_encoder: bool = True,
        # if finetuning
        backbone: object = None,
        # all else
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self.backbone = backbone
        self.encoder = WrapEncoder(self.backbone)

        if freeze_encoder:
            self.encoder.eval()
            for param in self.encoder.parameters():
                param.requires_grad = False

        num_tokens = img_size // patch_size

        # NECK
        self.decoder = ConvTransformerTokensToEmbeddingNeck(
            embed_dim=embed_dim,
            output_embed_dim=num_neck_filters,
            Hp=num_tokens,
            Wp=num_tokens,
            drop_cls_token=True,
            num_frames=num_frames,
        )

        # HEAD
        self.head = Autocalibration13Head(
            [num_neck_filters, img_size, img_size], output_dim
        )

        # set loss function
        match loss:
            case "mse":
                self.loss_function = nn.MSELoss()
            case "heteroscedastic":
                self.loss_function = HeteroscedasticLoss()
            case _:
                raise NotImplementedError(f"Loss function {loss} not implemented")

    def training_step(self, batch, batch_idx):
        degraded_img, degrad_factor, orig_img = batch
        x = self.encoder(degraded_img)
        # print("Autocal training: encoder out dim", x.shape)
        # x_hat = self.autoencoder.unpatchify(x_hat)
        x = self.decoder(x)
        # print("Autocal training: decoder out dim", x.shape)
        y_hat = self.head(x)
        loss = self.loss_function(y_hat[0, :, :], degrad_factor)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        degraded_img, degrad_factor, orig_img = batch
        x = self.encoder(degraded_img)
        # x_hat = self.autoencoder.unpatchify(x_hat)
        y_hat = self.head(self.decoder(x))
        loss = self.loss_function(y_hat[0, :, :], degrad_factor)
        self.log("val_loss", loss)


In [6]:
from pretrain import Pretrainer
MAE = Pretrainer(cfg, logger=None, is_backbone=True)
backbone = MAE.model

Using <class 'sdofm.datasets.SDOML.SDOMLDataModule'> Data Class
[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.
[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.
[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.
Loading checkpoint...
Done


In [7]:

backbone_params = {}
backbone_params["img_size"] = cfg.model.mae.img_size
backbone_params["patch_size"] = cfg.model.mae.patch_size
backbone_params["embed_dim"] = cfg.model.mae.embed_dim
backbone_params["num_frames"] = cfg.model.mae.num_frames

model = Autocalibration(
    # backbone
    **backbone_params,
    # backbone
    backbone=backbone,
    hyperparam_ignore=["backbone"],
)

Autocalibration initialising
input_channels: 32
[32, 512, 512]
cnn_output_dim: 401408


In [8]:
from lightning.pytorch import Trainer 
os.environ['PJRT_DEVICE'] = 'GPU'
trainer = Trainer(max_epochs=2, precision=32)
trainer.fit(model=model, datamodule=data_module)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name          | Type                                 | Params | Mode 
-------------------------------------------------------------------------------
0 | backbone      | MAE                                  | 104 M  | eval 
1 | encoder       | WrapEncoder                          | 104 M  | eval 
2 | decoder       | ConvTransformerTokensToEmbeddingNeck | 78.1 K | train
3 | head          | Autocalibration13Head                | 895 K  | train
4 | loss_function | MSELoss                              | 0      | train
-------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


Training: |          | 0/? [00:00<?, ?it/s]