Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

let application pass session options to runtime, allow float16 for llm kv-cache #631

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 18 additions & 45 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@
},
"homepage": "https://github.com/xenova/transformers.js#readme",
"dependencies": {
"onnxruntime-web": "1.14.0",
"onnxruntime-web": "1.17.1",
"sharp": "^0.32.0",
"@huggingface/jinja": "^0.2.1"
},
"optionalDependencies": {
"onnxruntime-node": "1.14.0"
"onnxruntime-node": "1.17.1"
},
"devDependencies": {
"@types/jest": "^29.5.1",
Expand Down
64 changes: 38 additions & 26 deletions src/models.js
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,29 @@ async function constructSession(pretrained_model_name_or_path, fileName, options
let buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options);

try {
return await InferenceSession.create(buffer, {
executionProviders,
});
} catch (err) {
// If the execution provided was only wasm, throw the error
if (executionProviders.length === 1 && executionProviders[0] === 'wasm') {
throw err;
let opt = options.session_options || {};

// use default execution providers if application did not specify one
if (opt.executionProviders === undefined) {
opt.executionProviders = executionProviders;
}

console.warn(err);
console.warn(
'Something went wrong during model construction (most likely a missing operation). ' +
'Using `wasm` as a fallback. '
)
return await InferenceSession.create(buffer, {
executionProviders: ['wasm']
});
// handle onnx external data files
if (opt.externalData !== undefined) {
for (let i = 0; i < opt.externalData.length; i++) {
const ext = opt.externalData[i];
// if the external data is a string, fetch the file and replace the string with its content
if (typeof ext.data === "string") {
const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options);
ext.data = ext_buffer;
}
}
}
return await InferenceSession.create(buffer, opt);
} catch (err) {
// if the session fails, let the application handle it. Ie. if webgpu fails and we
// fallback to wasm, let the application decide if we want to use a quantized model, etc.
throw err;
}
}

Expand Down Expand Up @@ -741,6 +747,7 @@ export class PreTrainedModel extends Callable {
local_files_only = false,
revision = 'main',
model_file_name = null,
session_options = {},
} = {}) {

let options = {
Expand All @@ -751,6 +758,7 @@ export class PreTrainedModel extends Callable {
local_files_only,
revision,
model_file_name,
session_options,
}

const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
Expand Down Expand Up @@ -1296,6 +1304,8 @@ export class PreTrainedModel extends Callable {
} else {
// TODO support batches (i.e., batch_size > 1)
const batch_size = 1;
const dtype = this.config.precision || 'float32';
const empty = (dtype === 'float16') ? new Uint16Array() : [];

// @ts-ignore
if (this.config.is_encoder_decoder && (this.add_encoder_pkv ?? true)) {
Expand All @@ -1305,26 +1315,26 @@ export class PreTrainedModel extends Callable {
let decoder_dims = [batch_size, this.num_decoder_heads, 0, this.decoder_dim_kv];
// @ts-ignore
for (let i = 0; i < this.num_decoder_layers; ++i) {
decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor('float32', [], encoder_dims)
decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor('float32', [], encoder_dims)
decoderFeeds[`past_key_values.${i}.decoder.key`] = new Tensor('float32', [], decoder_dims)
decoderFeeds[`past_key_values.${i}.decoder.value`] = new Tensor('float32', [], decoder_dims)
decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor(dtype, empty, encoder_dims)
decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor(dtype, empty, encoder_dims)
decoderFeeds[`past_key_values.${i}.decoder.key`] = new Tensor(dtype, empty, decoder_dims)
decoderFeeds[`past_key_values.${i}.decoder.value`] = new Tensor(dtype, empty, decoder_dims)
}
} else if (this.config.model_type === 'falcon') {
// NOTE: Custom implementation for Falcon
// @ts-ignore
let dims = [batch_size * this.num_heads, 0, this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims)
decoderFeeds[`past_key_values.${i}.key`] = new Tensor(dtype, empty, dims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor(dtype, empty, dims)
}
} else if (this.config.multi_query) { // e.g., for `gpt_bigcode`
// @ts-ignore
let dims = [batch_size * this.num_heads, 0, 2 * this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims)
decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor(dtype, empty, dims)
}
} else if (this.config.model_type === 'bloom') {
// NOTE: Custom implementation for Bloom
Expand All @@ -1335,16 +1345,16 @@ export class PreTrainedModel extends Callable {
let valueDims = [batch_size * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims)
decoderFeeds[`past_key_values.${i}.key`] = new Tensor(dtype, empty, keyDims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor(dtype, empty, valueDims)
}
} else { // Decoder-only
// @ts-ignore
let dims = [batch_size, this.num_heads, 0, this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims)
decoderFeeds[`past_key_values.${i}.key`] = new Tensor(dtype, empty, dims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor(dtype, empty, dims)
}
}
}
Expand Down Expand Up @@ -5380,6 +5390,7 @@ export class PretrainedMixin {
local_files_only = false,
revision = 'main',
model_file_name = null,
session_options = {},
} = {}) {

let options = {
Expand All @@ -5390,6 +5401,7 @@ export class PretrainedMixin {
local_files_only,
revision,
model_file_name,
session_options,
}
config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);
if (!options.config) {
Expand Down
2 changes: 2 additions & 0 deletions src/pipelines.js
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3019,6 +3019,7 @@ export async function pipeline(
cache_dir = null,
local_files_only = false,
revision = 'main',
session_options = {},
} = {}
) {
// Helper method to construct pipeline
Expand Down Expand Up @@ -3046,6 +3047,7 @@ export async function pipeline(
cache_dir,
local_files_only,
revision,
session_options,
}

const classes = new Map([
Expand Down
1 change: 1 addition & 0 deletions src/utils/hub.js
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ if (!globalThis.ReadableStream) {
* since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
* NOTE: This setting is ignored for local requests.
* @property {string} [model_file_name=null] If specified, load the model with this name (excluding the .onnx suffix). Currently only valid for encoder- or decoder-only models.
* @property {{}} [session_options={}] Session options passed to the runtime.
*/

class FileResponse {
Expand Down