In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
from pathlib import Path

import pytorch_lightning as pl
import torch
import wandb
from sdofm import utils
from sdofm.datasets import SDOMLDataModule, DimmedSDOMLDataModule
from sdofm.pretraining import SAMAE

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import omegaconf
cfg = omegaconf.OmegaConf.load("../experiments/pretrain_32.2M_samae_tpu.yaml")

In [5]:
data_module = SDOMLDataModule(
    hmi_path=None,
    aia_path=os.path.join(
        cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia
    ),
    eve_path=None,
    components=cfg.data.sdoml.components,
    wavelengths=cfg.data.sdoml.wavelengths,
    ions=cfg.data.sdoml.ions,
    frequency=cfg.data.sdoml.frequency,
    batch_size=cfg.model.opt.batch_size,
    num_workers=cfg.data.num_workers,
    val_months=cfg.data.month_splits.val,
    test_months=cfg.data.month_splits.test,
    holdout_months=cfg.data.month_splits.holdout,
    cache_dir=os.path.join(
        cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.cache
    ),
)
data_module.setup()

[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.
[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.
[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.


In [6]:
model = SAMAE(
    **cfg.model.mae,
    **cfg.model.samae,
    optimiser=cfg.model.opt.optimiser,
    lr=cfg.model.opt.learning_rate,
    weight_decay=cfg.model.opt.weight_decay,
)

In [17]:
from torchsummary import summary
summary(model.autoencoder, input_size=((9, 1, 512, 512)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 128, 1, 32, 32]         295,040
          Identity-2            [-1, 1024, 128]               0
        PatchEmbed-3            [-1, 1024, 128]               0
         LayerNorm-4             [-1, 257, 128]             256
            Linear-5             [-1, 257, 384]          49,536
          Identity-6           [-1, 16, 257, 8]               0
          Identity-7           [-1, 16, 257, 8]               0
            Linear-8             [-1, 257, 128]          16,512
           Dropout-9             [-1, 257, 128]               0
        Attention-10             [-1, 257, 128]               0
         Identity-11             [-1, 257, 128]               0
         Identity-12             [-1, 257, 128]               0
        LayerNorm-13             [-1, 257, 128]             256
           Linear-14             [-1, 2

In [26]:
trainer = pl.Trainer(
    devices=1, accelerator=cfg.experiment.accelerator, max_epochs=cfg.model.opt.epochs
)
trainer.fit(model=model, datamodule=data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name        | Type                             | Params
-----------------------------------------------------------------
0 | autoencoder | SolarAwareMaskedAutoencoderViT3D | 3.3 M 
-----------------------------------------------------------------
3.0 M     Trainable params
262 K     Non-trainable params
3.3 M     Total params
13.005    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

/opt/conda/envs/sdofm/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [19]:
batch = next(iter(data_module.train_dataloader()))

In [25]:
batch.shape

torch.Size([3, 9, 1, 512, 512])

In [27]:
3*9*1*512*512*4 # in BYTES

28311552