This is a tutorial of how to use Large Language Model (LLM) with [Transformers.jl](https://github.com/chengchingwen/Transformers.jl).

In [1]:
using Transformers, CUDA

After loading the package, we need to setup the gpu. Currently multi-gpu is not supported. If your machine have multiple gpu devices, we can use `CUDA.devices()` to get the list of all device and use `CUDA.device!(device_number)` to specify the device we want to run our model on.

In [2]:
CUDA.devices()

CUDA.DeviceIterator() for 8 devices:
0. NVIDIA A100 80GB PCIe
1. NVIDIA A100 80GB PCIe
2. NVIDIA A100-PCIE-40GB
3. Tesla V100-PCIE-32GB
4. Tesla V100-PCIE-32GB
5. Tesla V100S-PCIE-32GB
6. Tesla V100-PCIE-32GB
7. Tesla V100-PCIE-32GB

In [3]:
CUDA.device!(1)

CuDevice(1): NVIDIA A100 80GB PCIe

For demonstration, we disable the scalar indexing on gpu so that we can make sure all gpu calls are handled without performance issue. By setting `enable_gpu`, we get a `todevice` provided by Transformers.jl that will move data/model to gpu device.

In [4]:
CUDA.allowscalar(false)
enable_gpu(true)

todevice (generic function with 1 method)

In this tutorial, we show how to do use the [llama-2-7b-chat (https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)](https://huggingface.co/databricks/dolly-v2-12b) in Julia.

The process should also work for other causal LM based model. With Transformers.jl, we can get the tokenizer and model by using the `hgf""` macro or `HuggingFace.load_tokenizer`/`HuggingFace.load_model`. The required files like the model weights will be downloaded and managed automatically.

You would need a huggingface account that has access to llama2. Once you have the account, you need to copy your access token and pass it to Transformers.jl:

```julia
access_token = ""

# This will save the access token to the disk, then all call to 
# download file from huggingface hub will use this token.
using HuggingFaceApi
HuggingFaceApi.save_token(access_token)

# or call those `load` function with `auth_token` keyword argument
# like this:
HuggingFace.load_tokenizer("meta-llama/Llama-2-7b-chat-hf"; auth_token = access_token)
```

In [5]:
using Transformers.HuggingFace

textenc = hgf"meta-llama/Llama-2-7b-chat-hf:tokenizer"
model = todevice(hgf"meta-llama/Llama-2-7b-chat-hf:ForCausalLM") # move to gpu with `todevice` (or `Flux.gpu`)

[33m[1m└ [22m[39m[90m@ Transformers.HuggingFace ~/Transformers.jl/src/huggingface/tokenizer/utils.jl:96[39m
[33m[1m└ [22m[39m[90m@ Transformers.TextEncoders ~/Transformers.jl/src/textencoders/TextEncoders.jl:76[39m


HGFLlamaForCausalLM(
  HGFLlamaModel(
    CompositeEmbedding(
      token = Embed(4096, 32000),       [90m# 131_072_000 parameters[39m
    ),
    Chain(
      Transformer<32>(
        PreNormTransformerBlock(
          SelfAttention(
            CausalGPTNeoXRoPEMultiheadQKVAttenOp(base = 10000.0, dim = 128, head = 32, p = nothing),
            Fork<3>(Dense(W = (4096, 4096), b = false)),  [90m# 50_331_648 parameters[39m
            Dense(W = (4096, 4096), b = false),  [90m# 16_777_216 parameters[39m
          ),
          RMSLayerNorm(4096, ϵ = 1.0e-6),  [90m# 4_096 parameters[39m
          Chain(
            LLamaGated(Dense(σ = NNlib.swish, W = (4096, 11008), b = false), Dense(W = (4096, 11008), b = false)),  [90m# 90_177_536 parameters[39m
            Dense(W = (11008, 4096), b = false),  [90m# 45_088_768 parameters[39m
          ),
          RMSLayerNorm(4096, ϵ = 1.0e-6),  [90m# 4_096 parameters[39m
        ),
      ),[90m                  # Total: 288 arrays, [39

We define some helper functions for the text generation. Here we are doing the simple greedy decoding. It can be replaced with other decoding algorithm like beam search. The `k` in `top_k_sample` decide the number of possible choices at each generation step. The default `k = 1` is simply `argmax`.

In [6]:
using Flux
using StatsBase

function temp_softmax(logits; temperature = 1.2)
    return softmax(logits ./ temperature)
end

function top_k_sample(probs; k = 1)
    sorted = sort(probs, rev = true)
    indexes = partialsortperm(probs, 1:k, rev=true)
    index = sample(indexes, ProbabilityWeights(sorted[1:k]), 1)
    return index
end

top_k_sample (generic function with 1 method)

The main generation loop is defined as follows:

1. The prompt is first preprocessed and encoded with the tokenizer `textenc`. The `encode` function return a `NamedTuple` where `.token` is the one-hot representation of our context tokens.
2. At each iteration, we copy the tokens to gpu and feed them to the model. The model also return a `NamedTuple` where `.logit` is the predictions of our model. We then apply the greedy decoding scheme to get the prediction of next token. The token will be appended to the end of context tokens. The iteration stop if we exceed the maximum generation length or the predicted token is an end token.
3. After the loop, we decode the one-hot encoding back to text tokens. The `decode` function convert the onehots to texts and also perform some post-processing to get the final list of strings.

In [7]:
using Transformers.TextEncoders

function generate_text(textenc, model, context = ""; max_length = 512, k = 1, temperature = 1.2, ends = textenc.endsym)
    encoded = encode(textenc, context).token
    ids = encoded.onehots
    ends_id = lookup(textenc.vocab, ends)
    for i in 1:max_length
        input = (; token = encoded) |> todevice
        outputs = model(input)
        logits = @view outputs.logit[:, end, 1]
        probs = temp_softmax(logits; temperature)
        new_id = top_k_sample(collect(probs); k)[1]
        push!(ids, new_id)
        new_id == ends_id && break
    end
    return encoded
end

generate_text (generic function with 2 methods)

We follow the prompt in [huggingface's llama2 blogpost](https://huggingface.co/blog/llama2)

In [8]:
function generate(textenc, model, instruction; max_length = 512, k = 1, temperature = 1.2)
    prompt = """
    [INST] <<SYS>>
    You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

    If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
    <</SYS>>

    $instruction [/INST]
    
    """
    text_token = generate_text(textenc, model, prompt; max_length, k, temperature)
    gen_text = decode_text(textenc, text_token)
    println(gen_text)
end

generate (generic function with 1 method)

In [9]:
generate(textenc, model, "Can you explain to me briefly what is the Julia programming language?")

<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

Can you explain to me briefly what is the Julia programming language? [/INST]

Of course! Julia is a high-level, high-performance programming language for technical computing. It was created in 2009 by Jeff Bezanson, Alan Edelman, Stefan Karpinski, and Viral Shah. Julia is primarily designed for numerical and scientific computing, and it aims to provide a more efficient and expressive alternative to languages like Python, R, and Matlab.

Some of the key featu