Statistics about Dataloaders

In [6]:
from hf_ehr.utils import load_config_from_path
from hf_ehr.data.tokenization import CLMBRTokenizer
from hf_ehr.data.datasets import BaseDataset
from torch.utils.data import DataLoader
from hf_ehr.trainer.loaders import load_datasets, load_dataloaders
from omegaconf import OmegaConf
import os
from hf_ehr.config import PATH_TO_FEMR_EXTRACT_v8
from typing import Dict, List
from tqdm import tqdm
import json
 
# Load config
config = load_config_from_path('/share/pi/nigam/suhana/hf_ehr/cache/runs_backup/mamba-tiny-16384--clmbr/ckpts/train-tokens-total_nonPAD-ckpt_val=2000000000-persist.ckpt')
OmegaConf.set_struct(config, False)
config.data.dataset.name = 'FEMRDataset'
config.data.dataset.path_to_femr_extract = PATH_TO_FEMR_EXTRACT_v8

# Load dataloader
tokenizer = CLMBRTokenizer( config.data.tokenizer.path_to_config )
datasets: Dict[str, BaseDataset] = load_datasets(config, tokenizer)
dataloaders: Dict[str, DataLoader] = load_dataloaders(config, datasets, tokenizer)

[32m2024-10-17 20:47:41.184[0m | [1mINFO    [0m | [36mhf_ehr.trainer.loaders[0m:[36mload_dataloaders[0m:[36m30[0m - [1m====> Loading ApproxBatchSampler[0m


Loading `seq_length_per_patient.json` from `/share/pi/nigam/mwornow/hf_ehr/cache/tokenizers/clmbr_v8/versions/2024-07-20_05-21-33/datasets/2024-07-20_05-22-12/seq_length_per_patient.json` for split=`train`
Loading `seq_length_per_patient.json` from `/share/pi/nigam/mwornow/hf_ehr/cache/tokenizers/clmbr_v8/versions/2024-07-20_05-21-33/datasets/2024-07-21_10-55-32/seq_length_per_patient.json` for split=`val`
Loading `seq_length_per_patient.json` from `/share/pi/nigam/mwornow/hf_ehr/cache/tokenizers/clmbr_v8/versions/2024-07-20_05-21-33/datasets/2024-07-21_11-44-30/seq_length_per_patient.json` for split=`test`


In [9]:
# Loop through train dataloader, keeping track of all sequence lengths seen
if os.path.exists('../cache/train_seq_lengths-mamba-16k.json'):
    data = json.load(open('../cache/train_seq_lengths-mamba-16k.json', 'r'))
    train_seq_lengths: List[int] = data['train_seq_lengths']
else:
    train_seq_lengths: List[int] = []
    for batch in tqdm(dataloaders['train']):
        lengths = batch['tokens']['attention_mask'].sum(dim=1)
        assert len(lengths) == len(batch['patient_ids'])
        train_seq_lengths.extend(lengths)
    train_seq_lengths = [ x.item() for x in train_seq_lengths ]
    json.dump({ 'train_seq_lengths' : train_seq_lengths, }, open('../cache/train_seq_lengths-mamba-16k.json', 'w'))
    print("# of batches:", len(dataloaders['train']))
print("# of seqs:", len(train_seq_lengths))

  0%|          | 7/118431 [00:06<18:14:38,  1.80it/s] Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7feaeb356a70>
Traceback (most recent call last):
  File "/home/mwornow/llama_hf_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/home/mwornow/llama_hf_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    Exception ignored in: if w.is_alive():<function _MultiProcessingDataLoaderIter.__del__ at 0x7feaeb356a70>
  File "/home/mwornow/llama_hf_env/lib/python3.10/multiprocessing/process.py", line 160, in is_alive

    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: Traceback (most recent call last):
can only test a child process  File "/home/mwornow/llama_hf_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    
self._shutdown_workers()
  File "/home/mwornow/llama

KeyboardInterrupt: 

In [None]:
# Plot histogram of sequence lengths
from collections import Counter

counter = Counter(train_seq_lengths)
plt.scatter(list(counter.keys()), list(counter.values()), s=5)
plt.yscale('log')
plt.xscale('log')
plt.title('Seq Lengths v. Frequency from Train Dataloader (log-log plot)')
plt.xlabel('Seq Length (# of tokens)')
plt.ylabel('# of Occurrences in Train Dataloader')

# Show plot
plt.show()

In [None]:
# Plot CDF of sequence lengths
seq_lengths = np.array(list(counter.keys()))
occurrences = np.array(list(counter.values()))
sorted_indices = np.argsort(seq_lengths)
sorted_seq_lengths = seq_lengths[sorted_indices]
sorted_occurrences = occurrences[sorted_indices]

# CDF is the cumulative sum of the occurrences divided by the total number of sequences
cdf = np.cumsum(sorted_occurrences) / np.sum(sorted_occurrences)

# Plot CDF
plt.scatter(sorted_seq_lengths, cdf, marker='o', linestyle='-', s=5)
plt.title('CDF of Sequence Lengths from Train Dataloader')
plt.xlabel('Seq Length (# of tokens)')
plt.ylabel('Cumulative Proportion of Samples')

# Show plot
plt.show()

In [None]:
# Percentiles
print_percentiles(pd.Series(train_seq_lengths))
print("\n")
print("% of sequences >=512", len([ x for x in train_seq_lengths if x >= 512 ]) / len(train_seq_lengths))
print("% of sequences >=768", len([ x for x in train_seq_lengths if x >= 768 ]) / len(train_seq_lengths))
print("% of sequences >=1024", len([ x for x in train_seq_lengths if x >= 1024 ]) / len(train_seq_lengths))
print("% of sequences >=2048", len([ x for x in train_seq_lengths if x >= 2048 ]) / len(train_seq_lengths))
print("% of sequences >=4096", len([ x for x in train_seq_lengths if x >= 4096 ]) / len(train_seq_lengths))
print("% of sequences >=8192", len([ x for x in train_seq_lengths if x >= 8192 ]) / len(train_seq_lengths))
print("% of sequences >=16384", len([ x for x in train_seq_lengths if x >= 16384 ]) / len(train_seq_lengths))