Skip to content

Commit

Permalink
Support for gemini-1.5-pro-latest, closes #5
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Mar 27, 2024
1 parent faab0a1 commit 7500ba9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ To chat interactively with the model, run `llm chat`:
llm chat -m gemini-pro
```

If you have access to the Gemini 1.5 Pro preview you can use `-m gemini-1.5-pro-latest` to work with that model.

## Development

To set up this plugin locally, first checkout the code. Then create a new virtual environment:
Expand Down
18 changes: 12 additions & 6 deletions llm_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,34 @@

@llm.hookimpl
def register_models(register):
register(GeminiPro())
register(GeminiPro("gemini-pro"))
register(GeminiPro("gemini-1.5-pro-latest"))


class GeminiPro(llm.Model):
model_id = "gemini-pro"
can_stream = True

def __init__(self, model_id):
self.model_id = model_id

def build_messages(self, prompt, conversation):
if not conversation:
return [{"role": "user", "parts": [{"text": prompt.prompt}]}]
messages = []
for response in conversation.responses:
messages.append({"role": "user", "parts": [{"text": response.prompt.prompt}]})
messages.append(
{"role": "user", "parts": [{"text": response.prompt.prompt}]}
)
messages.append({"role": "model", "parts": [{"text": response.text()}]})
messages.append({"role": "user", "parts": [{"text": prompt.prompt}]})
return messages

def execute(self, prompt, stream, response, conversation):
key = llm.get_key("", "gemini", "LLM_GEMINI_KEY")
url = (
"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent?"
+ urllib.parse.urlencode({"key": key})
url = "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?".format(
self.model_id
) + urllib.parse.urlencode(
{"key": key}
)
gathered = []
with httpx.stream(
Expand Down

0 comments on commit 7500ba9

Please sign in to comment.