In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path

import pytorch_lightning as pl
import torch
import wandb
from sdofm import utils
from sdofm.datasets import SDOMLDataModule, DimmedSDOMLDataModule
from sdofm.pretraining import NVAE
from sdofm.finetuning import Autocalibration

In [3]:
import omegaconf
cfg = omegaconf.OmegaConf.load('../experiments/pretrain_tiny_nvae.yaml')

In [4]:
data_module = SDOMLDataModule(
    hmi_path=os.path.join(
        cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.hmi
    ),
    aia_path=os.path.join(
        cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia
    ),
    eve_path=None,
    components=cfg.data.sdoml.components,
    wavelengths=cfg.data.sdoml.wavelengths,
    ions=cfg.data.sdoml.ions,
    frequency=cfg.data.sdoml.frequency,
    batch_size=cfg.model.opt.batch_size,
    num_workers=cfg.data.num_workers,
    val_months=cfg.data.month_splits.val,
    test_months=cfg.data.month_splits.test,
    holdout_months=cfg.data.month_splits.holdout,
    cache_dir=os.path.join(
        cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.cache
    )
)
data_module.setup()

[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_HMI_FULL_12min.csv.
[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_HMI_FULL_12min.json.
[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.


In [5]:
# torch.Tensor(data_module.hmi_mask).to(dtype=torch.bool)
# data_module.hmi_mask_cache_filename
# import numpy as np
# np.load(data_module.hmi_mask_cache_filename)
print(data_module.hmi_mask)
# loaded_mask = np.load(data_module.hmi_mask_cache_filename)
# torch.Tensor(loaded_mask).to(dtype=torch.uint8)

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.uint8)


In [6]:
model = NVAE(
        **cfg.model.nvae,
        optimiser=cfg.model.opt.optimiser,
        lr=cfg.model.opt.learning_rate,
        weight_decay=cfg.model.opt.weight_decay,
        hmi_mask=data_module.hmi_mask,
)

In [7]:
trainer = pl.Trainer(devices=1, accelerator=cfg.experiment.accelerator, max_epochs=cfg.model.opt.epochs)
trainer.fit(model=model, datamodule=data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


/opt/conda/envs/sdofm/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type        | Params
--------------------------------------------
0 | autoencoder | AutoEncoder | 21.4 M
--------------------------------------------
21.4 M    Trainable params
2.5 K     Non-trainable params
21.4 M    Total params
85.652    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

/opt/conda/envs/sdofm/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
