In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor
import argparse
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, Optional
from abc import abstractmethod
from warnings import warn

from polarbert.pretraining import (
    load_and_process_config,
    setup_callbacks,
    get_dataloaders,
    update_training_steps,
    compute_batch_params,
    default_transform,
    add_random_time_offset,
)

from polarbert.loss_functions import angles_to_unit_vector, angular_dist_score_unit_vectors
from typing import Dict, Tuple, Any, Optional, Callable
from torch.utils.data import DataLoader

In [9]:
from polarbert.dataloader_utils import (
    get_dataloaders, 
    target_transform_prometheus, 
    target_transform_kaggle,
    default_transform)

from polarbert.config import PolarBertConfig

In [10]:
config = PolarBertConfig.from_yaml('/groups/pheno/inar/PolarBERT/configs/te_finetuning_combined.yaml')

Loading configuration from: /groups/pheno/inar/PolarBERT/configs/te_finetuning_combined.yaml
Validating configuration...




In [11]:
kaggle_train, kaggle_val = get_dataloaders(
    config,
    dataset_type='kaggle',
    transform=default_transform,
    target_transform=target_transform_kaggle
)

prometheus_train, prometheus_val = get_dataloaders(
    config,
    dataset_type='prometheus',
    transform=default_transform,
    target_transform=target_transform_prometheus
)

In [12]:
len(kaggle_train), len(kaggle_val), len(prometheus_train), len(prometheus_val)

(96, 96, 96, 96)

In [13]:
from polarbert.time_embed_polarbert import PolarBertModel

In [14]:
model = PolarBertModel(config)

INFO: Concatenated embeddings directly match model embedding dim. No projection layer used.


In [15]:
batch_kaggle = next(iter(kaggle_train))
batch_prometheus = next(iter(prometheus_train))

In [17]:
(x, l), (y, c) = batch_kaggle
x.shape, l.shape, y.shape, c.shape

(torch.Size([1024, 127, 4]),
 torch.Size([1024]),
 torch.Size([1024, 2]),
 torch.Size([1024]))

In [18]:
(x, l), (y, c) = batch_prometheus
x.shape, l.shape, y.shape, c.shape

(torch.Size([1024, 127, 4]),
 torch.Size([1024]),
 torch.Size([1024, 2]),
 torch.Size([1024]))

In [21]:
model(batch_prometheus)

