Breaking the Ceiling of the LLM Community by Treating Token Generation as a Classification for Ensembling [EMNLP 2024]
This work (GaC) allows multiple heterogeneous LLMs to ensemble at each generation step and collectively decide the next token. We take the union of the vocabularies of all LLMs participating in the ensemble and, at each generation step, map the probability vectors generated by each LLM to the union vocab. We then compute the (weighted) average to determine the next token. Experiments show that this simple method can break the ceiling of the open-source LLM community, allowing the ensemble result to outperform any single state-of-the-art LLM. [Paper]
We provide two types of ensemble methods:
- Every Step Ensemble: All LLMs ensemble at each generation step.
- Thresholded Ensemble: A primary(gate) LLM is specified, and ensemble is performed only if the maximum confidence token of the primary LLM at a step is below a threshold, thereby saving computational resources.
We support parallel execution of the LLMs involved in the ensemble to save time. If each LLM is allocated to a different GPU, the latency of ensembling is almost the same as using a single LLM, all managed by Ray.
Id | Models | MMLU | GSM8K | BBH | TriviaQA | NQ | Avg. | Date | Latency |
---|---|---|---|---|---|---|---|---|---|
1 | Yi-34B-Chat | 72.75 | 68.76 | 50.88 | 70.01 | 29.81 | 58.44 | 2023/11/08 | 67.96 ms/token |
2 | Mixtral-8x7B-Instruct-v0.1 | 70.89 | 66.82 | 49.84 | 76.54 | 34.35 | 59.69 | 2023/12/11 | 96.64 ms/token |
3 | Qwen1.5-72B-Chat | 77.79 | 83.33 | 48.94 | 65.69 | 27.02 | 60.55 | 2024/02/04 | 102.11 ms/token |
4 | Llama-3-70B-Instruct | 79.68 | 90.00 | 57.13 | 79.12 | 35.57 | 68.30 | 2024/04/18 | 150.32 ms/token |
5 | Qwen2-72B-Instruct | 82.30 | 89.70 | 62.57 | 73.58 | 33.11 | 68.25 | 2024/06/07 | 113.91 ms/token |
6 | GaC(Yi + Mixtral) | 74.83 | 71.21 | 52.64 | 75.60 | 33.52 | 61.56 ↑3.13% | ~2023/12/11 | 98.13 ms/token |
7 | GaC(Qwen1.5-72B + Yi) | 79.83 | 77.27 | 52.05 | 70.88 | 33.80 | 62.77 ↑3.65% | ~2024/02/04 | 103.69 ms/token |
8 | GaC(Qwen1.5-72B + Mixtral) | 79.55 | 75.76 | 54.19 | 75.71 | 31.09 | 63.26 ↑4.47% | ~2024/02/04 | 112.83 ms/token |
9 | GaC(Llama-3 + Qwen1.5-72B) | 81.49 | 87.06 | 56.73 | 78.60 | 36.01 | 67.98 ↓0.47% | ~2024/04/18 | 153.96 ms/token |
10 | GaC(Qwen2-72B + Llama-3) | 83.54 | 90.91 | 63.99 | 79.29 | 37.65 | 71.08 ↑4.06% | ~2024/06/07 | 151.56 ms/token |
Note: Ensemble of available SOTA LLMs from different periods. The top part lists the individual models, while the bottom part shows the ensemble results (model names abbreviated). ↑ indicates the percentage improvement over the individual models.
- Operating System: Ubuntu 20.04. We have not tested on Windows.
- Python Version: 3.11
- GPU: Ensure that your GPU has enough RAM to load all the models you want to ensemble.
- Environment Management Tool: Anaconda or any other suitable tool
- Open your terminal.
- Create and activate a new conda environment:
conda create -n gac_env python=3.11
conda activate gac_env
cd [root-of-this-repo]/GaC
pip install -r requirements.txt
We have integrated our work into an API server, which can be configured with a YAML file at startup to determine which LLMs to use for ensembling. An example is shown below:
NORM_TYPE_API_SERVER: 'average' # 'average' or 'score'
THRESHOLD_API_SERVER: 1.0
CONFIG_API_SERVER:
- weight: '[Please replace with the path to the local model weight]' # or 'upstage/SOLAR-10.7B-Instruct-v1.0'
max_memory:
0: '24GiB'
num_gpus: 0.5
name: 'SOLAR-10.7B-Instruct-v1.0'
score: 100
priority: 'supportive' # 'primary' or 'supportive'
quantization: 'none' # 'none'/'8bit'/'4bit'
- weight: '[Please replace with the path to the local model weight]' # or 'openchat/openchat-3.5-0106'
max_memory:
0: '24GiB'
num_gpus: 0.5
name: 'openchat-3.5-0106'
score: 100
priority: 'supportive' # 'primary' or 'supportive'
quantization: 'none' # 'none'/'8bit'/'4bit'
Note: Please ensure that the number of GPUs on your computer is greater than the sum of all
num_gpus
values, and that themax_memory
index for each model always starts from0
(you can assume each model runs on an independent machine managed by Ray).
Explanation of Parameters
- CONFIG_API_SERVER: List of models to be used in the ensemble. Each model configuration includes:
- weight: Local path to the model weight. You can also choose to use the Hugging Face model card name to download automatically.
- max_memory: Controls how much memory each GPU uses. Since each model is managed independently by Ray, the GPU IDs always start from 0. For example, if you set
num_gpus
to 2, you should allocate the maximum memory for each GPU, such as{0: 'xxGiB', 1: 'xxGiB'}
. - num_gpus: Number of GPUs allocated to this model. Controlled by Ray. To load two models on one GPU, set
num_gpus
to 0.5 for both models. So, a total of 0.5+0.5=1 GPU will be used in this case. - priority: If all models are 'supportive', the ensemble will be performed at every generation step. For threshold-based ensembling, set the gate model's priority to "primary".
- quantization: Whether to load the model with
4bit
or8bit
quantization. Setting'none'
will use the default data type specified in the huggingface model config.
- NORM_TYPE_API_SERVER: Ensemble weight type, 'average' or 'score'. 'Score' means each model's output vector in the GaC ensemble is weighted by its score divided by the total score.
- THRESHOLD_API_SERVER: Threshold for ensemble. This parameter is ineffective if all models are supportive.
Examples and Tested Models
We have listed examples of ensembling SOLAR-10.7B-Instruct-v1.0 and openchat-3.5-0106 under example_configs/
:
- example_ensemble_every_step.yaml: Ensembles at every generation step, ensuring each model's priority is 'supportive'.
THRESHOLD_API_SERVER
will be ignored. - example_thresholded_ensemble.yaml: Only ensembles at a generation step if the primary model's highest confidence token is below
THRESHOLD_API_SERVER
.
Additionally, we have listed the models that have been tested in tested_models.yaml
. However, this does not mean that the latest models not included in the list won't work; it just means we do not guarantee them.
To start the GaC server, use the following command in your terminal:
python gac_api_server.py --config-path [path-to-your-config-file.yaml] --host 0.0.0.0 --port 8000
After setting up the API, you can directly execute the GaC ensemble by making calls as demonstrated in call.py
. Here’s an explanation of the key parameters:
- messages_list: A list of messages, where each messages represent a conversation. Each messages can contain multiple sub-messages in the format
{"role": "..", "content": ".."}
. The value of "role" can besystem
,user
, orassistant
. The list containingn
messages indicates a batch size ofn
. However, for thresholded ensembles, only a batch size of 1 is currently supported. - max_new_tokens: An integer that determines the maximum number of tokens that can be generated.
- apply_chat_template: A boolean. If
True
, each model will assemble the messages into a specific template according to the "chat_template" Jinja template defined in itstokenizer_config.json
.
Here’s an example code:
import requests
url = "http://0.0.0.0:8000/api/generate/"
data = {
"messages_list": [
# Conversation 1
[{"role": "user", "content": "9.11 and 9.9, which is bigger?"},
{"role": "assistant", "content": "..."},
{"role": "user", "content": "..."}],
# Conversation 2
[{"role": "user", "content": "How are you?"}]
],
"max_new_tokens": 1024,
"apply_chat_template": True,
}
response = requests.post(url, json=data)
print(response.json())
@misc{yu2024breaking,
title={Breaking the Ceiling of the LLM Community by Treating Token Generation as a Classification for Ensembling},
author={Yu, Yao-Ching and Kuo, Chun-Chih and Ye, Ziqi and Chang, Yu-Cheng and Li, Yueh-Se},
year={2024},
eprint={2406.12585},
archivePrefix={arXiv},
primaryClass={cs.CL}
}