In [1]:
!rm -r ViPubmed
!git clone https://github.com/justinphan3110/ViPubmed.git

Cloning into 'ViPubmed'...
remote: Enumerating objects: 410, done.[K
remote: Counting objects: 100% (89/89), done.[K
remote: Compressing objects: 100% (66/66), done.[K
remote: Total 410 (delta 45), reused 58 (delta 21), pack-reused 321[K
Receiving objects: 100% (410/410), 8.03 MiB | 12.07 MiB/s, done.
Resolving deltas: 100% (213/213), done.


In [None]:
## Install JAX for GPU
!pip install jaxlib==0.4.2+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
## Install T5X and dependencies
!cd ViPubmed && python3 setup.py

In [None]:
# download ViPubmedT5X base model
!gsutil -m cp -r gs://vietai_public/vipubmedt5_base .

In [None]:
############################### FAQ_summarization ###############################

task = 'FAQ_summarization'
train_file = f'ViPubmed/data/{task}/train.tsv'
test_file = f'ViPubmed/data/{task}/test.tsv'
dev_file = f'ViPubmed/data/{task}/dev.tsv'

model_dir = f'out/{task}/vipubmedt5_base'
pretrained_path=f'vipubmedt5_base/checkpoint_1500000'

gin_file = f'ViPubmed/configs/runs/base_finetune.gin'

metric = 'rouge'

# Train settings
batch_size = 64
features_length = {"inputs": 256, "targets": 64}
train_steps = 1000 + 1500000 # 1000 finetune steps + 1.5M pretraining step
save_period = train_steps
eval_period = 1000

!python3 'ViPubmed/src/finetune_t5x.py' \
  --gin_file="{gin_file}" \
  --gin.MODEL_DIR="'{model_dir}'" \
  --gin.INITIAL_CHECKPOINT_PATH="'{pretrained_path}'" \
  --gin.TRAIN_STEPS='{train_steps}' \
  --gin.SAVE_PERIOD='{save_period}'\
  --gin.EVAL_PERIOD='{eval_period}'\
  --gin.MIXTURE_OR_TASK_NAME="'{task}'" \
  --gin.TASK_FEATURE_LENGTHS="{features_length}" \
  --gin.BATCH_SIZE='{batch_size}' \
  --task="{task}" \
  --metric="{metric}" \
  --train_file="{train_file}" \
  --predict_file="{test_file}" # or {dev_file}


In [None]:
############################### acrDrAid ###############################

task = 'acrDrAid'
train_file = f'ViPubmed/data/{task}/train.tsv'
test_file = f'ViPubmed/data/{task}/test.tsv'
dev_file = f'ViPubmed/data/{task}/dev.tsv'

model_dir = f'out/{task}/vipubmedt5_base_dev_0'
pretrained_path=f'vipubmedt5_base/checkpoint_1500000'

gin_file = f'ViPubmed/configs/runs/base_finetune.gin'

metric = 'macro_f1'

# Train settings
batch_size = 64
features_length = {"inputs": 256, "targets": 3}
train_steps = 1600 + 1500000 # 1600 finetune steps + 1.5M pretraining step
save_period = train_steps
eval_period = 1600
learning_rate = 0.0005


!python3 'ViPubmed/src/finetune_t5x.py' \
  --gin_file="{gin_file}" \
  --gin.MODEL_DIR="'{model_dir}'" \
  --gin.INITIAL_CHECKPOINT_PATH="'{pretrained_path}'" \
  --gin.TRAIN_STEPS='{train_steps}' \
  --gin.SAVE_PERIOD='{save_period}'\
  --gin.EVAL_PERIOD='{eval_period}'\
  --gin.MIXTURE_OR_TASK_NAME="'{task}'" \
  --gin.LEARNING_RATE='{learning_rate}' \
  --gin.TASK_FEATURE_LENGTHS="{features_length}" \
  --gin.BATCH_SIZE='{batch_size}' \
  --task="{task}" \
  --metric="{metric}" \
  --train_file="{train_file}" \
  --predict_file="{test_file}" # or {dev_file}


In [None]:
############################### vimednli ###############################

task = 'vi_mednli'
train_file = f'ViPubmed/data/{task}/train_vi_refined.tsv'
test_file = f'ViPubmed/data/{task}/test_vi_refined.tsv'
dev_file = f'ViPubmed/data/{task}/dev_vi_refined.tsv'

model_dir = f'out/{task}/vipubmedt5_base'
pretrained_path=f'vipubmedt5_base/checkpoint_1500000'

gin_file = f'ViPubmed/configs/runs/base_finetune.gin'

metric = 'accuracy'

# Train settings
batch_size = 64
features_length = {"inputs": 128, "targets": 6}
train_steps = 4800 + 1500000 # 1000 finetune steps + 1.5M pretraining step
save_period = train_steps
eval_period = 4800
learning_rate = 0.0005

!python3 'ViPubmed/src/finetune_t5x.py' \
  --gin_file="{gin_file}" \
  --gin.MODEL_DIR="'{model_dir}'" \
  --gin.INITIAL_CHECKPOINT_PATH="'{pretrained_path}'" \
  --gin.TRAIN_STEPS='{train_steps}' \
  --gin.SAVE_PERIOD='{save_period}'\
  --gin.EVAL_PERIOD='{eval_period}'\
  --gin.MIXTURE_OR_TASK_NAME="'{task}'" \
  --gin.LEARNING_RATE='{learning_rate}' \
  --gin.TASK_FEATURE_LENGTHS="{features_length}" \
  --gin.BATCH_SIZE='{batch_size}' \
  --task="{task}" \
  --metric="{metric}" \
  --train_file="{train_file}" \
  --predict_file="{test_file}" # or {dev_file}
