In [None]:
!pip install -q zarr

In [None]:
import os
from os.path import join

import dask
import dask.array as da
import pandas as pd
import numpy as np
import numba

from numba.typed import Dict
from numba import prange

In [None]:
PATH = '/mnt/dssfs02/cxg_census/data_2023_05_15'

# Get idxs for subsampling

In [None]:
obs_train = pd.read_parquet(join(PATH, 'train/obs.parquet')).reset_index(drop=True)
x_train = da.from_zarr(join(PATH, 'train/zarr'), component='X')

obs_val = pd.read_parquet(join(PATH, 'val/obs.parquet')).reset_index(drop=True)
x_val = da.from_zarr(join(PATH, 'val/zarr'), component='X')

obs_test = pd.read_parquet(join(PATH, 'test/obs.parquet')).reset_index(drop=True)
x_test = da.from_zarr(join(PATH, 'test/zarr'), component='X')

In [None]:
var = pd.read_parquet(join(PATH, 'train/var.parquet'))

In [None]:
for col in obs_train.columns:
    if obs_train[col].dtype.name == 'category':
        obs_train[col] = obs_train[col].cat.remove_unused_categories()


for col in obs_val.columns:
    if obs_val[col].dtype.name == 'category':
        obs_val[col] = obs_val[col].cat.remove_unused_categories()
        

for col in obs_test.columns:
    if obs_test[col].dtype.name == 'category':
        obs_test[col] = obs_test[col].cat.remove_unused_categories()


In [None]:
rng = np.random.default_rng(seed=1)

subset_idxs = {}


for split, obs in [('train', obs_train), ('val', obs_val), ('test', obs_test)]:
    idx_subset = obs[obs.tissue_general == 'lung'].index.to_numpy()
    rng.shuffle(idx_subset)
    subset_idxs[split] = idx_subset


In [None]:
subset_idxs

# Store balanced data to disk

In [None]:
SAVE_PATH = f'/mnt/dssfs02/cxg_census/data_2023_05_15_lung_only'
CHUNK_SIZE = 16384

In [None]:
for split, x, obs in [
    ('train', x_train, obs_train),
    ('val', x_val, obs_val),
    ('test', x_test, obs_test)
]:
    # out-of-order indexing is on purpose here as we want to shuffle the data to break up data sets
    X_split = x[subset_idxs[split], :].rechunk((CHUNK_SIZE, -1))
    obs_split = obs.iloc[subset_idxs[split], :]

    save_dir = join(SAVE_PATH, split)
    os.makedirs(save_dir)

    var.to_parquet(path=join(save_dir, 'var.parquet'), engine='pyarrow', compression='snappy', index=None)
    obs_split.to_parquet(path=join(save_dir, 'obs.parquet'), engine='pyarrow', compression='snappy', index=None)
    da.to_zarr(
        X_split,
        join(save_dir, 'zarr'),
        component='X',
        compute=True,
        compressor='default', 
        order='C'
    )
