-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
defmodule Demo.Encoder do | ||
@moduledoc false | ||
|
||
alias Bumblebee.Shared | ||
|
||
def cross_encoder(model_info, tokenizer, opts \\ []) do | ||
%{model: model, params: params, spec: _spec} = model_info | ||
|
||
opts = | ||
Keyword.validate!(opts, [ | ||
:compile, | ||
defn_options: [], | ||
preallocate_params: false | ||
]) | ||
|
||
preallocate_params = opts[:preallocate_params] | ||
defn_options = opts[:defn_options] | ||
|
||
compile = | ||
if compile = opts[:compile] do | ||
compile | ||
|> Keyword.validate!([:batch_size, :sequence_length]) | ||
|> Shared.require_options!([:batch_size, :sequence_length]) | ||
end | ||
|
||
batch_size = compile[:batch_size] | ||
sequence_length = compile[:sequence_length] | ||
|
||
scores_fun = fn params, inputs -> | ||
Axon.predict(model, params, inputs) | ||
end | ||
|
||
batch_keys = Shared.sequence_batch_keys(sequence_length) | ||
|
||
Nx.Serving.new( | ||
fn batch_key, defn_options -> | ||
params = Shared.maybe_preallocate(params, preallocate_params, defn_options) | ||
|
||
scores_fun = | ||
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> | ||
{:sequence_length, sequence_length} = batch_key | ||
|
||
inputs = %{ | ||
"input_ids" => Nx.template({batch_size, sequence_length}, :u32), | ||
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32) | ||
This comment has been minimized.
Sorry, something went wrong. |
||
} | ||
|
||
[params, inputs] | ||
end) | ||
|
||
fn inputs -> | ||
inputs = Shared.maybe_pad(inputs, batch_size) | ||
scores_fun.(params, inputs) | ||
end | ||
end, | ||
defn_options | ||
) | ||
|> Nx.Serving.batch_size(batch_size) | ||
|> Nx.Serving.process_options(batch_keys: batch_keys) | ||
|> Nx.Serving.client_preprocessing(fn raw_input -> | ||
multi? = Enum.count(raw_input) > 1 | ||
|
||
inputs = | ||
Nx.with_default_backend(Nx.BinaryBackend, fn -> | ||
Bumblebee.apply_tokenizer(tokenizer, raw_input, | ||
length: sequence_length, | ||
return_token_type_ids: false | ||
This comment has been minimized.
Sorry, something went wrong.
baransu
|
||
) | ||
end) | ||
|
||
batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) | ||
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) | ||
|
||
{batch, multi?} | ||
end) | ||
|> Nx.Serving.client_postprocessing(fn {scores, _metadata}, multi? -> | ||
%{results: scores.logits |> Nx.to_flat_list()} | ||
|> Shared.normalize_output(multi?) | ||
end) | ||
end | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
defmodule Demo.Repo.Migrations.AddGinIndex do | ||
use Ecto.Migration | ||
|
||
def change do | ||
execute """ | ||
CREATE INDEX sections_text_search_idx ON sections USING GIN (to_tsvector('english', text)); | ||
""" | ||
end | ||
end |
And to add: