Skip to content

Commit

Permalink
Merge pull request #8 from smspillaz/quantization-conversion-support
Browse files Browse the repository at this point in the history
Support for on-loading quantization
  • Loading branch information
smspillaz committed Jul 27, 2023
2 parents 6511e3a + 102980b commit ea09786
Show file tree
Hide file tree
Showing 19 changed files with 1,103 additions and 84 deletions.
95 changes: 71 additions & 24 deletions examples/llm-writer-app/src/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,28 @@ const list_store_from_rows = (rows) => {
return list_store;
};

const load_model = (model, cancellable, callback, progress_callback) => {
const load_model = (model, quantization_level, cancellable, callback, progress_callback) => {
const istream = GGML.LanguageModel.stream_from_cache(model);

if (progress_callback) {
istream.set_download_progress_callback(progress_callback);
}

const config = GGML.ModelConfig.new();

if (quantization_level !== null)
{
config.set_quantization_config(
quantization_level,
GGML.gpt_model_quantization_regexes(),
null
)
}

GGML.LanguageModel.load_defined_from_istream_async(
model,
istream,
config,
cancellable,
(src, res) => {
try {
Expand All @@ -77,10 +89,20 @@ const load_model = (model, cancellable, callback, progress_callback) => {
};

const COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM = Object.keys(GGML.DefinedLanguageModel).map(k => GGML.DefinedLanguageModel[k]);
const COMBOBOX_ID_TO_QUANTIZATION_LEVEL_ENUM = [
null,
GGML.DataType.F16,
GGML.DataType.Q8_0,
GGML.DataType.Q5_0,
GGML.DataType.Q5_1,
GGML.DataType.Q4_0,
GGML.DataType.Q4_1,
];

class ModelLoader {
constructor() {
this._model_enum = null;
this._quantization_enum = null;
this._model = null;
this._pending_load = null;
}
Expand All @@ -95,21 +117,24 @@ class ModelLoader {
* if the action is cancelled, then @callback won't be invoked, but
* the model will stil be downloaded if the download is in progress.
*/
with_model(model_enum, cancellable, callback, progress_callback) {
if (this._model_enum === model_enum) {
with_model(model_enum, quantization_enum, cancellable, callback, progress_callback) {
if (this._model_enum === model_enum &&
this._quantization_enum === quantization_enum) {
return callback(this._model)
}

if (this._pending_load) {
/* We only do the most recent callback once the model is loaded
* and discard other ones */
if (this._pending_load.model_enum !== model_enum) {
if (this._pending_load.model_enum !== model_enum ||
this._pending_load.quantization_enum !== quantization_enum) {
/* Cancel the existing pending load and start over again */
this._pending_load.load_cancellable.cancel();
} else {
/* Don't cancel the pending load operation, but change the callback */
this._pending_load = {
model_enum: model_enum,
quantization_enum: quantization_enum,
callback: callback,
load_cancellable: this._pending_load.load_cancellable,
action_cancellable: cancellable
Expand All @@ -121,16 +146,18 @@ class ModelLoader {
/* Create a pending load and load the model */
this._pending_load = {
model_enum: model_enum,
quantization_enum: quantization_enum,
callback: callback,
load_cancellable: new Gio.Cancellable(),
action_cancellable: cancellable
};

load_model(model_enum, this._pending_load.load_cancellable, model => {
load_model(model_enum, quantization_enum, this._pending_load.load_cancellable, model => {
const { callback, action_cancellable } = this._pending_load;

if (action_cancellable === null || !action_cancellable.is_cancelled()) {
this._model_enum = model_enum;
this._quantization_enum = quantization_enum;
this._model = model;

System.gc();
Expand All @@ -140,6 +167,19 @@ class ModelLoader {
}
}

const makeCombobox = (listOptions, callback) => {
const combobox = Gtk.ComboBox.new_with_model(
list_store_from_rows(listOptions)
);
const renderer = new Gtk.CellRendererText();
combobox.pack_start(renderer, true);
combobox.add_attribute(renderer, 'text', 0);
combobox.set_active(0);
combobox.connect('changed', callback);

return combobox;
};

const LLMWriterAppMainWindow = GObject.registerClass({
Template: `${RESOURCE_PATH}/main.ui`,
Children: [
Expand Down Expand Up @@ -178,30 +218,36 @@ const LLMWriterAppMainWindow = GObject.registerClass({
this._spinner = new Gtk.Spinner({
visible: true
});
const combobox = Gtk.ComboBox.new_with_model(
list_store_from_rows([
['GPT2 117M'],
['GPT2 345M'],
['GPT2 774M'],
['GPT2 1558M'],
])
);
const renderer = new Gtk.CellRendererText();
combobox.pack_start(renderer, true);
combobox.add_attribute(renderer, 'text', 0);
combobox.set_active(0);
combobox.connect('changed', () => {
const comboboxChangedCallback = () => {
resetProgress();
this._model_loader.with_model(
COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM[combobox.active],
COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM[modelCombobox.active],
COMBOBOX_ID_TO_QUANTIZATION_LEVEL_ENUM[quantizationCombobox.active],
null,
() => this._spinner.stop(),
progressCallback
);
});
combobox.show();

header.pack_start(combobox);
};
const modelCombobox = makeCombobox([
['GPT2 117M'],
['GPT2 345M'],
['GPT2 774M'],
['GPT2 1558M'],
], comboboxChangedCallback);
modelCombobox.show();
const quantizationCombobox = makeCombobox([
['No quantization'],
['F16'],
['Q8_0'],
['Q5_0'],
['Q5_1'],
['Q4_0'],
['Q4_1'],
], comboboxChangedCallback);
quantizationCombobox.show();

header.pack_start(modelCombobox);
header.pack_start(quantizationCombobox);
header.pack_end(this._spinner);
this.set_titlebar(header);

Expand Down Expand Up @@ -262,7 +308,8 @@ const LLMWriterAppMainWindow = GObject.registerClass({
buffer.create_mark("predictions-start", buffer.get_end_iter(), true);

this._model_loader.with_model(
COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM[combobox.active],
COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM[modelCombobox.active],
COMBOBOX_ID_TO_QUANTIZATION_LEVEL_ENUM[quantizationCombobox.active],
this._cancellable,
model => {
model.complete_async(
Expand Down
21 changes: 21 additions & 0 deletions ggml-gobject/ggml-context.c
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,27 @@ ggml_context_unref (GGMLContext *context)
}
}

/**
* ggml_context_new_tensor:
* @context: A #GGMLContext
* @data_type: A #GGMLDataType for the new tensor
* @shape: (array length=n_dims): Shape of the tensor
* @n_dims: Number of dimensions in the tensor shape
*
* Creates a new #GGMLTensor from the memory pool of @context
* with shape @shape
*
* Returns: (transfer full): The #GGMLTensor
*/
GGMLTensor *
ggml_context_new_tensor (GGMLContext *context,
GGMLDataType data_type,
int64_t *shape,
size_t n_dims)
{
return ggml_tensor_new (context, data_type, shape, n_dims);
}

/**
* ggml_context_new_tensor_1d:
* @context: A #GGMLContext
Expand Down
4 changes: 4 additions & 0 deletions ggml-gobject/ggml-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ GGMLContext *ggml_context_new_from_mem_buffer (GBytes *mem_buffer);
GGMLContext *ggml_context_new (size_t memory_size);
GGMLContext *ggml_context_ref (GGMLContext *context);
void ggml_context_unref (GGMLContext *context);
GGMLTensor *ggml_context_new_tensor (GGMLContext *context,
GGMLDataType data_type,
int64_t *shape,
size_t n_dims);
GGMLTensor *ggml_context_new_tensor_1d (GGMLContext *context,
GGMLDataType data_type, size_t size);
GGMLTensor *ggml_context_new_tensor_2d (GGMLContext *context,
Expand Down
22 changes: 22 additions & 0 deletions ggml-gobject/ggml-gpt.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
* 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/

#include <ggml-gobject/ggml-quantize.h>
#include <ggml-gobject/ggml-gpt.h>
#include <math.h>

Expand Down Expand Up @@ -619,3 +620,24 @@ ggml_gpt_model_forward_pass (GGMLModel *model,
NULL);
return g_steal_pointer (&lm_head_output);
}

static const char *ggml_gpt_model_quantize_regexes[] = {
"model/wte",
"model/lm_head",
"model/h.*/attn/c_attn/w",
"model/h.*/attn/c_proj/w",
"model/h.*/mlp/c_fc/w",
"model/h.*/mlp/c_proj/w",
NULL
};

/**
* ggml_gpt_model_quantization_regexes:
*
* Returns: (transfer none) (array zero-terminated=1): A strv of weights to quantize for GPT models
*/
const char **
ggml_gpt_model_quantization_regexes (void)
{
return ggml_gpt_model_quantize_regexes;
}
1 change: 1 addition & 0 deletions ggml-gobject/ggml-gpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ GGMLModelDescNode * ggml_create_gpt2_model_desc (int32_t n_vocab,
int32_t n_ctx);

GGMLModelDescNode * ggml_create_gpt2_model_desc_from_hyperparameters (GGMLHyperparameters *hyperparameters);
const char ** ggml_gpt_model_quantization_regexes (void);

GBytes * ggml_gpt_model_forward_pass_create_memory_buffer (size_t n_tokens);

Expand Down
Loading

0 comments on commit ea09786

Please sign in to comment.