Skip to content

Commit

Permalink
Add embedding method for Langchain::LLM::GoogleGemini (#631)
Browse files Browse the repository at this point in the history
* Add embedding method for Langchain::LLM::GoogleGemini

* Fix rubocop violation

* Use dummy value for api key in tests
  • Loading branch information
swerner committed May 20, 2024
1 parent 7c44bfc commit c6c4e11
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 15 deletions.
31 changes: 31 additions & 0 deletions lib/langchain/llm/google_gemini.rb
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module Langchain::LLM
class GoogleGemini < Base
DEFAULTS = {
chat_completion_model_name: "gemini-1.5-pro-latest",
embeddings_model_name: "text-embedding-004",
temperature: 0.0
}

Expand Down Expand Up @@ -63,5 +64,35 @@ def chat(params = {})
raise StandardError.new(response)
end
end

def embed(
text:,
model: @defaults[:embeddings_model_name]
)

params = {
content: {
parts: [
{
text: text
}
]
}
}

uri = URI("https://generativelanguage.googleapis.com/v1beta/models/#{model}:embedContent?key=#{api_key}")

request = Net::HTTP::Post.new(uri)
request.content_type = "application/json"
request.body = params.to_json

response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
http.request(request)
end

parsed_response = JSON.parse(response.body)

Langchain::LLM::GoogleGeminiResponse.new(parsed_response, model: model)
end
end
end
6 changes: 5 additions & 1 deletion lib/langchain/llm/response/google_gemini_response.rb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ def embedding
end

def embeddings
[raw_response.dig("predictions", 0, "embeddings", "values")]
if raw_response.key?("embedding")
[raw_response.dig("embedding", "values")]
else
[raw_response.dig("predictions", 0, "embeddings", "values")]
end
end

def prompt_tokens
Expand Down
9 changes: 9 additions & 0 deletions spec/fixtures/llm/google_gemini/embed.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"embedding": {
"values": [
0.013168523,
-0.008711934,
-0.046782676
]
}
}
39 changes: 39 additions & 0 deletions spec/langchain/llm/google_gemini_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# frozen_string_literal: true

RSpec.describe Langchain::LLM::GoogleGemini do
let(:subject) { described_class.new(api_key: "123") }

describe "#embed" do
let(:embedding) { [0.013168523, -0.008711934, -0.046782676] }
let(:raw_embedding_response) { double(body: File.read("spec/fixtures/llm/google_gemini/embed.json")) }

before do
allow(Net::HTTP).to receive(:start).and_return(raw_embedding_response)
end

it "returns valid llm response object" do
response = subject.embed(text: "Hello world")

expect(response).to be_a(Langchain::LLM::GoogleGeminiResponse)
expect(response.model).to eq("text-embedding-004")
expect(response.embedding).to eq(embedding)
end
end

describe "#chat" do
let(:messages) { [{role: "user", parts: [{text: "How high is the sky?"}]}] }
let(:raw_chat_completions_response) { double(body: File.read("spec/fixtures/llm/google_gemini/chat.json")) }

before do
allow(Net::HTTP).to receive(:start).and_return(raw_chat_completions_response)
end

it "returns valid llm response object" do
response = subject.chat(messages: messages)

expect(response).to be_a(Langchain::LLM::GoogleGeminiResponse)
expect(response.model).to eq("gemini-1.5-pro-latest")
expect(response.chat_completion).to eq("The answer is 4.0")
end
end
end
52 changes: 38 additions & 14 deletions spec/langchain/llm/response/google_gemini_response_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,48 @@
end

describe "#embeddings" do
let(:raw_embedding_response) { JSON.parse(File.read("spec/fixtures/llm/google_vertex_ai/embed.json")) }
context "with google vertex response" do
let(:raw_embedding_response) { JSON.parse(File.read("spec/fixtures/llm/google_vertex_ai/embed.json")) }

subject { described_class.new(raw_embedding_response) }
subject { described_class.new(raw_embedding_response) }

it "returns embeddings" do
expect(subject.embeddings).to eq([[
-0.00879860669374466,
0.007578692398965359,
0.021136576309800148
]])
it "returns embeddings" do
expect(subject.embeddings).to eq([[
-0.00879860669374466,
0.007578692398965359,
0.021136576309800148
]])
end

it "#returns embedding" do
expect(subject.embedding).to eq([
-0.00879860669374466,
0.007578692398965359,
0.021136576309800148
])
end
end

it "#returns embedding" do
expect(subject.embedding).to eq([
-0.00879860669374466,
0.007578692398965359,
0.021136576309800148
])
context "with google gemini response" do
let(:raw_embeddings_response) { JSON.parse(File.read("spec/fixtures/llm/google_gemini/embed.json")) }

subject { described_class.new(raw_embeddings_response) }

it "returns embeddings" do
expect(subject.embeddings).to eq([[
0.013168523,
-0.008711934,
-0.046782676
]])
end

it "#returns embedding" do
expect(subject.embedding).to eq([
0.013168523,
-0.008711934,
-0.046782676
])
end
end
end
end

0 comments on commit c6c4e11

Please sign in to comment.