In [None]:
%load_ext autoreload
%autoreload 2

In [19]:
import os
from pathlib import Path

import pytorch_lightning as pl
import torch
import wandb
from sdofm import utils
from sdofm.datasets import SDOMLDataModule
from sdofm.benchmarks import reconstruction as bench_recon
from sdofm.constants import ALL_WAVELENGTHS

from sdofm.pretraining import MAE

In [20]:
import omegaconf

cfg = omegaconf.OmegaConf.load("../experiments/pretrain_tiny_mae.yaml")

In [21]:
data_module = SDOMLDataModule(
# hmi_path=os.path.join(
#      cfg.data.sdoml.base_directory,  cfg.data.sdoml.sub_directory.hmi
# ),
hmi_path=None,
aia_path=os.path.join(
        cfg.data.sdoml.base_directory,
        cfg.data.sdoml.sub_directory.aia,
),
eve_path=None,
components= cfg.data.sdoml.components,
wavelengths= cfg.data.sdoml.wavelengths,
ions= cfg.data.sdoml.ions,
frequency= cfg.data.sdoml.frequency,
batch_size= cfg.model.opt.batch_size,
num_workers= cfg.data.num_workers,
val_months= cfg.data.month_splits.val,
test_months= cfg.data.month_splits.test,
holdout_months= cfg.data.month_splits.holdout,
cache_dir=os.path.join(
        cfg.data.sdoml.base_directory,
        cfg.data.sdoml.sub_directory.cache,
),
min_date=cfg.data.min_date,
max_date=cfg.data.max_date,
num_frames=cfg.model.mae.num_frames,
)
data_module.setup()

model = MAE(
**cfg.model.mae,
optimiser=cfg.model.opt.optimiser,
lr=cfg.model.opt.learning_rate,
weight_decay=cfg.model.opt.weight_decay,
)

