# Missing Data Reconstruction from Pretrained Embeddings

For this example we're going to build on `01 - Finetune Virtual EVE.ipynb` and create a simpler finetuning set up.

![Figure 1: Architectural Diagram](assets/architecture_diags_corrupt.svg)

In [3]:
import os
import omegaconf
from sdofm.datasets import SDOMLDataModule
import numpy as np

The following packages are not installed:
['mpl-animators>=1.0.0', 'reproject>=0.9.0']
To install sunpy with these dependencies use `pip install sunpy[map]` or `pip install sunpy[all]` for all extras. 
If you installed sunpy via conda, please report this to the community channel: https://matrix.to/#/#sunpy:openastronomy.org [sunpy.util.sysinfo]
The following packages are not installed:
['mpl-animators>=1.0.0']
To install sunpy with these dependencies use `pip install sunpy[visualization]` or `pip install sunpy[all]` for all extras. 
If you installed sunpy via conda, please report this to the community channel: https://matrix.to/#/#sunpy:openastronomy.org [sunpy.util.sysinfo]


In [4]:
cfg = omegaconf.OmegaConf.load("finetune_corrupt_data.yml")

In [29]:
data_module = SDOMLDataModule(
    hmi_path=None,
    aia_path=(
        os.path.join(
            cfg.data.sdoml.base_directory,
            cfg.data.sdoml.sub_directory.aia,
        )
        if cfg.data.sdoml.sub_directory.aia
        else None
    ),
    eve_path=None,
    components=cfg.data.sdoml.components,
    wavelengths=cfg.data.sdoml.wavelengths,
    ions=cfg.data.sdoml.ions,
    frequency=cfg.data.sdoml.frequency,
    batch_size=cfg.model.opt.batch_size,
    num_workers=cfg.data.num_workers,
    val_months=cfg.data.month_splits.val,
    test_months=cfg.data.month_splits.test,
    holdout_months=cfg.data.month_splits.holdout,
    cache_dir=os.path.join(
        cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.cache
    ),
    min_date=cfg.data.min_date,
    max_date=cfg.data.max_date,
    num_frames=cfg.data.num_frames,
    drop_frame_dim=cfg.data.drop_frame_dim,
)
data_module.setup()

[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.
[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.
[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.


In [32]:
from sdofm.models import WrapEncoder, ConvTransformerTokensToEmbeddingNeck
from sdofm.benchmarks import reconstruction as bench_recon
import torch.nn.functional as F
from sdofm.constants import ALL_WAVELENGTHS
from sdofm import BaseModule

class MissingDataModel(BaseModule):
    def __init__(
            self,
            # Backbone parameters
            img_size: int = 512,
            patch_size: int = 16,
            embed_dim: int = 128,
            num_frames: int = 1,
            # for finetuning
            backbone: object = None,
            freeze_encoder: bool = True,
            # all else
            *args,
            **kwargs,
        ):
            super().__init__(*args, **kwargs)

            self.backbone = backbone

            self.masking_ratio = 0.75
            self.validation_metrics = []

            if freeze_encoder:
                self.backbone.autoencoder.blocks.eval()
                for param in self.backbone.autoencoder.blocks.parameters():
                    param.requires_grad = False

            self.simulated_corrupt_wavelength = 5

            # As this is a reconstruction task, something that the MAE
            # was designed to do, we don't require the neck.
            
    def forward_corrupt_data_override(self, imgs, mask_ratio=0.75):
        # corrupt our wavelength by setting it all to 0
        imgs[:,self.simulated_corrupt_wavelength,:,:] = 0
        # continue as normal
        latent, mask, ids_restore = self.backbone.autoencoder.forward_encoder(imgs, mask_ratio)
        pred = self.backbone.autoencoder.forward_decoder(latent, ids_restore)
        loss = self.backbone.autoencoder.forward_loss(imgs, pred, mask)
        return loss, pred, mask

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x = batch
        loss, x_hat, mask = self.forward_corrupt_data_override(x, mask_ratio=self.masking_ratio)
        x_hat = self.backbone.autoencoder.unpatchify(x_hat)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
            x = batch
            loss, x_hat, mask = self.backbone.autoencoder(x, mask_ratio=self.masking_ratio)
            x_hat = self.backbone.autoencoder.unpatchify(x_hat)
            loss = F.mse_loss(x_hat, x)
            for i in range(x.shape[0]):
                for frame in range(x.shape[2]):
                    self.validation_metrics.append(
                        bench_recon.get_metrics(
                            x[i, :, frame, :, :], x_hat[i, :, frame, :, :], ALL_WAVELENGTHS
                        )
                    )

            self.log("val_loss", loss)

In [33]:
from pretrain import Pretrainer
MAE = Pretrainer(cfg, logger=None, is_backbone=True)

Using <class 'sdofm.datasets.SDOML.SDOMLDataModule'> Data Class
[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.
[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.
[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.
Loading checkpoint...
Done


In [34]:
backbone = MAE.model

In [35]:
backbone_params = {}
backbone_params["img_size"] = cfg.model.mae.img_size
backbone_params["patch_size"] = cfg.model.mae.patch_size
backbone_params["embed_dim"] = cfg.model.mae.embed_dim
backbone_params["num_frames"] = cfg.model.mae.num_frames

model = MissingDataModel(
    # backbone
    **backbone_params,
    # backbone
    backbone=backbone,
    hyperparam_ignore=["backbone"],
)

In [36]:
from lightning.pytorch import Trainer 
os.environ['PJRT_DEVICE'] = 'GPU'
trainer = Trainer(max_epochs=2, precision=32)
trainer.fit(model=model, datamodule=data_module)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name     | Type | Params | Mode 
------------------------------------------
0 | backbone | MAE  | 104 M  | train
------------------------------------------
27.8 M    Trainable params
76.7 M    Non-trainable params
104 M     Total params
418.215   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
