In [None]:
!rm -r ViT5
!git clone https://github.com/vietAI/ViT5.git

In [None]:
## Install JAX for GPU
!pip install jaxlib==0.4.2+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
## Install T5X and dependencies
!cd ViT5 && python3 setup.py

In [None]:
MODEL_SIZE = "base"

In [None]:
# download ViT5X base model
!gsutil -m cp -r gs://vietai_public/viT5/ViT5_{MODEL_SIZE} .

In [None]:
############################### Generation Task ###############################
######################### A Health-domain Summarization Dataset released by ######################
######################### ViHealthBERT (https://aclanthology.org/2022.lrec-1.35/) #################

!mkdir FAQ_summarization
!wget -O FAQ_summarization/dev.tsv https://raw.githubusercontent.com/vietai/ViPubmed/main/data/FAQ_summarization/dev.tsv
!wget -O FAQ_summarization/test.tsv https://raw.githubusercontent.com/vietai/ViPubmed/main/data/FAQ_summarization/test.tsv
!wget -O FAQ_summarization/train.tsv https://raw.githubusercontent.com/vietai/ViPubmed/main/data/FAQ_summarization/train.tsv

################################################################# 

task = 'FAQ_summarization'
train_file = f'{task}/train.tsv'
test_file = f'{task}/test.tsv'
dev_file = f'{task}/dev.tsv'

model_dir = f'out/{task}/vit5_base'
pretrained_path=f'ViT5_{MODEL_SIZE}/checkpoint_1000000'

gin_file = f'ViT5/configs/runs/{MODEL_SIZE}_finetune.gin'

metric = 'rouge'

# Train settings
batch_size = 64
features_length = {"inputs": 256, "targets": 64}
train_steps = 1000 + 1500000 # 1000 finetune steps + 1.5M pretraining step
save_period = 1000
eval_period = 1000

!python3 'ViT5/src/finetune_t5x.py' \
  --gin_file="{gin_file}" \
  --gin.MODEL_DIR="'{model_dir}'" \
  --gin.INITIAL_CHECKPOINT_PATH="'{pretrained_path}'" \
  --gin.TRAIN_STEPS='{train_steps}' \
  --gin.SAVE_PERIOD='{save_period}'\
  --gin.EVAL_PERIOD='{eval_period}'\
  --gin.MIXTURE_OR_TASK_NAME="'{task}'" \
  --gin.TASK_FEATURE_LENGTHS="{features_length}" \
  --gin.BATCH_SIZE='{batch_size}' \
  --task="{task}" \
  --metric="{metric}" \
  --train_file="{train_file}" \
  --predict_file="{test_file}" # or {dev_file}


I0206 18:15:33.707216 140281490196288 finetune_t5x.py:623] Epoch 1000 of 1501
I0206 18:15:33.707377 140281490196288 finetune_t5x.py:629] BEGIN Train loop.
I0206 18:15:33.707416 140281490196288 finetune_t5x.py:634] Training for 1000 steps.
I0206 18:15:33.708324 140223275026176 logging_writer.py:48] [1000000] collection=train timing/compilation_seconds=120.909
I0206 18:15:33.710309 140223275026176 logging_writer.py:48] [1000000] collection=train timing/train_iter_warmup=6.19888e-06
I0206 18:15:33.710495 140281490196288 trainer.py:500] Training: step 1000000
I0206 18:15:43.957198 140281490196288 trainer.py:500] Training: step 1000025
I0206 18:15:54.192222 140281490196288 trainer.py:500] Training: step 1000050
I0206 18:16:04.432064 140281490196288 trainer.py:500] Training: step 1000075
I0206 18:16:14.711035 140281490196288 trainer.py:500] Training: step 1000100
I0206 18:16:24.777673 140281490196288 trainer.py:500] Training: step 1000124
I0206 18:16:35.114796 140281490196288 trainer.py:500]

In [None]:
############################### Classification Task ##########################
############################### ViMedNLI #####################################
############## https://arxiv.org/abs/2210.05610 ##############################

