In [81]:
from os.path import join
from datasets import load_dataset, concatenate_datasets, DatasetDict
import numpy as np

In [95]:
ASSETS = '/assets'
DATASET_FOLDER = ASSETS + '/datasets'

### Split medical records proportionally in each federated dataset

In [None]:
# Set the dataset names generated by `src/create_dataset.py`
datasets = {
  "medmcqa": {
    "dataset": load_dataset(join(DATASET_FOLDER,"2024_09_10_10_30_51-medmcqa_None-val_size_0.1-max_input_length_1024"), num_proc=64),
  },
  "pubmedqa": {
    "dataset": load_dataset(join(DATASET_FOLDER,"2024_10_22_10_24_48-pubmedqa_1k_50000-val_size_0.1-max_input_length_1024"), num_proc=64),
  },
  "flashcard": {
    "dataset": load_dataset(join(DATASET_FOLDER,"2024_11_12_11_01_18-flashcard_None-val_size_0.1-max_input_length_1024"), num_proc=64),
  },
  "PHI": {
    "dataset": load_dataset(join(DATASET_FOLDER,"2024_10_22_10_31_43-PHI_None-val_size_0.1-max_input_length_1024"), num_proc=64),
  },
}

In [None]:
for name, dataset in datasets.items():
  dataset['size'] = dataset['dataset']['train'].size_in_bytes + dataset['dataset']['validation'].size_in_bytes
  print(name, dataset['size'] / 1024 // 1024)

In [None]:
total = datasets['medmcqa']['size'] + datasets['pubmedqa']['size'] + datasets['flashcard']['size']
for name, dataset in datasets.items():
  if name == 'PHI':
      continue
  dataset['PHI_ratio'] = dataset['size'] / total
  print(dataset['PHI_ratio'])

In [None]:
PHI_train = datasets['PHI']['dataset']['train'].shuffle(42)
PHI_val = datasets['PHI']['dataset']['validation'].shuffle(42)
PHI_train.num_rows, PHI_val.num_rows

In [None]:
for name, dataset in datasets.items():
  if name == 'PHI':
    continue
  dataset['PHI_train_rows'] = int(PHI_train.num_rows * dataset['PHI_ratio'])
  dataset['PHI_val_rows'] = int(PHI_val.num_rows * dataset['PHI_ratio'])
  print(dataset['PHI_train_rows'], dataset['PHI_val_rows'])

In [None]:
train_splits, val_splits = [], []

train_splits.append(datasets['medmcqa']['PHI_train_rows'])
val_splits.append(datasets['medmcqa']['PHI_val_rows'])

train_splits.append(train_splits[0] + datasets['pubmedqa']['PHI_train_rows'])
val_splits.append(val_splits[0] + datasets['pubmedqa']['PHI_val_rows'])
train_splits, val_splits

In [None]:
train_splits_idx = np.split(np.arange(PHI_train.num_rows), train_splits)
val_splits_idx = np.split(np.arange(PHI_val.num_rows), val_splits)

In [88]:
for i, (name, dataset) in enumerate(datasets.items()):
  if name == 'PHI':
    continue
  
  datasets[name]['new_dataset'] = DatasetDict({
    'train': concatenate_datasets([dataset['dataset']['train'], PHI_train.select(train_splits_idx[i])]),
    'validation': concatenate_datasets([dataset['dataset']['validation'], PHI_val.select(val_splits_idx[i])]),
  })

In [None]:
for name, d in datasets.items():
  if name == 'PHI':
    continue
  dataset = d['new_dataset']
  dataset.set_format("torch", device="cuda")
  dataset_filename = f"{name}"

  output_path = join(DATASET_FOLDER, "federated/2024_11_12_13_06_32_PHI_proportional_splits", dataset_filename)
  dataset.save_to_disk(output_path, num_proc=64)