(tensor([[[-7.2651e-01, -8.8787e-01,  2.1512e-02,  ..., -1.4985e-01,
            2.8236e-01,  4.3037e-02],
          [-1.1994e+00, -6.5726e-01,  1.0863e+00,  ...,  9.8079e-02,
            3.2950e-01,  8.0259e-01],
          [ 1.9032e-01, -6.8657e-01,  1.9045e-01,  ..., -1.3535e-01,
            5.2155e-01, -1.1243e-01],
          ...,
          [-1.1697e+00,  7.2458e-01, -1.0877e-02,  ..., -1.7296e-01,
           -3.6154e-01, -4.2121e-01],
          [-1.1350e+00,  4.7216e-01, -5.5229e-01,  ...,  4.6709e-01,
           -2.7490e-02,  3.0391e-01],
          [-4.8471e-01, -1.4974e-02, -5.0631e-01,  ...,  2.2416e-01,
            2.1430e-01, -7.4323e-01]],
 
         [[ 4.2968e-01, -6.8890e-01, -2.0038e-01,  ..., -8.3614e-01,
            8.2966e-03, -5.4272e-02],
          [ 2.4566e-01, -7.9746e-02,  1.1343e+00,  ...,  4.1714e-01,
           -3.3114e-01,  4.1901e-01],
          [-3.7914e-01, -2.6378e-01, -1.2692e-01,  ...,  2.3355e-01,
            1.0407e+00,  6.7495e-01],
          ...,
    

In [3]:

def target_transform_prometheus(y, c):
    y = np.vstack([y['initial_state_azimuth'].astype(np.float32), y['initial_state_zenith'].astype(np.float32)]).T
    return y, c.astype(np.float32)


def target_transform_kaggle(y, c):
    return y.astype(np.float32), c.astype(np.float32)

In [4]:
from polarbert.te_pretraining import get_dataloaders
from polarbert.config import PolarBertConfig

In [12]:
config = PolarBertConfig.from_yaml('/groups/pheno/inar/PolarBERT/configs/te_finetuning_combined.yaml')

Loading configuration from: /groups/pheno/inar/PolarBERT/configs/te_finetuning_combined.yaml
Validating configuration...




In [37]:
def default_target_transform(y, c):
    return None, c.astype(np.float32)

def default_transform(x, l):
    return x.astype(np.float32), l.astype(np.float32)

def get_dataloaders(
        config: PolarBertConfig,
        dataset_type: str,
        transform=default_transform,
        target_transform=default_target_transform,
        override_batch_size: Optional[int]=None,
    ) -> Tuple[DataLoader, DataLoader]:
    """Creates train and validation dataloaders using a PolarBertConfig object."""

    if dataset_type == 'prometheus':
        from polarbert.prometheus_dataset import IceCubeDataset
        data_dir=config.data.prometheus_dir
    elif dataset_type == 'kaggle':
        from polarbert.icecube_dataset import IceCubeDataset
        data_dir=config.data.train_dir
    else:
        assert False, f"Unknown dataset type: {dataset_type}"
    
    full_dataset = IceCubeDataset(
        data_dir=data_dir,
        batch_size=override_batch_size if override_batch_size is not None else config.data.max_per_device_batch_size,
        transform=transform,
        target_transform=target_transform
    )
    train_events = config.data.train_events
    val_events = config.data.val_events

    if dataset_type == 'prometheus':
        if val_events is None:
            raise ValueError("Number of validation events must be specified for the Prometheus dataset")
        val_dataset = full_dataset.slice(0, val_events)
        train_dataset = full_dataset.slice(val_events, val_events + train_events) if train_events else full_dataset.slice(val_events, None)
    elif dataset_type == 'kaggle':
        # Training dataset
        train_dataset = full_dataset.slice(0, train_events)
        # Validation dataset with optional subsampling
        full_val_dataset = IceCubeDataset(
            data_dir=config.data.val_dir, 
            batch_size=override_batch_size if override_batch_size is not None else config.data.max_per_device_batch_size,
            transform=transform,
            target_transform=target_transform
        )
        val_dataset = full_val_dataset.slice(0, val_events)
    else:
        assert False
    
    loader_kwargs = {
        'batch_size': None,
        'num_workers': config.data.num_workers,
        'pin_memory': config.data.pin_memory,
        'persistent_workers': config.data.persistent_workers
    }
    
    return (
        DataLoader(train_dataset, **loader_kwargs),
        DataLoader(val_dataset, **loader_kwargs)
    )

In [38]:
config.data.max_per_device_batch_size

1024

In [39]:
kaggle_train, kaggle_val = get_dataloaders(
    config,
    dataset_type='kaggle',
    transform=default_transform,
    target_transform=target_transform_kaggle
)

In [40]:
prometheus_train, prometheus_val = get_dataloaders(
    config,
    dataset_type='prometheus',
    transform=default_transform,
    target_transform=target_transform_prometheus
)


In [42]:
len(kaggle_train), len(kaggle_val), len(prometheus_train), len(prometheus_val)

(96, 96, 96, 96)

In [23]:
kaggle_train.dataset, kaggle_val.dataset, prometheus_train.dataset, prometheus_val.dataset

(<polarbert.icecube_dataset.IceCubeDataset at 0x7f5a88ca17d0>,
 <polarbert.icecube_dataset.IceCubeDataset at 0x7f5a8917b850>,
 <polarbert.prometheus_dataset.IceCubeDataset at 0x7f5a88b8fb90>,
 <polarbert.prometheus_dataset.IceCubeDataset at 0x7f5a88b11750>)