!mkdir vi_mednli
!wget -O vi_mednli/dev.tsv https://raw.githubusercontent.com/vietai/ViPubmed/main/data/vi_mednli/dev_vi_refined.tsv 
!wget -O vi_mednli/test.tsv https://raw.githubusercontent.com/vietai/ViPubmed/main/data/vi_mednli/test_vi_refined.tsv 
!wget -O vi_mednli/train.tsv https://raw.githubusercontent.com/vietai/ViPubmed/main/data/vi_mednli/train_vi_refined.tsv 

############################### vimednli ###############################

MODEL_SIZE = "base"
task = 'vi_mednli'
train_file = f'{task}/train.tsv'
test_file = f'{task}/test.tsv'
dev_file = f'{task}/dev.tsv'

model_dir = f'out/{task}/vit5_base'
pretrained_path=f'ViT5_{MODEL_SIZE}/checkpoint_1000000'

gin_file = f'ViT5/configs/runs/{MODEL_SIZE}_finetune.gin'

metric = 'accuracy'

# Train settings
batch_size = 64
features_length = {"inputs": 128, "targets": 6}
train_steps = 4800 + 1000000 # 1000 finetune steps + 1.5M pretraining step
save_period = 4800
eval_period = 4800
learning_rate = 0.0005

!python3 'ViT5/src/finetune_t5x.py' \
  --gin_file="{gin_file}" \
  --gin.MODEL_DIR="'{model_dir}'" \
  --gin.INITIAL_CHECKPOINT_PATH="'{pretrained_path}'" \
  --gin.TRAIN_STEPS='{train_steps}' \
  --gin.SAVE_PERIOD='{save_period}'\
  --gin.EVAL_PERIOD='{eval_period}'\
  --gin.MIXTURE_OR_TASK_NAME="'{task}'" \
  --gin.LEARNING_RATE='{learning_rate}' \
  --gin.TASK_FEATURE_LENGTHS="{features_length}" \
  --gin.BATCH_SIZE='{batch_size}' \
  --task="{task}" \
  --metric="{metric}" \
  --train_file="{train_file}" \
  --predict_file="{test_file}" # or {dev_file}



In [None]:
# download vietnews or wikilingua dataset
!gsutil cp -r gs://vietai_public/viT5/data/wikilingua .
# !gsutil cp -r gs://vietai_public/viT5/data/vietnews .

In [None]:
############################### Generation Task ##########################
############################### wikilingua ###############################
MODEL_SIZE = "base"
task = 'wikilingua'
train_file = f'{task}/train.tsv'
test_file = f'{task}/test.tsv'
dev_file = f'{task}/dev.tsv'

model_dir = f'out/{task}/vit5_{MODEL_SIZE}'
pretrained_path=f'ViT5_{MODEL_SIZE}/checkpoint_1000000'

gin_file = f'ViT5/configs/runs/{MODEL_SIZE}_finetune.gin'
metric = 'rouge'

# Train settings
batch_size = 16
features_length = {"inputs": 1024, "targets": 256}
train_steps = 10000 + 1000000 # 1000 finetune steps + 1.5M pretraining step
save_period = 10000
eval_period = 10000

!python3 'ViT5/src/finetune_t5x.py' \
  --gin_file="{gin_file}" \
  --gin.MODEL_DIR="'{model_dir}'" \
  --gin.INITIAL_CHECKPOINT_PATH="'{pretrained_path}'" \
  --gin.TRAIN_STEPS='{train_steps}' \
  --gin.SAVE_PERIOD='{save_period}'\
  --gin.EVAL_PERIOD='{eval_period}'\
  --gin.MIXTURE_OR_TASK_NAME="'{task}'" \
  --gin.TASK_FEATURE_LENGTHS="{features_length}" \
  --gin.BATCH_SIZE='{batch_size}' \
  --task="{task}" \
  --metric="{metric}" \
  --train_file="{train_file}" \
  --predict_file="{test_file}" # or {dev_file}
