In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
from tqdm import tqdm

import pytorch_lightning as pl
import torch
import wandb
from sdofm import utils
from sdofm.datasets import SDOMLDataModule
from sdofm.pretraining import MAE
from scripts.pretrain import Pretrainer
from lightning.pytorch.loggers.wandb import WandbLogger

In [3]:
import omegaconf

cfg = omegaconf.OmegaConf.load("../experiments/finetune_32.2M_mae_virtualeve.yaml")

In [4]:
data_module = SDOMLDataModule(
    hmi_path=None,
    aia_path=os.path.join(
        cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.aia
    ),
    eve_path=None,
    components=cfg.data.sdoml.components,
    wavelengths=cfg.data.sdoml.wavelengths,
    ions=cfg.data.sdoml.ions,
    frequency=cfg.data.sdoml.frequency,
    batch_size=cfg.model.opt.batch_size,
    num_workers=cfg.data.num_workers,
    val_months=cfg.data.month_splits.val,
    test_months=cfg.data.month_splits.test,
    holdout_months=cfg.data.month_splits.holdout,
    cache_dir=os.path.join(
        cfg.data.sdoml.base_directory, cfg.data.sdoml.sub_directory.cache
    ),
)
data_module.setup()

[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.
[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.
[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.


In [5]:
#model = MAE(
#    **cfg.model.mae,
    #    **cfg.model.samae,
    #    hmi_mask=data_module.hmi_mask,
#    optimiser=cfg.model.opt.optimiser,
#    lr=cfg.model.opt.learning_rate,
#    weight_decay=cfg.model.opt.weight_decay,
    
#)

logger = WandbLogger(
    # WandbLogger params
    name=cfg.experiment.name,
    project=cfg.experiment.project,
    dir=cfg.experiment.wandb.output_directory,
    log_model=cfg.experiment.wandb.log_model,
    # kwargs for wandb.init
    tags=cfg.experiment.wandb.tags,
    notes=cfg.experiment.wandb.notes,
    group=cfg.experiment.wandb.group,
    save_code=True,
    job_type=cfg.experiment.wandb.job_type,

)
model = Pretrainer(cfg, logger=logger, is_backbone=True)


[* CACHE SYSTEM *] Found cached index data in /mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv.
[* CACHE SYSTEM *] Found cached normalization data in /mnt/sdoml/cache/normalizations_AIA_FULL_12min.json.
[* CACHE SYSTEM *] Found cached HMI mask data in /mnt/sdoml/cache/hmi_mask_512x512.npy.
Loading checkpoint...
Found pre-downloaded checkpoint at artifacts/model-tk45el88:v12/model.ckpt


/opt/conda/envs/sdofm/lib/python3.10/site-packages/lightning/pytorch/utilities/migration/utils.py:56: The loaded checkpoint was produced with Lightning v2.2.5, which is newer than your current Lightning version: v2.2.1


Checkpoint loaded from artifacts/model-tk45el88:v12/model.ckpt


In [6]:
val_dataset = data_module.valid_ds
train_dataset = data_module.train_ds

In [7]:
model.model.to("cuda");
model.model.eval();

In [8]:
val_dataset.__len__(), train_dataset.__len__()

(43131, 450432)

In [9]:
import pandas as pd
dates = []
for i in range(train_dataset.__len__()):

    dates.append(train_dataset.aligndata.iloc[i].name)

dates_df = pd.DataFrame(dates, columns=["date"])

dates_df.reset_index(inplace=True, drop=False)

In [10]:
dates_df['year'] = pd.to_datetime(dates_df['date']).dt.year
dates_df['month'] = pd.to_datetime(dates_df['date']).dt.month
dates_df['day'] = pd.to_datetime(dates_df['date']).dt.day
dates_df['hour'] = pd.to_datetime(dates_df['date']).dt.hour


In [11]:
df_2011 = dates_df[dates_df['year'] == 2011]
# groupby month, select 100 random samples
df_2011_subset = df_2011.groupby('month').apply(lambda x: x.sample(100, random_state=1)).reset_index(drop=True)
quiet_months = [1, 2, 5, 6, 7, 8]
df_2011_subset['is_active'] = df_2011_subset['month'].apply(lambda x: 0 if x in quiet_months else 1)

  df_2011_subset = df_2011.groupby('month').apply(lambda x: x.sample(100, random_state=1)).reset_index(drop=True)


In [12]:
cls_embeddings = []
mean_embeddings = []
names = []

for idx in tqdm(df_2011_subset['index'].values):
    batch = train_dataset[idx]
    name = train_dataset.aligndata.iloc[idx].name
    batch = torch.tensor(batch).unsqueeze(0)    
    batch = batch.to("cuda")    
    x, mask, ids_restore = model.model.forward_encoder(batch, mask_ratio = 0)
    # cls_token 
    cls_embedding = x[:,0,:].detach().cpu()
    mean_embedding = x[:,1:,:].mean(dim=1).detach().cpu()
    cls_embeddings.append(cls_embedding)
    mean_embeddings.append(mean_embedding)
    names.append(name)
cls_embeddings = torch.cat(cls_embeddings, dim=0)
mean_embeddings = torch.cat(mean_embeddings, dim=0)

 49%|████▉     | 493/1000 [10:17<12:46,  1.51s/it]

num_samples = 5000
cls_embeddings = []
mean_embeddings = []
names = []
for i in tqdm(range(num_samples)):
    batch = val_dataset[i]
    name = val_dataset.aligndata.iloc[i].name
    batch = torch.tensor(batch).unsqueeze(0)    
    batch = batch.to("cuda")    
    x, mask, ids_restore = model.model.forward_encoder(batch, mask_ratio = 0)
    # cls_token 
    cls_embedding = x[:,0,:].detach().cpu()
    mean_embedding = x[:,1:,:].mean(dim=1).detach().cpu()
    cls_embeddings.append(cls_embedding)
    mean_embeddings.append(mean_embedding)
    names.append(name)
cls_embeddings = torch.cat(cls_embeddings, dim=0)
mean_embeddings = torch.cat(mean_embeddings, dim=0)

In [None]:
from sklearn.manifold import TSNE
import plotly.express as px

In [None]:

tsne = TSNE(n_components=2, random_state=0)

cls_embeddings_np = cls_embeddings.numpy()
cls_embeddings_tsne = tsne.fit_transform(cls_embeddings_np)



df_2011_subset['cls_tsne_x'] = cls_embeddings_tsne[:,0]
df_2011_subset['cls_tsne_y'] = cls_embeddings_tsne[:,1]


tsne = TSNE(n_components=2, random_state=0)

mean_embeddings_np = mean_embeddings.numpy()
mean_embeddings_tsne = tsne.fit_transform(mean_embeddings_np)



df_2011_subset['avg_tsne_x'] = mean_embeddings_tsne[:,0]
df_2011_subset['avg_tsne_y'] = mean_embeddings_tsne[:,1]



In [None]:

fig = px.scatter(df_2011_subset, x="cls_tsne_x", y="cls_tsne_y", color="is_active", hover_data=["month", "day", "hour"])

fig.show()


In [None]:
fig = px.scatter(df_2011_subset, x="avg_tsne_x", y="avg_tsne_y", color="is_active", hover_data=["month", "day", "hour"])

fig.show()


In [None]:
# save html
fig.write_html("mean_pooling_tsne.html")


In [None]:
#!pip install plotly