# Disclaimer
This material was prepared as an account of work sponsored by an agency of the United States Government.  Neither the United States Government nor the United States Department of Energy, nor Battelle, nor any of their employees, nor any jurisdiction or organization that has cooperated in the development of these materials, makes any warranty, express or implied, or assumes any legal liability or responsibility for the accuracy, completeness, or usefulness or any information, apparatus, product, software, or process disclosed, or represents that its use would not infringe privately owned rights. Reference herein to any specific commercial product, process, or service by trade name, trademark, manufacturer, or otherwise does not necessarily constitute or imply its endorsement, recommendation, or favoring by the United States Government or any agency thereof, or Battelle Memorial Institute. The views and opinions of authors expressed herein do not necessarily state or reflect those of the United States Government or any agency thereof.

PACIFIC NORTHWEST NATIONAL LABORATORY operated by BATTELLE for the UNITED STATES DEPARTMENT OF ENERGY under Contract DE-AC05-76RL01830.

# Fine-Tune NukeLM on Azure Databricks

Wraps `nukelm.run_fine_tune` to install dependencies, provide a drop-down menu for parameters, and copy results to Azure storage. It utilizes multiple GPUs and integrates with MLFlow.

Please use Databricks Runtime 7.3 LTS ML (preferably with GPU support).

| Parameter | Description |
| --- | --- |
| Model Name | Short-hand name for a pre-trained language model starting point. |
| Label Type | Whether to use fine-grained (multi-class) or coarse-grained (binary) labels. |
| RNG Seed | Integer with which to seed a random number generator. |

In [None]:
# Note: this approach is for running on DataBricks and assumes installation of nukelm via the wheel file.
%pip install nukelm-1.0.0-py3-none-any.whl

In [None]:
import json
import logging
import shutil
import sys
from pathlib import Path

import numpy as np
import torch

from nukelm import run_fine_tune


# disable noisy py4j logger
logging.getLogger("py4j").setLevel(logging.WARNING)

# locations on Azure storage
blob_path = Path("")
assert blob_path.exists()

working_path = blob_path
dbfs_working_path = blob_path

src_path = working_path / "src"
if src_path not in sys.path:
    sys.path.append(str(src_path))

data_path = working_path / "osti" / "finetune"
DATASET_SIZE = 188654

In [None]:
MODELS = {
    "roberta_base-ots": {
        "Model Name or Path": "roberta-base",
        "Tokenizer Name": "roberta-base",
    },
    "roberta_base-trained": {
        "Model Name or Path": "/databricks/mlflow/1636808823342661/b61deb3377ac412291956719bf0ca952/artifacts/model",
        "Tokenizer Name": "roberta-base",
    },
    "roberta_large-ots": {
        "Model Name or Path": "roberta-large",
        "Tokenizer Name": "roberta-large",
    },
    "roberta_large-trained": {
        "Model Name or Path": str(working_path / "models" / "MLM"),
        "Tokenizer Name": "roberta-large",
    },
    "scibert-ots": {
        "Model Name or Path": "allenai/scibert_scivocab_uncased",
        "Tokenizer Name": "allenai/scibert_scivocab_uncased",
    },
    "scibert-trained": {
        "Model Name or Path": "/databricks/mlflow/1636808823342661/fcbe913b86cd4e5c807efad7cf16ef74/artifacts/model",
        "Tokenizer Name": "allenai/scibert_scivocab_uncased",
    },
}

model_names = list(MODELS.keys())
dbutils.widgets.dropdown("Model Name", model_names[0], model_names)  # type: ignore # NOQA: F821
dbutils.widgets.dropdown("Label Type", "Coarse", ["Fine", "Coarse"])  # type: ignore # NOQA: F821
dbutils.widgets.text("RNG Seed", "42")  # type: ignore # NOQA: F821

In [None]:
model_name = dbutils.widgets.get("Model Name")  # type: ignore # NOQA: F821
label_type = dbutils.widgets.get("Label Type")  # type: ignore # NOQA: F821
rng_seed = int(dbutils.widgets.get("RNG Seed"))  # type: ignore # NOQA: F821

n_devices = torch.cuda.device_count()

run_dir = Path(f"/tmp/run_{model_name.replace('_', '-')}_{label_type.lower()}-grained-labels_{rng_seed}")
try:
    run_dir.mkdir()
except FileExistsError:
    print("We've run with that seed before -- try a new one?")

params = {
    "model_name_or_path": MODELS[model_name]["Model Name or Path"],
    "tokenizer_name": MODELS[model_name]["Tokenizer Name"],
    "train_file": str(data_path / f"{label_type.lower()}-grained-labels" / "train.csv"),
    "validation_file": str(data_path / f"{label_type.lower()}-grained-labels" / "val.csv"),
    "num_train_epochs": 3,
    "learning_rate": 1e-5,
    "per_device_train_batch_size": 4,
    "gradient_accumulation_steps": 4,
    "per_device_eval_batch_size": 8,
    "output_dir": str(run_dir / "output"),
    "logging_dir": str(run_dir / "logs"),
    "seed": rng_seed,
    "do_train": True,
    "do_eval": True,
    "evaluation_strategy": "steps",
    "load_best_model_at_end": True,
    "fp16": True,
    "n_gpu": n_devices,
}

effective_batch_size = (
    params["per_device_train_batch_size"] * params["gradient_accumulation_steps"] * torch.cuda.device_count()
)
n_steps = params["num_train_epochs"] * np.ceil(DATASET_SIZE / effective_batch_size)

params["warmup_steps"] = int(0.06 * n_steps)

logging_steps = int(np.floor(n_steps / 20))
params["logging_steps"] = logging_steps
params["save_steps"] = logging_steps
params["eval_steps"] = logging_steps

with open(run_dir / "training_config.json", "w") as fh:
    json.dump(params, fh)

params

In [None]:
run_fine_tune(config_path=str(run_dir / "training_config.json"))

In [None]:
save_dir = (
    working_path
    / "finetune_output"
    / f"run_{model_name.replace('_', '-')}_{label_type.lower()}-grained-labels_{rng_seed}"
)
save_dir.mkdir(parents=True)

In [None]:
files_to_save = [
    "config.json",
    "merges.txt",
    "pytorch_model.bin",
    "special_tokens_map.json",
    "tokenizer_config.json",
    "training_args.bin",
    "vocab.txt",
]

shutil.copy2(run_dir / "config.json", save_dir / "training-config.json")
for f in files_to_save:
    filepath = run_dir / "output" / f
    if not (filepath).exists():
        print(f"File {filepath} does not exist")
        continue
    shutil.copy2(filepath, save_dir)