# Missing Data Reconstruction from Pretrained Embeddings

For this example we're going to build on `01 - Finetune Virtual EVE.ipynb` and create a simpler finetuning set up.

In [1]:
import os
import omegaconf
from sdofm.datasets import SDOMLDataModule
import numpy as np

In [None]:
from sdofm.models import WrapEncoder, ConvTransformerTokensToEmbeddingNeck

class MissingDataModel(BaseModule):
    def __init__(
            self,
            # Backbone parameters
            img_size: int = 512,
            patch_size: int = 16,
            embed_dim: int = 128,
            num_frames: int = 1,
            # Neck parameters
            num_neck_filters: int = 32,
            # Head parameters
            # d_input=None,
            cnn_model: str = "efficientnet_b3",
            lr_linear: float = 0.01,
            lr_cnn: float = 0.0001,
            cnn_dp: float = 0.75,
            epochs_linear: int = 50,
            d_output=None,
            eve_norm=None,
            # for finetuning
            backbone: object = None,
            freeze_encoder: bool = True,
            # all else
            *args,
            **kwargs,
        ):
            super().__init__(*args, **kwargs)
            self.eve_norm = eve_norm

            self.backbone = backbone

            # WrapEncoder is a simple class to only run the forward
            # encoder pass of the MAE. This allows us to easily
            # freeze params as completed below.
            self.encoder = WrapEncoder(self.backbone)

            if freeze_encoder:
                self.encoder.eval()
                for param in self.encoder.parameters():
                    param.requires_grad = False

            num_tokens = img_size // patch_size

            # As this is a reconstruction task, something that the MAE
            # was designed to do, we don't require the neck.
            
            

    def training_step(self, batch, batch_idx):
        imgs, eve = batch
        x = self.encoder(imgs[:, :9, :, :])
        y_hat = self.head(self.decoder(x))
        loss = self.head.loss_func(y_hat, eve[:, :38, 0])
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, eve = batch
        x = self.encoder(imgs[:, :9, :, :])
        y_hat = self.head(self.decoder(x))
        # print(eve.shape)
        loss = self.head.loss_func(y_hat, eve[:, :38, 0])
        self.log("val_loss", loss)