In [None]:
%load_ext dotenv
%dotenv

import pickle
import time
import sys
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

import random
import copy
import logging

import numpy as np
from tqdm import tqdm
import argparse

import torch
import torch.nn as nn
from instruct_pipeline import InstructionTextGenerationPipeline
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

In [None]:
args = {}
args.input_model = "dolly-v2-3b"
args.gpu_family = "a2000"
args.timestamp = int(time.time())

logFormatter = logging.Formatter(
    "%(asctime)s " + "[%(threadName)-12.12s] " + "[%(levelname)-5.5s]  " + "%(message)s"
)

fileHandler = logging.FileHandler("dolly-v2-3b_{}.log".format(int(args.timestamp)))
fileHandler.setFormatter(logFormatter)
fileHandler.setLevel(logging.DEBUG)

consoleHandler = logging.StreamHandler(sys.stdout)
consoleHandler.setLevel(logging.INFO)
consoleHandler.setFormatter(logFormatter)

logging.basicConfig(handlers=[fileHandler, consoleHandler], level=logging.DEBUG)

logging.debug(str(sys.path))
logging.debug(str(os.environ["PYTHONPATH"]))

start_time = int(args.timestamp)

In [None]:
# setup random seed
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# torch.use_deterministic_algorithms(True)

In [None]:
logging.info("setup model")

if (
    args.input_model == "dolly-v2-3b"
    or args.input_model == "dolly-v2-7b"
    or args.input_model == "dolly-v2-13b"
):
    model_path_or_name = "databricks/{}".format(args.input_model)
else:
    model_path_or_name = os.path.expanduser(os.path.expandvars(args.input_model))

# configure the batch_size
batch_size = 4
if args.gpu_family == "a10":
    batch_size = 6
elif args.gpu_family == "a100":
    batch_size = 8
elif args.gpu_family == "a2000":
    batch_size = 4

tokenizer = AutoTokenizer.from_pretrained(
    model_path_or_name,
    padding_side="left",
)

device_map = None

model = AutoModelForCausalLM.from_pretrained(
    model_path_or_name,
    device_map=device_map,
    torch_dtype=torch.float32,
    torchscript=True,
)
model.eval()

if hasattr(model, "hf_device_map"):
    logging.info("device_map: {}".format(model.hf_device_map))

generate_text = InstructionTextGenerationPipeline(model=model, tokenizer=tokenizer)

In [None]:
# add autoTVM on GPU
# ref: https://tvm.apache.org/docs/how_to/tune_with_autoscheduler/tune_network_cuda.html

import numpy as np

import tvm
from tvm import relay, auto_scheduler, runtime

In [None]:
# Define the neural network and compilation target
network = "dolly-v2-3b"
batch_size = 2
sequence_size = 128
layout = "NL"
target = tvm.target.Target("cuda")
dtype = "float32"
log_file = "%s-%s-B%d-%s.json" % (network, layout, batch_size, target.kind.name)

In [None]:
# prepare network
query = "Explain to me the difference between nuclear fission and fusion."

# generate sequence ids for model input
model_inputs = generate_text.preprocess(query)
input_names = ["input_ids", "attention_mask"]
if batch_size > 1:
    for input_name in input_names:
        model_inputs[input_name] = model_inputs[input_name].repeat(batch_size, 1)
dummy_inputs = [model_inputs[input_name] for input_name in input_names]
batch_size = model_inputs["input_ids"].shape[0]
sequence_size = model_inputs["input_ids"].shape[1]

In [None]:
# trace model
for para in model.parameters():
    para.requires_grad = False

traced_file = "{}_traced.pt".format(network)
if os.path.exists(traced_file):
    logging.info("Load traced model...")
    scripted_model = torch.jit.load(traced_file)
else:
    logging.info("Trace model...")
    scripted_model = torch.jit.trace(
        model,
        dummy_inputs,
    )
    torch.jit.save(scripted_model, traced_file)
scripted_model.eval()
for para in scripted_model.parameters():
    para.requires_grad = False

In [None]:
# Extract tasks from the network
logging.info("Extract tasks...")

shape_list = [(input_name, [batch_size, sequence_size]) for input_name in input_names]
mod, params = relay.frontend.from_pytorch(
    scripted_model, shape_list, default_dtype=dtype
)

tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)

for idx, task in enumerate(tasks):
    logging.info(
        "========== Task %d  (workload key: %s) ==========" % (idx, task.workload_key)
    )
    logging.debug(str(task.compute_dag))

In [None]:
def run_tuning():
    logging.info("Begin tuning...")
    measure_ctx = auto_scheduler.LocalRPCMeasureContext(
        repeat=1, min_repeat_ms=300, timeout=10
    )

    tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=200,  # change this to 20000 to achieve the best performance
        runner=measure_ctx.runner,
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    )

    tuner.tune(tune_option)

In [None]:
run_tuning()

In [None]:
# Compile with the history best
# NOTE: use vm because graph executor does not support dynamic shape
logging.info("Compile...")
with auto_scheduler.ApplyHistoryBest(log_file):
    with tvm.transform.PassContext(
        opt_level=3, config={"relay.backend.use_auto_scheduler": True}
    ):
        vmc = relay.vm.compile(mod, target=target, params=params)

# Create VM runtime
logging.info("Create runtime...")
dev = tvm.device(str(target), 0)
module = runtime.vm.VirtualMachine(vmc, dev)
for input_name in input_names:
    module.set_input(
        input_name, tvm.nd.array(model_inputs[input_name].numpy().astype(dtype))
    )

# Evaluate
logging.info("Evaluate inference time cost...")
logging.info(str(module.benchmark(dev, repeat=3, min_repeat_ms=500)))