Skip to content

whaleloops/ClinicalMamba

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ClinicalMamba

This repository contains the implementation of prompt-based fine-tuning ClinicalMamba on n2c2 2018 shared task 1: Cohort Selection for Clinical Trials. This is a classification task that identifies which patients meet and do not meet the identified selection criteria given in their longitudinal clinical notes.

The ClinicalMamba: A Generative Clinical Language Model on Longitudinal Clinical Notes paper contains 2 unique ClinicalMamba models with different number of parameters: clinicalmamba-2.8b and clinicalmamba-130m. These two models are currently under review and will be available under mimic license.

Dependencies

  • python=3.9.18
  • numpy=1.26.3
  • transformers=4.36.2
  • tokenizers=0.15.0
  • mamba-ssm=1.1.2
  • causal-conv1d=1.1.1
  • pytorch=2.1.2
  • pytorch-cuda=12.1
  • scikit-learn=1.4.0 

Full environment setting is lised here and can be installed through:

conda env create -f conda-environment.yaml
conda activate mixtral

Download / preprocess data

  1. Download raw n2c2 data folder train and n2c2-t1_gold_standard_test_data, and put them under ./data
  2. Proprcesss the data by running the notebook: ./preprocess/preprocess.ipynb. It will transform from xml to json format, where each instance is a dictionary input is 'text' and output should start with ‘label’. Example in image below:
  3. Define your labels and associated prompts here ./config_labels.py. Example in image below:
  4. The model then learns to assign token yes or no to each prompt.

Train and Eval

To test the perplexity of clinical mamba:

CUDA_VISIBLE_DEVICES=0 python test.py \
                --overwrite_output_dir --seed 42 --data_seed 42 --ddp_find_unused_parameters False \
                --data_path ./preprocess \
                --model_name mamba \
                --tokenizer_name PATH_TO_MODEL/clinicalmamba-130m \
                --model_name_or_path PATH_TO_MODEL/clinicalmamba-130m \
                --do_train --do_eval --max_seq_length 16002 \
                --per_device_train_batch_size 1 --gradient_accumulation_steps 4 --per_device_eval_batch_size 1 \
                --logging_first_step \
                --output_dir ./saved_models/mamba-test01

To test the perplexity of Asclepius:

CUDA_VISIBLE_DEVICES=0 python test.py \
                --overwrite_output_dir --seed 42 --data_seed 42 \
                --data_path ./preprocess \
                --model_name gpt \
                --tokenizer_name PATH_TO_MODEL/Asclepius-R-7B \
                --model_name_or_path PATH_TO_MODEL/Asclepius-R-7B \
                --do_train --do_eval --max_seq_length 16002 \
                --per_device_train_batch_size 1 --gradient_accumulation_steps 4 --per_device_eval_batch_size 1 \
                --logging_first_step \
                --output_dir ./saved_models/asclepius-test01

To finetune on Cohort Selection for Clinical Trials:

CUDA_VISIBLE_DEVICES=0 python main.py \
                --seed 3407 --data_seed 3407 --ddp_find_unused_parameters False \
                --data_path ./data \
                --config_name PATH_TO_MODEL/clinicalmamba-130m \
                --tokenizer_name PATH_TO_MODEL/clinicalmamba-130m \
                --model_name_or_path PATH_TO_MODEL/clinicalmamba-130m \
                --do_train --do_eval --max_seq_length 15004 \
                --per_device_train_batch_size 2 --gradient_accumulation_steps 4 --per_device_eval_batch_size 2 \
                --adam_beta1 0.9 --adam_beta2 0.95 --adam_epsilon 1e-5  \
                --learning_rate 0.000445 --weight_decay 1e-2 --num_train_epochs 12 \
                --lr_scheduler_type linear --warmup_ratio 0.15 \
                --logging_steps 50 \
                --evaluation_strategy epoch --save_strategy no \
                --logging_first_step \
                --output_dir ./saved_models/clinicalmamba-test01

Citation

@misc{yang2024clinicalmamba,
      title={ClinicalMamba: A Generative Clinical Language Model on Longitudinal Clinical Notes}, 
      author={Zhichao Yang and Avijit Mitra and Sunjae Kwon and Hong Yu},
      year={2024},
      eprint={2403.05795},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

License

See the LICENSE file for more details.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published