In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
from tqdm import tqdm

import pytorch_lightning as pl
import torch
import wandb
from sdofm import utils
from sdofm.datasets import SDOMLDataModule
from sdofm.pretraining import MAE
from scripts.pretrain import Pretrainer
from lightning.pytorch.loggers.wandb import WandbLogger

In [3]:
import omegaconf

cfg = omegaconf.OmegaConf.load("../experiments/finetune_32.2M_mae_virtualeve.yaml")

In [4]:
data_module = SDOMLDataModule(
    hmi_path=None,
    aia_path=os.path.join(
        cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia
    ),
    eve_path=None,
    components=cfg.data.sdoml.components,
    wavelengths=cfg.data.sdoml.wavelengths,
    ions=cfg.data.sdoml.ions,
    frequency=cfg.data.sdoml.frequency,
    batch_size=cfg.model.opt.batch_size,
    num_workers=cfg.data.num_workers,
    val_months=cfg.data.month_splits.val,
    test_months=cfg.data.month_splits.test,
    holdout_months=cfg.data.month_splits.holdout,
    cache_dir=os.path.join(
        cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.cache
    ),
)
data_module.setup()

[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.
[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.
[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.


In [5]:
#model = MAE(
#    **cfg.model.mae,
    #    **cfg.model.samae,
    #    hmi_mask=data_module.hmi_mask,
#    optimiser=cfg.model.opt.optimiser,
#    lr=cfg.model.opt.learning_rate,
#    weight_decay=cfg.model.opt.weight_decay,
    
#)

logger = WandbLogger(
    # WandbLogger params
    name=cfg.experiment.name,
    project=cfg.experiment.project,
    dir=cfg.experiment.wandb.output_directory,
    log_model=cfg.experiment.wandb.log_model,
    # kwargs for wandb.init
    tags=cfg.experiment.wandb.tags,
    notes=cfg.experiment.wandb.notes,
    group=cfg.experiment.wandb.group,
    save_code=True,
    job_type=cfg.experiment.wandb.job_type,

)
model = Pretrainer(cfg, logger=logger, is_backbone=True)


[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.
[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.
[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.
Loading checkpoint...
Found pre-downloaded checkpoint at artifacts/model-tk45el88:v12/model.ckpt


/opt/conda/envs/sdofm/lib/python3.10/site-packages/lightning/pytorch/utilities/migration/utils.py:56: The loaded checkpoint was produced with Lightning v2.2.5, which is newer than your current Lightning version: v2.2.1


Checkpoint loaded from artifacts/model-tk45el88:v12/model.ckpt


In [6]:
val_dataset = data_module.valid_ds

In [7]:
model.model.to("cuda");

In [8]:
val_dataset.__len__()

43131

In [9]:
num_samples = 5000
cls_embeddings = []
mean_embeddings = []
names = []
for i in tqdm(range(num_samples)):
    batch = val_dataset[i]
    name = val_dataset.aligndata.iloc[i].name
    batch = torch.tensor(batch).unsqueeze(0)    
    batch = batch.to("cuda")    
    x, mask, ids_restore = model.model.forward_encoder(batch, mask_ratio = 0)
    # cls_token 
    cls_embedding = x[:,0,:].detach().cpu()
    mean_embedding = x[:,1:,:].mean(dim=1).detach().cpu()
    cls_embeddings.append(cls_embedding)
    mean_embeddings.append(mean_embedding)
    names.append(name)
cls_embeddings = torch.cat(cls_embeddings, dim=0)
mean_embeddings = torch.cat(mean_embeddings, dim=0)

100%|██████████| 5000/5000 [30:59<00:00,  2.69it/s]


In [10]:
from sklearn.manifold import TSNE
import numpy as np

In [11]:

tsne = TSNE(n_components=2, random_state=0)

cls_embeddings_np = cls_embeddings.numpy()
cls_embeddings_tsne = tsne.fit_transform(cls_embeddings_np)

cls_embeddings_tsne.shape


(5000, 2)

In [14]:
# Create a color scale for the dates we see. Darker = earlier
import matplotlib.colors as mcolors
import matplotlib.cm as cm

norm = mcolors.Normalize(vmin=0, vmax=512)
cmap = cm.ScalarMappable(norm=norm, cmap=cm.viridis)
cmap.set_array([])
colors = cmap.to_rgba(np.arange(num_samples ))


In [15]:
import plotly.express as px


fig = px.scatter(x=cls_embeddings_tsne[:,0], y=cls_embeddings_tsne[:,1], color=range(num_samples), color_discrete_map={i: f"rgb({int(colors[i][0]*255)},{int(colors[i][1]*255)},{int(colors[i][2]*255)})" for i in range(512)}
                 
                 # only show text on hover
                    , hover_name=names

                 )
fig.update_traces(marker=dict(size=12, line=dict(width=2, color='DarkSlateGrey')), selector=dict(mode='markers'))
# hide the colorbar
fig.update_layout(coloraxis_showscale=False)


fig.update_traces(textposition='top center', textfont_size=8)
fig.update_layout(showlegend=False, autosize=False, width=800, height=800, margin=dict(l=0, r=0, b=0, t=0))


fig.show()


In [16]:

tsne = TSNE(n_components=2, random_state=0)

mean_embeddings_np = mean_embeddings.numpy()
mean_embeddings_tsne = tsne.fit_transform(mean_embeddings_np)

mean_embeddings_tsne.shape


(5000, 2)

In [None]:

# Mean pooling
fig = px.scatter(x=mean_embeddings_tsne[:,0], y=mean_embeddings_tsne[:,1], color=range(num_samples), color_discrete_map={i: f"rgb({int(colors[i][0]*255)},{int(colors[i][1]*255)},{int(colors[i][2]*255)})" for i in range(512)}
                 
                 # only show text on hover
                    , hover_name=names

                 )
fig.update_traces(marker=dict(size=12, line=dict(width=2, color='DarkSlateGrey')), selector=dict(mode='markers'))
# hide the colorbar
fig.update_layout(coloraxis_showscale=False)


fig.update_traces(textposition='top center', textfont_size=8)
fig.update_layout(showlegend=False, autosize=False, width=800, height=800, margin=dict(l=0, r=0, b=0, t=0))


fig.show()


In [None]:
#!pip install plotly