diff --git a/lib/demo/application.ex b/lib/demo/application.ex index ae7c5d6..1b4dfc4 100644 --- a/lib/demo/application.ex +++ b/lib/demo/application.ex @@ -9,6 +9,7 @@ defmodule Demo.Application do def start(_type, _args) do children = [ DemoWeb.Telemetry, + {Nx.Serving, serving: cross(), name: CrossEncoder}, {Nx.Serving, serving: serving(), name: SentenceTransformer}, Demo.Repo, {DNSCluster, query: Application.get_env(:demo, :dns_cluster_query) || :ignore}, @@ -48,4 +49,15 @@ defmodule Demo.Application do defn_options: [compiler: EXLA] ) end + + def cross() do + repo = "cross-encoder/ms-marco-MiniLM-L-6-v2" + {:ok, model_info} = Bumblebee.load_model({:hf, repo}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"}) + + Demo.Encoder.cross_encoder(model_info, tokenizer, + compile: [batch_size: 32, sequence_length: [512]], + defn_options: [compiler: EXLA] + ) + end end diff --git a/lib/demo/encoder.ex b/lib/demo/encoder.ex new file mode 100644 index 0000000..113c1d8 --- /dev/null +++ b/lib/demo/encoder.ex @@ -0,0 +1,81 @@ +defmodule Demo.Encoder do + @moduledoc false + + alias Bumblebee.Shared + + def cross_encoder(model_info, tokenizer, opts \\ []) do + %{model: model, params: params, spec: _spec} = model_info + + opts = + Keyword.validate!(opts, [ + :compile, + defn_options: [], + preallocate_params: false + ]) + + preallocate_params = opts[:preallocate_params] + defn_options = opts[:defn_options] + + compile = + if compile = opts[:compile] do + compile + |> Keyword.validate!([:batch_size, :sequence_length]) + |> Shared.require_options!([:batch_size, :sequence_length]) + end + + batch_size = compile[:batch_size] + sequence_length = compile[:sequence_length] + + scores_fun = fn params, inputs -> + Axon.predict(model, params, inputs) + end + + batch_keys = Shared.sequence_batch_keys(sequence_length) + + Nx.Serving.new( + fn batch_key, defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + + scores_fun = + Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> + {:sequence_length, sequence_length} = batch_key + + inputs = %{ + "input_ids" => Nx.template({batch_size, sequence_length}, :u32), + "attention_mask" => Nx.template({batch_size, sequence_length}, :u32) + } + + [params, inputs] + end) + + fn inputs -> + inputs = Shared.maybe_pad(inputs, batch_size) + scores_fun.(params, inputs) + end + end, + defn_options + ) + |> Nx.Serving.batch_size(batch_size) + |> Nx.Serving.process_options(batch_keys: batch_keys) + |> Nx.Serving.client_preprocessing(fn raw_input -> + multi? = Enum.count(raw_input) > 1 + + inputs = + Nx.with_default_backend(Nx.BinaryBackend, fn -> + Bumblebee.apply_tokenizer(tokenizer, raw_input, + length: sequence_length, + return_token_type_ids: false + ) + end) + + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) + batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) + + {batch, multi?} + end) + |> Nx.Serving.client_postprocessing(fn {scores, _metadata}, multi? -> + %{results: scores.logits |> Nx.to_flat_list()} + |> Shared.normalize_output(multi?) + end) + end +end diff --git a/lib/demo/section.ex b/lib/demo/section.ex index 5be474c..922a3ab 100644 --- a/lib/demo/section.ex +++ b/lib/demo/section.ex @@ -27,13 +27,32 @@ defmodule Demo.Section do |> validate_required(@required_attrs) end - def search_document(document_id, embedding) do + def search_document_embedding(document_id, embedding) do from(s in Section, + select: {s.id, s.page, s.text, s.document_id}, where: s.document_id == ^document_id, order_by: max_inner_product(s.embedding, ^embedding), - limit: 1 + limit: 4 + ) + |> Demo.Repo.all() + end + + def search_document_text(document_id, search) do + from(s in Section, + select: {s.id, s.page, s.text, s.document_id}, + where: + s.document_id == ^document_id and + fragment("to_tsvector('english', ?) @@ plainto_tsquery('english', ?)", s.text, ^search), + order_by: [ + desc: + fragment( + "ts_rank_cd(to_tsvector('english', ?), plainto_tsquery('english', ?))", + s.text, + ^search + ) + ], + limit: 4 ) |> Demo.Repo.all() - |> List.first() end end diff --git a/lib/demo_web/live/page_live.ex b/lib/demo_web/live/page_live.ex index 08a0dff..3e96a34 100644 --- a/lib/demo_web/live/page_live.ex +++ b/lib/demo_web/live/page_live.ex @@ -12,7 +12,7 @@ defmodule DemoWeb.PageLive do socket = socket - |> assign(task: nil, lookup: nil, filename: nil, messages: messages, version: version, documents: documents, result: nil, text: nil, loading: false, selected: nil, query: nil, transformer: nil, llama: nil, path: nil, focused: false, loadingpdf: false) + |> assign(encoder: nil, task: nil, lookup: nil, filename: nil, messages: messages, version: version, documents: documents, result: nil, text: nil, loading: false, selected: nil, query: nil, transformer: nil, llama: nil, path: nil, focused: false, loadingpdf: false) |> allow_upload(:document, accept: ~w(.pdf), progress: &handle_progress/3, auto_upload: true, max_entries: 1) {:ok, socket} @@ -71,9 +71,36 @@ defmodule DemoWeb.PageLive do @impl true def handle_info({ref, {selected, question, %{embedding: embedding}}}, socket) when socket.assigns.lookup.ref == ref do - version = socket.assigns.version + sections = Demo.Section.search_document_embedding(selected.id, embedding) + others = Demo.Section.search_document_text(selected.id, question) + deduplicated = sections ++ others |> Enum.uniq_by(fn {id, _, _, _} -> id end) + + data = + deduplicated + |> Enum.map(fn {_id, _page, text, _} -> {question, text} end) + + encoder = + Task.async(fn -> + section = + CrossEncoder + |> Nx.Serving.batched_run(data) + |> results() + |> Enum.zip(deduplicated) + |> Enum.map(fn {score, {id, page, text, document_id}} -> + %{id: id, page: page, text: text, document_id: document_id, score: score} + end) + |> Enum.sort(fn x, y -> x.score > y.score end) + |> List.first() + + {question, section} + end) - section = Demo.Section.search_document(selected.id, embedding) + {:noreply, assign(socket, lookup: nil, encoder: encoder)} + end + + @impl true + def handle_info({ref, {question, section}}, socket) when socket.assigns.encoder.ref == ref do + version = socket.assigns.version document = socket.assigns.documents |> Enum.find(&(&1.id == section.document_id)) prompt = """ @@ -91,7 +118,7 @@ defmodule DemoWeb.PageLive do {section, Replicate.Predictions.wait(prediction)} end) - {:noreply, assign(socket, lookup: nil, llama: llama, selected: document)} + {:noreply, assign(socket, encoder: nil, llama: llama, selected: document)} end @impl true @@ -215,6 +242,8 @@ defmodule DemoWeb.PageLive do end) end + def results(%{results: results}), do: results + @impl true def render(assigns) do ~H""" diff --git a/priv/repo/migrations/20231110230828_add_gin_index.exs b/priv/repo/migrations/20231110230828_add_gin_index.exs new file mode 100644 index 0000000..74dfec6 --- /dev/null +++ b/priv/repo/migrations/20231110230828_add_gin_index.exs @@ -0,0 +1,9 @@ +defmodule Demo.Repo.Migrations.AddGinIndex do + use Ecto.Migration + + def change do + execute """ + CREATE INDEX sections_text_search_idx ON sections USING GIN (to_tsvector('english', text)); + """ + end +end