Skip to content

Commit

Permalink
Merge pull request #7 from smspillaz/download-progress-callback
Browse files Browse the repository at this point in the history
ggml-gobject: Add progress callbacks for downloads
  • Loading branch information
smspillaz committed Jul 21, 2023
2 parents 48fcc37 + 4f5ee70 commit 6511e3a
Show file tree
Hide file tree
Showing 10 changed files with 696 additions and 102 deletions.
39 changes: 30 additions & 9 deletions examples/llm-writer-app/data/main.ui
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,43 @@
<property name="default-height">480</property>
<property name="default-width">640</property>
<child>
<object class="GtkBox">
<property name="visible">True</property>
<property name="orientation">vertical</property>
<object class="GtkOverlay">
<property name="visible">true</property>
<child>
<object class="GtkScrolledWindow" id="content-view">
<object class="GtkBox">
<property name="visible">true</property>
<property name="hexpand">true</property>
<property name="vexpand">true</property>
<property name="valign">fill</property>
<property name="orientation">vertical</property>
<child>
<object class="GtkTextView" id="text-view">
<object class="GtkScrolledWindow" id="content-view">
<property name="visible">true</property>
<property name="hexpand">true</property>
<property name="vexpand">true</property>
<property name="valign">fill</property>
<property name="wrap-mode">word</property>
<child>
<object class="GtkTextView" id="text-view">
<property name="visible">true</property>
<property name="hexpand">true</property>
<property name="vexpand">true</property>
<property name="valign">fill</property>
<property name="wrap-mode">word</property>
</object>
</child>
</object>
</child>
</object>
</child>
<child type="overlay">
<object class="GtkBox">
<property name="visible">true</property>
<property name="hexpand">true</property>
<property name="halign">end</property>
<property name="valign">end</property>
<child>
<object class="GtkProgressBar" id="progress-bar">
<property name="visible">false</property>
<property name="fraction">0.0</property>
<property name="text">Starting Download</property>
<property name="show-text">true</property>
</object>
</child>
</object>
Expand Down
233 changes: 187 additions & 46 deletions examples/llm-writer-app/src/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,137 @@ const STATE_TEXT_EDITOR = 0;
const STATE_PREDICTING = 1;
const STATE_WAITING = 2;

const list_store_from_rows = (rows) => {
const list_store = Gtk.ListStore.new(rows[0].map(() => GObject.TYPE_STRING));

rows.forEach(columns => {
const iter = list_store.append();
columns.forEach((c, i) => {
list_store.set_value(iter, i, c)
});
});

return list_store;
};

const load_model = (model, cancellable, callback, progress_callback) => {
const istream = GGML.LanguageModel.stream_from_cache(model);

if (progress_callback) {
istream.set_download_progress_callback(progress_callback);
}

GGML.LanguageModel.load_defined_from_istream_async(
model,
istream,
cancellable,
(src, res) => {
try {
callback(GGML.LanguageModel.load_defined_from_istream_finish(res));
} catch (e) {
if (e.code === Gio.IOErrorEnum.CANCELLED) {
return;
}
logError(e);
}
}
);
};

const COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM = Object.keys(GGML.DefinedLanguageModel).map(k => GGML.DefinedLanguageModel[k]);

class ModelLoader {
constructor() {
this._model_enum = null;
this._model = null;
this._pending_load = null;
}

/**
* with_model:
* @model_enum: A #GGMLModelDescription
* @cancellable: A #GCancellable
* @callback: A callback to invoke once the model is done loading
*
* Does some action with a model. Also accepts a @cancellable -
* if the action is cancelled, then @callback won't be invoked, but
* the model will stil be downloaded if the download is in progress.
*/
with_model(model_enum, cancellable, callback, progress_callback) {
if (this._model_enum === model_enum) {
return callback(this._model)
}

if (this._pending_load) {
/* We only do the most recent callback once the model is loaded
* and discard other ones */
if (this._pending_load.model_enum !== model_enum) {
/* Cancel the existing pending load and start over again */
this._pending_load.load_cancellable.cancel();
} else {
/* Don't cancel the pending load operation, but change the callback */
this._pending_load = {
model_enum: model_enum,
callback: callback,
load_cancellable: this._pending_load.load_cancellable,
action_cancellable: cancellable
};
return;
}
}

/* Create a pending load and load the model */
this._pending_load = {
model_enum: model_enum,
callback: callback,
load_cancellable: new Gio.Cancellable(),
action_cancellable: cancellable
};

load_model(model_enum, this._pending_load.load_cancellable, model => {
const { callback, action_cancellable } = this._pending_load;

if (action_cancellable === null || !action_cancellable.is_cancelled()) {
this._model_enum = model_enum;
this._model = model;

System.gc();
return callback(this._model);
}
}, progress_callback);
}
}

