## Imports

In [None]:
import sys
import os
from pathlib import Path
import io
import socket
import argparse
import bert_squad_main
from central.habana_model_runner_utils import get_canonical_path, get_canonical_path_str, is_valid_multi_node_config, get_multi_node_config_nodes
from central.multi_node_utils import run_per_ip
import TensorFlow.nlp.bert.download.download_pretrained_model as download_pretrained_model
from TensorFlow.common.common import setup_jemalloc
from TensorFlow.nlp.bert.demo_bert import *

## Customize training options

In [None]:
# choose between 'finetuning' or 'pretraining'
command =  "finetuning"

# choose between 'tiny', 'mini', 'small', 'medium', 'base', or 'large'
model_variant = "large"

# choose between "fp32" or "bf16"
data_type = "bf16"

# leave as "squad" for now
test_set = "squad"

# path to the squad dataset
dataset_path = "/nfs/pvc-datasets-research/processed/tf/bert/squad/"

# change working directory so code picks up checkpoints
os.chdir('/cnvrg/TensorFlow/nlp/bert')

## Launch Training on Gaudi

In [None]:
try:
    arg_list = ["--command", command, "--model_variant", model_variant, "--data_type", 
                data_type, "--test_set", test_set, "--dataset_path", dataset_path]
    args = BertArgparser().parse_args(arg_list)

    setup_jemalloc()
    check_data_type_and_tf_bf16_conversion(args)
    check_and_log_synapse_env_vars()
    model = get_model(args)
    if args.bert_config_dir is not None:
      model.pretrained_model = args.bert_config_dir
    # This downloads the model on all remote IPs if env variable MULTI_HLS_IPS is set
    if args.checkpoint_folder is not None:
      print(f"Using custom model folder: {args.checkpoint_folder}")
    elif args.use_horovod is not None and not args.kubernetes_run:
      model.download_pretrained_model(True)
    else:
      model.download_pretrained_model(False)
    the_map = {
        "pretraining": model.pretraining,
        "finetuning": model.finetuning
    }
    the_map[args.command](args)
except Exception as exc:
  raise RuntimeError("Error running finetuning") from exc