diff --git a/lib/langchain/llm/google_gemini.rb b/lib/langchain/llm/google_gemini.rb index 5be784da..75d8a724 100644 --- a/lib/langchain/llm/google_gemini.rb +++ b/lib/langchain/llm/google_gemini.rb @@ -6,6 +6,7 @@ module Langchain::LLM class GoogleGemini < Base DEFAULTS = { chat_completion_model_name: "gemini-1.5-pro-latest", + embeddings_model_name: "text-embedding-004", temperature: 0.0 } @@ -63,5 +64,35 @@ def chat(params = {}) raise StandardError.new(response) end end + + def embed( + text:, + model: @defaults[:embeddings_model_name] + ) + + params = { + content: { + parts: [ + { + text: text + } + ] + } + } + + uri = URI("https://generativelanguage.googleapis.com/v1beta/models/#{model}:embedContent?key=#{api_key}") + + request = Net::HTTP::Post.new(uri) + request.content_type = "application/json" + request.body = params.to_json + + response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http| + http.request(request) + end + + parsed_response = JSON.parse(response.body) + + Langchain::LLM::GoogleGeminiResponse.new(parsed_response, model: model) + end end end diff --git a/lib/langchain/llm/response/google_gemini_response.rb b/lib/langchain/llm/response/google_gemini_response.rb index 2eaeca08..e3c129f3 100644 --- a/lib/langchain/llm/response/google_gemini_response.rb +++ b/lib/langchain/llm/response/google_gemini_response.rb @@ -27,7 +27,11 @@ def embedding end def embeddings - [raw_response.dig("predictions", 0, "embeddings", "values")] + if raw_response.key?("embedding") + [raw_response.dig("embedding", "values")] + else + [raw_response.dig("predictions", 0, "embeddings", "values")] + end end def prompt_tokens diff --git a/spec/fixtures/llm/google_gemini/embed.json b/spec/fixtures/llm/google_gemini/embed.json new file mode 100644 index 00000000..36d19ccb --- /dev/null +++ b/spec/fixtures/llm/google_gemini/embed.json @@ -0,0 +1,9 @@ +{ + "embedding": { + "values": [ + 0.013168523, + -0.008711934, + -0.046782676 + ] + } +} diff --git a/spec/langchain/llm/google_gemini_spec.rb b/spec/langchain/llm/google_gemini_spec.rb new file mode 100644 index 00000000..023443ea --- /dev/null +++ b/spec/langchain/llm/google_gemini_spec.rb @@ -0,0 +1,39 @@ +# frozen_string_literal: true + +RSpec.describe Langchain::LLM::GoogleGemini do + let(:subject) { described_class.new(api_key: "123") } + + describe "#embed" do + let(:embedding) { [0.013168523, -0.008711934, -0.046782676] } + let(:raw_embedding_response) { double(body: File.read("spec/fixtures/llm/google_gemini/embed.json")) } + + before do + allow(Net::HTTP).to receive(:start).and_return(raw_embedding_response) + end + + it "returns valid llm response object" do + response = subject.embed(text: "Hello world") + + expect(response).to be_a(Langchain::LLM::GoogleGeminiResponse) + expect(response.model).to eq("text-embedding-004") + expect(response.embedding).to eq(embedding) + end + end + + describe "#chat" do + let(:messages) { [{role: "user", parts: [{text: "How high is the sky?"}]}] } + let(:raw_chat_completions_response) { double(body: File.read("spec/fixtures/llm/google_gemini/chat.json")) } + + before do + allow(Net::HTTP).to receive(:start).and_return(raw_chat_completions_response) + end + + it "returns valid llm response object" do + response = subject.chat(messages: messages) + + expect(response).to be_a(Langchain::LLM::GoogleGeminiResponse) + expect(response.model).to eq("gemini-1.5-pro-latest") + expect(response.chat_completion).to eq("The answer is 4.0") + end + end +end diff --git a/spec/langchain/llm/response/google_gemini_response_spec.rb b/spec/langchain/llm/response/google_gemini_response_spec.rb index d1416fa0..94a8fa09 100644 --- a/spec/langchain/llm/response/google_gemini_response_spec.rb +++ b/spec/langchain/llm/response/google_gemini_response_spec.rb @@ -28,24 +28,48 @@ end describe "#embeddings" do - let(:raw_embedding_response) { JSON.parse(File.read("spec/fixtures/llm/google_vertex_ai/embed.json")) } + context "with google vertex response" do + let(:raw_embedding_response) { JSON.parse(File.read("spec/fixtures/llm/google_vertex_ai/embed.json")) } - subject { described_class.new(raw_embedding_response) } + subject { described_class.new(raw_embedding_response) } - it "returns embeddings" do - expect(subject.embeddings).to eq([[ - -0.00879860669374466, - 0.007578692398965359, - 0.021136576309800148 - ]]) + it "returns embeddings" do + expect(subject.embeddings).to eq([[ + -0.00879860669374466, + 0.007578692398965359, + 0.021136576309800148 + ]]) + end + + it "#returns embedding" do + expect(subject.embedding).to eq([ + -0.00879860669374466, + 0.007578692398965359, + 0.021136576309800148 + ]) + end end - it "#returns embedding" do - expect(subject.embedding).to eq([ - -0.00879860669374466, - 0.007578692398965359, - 0.021136576309800148 - ]) + context "with google gemini response" do + let(:raw_embeddings_response) { JSON.parse(File.read("spec/fixtures/llm/google_gemini/embed.json")) } + + subject { described_class.new(raw_embeddings_response) } + + it "returns embeddings" do + expect(subject.embeddings).to eq([[ + 0.013168523, + -0.008711934, + -0.046782676 + ]]) + end + + it "#returns embedding" do + expect(subject.embedding).to eq([ + 0.013168523, + -0.008711934, + -0.046782676 + ]) + end end end end