const LLMWriterAppMainWindow = GObject.registerClass({
Template: `${RESOURCE_PATH}/main.ui`,
Children: [
'content-view',
'text-view'
'text-view',
'progress-bar'
]
}, class LLMWriterAppMainWindow extends Gtk.ApplicationWindow {
_init(params) {
super._init(params);

this._model_loader = new ModelLoader();

const resetProgress = () => {
this.progress_bar.set_visible(false);
this.progress_bar.set_text("Starting Download");
};
const progressCallback = (received_bytes, total_bytes) => {
if (received_bytes === -1) {
resetProgress();
return;
}

const fraction = received_bytes / total_bytes;

this.progress_bar.set_visible(true);
this.progress_bar.set_fraction(fraction);
this.progress_bar.set_text(`Downloading ${Math.trunc(fraction * 100)}%`);
};

const header = new Gtk.HeaderBar({
visible: true,
title: GLib.get_application_name(),
Expand All @@ -57,10 +178,33 @@ const LLMWriterAppMainWindow = GObject.registerClass({
this._spinner = new Gtk.Spinner({
visible: true
});
const combobox = Gtk.ComboBox.new_with_model(
list_store_from_rows([
['GPT2 117M'],
['GPT2 345M'],
['GPT2 774M'],
['GPT2 1558M'],
])
);
const renderer = new Gtk.CellRendererText();
combobox.pack_start(renderer, true);
combobox.add_attribute(renderer, 'text', 0);
combobox.set_active(0);
combobox.connect('changed', () => {
resetProgress();
this._model_loader.with_model(
COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM[combobox.active],
null,
() => this._spinner.stop(),
progressCallback
);
});
combobox.show();

header.pack_start(combobox);
header.pack_end(this._spinner);
this.set_titlebar(header);

this._languageModel = null;
this._textBufferState = STATE_TEXT_EDITOR;
this._predictionsStartedAt = -1;
this._cancellable = null;
Expand Down Expand Up @@ -101,56 +245,65 @@ const LLMWriterAppMainWindow = GObject.registerClass({
if (currentPosition > 0 &&
currentPosition === this._lastCursorOffset &&
count > 0 &&
this._languageModel !== null &&
this._textBufferState === STATE_TEXT_EDITOR) {
const text = buffer.get_text(
buffer.get_start_iter(),
buffer.get_end_iter(),
false
);

this.text_view.set_editable(false);
/* Reset state immediately if the operation is cancelled */
this._cancellable = new Gio.Cancellable({});
this._cancellable.connect(() => resetState());

this._textBufferState = STATE_PREDICTING;
this._candidateText = '';
this._spinner.start();
buffer.create_mark("predictions-start", buffer.get_end_iter(), true);
this._languageModel.complete_async(
text,
10,
2,

this._model_loader.with_model(
COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM[combobox.active],
this._cancellable,
(src, res) => {
let part, is_complete, is_complete_eos;
try {
[part, is_complete, is_complete_eos] = this._languageModel.complete_finish(res);
} catch (e) {
if (e.code == Gio.IOErrorEnum.CANCELLED) {
resetState();
model => {
model.complete_async(
text,
10,
2,
this._cancellable,
(src, res) => {
let part, is_complete, is_complete_eos;
try {
[part, is_complete, is_complete_eos] = model.complete_finish(res);
} catch (e) {
if (e.code == Gio.IOErrorEnum.CANCELLED) {
return;
}
logError(e);
return;
}

if (part === text) {
return;
}

if (is_complete) {
this._cancellable = null;
this._textBufferState = STATE_WAITING;
this._spinner.stop();
}

this._candidateText += part;
const markup = `<span foreground="gray">${GLib.markup_escape_text(part, part.length)}</span>`
buffer.insert_markup(buffer.get_end_iter(), markup, markup.length);
System.gc();
}
return;
}

if (part === text) {
return;
}

if (is_complete) {
this._cancellable = null;
this._textBufferState = STATE_WAITING;
this._spinner.stop();
}

this._candidateText += part;
const markup = `<span foreground="gray">${GLib.markup_escape_text(part, part.length)}</span>`
buffer.insert_markup(buffer.get_end_iter(), markup, markup.length);
System.gc();
}
);
},
progressCallback
);
} else if (currentPosition > 0 &&
currentPosition === this._lastCursorOffset &&
count > 0 &&
this._languageModel !== null &&
this._textBufferState === STATE_WAITING) {
// Delete the gray text and substitute the real text.
removePredictedText();
Expand Down Expand Up @@ -184,18 +337,6 @@ const LLMWriterAppMainWindow = GObject.registerClass({
}

vfunc_show() {
this._spinner.start();
const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2);
GGML.LanguageModel.load_defined_from_istream_async(
GGML.DefinedLanguageModel.GPT2,
istream,
null,
(src, res) => {
this._languageModel = GGML.LanguageModel.load_defined_from_istream_finish(res);
this._spinner.stop();
}
);

super.vfunc_show();
}
});
Expand Down
Loading

0 comments on commit 6511e3a

Please sign in to comment.