[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.
[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.
[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.


In [22]:
x = next(iter(data_module.train_dataloader()))

In [23]:
x.shape

torch.Size([4, 9, 2, 512, 512])

In [24]:
loss, x_hat, mask = model.autoencoder(x)
x_hat = model.autoencoder.unpatchify(x_hat)


In [25]:
x_hat.shape, x.shape

(torch.Size([4, 9, 2, 512, 512]), torch.Size([4, 9, 2, 512, 512]))

In [32]:
validation_metrics = []
for i in range(x.shape[2]):
    validation_metrics.append(bench_recon.get_metrics(x[i,:,0,:,:], x_hat[i,:,0,:,:], ALL_WAVELENGTHS))


In [33]:
merged_metrics = bench_recon.merge_metrics(validation_metrics)


In [34]:
batch_metrics = bench_recon.mean_metrics(merged_metrics)


In [36]:
batch_metrics

{'131A': {'flux_difference': -1.4445356768376798,
  'ppe10s': 0.012109756469726562,
  'ppe50s': 0.544952392578125,
  'rms_contrast_measure': 0.4743000500586574,
  'pixel_correlation': -0.017537385392165428,
  'rmse_intensity': 0.7930678055708583},
 '1600A': {'flux_difference': -1.0408771465211641,
  'ppe10s': 0.4959297180175781,
  'ppe50s': 0.5631198883056641,
  'rms_contrast_measure': 0.6043902129714565,
  'pixel_correlation': -0.022836154032107474,
  'rmse_intensity': 0.8877159866519266},
 '1700A': {'flux_difference': -0.9620249346126075,
  'ppe10s': 0.4963951110839844,
  'ppe50s': 0.5646400451660156,
  'rms_contrast_measure': 0.5965223285931959,
  'pixel_correlation': 0.020996891555964937,
  'rmse_intensity': 0.8857073985798133},
 '171A': {'flux_difference': -1.076724,
  'ppe10s': 0.034648895263671875,
  'ppe50s': 0.18700599670410156,
  'rms_contrast_measure': 0.69384634,
  'pixel_correlation': 0.0018008068821008419,
  'rmse_intensity': 1.0837147},
 '193A': {'flux_difference': -0.92

In [None]:
import matplotlib.pyplot as plt

plt.imshow(rbg_image_batch[0, 0, 0, :, :].cpu().numpy(), cmap="gray")

In [None]:
mask = data_module.hmi_mask
print(mask.shape)
plt.imshow(mask.cpu().numpy(), cmap="gray")

In [None]:
from sdofm.utils import get_1d_sincos_pos_embed_from_grid, get_3d_sincos_pos_embed

embed_dim = 128
num_frames = 1
tubelet_size = 1
img_size = (512, 512)
patch_size = (16, 16)
grid_size = (
    num_frames // tubelet_size,
    img_size[0] // patch_size[0],
    img_size[1] // patch_size[1],
)
num_patches = grid_size[0] * grid_size[1] * grid_size[2]
pos_embed_zeros = torch.zeros(1, num_patches + 1, embed_dim)
pos_embed = get_3d_sincos_pos_embed(
    pos_embed_zeros.shape[-1], grid_size, cls_token=False
)

In [None]:
single_image_batch.shape

In [None]:
mask.shape
a = np.repeat(mask[:, :, np.newaxis], 9, axis=2)
b = np.repeat(a[:, :, :, np.newaxis], 1, axis=3)
c = np.repeat(b[:, :, :, :, np.newaxis], 1, axis=4)
# c.shape
d = torch.Tensor(np.transpose(c, axes=(3, 2, 4, 0, 1))).to(dtype=torch.float)

In [None]:
mask.shape

In [None]:
single_image_batch[0, 0]

In [None]:
patch_embed = PatchEmbed(512, 16, 1, 1, 9, 128, flatten=True)
x = patch_embed(d)
x.shape

In [None]:
indices = np.arange(32 * 32).reshape(32, 32).astype(np.int64)
ai = np.repeat(indices[:, :, np.newaxis], 128, axis=2)
bi = np.repeat(ai[:, :, :, np.newaxis], 1, axis=3)
ci = np.repeat(bi[:, :, :, :, np.newaxis], 1, axis=4)
di = torch.Tensor(np.transpose(ci, axes=(3, 2, 4, 0, 1)))  # .to(dtype=torch.float)
di.shape
fi = di.flatten(2).transpose(1, 2)
fi.shape

In [None]:
plt.scatter(x=range(1024), y=fi[0, :, 0].detach().numpy())

In [None]:
15 * 15

In [None]:
fi[0, :, 0].max()

In [None]:
plt.imshow(x[0, 0, 0, :, :].detach().numpy())

In [None]:
plt.imshow(x[0, :, :].detach().numpy(), aspect="auto")

In [None]:
pos_embed = get_3d_sincos_pos_embed(
    pos_embed.shape[-1], patch_embed.grid_size, cls_token=False
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
pos_embed.shape

In [None]:
patchified.shape

In [None]:
patch_embed.grid_size[1]

In [None]:
import astropy.units as u
from astropy.coordinates import SkyCoord
import sunpy.data.sample
import sunpy.map
from sunpy.coordinates.frames import HeliographicStonyhurst

aiamap = sunpy.map.Map(
    sunpy.data.sample.AIA_171_IMAGE
)  # example image is loaded at 1024x1024


def stonyhurst_to_patch_index(lat, lon):
    # Heliographic Stonyhurst coordinates to patch index
    # lat, lon = 15.73, 0
    coord = SkyCoord(lat * u.deg, lon * u.deg, frame=HeliographicStonyhurst)
    x, y = aiamap.wcs.world_to_pixel(coord)  # (x, y) in pixels
    x, y = x / 2 // patch_embed.patch_size[0], y / 2 // patch_embed.patch_size[0]
    return np.array([x, y])

In [None]:
middle_patch = stonyhurst_to_patch_index(0, 0)
middle_patch

In [None]:
r1_patch = stonyhurst_to_patch_index(0, -60)[1]
r2_patch = stonyhurst_to_patch_index(0, 60)[1]

In [None]:
stonyhurst_to_patch_index(60, 0)
# r2_patch = stonyhurst_to_patch_index(0, 60)[1]

In [None]:
mean_patch = (
    abs(
        (stonyhurst_to_patch_index(15.73, 0) - stonyhurst_to_patch_index(0, 0))
        + abs(stonyhurst_to_patch_index(-15.73, 0) - stonyhurst_to_patch_index(0, 0))
    )
)[0] / 2
mean_patch

In [None]:
std_patch = (stonyhurst_to_patch_index(6.14, 0) - stonyhurst_to_patch_index(0, 0))[0]
std_patch

In [None]:
# torch.uniform(0, 32)
# x = torch.distributions.uniform.Uniform(0,32).sample((1024,))
# x

# get uniform random numbers between [r1_patch, r2_patch]
# (r1 - r2) * torch.rand(a, b) + r2
# random_lons = torch.floor( (r1_patch - r2_patch ) * torch.rand((1024,)) + r2_patch ).to(dtype=torch.uint8)
N = 2
random_lons = torch.floor((r1_patch - r2_patch) * torch.rand((N, 1024)) + r2_patch).to(
    dtype=torch.uint8
)

In [None]:
normal_lats = torch.floor(torch.normal(mean_patch, std_patch, size=(N, 1024)))

In [None]:
random_hemisphere = torch.floor(torch.rand((N, 1024)) * (2)).to(dtype=torch.int8)
random_hemisphere[random_hemisphere == 0] = -1
random_lats = random_hemisphere * normal_lats

In [None]:
random_lons.shape

In [None]:
plt.scatter(random_lons[0, :], random_lats[0, :] + 15)
plt.title("Per-hemisphere lat-normally lon-uniformly distributed patch locations")
plt.xlabel("Patch index (solar longitude)")
plt.ylabel("Patch index (solar latitude)")

In [None]:
N = 1
L = 1024
noise = torch.rand(N, L)  # noise in [0, 1]

# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)

In [None]:
ids_shuffle = random_lons * (random_lats + 15)
ids_shuffle.to(dtype=torch.int16)

In [None]:
torch.argsort(ids_shuffle, dim=1)

In [None]:
ids_shuffle

In [None]:
32 * 32

In [None]:
15 - 12
19 - 15
# patch_idx = x*y

# Patch index to embedding index
# frame_number

In [None]:
torch.normal(3.5)

In [None]:
import numpy as np

# Patch number to embedding index
indices = np.arange(32 * 32).reshape(32, 32).astype(np.int64)
ai = np.repeat(indices[:, :, np.newaxis], 128, axis=2)
bi = np.repeat(ai[:, :, :, np.newaxis], 1, axis=3)
ci = np.repeat(bi[:, :, :, :, np.newaxis], 1, axis=4)
di = torch.Tensor(np.transpose(ci, axes=(3, 2, 4, 0, 1)))  # .to(dtype=torch.float)
di.shape
fi = di.flatten(2).transpose(1, 2)
fi.shape

In [None]:
torch.where(fi[0, :, 0] == patch_idx)

In [None]:
p = patch_embed.patch_size[0]
num_p = patch_embed.img_size[0] // p
tub = patch_embed.tubelet_size
imgs = rearrange(
    pos_embed,
    "b (t h w) (tub p q c) -> b c (t tub) (h p) (w q)",
    h=num_p,
    w=num_p,
    tub=tub,
    p=p,
    q=p,
)

In [None]:
num_p

In [None]:
plt.imshow(imgs[0, 0, 0, :, :].cpu().numpy(), cmap="gray")

In [None]:
import numpy as np

b = np.repeat(mask[:, :, np.newaxis], 4, axis=2)
b = np.transpose(b, axes=[2, 0, 1])
(b[0, :, :] == b[1, :, :])

In [None]:
from sdofm.models.samae3d import PatchEmbed

patch_embed = PatchEmbed(512, 16, 1, 1, 9, 128, flatten=False)

nn.Conv3d(
    in_chans=3,
    embed_dim=128,
    kernel_size=(1, 16, 16),
    stride=(1, 16, 16),
    bias=True,
)

In [None]:
patch_embed(image_batch).shape

In [None]:
patch_embed.proj.weight.shape

In [None]:
16 * 32

The MaskedConv3D is a standard Conv3D with a binary mask applied on sampling locations that shouldn't contribute to the learning process. Whilst in theory a Conv3D could be written to take non-cubic input voxels this should achieve the same effect. The standard torch `nn.Conv3d` is modified such that 

In the simplest case, the output value of the layer with input size $(N, C_{in},D,H,W)$, output $(N, C_{out},D_{out},H_{out},W_{out})$, and logical mask $M$ can be described as:

$$out(N_i, C_{out_j}) = bias(C_{out_j})+ \sum_{C_{in}-1}^{k=0} M*weight(C_{out_j}, k) \star input(N_i, k)$$