In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path

import lightning.pytorch as pl
import torch
import wandb
from sdofm import utils
from sdofm.datasets import SDOMLDataModule, DegradedSDOMLDataModule
from sdofm.pretraining import MAE, SAMAE
from sdofm.finetuning import Autocalibration

In [3]:
import omegaconf

cfg = omegaconf.OmegaConf.load("../experiments/finetune_tiny.yaml")

In [4]:
# from sdofm.utils import flatten_dict
# import yaml

# data = flatten_dict(cfg, sep="___")
# with open('testingout.yaml', 'w+') as outfile:
#     yaml.dump(data, outfile, default_flow_style=False)
# outfile.close()

In [5]:
# from omegaconf import OmegaConf
# OmegaConf.save(config, "testout.yaml")

In [6]:
degraded_data_module = DegradedSDOMLDataModule(
    hmi_path=None,
    aia_path=os.path.join(
        cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia
    ),
    eve_path=None,
    components=cfg.data.sdoml.components,
    wavelengths=cfg.data.sdoml.wavelengths,
    ions=cfg.data.sdoml.ions,
    frequency=cfg.data.sdoml.frequency,
    batch_size=cfg.model.opt.batch_size,
    num_workers=cfg.data.num_workers,
    val_months=cfg.data.month_splits.val,
    test_months=cfg.data.month_splits.test,
    holdout_months=cfg.data.month_splits.holdout,
    cache_dir=os.path.join(
        cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.cache
    ),
    min_date=cfg.data.min_date,
    max_date=cfg.data.max_date,
    num_frames=1,
)
degraded_data_module.setup()

[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.
[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.
[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.


In [7]:
import yaml
from pathlib import Path

conf = yaml.safe_load(
    Path(
        "/home/walsh/repos/SDO-FM/output/2024-04-17-00-41-29/0/wandb/latest-run/files/config.yaml"
    ).read_text()
)
backbone_cfg = utils.unflatten_dict(conf, sep="___", wandb_mode=True)
backbone_cfg

{'wandb_version': 1,
 'experiment': {'name': 'default',
  'project': 'sdofm',
  'model': 'samae',
  'task': 'pretrain',
  'seed': 0,
  'disable_cuda': False,
  'disable_wandb': False,
  'wandb_entity': 'fdlx',
  'wandb_group': 'sdofm-phase1',
  'wandb_job_type': 'pretrain',
  'wandb_tags': [],
  'wandb_notes': '',
  'fold': None,
  'evaluate': False,
  'checkpoint': None,
  'device': 'cuda',
  'precision': 64,
  'log_n_batches': 1000,
  'save_results': True,
  'accelerator': 'auto',
  'distributed_enabled': True,
  'distributed_backend': 'ddp',
  'distributed_world_size': 'auto'},
 'data': {'min_date': '0000-00-00 00:00:00',
  'max_date': '0000-00-00 00:00:00',
  'month_splits_val': [11],
  'month_splits_test': [12],
  'month_splits_holdout': [],
  'num_workers': 16,
  'output_directory': 'output',
  'sdoml_base_directory': '/mnt/sdoml',
  'sdoml_sub_directory_hmi': 'HMI.zarr',
  'sdoml_sub_directory_aia': 'AIA.zarr',
  'sdoml_sub_directory_eve': 'EVE_legacy.zarr"',
  'sdoml_sub_direct

In [8]:
backbone_cfg.model.opt

{'loss': 'mse',
 'scheduler': 'constant',
 'scheduler_warmup': 0,
 'batch_size': 3,
 'learning_rate': 0.0001,
 'weight_decay': 0.0003,
 'optimiser': 'adam',
 'epochs': 4,
 'patience': 2}

In [15]:
backbone = SAMAE.load_from_checkpoint(
    **backbone_cfg.model.mae,
    **backbone_cfg.model.samae,
    optimiser=backbone_cfg.model.opt.optimiser,
    lr=backbone_cfg.model.opt.learning_rate,
    weight_decay=backbone_cfg.model.opt.weight_decay,
    checkpoint_path="/home/walsh/repos/SDO-FM/output/2024-04-17-00-41-29/0/sdofm/zy68fa00/checkpoints/epoch=3-step=48556.ckpt"
)

In [16]:
model = Autocalibration(
    **cfg.model.mae,
    **cfg.model.dimming,
    optimiser=cfg.model.opt.optimiser,
    lr=cfg.model.opt.learning_rate,
    weight_decay=cfg.model.opt.weight_decay,
    backbone=backbone
)

In [17]:
trainer = pl.Trainer(
    devices=1, accelerator=cfg.experiment.accelerator, max_epochs=cfg.model.opt.epochs
)
trainer.fit(model=model, datamodule=degraded_data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



  | Name          | Type                                 | Params
-----------------------------------------------------------------------
0 | backbone      | SAMAE                                | 32.2 M
1 | encoder       | PrithviEncoder                       | 32.2 M
2 | decoder       | ConvTransformerTokensToEmbeddingNeck | 28.9 K
3 | head          | Autocalibration13                    | 93.4 K
4 | loss_function | MSELoss                              | 0     
-----------------------------------------------------------------------
122 K     Trainable params
32.2 M    Non-trainable params
32.3 M    Total params
129.203   Total estimated model params size (MB)


ValueError: dictionary update sequence element #0 has length 1; 2 is required