In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle
import sys
import shutil
import h5py

sys.path.append('/jet/home/tvnguyen/accreted_catalog/gaia_accreted_catalog')

import ml_collections
import numpy as np
import scipy.special as special
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import pytorch_lightning.loggers as pl_loggers
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Subset
from ml_collections import config_flags
from memory_profiler import profile

import datasets
from models import models, classifier, utils, infer_utils

%matplotlib inline
plt.style.use('/jet/home/tvnguyen/mplstyle/default.mplstyle')

In [11]:
# read in the model
logdir = '/ocean/projects/phy210068p/tvnguyen/accreted_catalog/logging'
name = 'greasy-vote-10'
checkpoint = 'epoch=51-step=1100164.ckpt'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# read in the dataset and prepare the data loader for training
data_root = '/ocean/projects/phy210068p/tvnguyen/accreted_catalog/datasets'
data_name = 'GaiaDR3_transfer'
data_features = ['ra', 'dec', 'pmra', 'pmdec', 'parallax', 'radial_velocity']
data_dir = os.path.join(data_root, data_name)

# catalog name
catalog_root = '/ocean/projects/phy210068p/tvnguyen/accreted_catalog/gaia_catalogs'
catalog_name = 'GaiaDR3_FeH_reduced_v1'

In [12]:
checkpoint_path = os.path.join(logdir, name, 'lightning_logs/checkpoints', checkpoint)
model = classifier.MLPClassifier.load_from_checkpoint(
    checkpoint_path, map_location=device)

In [13]:
def read_dataset(data_fn, features):
    """ Read in the dataset from the hdf5 file. """
    x = []
    with h5py.File(data_fn, 'r') as f:
        if features is not None:
            # if radial velocity is in the features, get only the stars with
            # availabel radial velocity
            if 'radial_velocity' in features:
                mask = ~np.isnan(f['radial_velocity'][:])
            else:
                mask = np.ones(len(f['ra']), dtype=bool)
            for feature in features:
                x.append(f[feature][mask])

    x = np.stack(x, axis=-1) if len(x) > 0 else None
    y = np.zeros(len(x), dtype=bool)
    return x, y

In [14]:
for i in range(10):
    
    data_fn = os.path.join(data_dir, f'data.{i}.hdf5')
    print(f'Reading data from {data_fn}')

    features, labels = read_dataset(data_fn, data_features)
    features = torch.tensor(features, dtype=torch.float32)
    labels = torch.tensor(labels, dtype=torch.long)
    data_loader = DataLoader(
        TensorDataset(features, labels), batch_size=1024, shuffle=False)

    # inference 
    y_pred, y_true = infer_utils.infer(
        model, data_loader, softmax=False, to_numpy=True)
    y_pred_score = special.softmax(y_pred, axis=1)
    y_pred_score = y_pred_score[..., 1]

    # save the results
    results = {
        'y_true': y_true,
        'y_pred': y_pred,
        'y_pred_score': y_pred_score,
    }
    catalog_path = os.path.join(catalog_root, catalog_name, f'results.{i}.pkl')
    os.makedirs(os.path.dirname(catalog_path), exist_ok=True)
    with open(catalog_path, 'wb') as f:
        pickle.dump(results, f)

Reading data from /ocean/projects/phy210068p/tvnguyen/accreted_catalog/datasets/GaiaDR3_transfer/data.0.hdf5


Inferencing: 100%|████████████████████████████████████████████████████| 931/931 [00:07<00:00, 131.57it/s]


Reading data from /ocean/projects/phy210068p/tvnguyen/accreted_catalog/datasets/GaiaDR3_transfer/data.1.hdf5


Inferencing: 100%|████████████████████████████████████████████████████| 966/966 [00:07<00:00, 129.78it/s]


Reading data from /ocean/projects/phy210068p/tvnguyen/accreted_catalog/datasets/GaiaDR3_transfer/data.2.hdf5


Inferencing: 100%|██████████████████████████████████████████████████| 1671/1671 [00:13<00:00, 128.41it/s]


Reading data from /ocean/projects/phy210068p/tvnguyen/accreted_catalog/datasets/GaiaDR3_transfer/data.3.hdf5


Inferencing: 100%|██████████████████████████████████████████████████| 2834/2834 [00:22<00:00, 126.74it/s]


Reading data from /ocean/projects/phy210068p/tvnguyen/accreted_catalog/datasets/GaiaDR3_transfer/data.4.hdf5


Inferencing: 100%|█████████████████████████████████████████| 2862/2862 [00:22<00:00, 127.10it/s]


Reading data from /ocean/projects/phy210068p/tvnguyen/accreted_catalog/datasets/GaiaDR3_transfer/data.5.hdf5


Inferencing: 100%|█████████████████████████████████████████| 2852/2852 [00:22<00:00, 126.72it/s]


Reading data from /ocean/projects/phy210068p/tvnguyen/accreted_catalog/datasets/GaiaDR3_transfer/data.6.hdf5


Inferencing: 100%|█████████████████████████████████████████| 2848/2848 [00:22<00:00, 126.72it/s]


Reading data from /ocean/projects/phy210068p/tvnguyen/accreted_catalog/datasets/GaiaDR3_transfer/data.7.hdf5


Inferencing: 100%|█████████████████████████████████████████| 2844/2844 [00:22<00:00, 125.77it/s]


Reading data from /ocean/projects/phy210068p/tvnguyen/accreted_catalog/datasets/GaiaDR3_transfer/data.8.hdf5


Inferencing: 100%|█████████████████████████████████████████| 2850/2850 [00:22<00:00, 126.31it/s]


Reading data from /ocean/projects/phy210068p/tvnguyen/accreted_catalog/datasets/GaiaDR3_transfer/data.9.hdf5


Inferencing: 100%|█████████████████████████████████████████| 2852/2852 [00:22<00:00, 126.05it/s]
