Skip to content

Commit

Permalink
added cross encoder serving
Browse files Browse the repository at this point in the history
  • Loading branch information
toranb committed Nov 11, 2023
1 parent f1a3f37 commit 4b515ed
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 7 deletions.
12 changes: 12 additions & 0 deletions lib/demo/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ defmodule Demo.Application do
def start(_type, _args) do
children = [
DemoWeb.Telemetry,
{Nx.Serving, serving: cross(), name: CrossEncoder},
{Nx.Serving, serving: serving(), name: SentenceTransformer},
Demo.Repo,
{DNSCluster, query: Application.get_env(:demo, :dns_cluster_query) || :ignore},
Expand Down Expand Up @@ -48,4 +49,15 @@ defmodule Demo.Application do
defn_options: [compiler: EXLA]
)
end

def cross() do
repo = "cross-encoder/ms-marco-MiniLM-L-6-v2"
{:ok, model_info} = Bumblebee.load_model({:hf, repo})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})

Demo.Encoder.cross_encoder(model_info, tokenizer,
compile: [batch_size: 32, sequence_length: [512]],
defn_options: [compiler: EXLA]
)
end
end
81 changes: 81 additions & 0 deletions lib/demo/encoder.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
defmodule Demo.Encoder do
@moduledoc false

alias Bumblebee.Shared

def cross_encoder(model_info, tokenizer, opts \\ []) do
%{model: model, params: params, spec: _spec} = model_info

opts =
Keyword.validate!(opts, [
:compile,
defn_options: [],
preallocate_params: false
])

preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]

compile =
if compile = opts[:compile] do
compile
|> Keyword.validate!([:batch_size, :sequence_length])
|> Shared.require_options!([:batch_size, :sequence_length])
end

batch_size = compile[:batch_size]
sequence_length = compile[:sequence_length]

scores_fun = fn params, inputs ->
Axon.predict(model, params, inputs)
end

batch_keys = Shared.sequence_batch_keys(sequence_length)

Nx.Serving.new(
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

scores_fun =
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32)

This comment has been minimized.

Copy link
@baransu

baransu Mar 26, 2024

And to add:

"token_type_ids" => Nx.template({batch_size, sequence_length}, :u32)
}

[params, inputs]
end)

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
scores_fun.(params, inputs)
end
end,
defn_options
)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.process_options(batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn raw_input ->
multi? = Enum.count(raw_input) > 1

inputs =
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Bumblebee.apply_tokenizer(tokenizer, raw_input,
length: sequence_length,
return_token_type_ids: false

This comment has been minimized.

Copy link
@baransu

baransu Mar 26, 2024

Hi 👋 I'm experimenting with cross encoding using your serving. Thank you for posting this commit in the Bumblebee repo!

I wanted to compare scores for the same input between the Elixir implementation and sentence transformer's CrossEncoder. It looks like passing return_token_type_ids: false completely changes output. Without this option, I was able to achieve the same score as in Python.

This comment has been minimized.

Copy link
@toranb

toranb Mar 26, 2024

Author Owner

@baransu excellent feedback! Thanks for sharing this - I'll update the main branch to remove this keyword arg and instead add token_type_ids to the inputs as you showed above

)
end)

batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)

{batch, multi?}
end)
|> Nx.Serving.client_postprocessing(fn {scores, _metadata}, multi? ->
%{results: scores.logits |> Nx.to_flat_list()}
|> Shared.normalize_output(multi?)
end)
end
end
25 changes: 22 additions & 3 deletions lib/demo/section.ex
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,32 @@ defmodule Demo.Section do
|> validate_required(@required_attrs)
end

def search_document(document_id, embedding) do
def search_document_embedding(document_id, embedding) do
from(s in Section,
select: {s.id, s.page, s.text, s.document_id},
where: s.document_id == ^document_id,
order_by: max_inner_product(s.embedding, ^embedding),
limit: 1
limit: 4
)
|> Demo.Repo.all()
end

def search_document_text(document_id, search) do
from(s in Section,
select: {s.id, s.page, s.text, s.document_id},
where:
s.document_id == ^document_id and
fragment("to_tsvector('english', ?) @@ plainto_tsquery('english', ?)", s.text, ^search),
order_by: [
desc:
fragment(
"ts_rank_cd(to_tsvector('english', ?), plainto_tsquery('english', ?))",
s.text,
^search
)
],
limit: 4
)
|> Demo.Repo.all()
|> List.first()
end
end
37 changes: 33 additions & 4 deletions lib/demo_web/live/page_live.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ defmodule DemoWeb.PageLive do

socket =
socket
|> assign(task: nil, lookup: nil, filename: nil, messages: messages, version: version, documents: documents, result: nil, text: nil, loading: false, selected: nil, query: nil, transformer: nil, llama: nil, path: nil, focused: false, loadingpdf: false)
|> assign(encoder: nil, task: nil, lookup: nil, filename: nil, messages: messages, version: version, documents: documents, result: nil, text: nil, loading: false, selected: nil, query: nil, transformer: nil, llama: nil, path: nil, focused: false, loadingpdf: false)
|> allow_upload(:document, accept: ~w(.pdf), progress: &handle_progress/3, auto_upload: true, max_entries: 1)

{:ok, socket}
Expand Down Expand Up @@ -71,9 +71,36 @@ defmodule DemoWeb.PageLive do

@impl true
def handle_info({ref, {selected, question, %{embedding: embedding}}}, socket) when socket.assigns.lookup.ref == ref do
version = socket.assigns.version
sections = Demo.Section.search_document_embedding(selected.id, embedding)
others = Demo.Section.search_document_text(selected.id, question)
deduplicated = sections ++ others |> Enum.uniq_by(fn {id, _, _, _} -> id end)

data =
deduplicated
|> Enum.map(fn {_id, _page, text, _} -> {question, text} end)

encoder =
Task.async(fn ->
section =
CrossEncoder
|> Nx.Serving.batched_run(data)
|> results()
|> Enum.zip(deduplicated)
|> Enum.map(fn {score, {id, page, text, document_id}} ->
%{id: id, page: page, text: text, document_id: document_id, score: score}
end)
|> Enum.sort(fn x, y -> x.score > y.score end)
|> List.first()

{question, section}
end)

section = Demo.Section.search_document(selected.id, embedding)
{:noreply, assign(socket, lookup: nil, encoder: encoder)}
end

@impl true
def handle_info({ref, {question, section}}, socket) when socket.assigns.encoder.ref == ref do
version = socket.assigns.version
document = socket.assigns.documents |> Enum.find(&(&1.id == section.document_id))

prompt = """
Expand All @@ -91,7 +118,7 @@ defmodule DemoWeb.PageLive do
{section, Replicate.Predictions.wait(prediction)}
end)

{:noreply, assign(socket, lookup: nil, llama: llama, selected: document)}
{:noreply, assign(socket, encoder: nil, llama: llama, selected: document)}
end

@impl true
Expand Down Expand Up @@ -215,6 +242,8 @@ defmodule DemoWeb.PageLive do
end)
end

def results(%{results: results}), do: results

@impl true
def render(assigns) do
~H"""
Expand Down
9 changes: 9 additions & 0 deletions priv/repo/migrations/20231110230828_add_gin_index.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defmodule Demo.Repo.Migrations.AddGinIndex do
use Ecto.Migration

def change do
execute """
CREATE INDEX sections_text_search_idx ON sections USING GIN (to_tsvector('english', text));
"""
end
end

0 comments on commit 4b515ed

Please sign in to comment.