In [1]:
!rm -r ViT5
!git clone https://github.com/vietAI/ViT5.git

Cloning into 'ViT5'...
remote: Enumerating objects: 842, done.[K
remote: Counting objects: 100% (133/133), done.[K
remote: Compressing objects: 100% (79/79), done.[K
remote: Total 842 (delta 69), reused 107 (delta 48), pack-reused 709[K
Receiving objects: 100% (842/842), 53.23 MiB | 26.98 MiB/s, done.
Resolving deltas: 100% (488/488), done.


In [3]:
## Install JAX for GPU
!pip install jaxlib==0.4.2+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
## Install T5X and dependencies
!cd ViPubmed && python3 setup.py

Successfully installed argcomplete-2.0.0 boto-2.49.0 crcmod-1.7 cryptography-39.0.0 etils-1.0.0 fasteners-0.18 gcs-oauth2-boto-plugin-3.0 google-apitools-0.5.32 google-reauth-0.1.1 gsutil-5.8 httplib2-0.21.0 lxml-4.9.2 monotonic-1.6 nest-asyncio-1.5.6 oauth2client-4.1.3 pyOpenSSL-23.0.0 pyu2f-0.1.5 retry_decorator-1.1.1 rsa-4.7.2 tensorflow_datasets-4.8.1
[0m

In [4]:
MODEL_SIZE = "base"

In [None]:
# download ViPubmedT5X base model
!gsutil -m cp -r gs://vietai_public/viT5/ViT5_{MODEL_SIZE} .

In [7]:
# download vietnews or wikilingua dataset
!gsutil cp -r gs://vietai_public/viT5/data/wikilingua .
# !gsutil cp -r gs://vietai_public/viT5/data/vietnews .


Copying gs://vietai_public/viT5/data/wikilingua/test.tsv...
Copying gs://vietai_public/viT5/data/wikilingua/train.tsv...                    
Copying gs://vietai_public/viT5/data/wikilingua/val.tsv...                      
| [3 files][ 63.5 MiB/ 63.5 MiB]    8.8 MiB/s                                   
Operation completed over 3 objects/63.5 MiB.                                     


In [None]:
############################### Generation Task ##########################
############################### wikilingua ###############################
MODEL_SIZE = "base"
task = 'wikilingua'
train_file = f'{task}/train.tsv'
test_file = f'{task}/test.tsv'
dev_file = f'{task}/dev.tsv'

model_dir = f'out/{task}/vit5_{MODEL_SIZE}'
pretrained_path=f'ViT5_{MODEL_SIZE}/checkpoint_1000000'

gin_file = f'ViT5/configs/runs/{MODEL_SIZE}_finetune.gin'
metric = 'rouge'

# Train settings
batch_size = 32
features_length = {"inputs": 1024, "targets": 256}
train_steps = 10000 + 1500000 # 1000 finetune steps + 1.5M pretraining step
save_period = train_steps
eval_period = train_steps

!python3 'ViPubmed/src/finetune_t5x.py' \
  --gin_file="{gin_file}" \
  --gin.MODEL_DIR="'{model_dir}'" \
  --gin.INITIAL_CHECKPOINT_PATH="'{pretrained_path}'" \
  --gin.TRAIN_STEPS='{train_steps}' \
  --gin.SAVE_PERIOD='{save_period}'\
  --gin.EVAL_PERIOD='{eval_period}'\
  --gin.MIXTURE_OR_TASK_NAME="'{task}'" \
  --gin.TASK_FEATURE_LENGTHS="{features_length}" \
  --gin.BATCH_SIZE='{batch_size}' \
  --task="{task}" \
  --metric="{metric}" \
  --train_file="{train_file}" \
  --predict_file="{test_file}" # or {dev_file}


In [None]:
############################### Classification Task ##########################
############################### ViMedNLI #####################################
############## https://arxiv.org/abs/2210.05610 ##############################

!mkdir vi_mednli
!wget -O vi_mednli/dev.tsv https://raw.githubusercontent.com/vietai/ViPubmed/main/data/vi_mednli/dev_vi_refined.tsv 
!wget -O vi_mednli/test.tsv https://raw.githubusercontent.com/vietai/ViPubmed/main/data/vi_mednli/test_vi_refined.tsv 
!wget -O vi_mednli/train.tsv https://raw.githubusercontent.com/vietai/ViPubmed/main/data/vi_mednli/train_vi_refined.tsv 

############################### vimednli ###############################

MODEL_SIZE = "base"
task = 'vi_mednli'
train_file = f'{task}/train.tsv'
test_file = f'{task}/test.tsv'
dev_file = f'{task}/dev.tsv'

model_dir = f'out/{task}/vit5_base'
pretrained_path=f'vip_{MODEL_SIZE}/checkpoint_1500000'

gin_file = f'ViPubmed/configs/runs/{MODEL_SIZE}_finetune.gin'

metric = 'accuracy'

# Train settings
batch_size = 64
features_length = {"inputs": 128, "targets": 6}
train_steps = 4800 + 1500000 # 1000 finetune steps + 1.5M pretraining step
save_period = train_steps
eval_period = 4800
learning_rate = 0.0005

!python3 'ViPubmed/src/finetune_t5x.py' \
  --gin_file="{gin_file}" \
  --gin.MODEL_DIR="'{model_dir}'" \
  --gin.INITIAL_CHECKPOINT_PATH="'{pretrained_path}'" \
  --gin.TRAIN_STEPS='{train_steps}' \
  --gin.SAVE_PERIOD='{save_period}'\
  --gin.EVAL_PERIOD='{eval_period}'\
  --gin.MIXTURE_OR_TASK_NAME="'{task}'" \
  --gin.LEARNING_RATE='{learning_rate}' \
  --gin.TASK_FEATURE_LENGTHS="{features_length}" \
  --gin.BATCH_SIZE='{batch_size}' \
  --task="{task}" \
  --metric="{metric}" \
  --train_file="{train_file}" \
  --predict_file="{test_file}" # or {dev_file}

