Google Vertex AI provider with Region Selection support #58934
Replies: 33 comments
-
|
Does Vertex AI have a distinct API semantics or is it just an alternate endpoint for Gemini Flash / Gemini Pro? If the latter it may just require tweaking our endpoint code. See: Assistant Configuration: Custom Endpoint in the docs. Currently we assume the following URL structure under that endpoint: zed/crates/google_ai/src/google_ai.rs Lines 20 to 23 in cdead57 But looking at the Vertex AI docs I think the endpoints are alternatively of the form:
|
Beta Was this translation helpful? Give feedback.
-
I find this interesting because apparently Claude can also run on Vertex.ai with zero downtime |
Beta Was this translation helpful? Give feedback.
-
%{
"contents" => [
%{
"role" => "user",
"parts" => [
%{
"inlineData" => %{
"mimeType" => "application/pdf",
"data" => base64_content
}
},
%{
"text" => prompt
}
]
}
]
} |
Beta Was this translation helpful? Give feedback.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as duplicate.
-
|
+1 Please add support for Vertex, specifically Anthropic models through Vertex. |
Beta Was this translation helpful? Give feedback.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as duplicate.
-
|
Hi All, I created a provider for this that is working great for inline edits, agent mode, token counting, thread summation etc. It's really simple to add it to the project:
codeuse std::mem;
use anyhow::{Result, anyhow, bail};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
// MODIFICATION 1: Update API_URL to the correct Vertex AI endpoint.
pub const API_URL: &str = "https://aiplatform.googleapis.com";
// MODIFICATION 2: Change function signature to accept Vertex AI parameters and remove api_key.
pub async fn stream_generate_content(
client: &dyn HttpClient,
api_url: &str,
project_id: &str,
location_id: &str,
access_token: &str,
mut request: GenerateContentRequest,
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
validate_generate_content_request(&request)?;
// The `model` field is emptied as it is provided as a path parameter.
let model_id = mem::take(&mut request.model.model_id);
// MODIFICATION 3: Update URL to the correct Vertex AI format.
let uri = format!(
"{api_url}/v1/projects/{project_id}/locations/{location_id}/publishers/google/models/{model_id}:streamGenerateContent?alt=sse"
);
// MODIFICATION 4: Add Authorization header for bearer token authentication.
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Authorization", format!("Bearer {}", access_token))
.header("Content-Type", "application/json");
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
Ok(reader
.lines()
.filter_map(|line| async move {
match line {
Ok(line) => {
if let Some(line) = line.strip_prefix("data: ") {
match serde_json::from_str(line) {
Ok(response) => Some(Ok(response)),
Err(error) => Some(Err(anyhow!(format!(
"Error parsing JSON: {error:?}\n{line:?}"
)))),
}
} else {
None
}
}
Err(error) => {
Some(Err(anyhow!(error)))
},
}
})
.boxed())
} else {
let mut text = String::new();
response.body_mut().read_to_string(&mut text).await?;
Err(anyhow!(
"error during streamGenerateContent, status code: {:?}, body: {}",
response.status(),
text
))
}
}
// MODIFICATION 5: Change function signature to accept Vertex AI parameters and remove api_key.
pub async fn count_tokens(
client: &dyn HttpClient,
api_url: &str,
project_id: &str,
location_id: &str,
access_token: &str,
request: CountTokensRequest,
) -> Result<CountTokensResponse> {
validate_generate_content_request(&request.generate_content_request)?;
// MODIFICATION 6: Update URL to the correct Vertex AI format.
let uri = format!(
"{api_url}/v1/projects/{project_id}/locations/{location_id}/publishers/google/models/{model_id}:countTokens",
model_id = &request.generate_content_request.model.model_id,
);
// convert requests.generate_content_request.contents to {contents: <requests.generate_content_request.contents>}
// Construct the payload to match the {"contents": [...]} format
#[derive(Serialize)]
struct CountTokensPayload {
contents: Vec<Content>,
}
let payload = CountTokensPayload {
contents: request.generate_content_request.contents,
};
let request_body = serde_json::to_string(&payload)?;
// MODIFICATION 7: Add Authorization header for bearer token authentication.
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(&uri)
.header("Authorization", format!("Bearer {}", access_token))
.header("Content-Type", "application/json");
let http_request = request_builder.body(AsyncBody::from(request_body))?;
let mut response = client.send(http_request).await?;
let mut text = String::new();
response.body_mut().read_to_string(&mut text).await?;
anyhow::ensure!(
response.status().is_success(),
"error during countTokens, status code: {:?}, body: {}",
response.status(),
text
);
Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
}
pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
if request.model.is_empty() {
bail!("Model must be specified");
}
if request.contents.is_empty() {
bail!("Request must contain at least one content item");
}
if let Some(user_content) = request
.contents
.iter()
.find(|content| content.role == Role::User)
{
if user_content.parts.is_empty() {
bail!("User content must contain at least one part");
}
}
Ok(())
}
#[derive(Debug, Serialize, Deserialize)]
pub enum Task {
#[serde(rename = "generateContent")]
GenerateContent,
#[serde(rename = "streamGenerateContent")]
StreamGenerateContent,
#[serde(rename = "countTokens")]
CountTokens,
#[serde(rename = "embedContent")]
EmbedContent,
#[serde(rename = "batchEmbedContents")]
BatchEmbedContents,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest {
#[serde(default, skip_serializing_if = "ModelName::is_empty")]
pub model: ModelName,
pub contents: Vec<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_instruction: Option<SystemInstruction>,
#[serde(skip_serializing_if = "Option::is_none")]
pub generation_config: Option<GenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_settings: Option<Vec<SafetySetting>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_config: Option<ToolConfig>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub candidates: Option<Vec<GenerateContentCandidate>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_feedback: Option<PromptFeedback>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage_metadata: Option<UsageMetadata>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentCandidate {
#[serde(skip_serializing_if = "Option::is_none")]
pub index: Option<usize>,
pub content: Content,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_message: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_ratings: Option<Vec<SafetyRating>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub citation_metadata: Option<CitationMetadata>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Content {
#[serde(default)]
pub parts: Vec<Part>,
pub role: Role,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SystemInstruction {
pub parts: Vec<Part>,
}
#[derive(Debug, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub enum Role {
User,
Model,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Part {
TextPart(TextPart),
InlineDataPart(InlineDataPart),
FunctionCallPart(FunctionCallPart),
FunctionResponsePart(FunctionResponsePart),
ThoughtPart(ThoughtPart),
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TextPart {
pub text: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InlineDataPart {
pub inline_data: GenerativeContentBlob,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerativeContentBlob {
pub mime_type: String,
pub data: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCallPart {
pub function_call: FunctionCall,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionResponsePart {
pub function_response: FunctionResponse,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ThoughtPart {
pub thought: bool,
pub thought_signature: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CitationSource {
#[serde(skip_serializing_if = "Option::is_none")]
pub start_index: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub end_index: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub license: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CitationMetadata {
pub citation_sources: Vec<CitationSource>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptFeedback {
#[serde(skip_serializing_if = "Option::is_none")]
pub block_reason: Option<String>,
pub safety_ratings: Vec<SafetyRating>,
#[serde(skip_serializing_if = "Option::is_none")]
pub block_reason_message: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct UsageMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_token_count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cached_content_token_count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub candidates_token_count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_use_prompt_token_count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thoughts_token_count: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total_token_count: Option<u64>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ThinkingConfig {
pub thinking_budget: u32,
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
pub enum GoogleModelMode {
#[default]
Default,
Thinking {
budget_tokens: Option<u32>,
},
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub candidate_count: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_config: Option<ThinkingConfig>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetySetting {
pub category: HarmCategory,
pub threshold: HarmBlockThreshold,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum HarmCategory {
#[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
Unspecified,
#[serde(rename = "HARM_CATEGORY_DEROGATORY")]
Derogatory,
#[serde(rename = "HARM_CATEGORY_TOXICITY")]
Toxicity,
#[serde(rename = "HARM_CATEGORY_VIOLENCE")]
Violence,
#[serde(rename = "HARM_CATEGORY_SEXUAL")]
Sexual,
#[serde(rename = "HARM_CATEGORY_MEDICAL")]
Medical,
#[serde(rename = "HARM_CATEGORY_DANGEROUS")]
Dangerous,
#[serde(rename = "HARM_CATEGORY_HARASSMENT")]
Harassment,
#[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
HateSpeech,
#[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
SexuallyExplicit,
#[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
DangerousContent,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum HarmBlockThreshold {
#[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
Unspecified,
BlockLowAndAbove,
BlockMediumAndAbove,
BlockOnlyHigh,
BlockNone,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum HarmProbability {
#[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
Unspecified,
Negligible,
Low,
Medium,
High,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetyRating {
pub category: HarmCategory,
pub probability: HarmProbability,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensRequest {
pub generate_content_request: GenerateContentRequest,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensResponse {
pub total_tokens: u64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub args: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionResponse {
pub name: String,
pub response: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
pub function_declarations: Vec<FunctionDeclaration>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolConfig {
pub function_calling_config: FunctionCallingConfig,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCallingConfig {
pub mode: FunctionCallingMode,
#[serde(skip_serializing_if = "Option::is_none")]
pub allowed_function_names: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum FunctionCallingMode {
Auto,
Any,
None,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
// NOTE: The ModelName struct and its serialization logic are no longer correct for Vertex AI,
// as the model is not part of the path prefix anymore. It's just the ID.
// However, the existing code correctly `mem::take`s the model_id and uses it in the path,
// so this logic can be left as-is without breaking anything. No modification needed here.
#[derive(Debug, Default)]
pub struct ModelName {
pub model_id: String,
}
impl ModelName {
pub fn is_empty(&self) -> bool {
self.model_id.is_empty()
}
}
const MODEL_NAME_PREFIX: &str = "models/";
impl Serialize for ModelName {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
}
}
impl<'de> Deserialize<'de> for ModelName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let string = String::deserialize(deserializer)?;
if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
Ok(Self {
model_id: id.to_string(),
})
} else {
// Vertex AI model names (e.g., in responses) might not have this prefix,
// so we handle that case gracefully.
Ok(Self {
model_id: string,
})
}
}
}
// MODIFICATION STARTS: Model enum updated to only include versions 2.0 and higher.
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
pub enum Model {
#[serde(rename = "gemini-2.0-flash")]
Gemini20Flash,
#[serde(
rename = "gemini-2.5-flash",
alias = "gemini-2.0-flash-thinking-exp",
alias = "gemini-2.5-flash-preview-04-17",
alias = "gemini-2.5-flash-preview-05-20",
alias = "gemini-2.5-flash-preview-latest"
)]
#[default]
Gemini25Flash,
#[serde(
rename = "gemini-2.5-pro",
alias = "gemini-2.0-pro-exp",
alias = "gemini-2.5-pro-preview-latest",
alias = "gemini-2.5-pro-exp-03-25",
alias = "gemini-2.5-pro-preview-03-25",
alias = "gemini-2.5-pro-preview-05-06",
alias = "gemini-2.5-pro-preview-06-05"
)]
Gemini25Pro,
#[serde(rename = "custom")]
Custom {
name: String,
/// The name displayed in the UI, such as in the assistant panel model dropdown menu.
display_name: Option<String>,
max_tokens: u64,
#[serde(default)]
mode: GoogleModelMode,
},
}
impl Model {
pub fn default_fast() -> Self {
Self::Gemini20Flash
}
pub fn id(&self) -> &str {
match self {
Self::Gemini20Flash => "gemini-2.0-flash",
Self::Gemini25Flash => "gemini-2.5-flash",
Self::Gemini25Pro => "gemini-2.5-pro",
Self::Custom { name, .. } => name,
}
}
pub fn request_id(&self) -> &str {
match self {
Self::Gemini20Flash => "gemini-2.0-flash",
Self::Gemini25Flash => "gemini-2.5-flash",
Self::Gemini25Pro => "gemini-2.5-pro",
Self::Custom { name, .. } => name,
}
}
pub fn display_name(&self) -> &str {
match self {
Self::Gemini20Flash => "Gemini 2.0 Flash",
Self::Gemini25Flash => "Gemini 2.5 Flash",
Self::Gemini25Pro => "Gemini 2.5 Pro",
Self::Custom {
name, display_name, ..
} => display_name.as_ref().unwrap_or(name),
}
}
pub fn max_token_count(&self) -> u64 {
match self {
Self::Gemini20Flash => 1_048_576,
Self::Gemini25Flash => 1_048_576,
Self::Gemini25Pro => 1_048_576,
Self::Custom { max_tokens, .. } => *max_tokens,
}
}
pub fn max_output_tokens(&self) -> Option<u64> {
match self {
Model::Gemini20Flash => Some(8_192),
Model::Gemini25Flash => Some(65_536),
Model::Gemini25Pro => Some(65_536),
Model::Custom { .. } => None,
}
}
pub fn supports_tools(&self) -> bool {
true
}
pub fn supports_images(&self) -> bool {
true
}
pub fn mode(&self) -> GoogleModelMode {
match self {
Self::Gemini20Flash => GoogleModelMode::Default,
Self::Gemini25Flash | Self::Gemini25Pro => {
GoogleModelMode::Thinking {
// By default these models are set to "auto", so we preserve that behavior
// but indicate they are capable of thinking mode
budget_tokens: None,
}
}
Self::Custom { mode, .. } => *mode,
}
}
}
impl std::fmt::Display for Model {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.id())
}
}
codeuse anyhow::{Context as _, Result, anyhow};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use google_vertex_ai::{
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
ThinkingConfig, UsageMetadata,
};
use gpui::{
AnyView, App, AsyncApp, Context, Subscription, Task,
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LanguageModelToolUse,
LanguageModelToolUseId, MessageContent, StopReason,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, RateLimiter, Role,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::sync::{
Arc,
atomic::{self, AtomicU64},
};
use strum::IntoEnumIterator;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
use crate::AllLanguageModelSettings;
use crate::ui::InstructionListItem;
const PROVIDER_ID: &str = "google-vertex-ai";
const PROVIDER_NAME: &str = "Google Vertex AI";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct GoogleVertexSettings {
pub api_url: String,
pub project_id: String, // ADDED
pub location_id: String, // ADDED
pub available_models: Vec<AvailableModel>,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ModelMode {
#[default]
Default,
Thinking {
/// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
budget_tokens: Option<u32>,
},
}
impl From<ModelMode> for GoogleModelMode {
fn from(value: ModelMode) -> Self {
match value {
ModelMode::Default => GoogleModelMode::Default,
ModelMode::Thinking { budget_tokens } => GoogleModelMode::Thinking { budget_tokens },
}
}
}
impl From<GoogleModelMode> for ModelMode {
fn from(value: GoogleModelMode) -> Self {
match value {
GoogleModelMode::Default => ModelMode::Default,
GoogleModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
name: String,
display_name: Option<String>,
max_tokens: u64,
mode: Option<ModelMode>,
}
pub struct GoogleVertexLanguageModelProvider {
http_client: Arc<dyn HttpClient>,
state: gpui::Entity<State>,
}
pub struct State {
api_key: Option<String>,
api_key_from_env: bool,
_subscription: Subscription,
}
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
// Ensure api_url, project_id, and location_id are available for credentials deletion
let settings = AllLanguageModelSettings::get_global(cx)
.google_vertex
.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(&settings.api_url, &cx) // Use api_url
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
cx.notify();
})
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
log::info!("Authenticating Google Vertex AI...");
if self.is_authenticated() {
return Task::ready(Ok(()));
}
// The Tokio runtime provided by `gpui::spawn` is not sufficient for `tokio::process`
// or `tokio::task::spawn_blocking`. We must fall back to the standard library's threading
// to run the synchronous `gcloud` command, and use a channel to communicate the
// result back to our async context.
cx.spawn(async move |this, cx| {
let (tx, rx) = futures::channel::oneshot::channel();
std::thread::spawn(move || {
let result = std::process::Command::new("gcloud")
.args(&["auth", "application-default", "print-access-token"])
.output()
.map_err(|e| AuthenticateError::Other(anyhow!("Failed to execute gcloud command: {}", e)));
// Send the result back to the async task, ignoring if the receiver was dropped.
let _ = tx.send(result);
});
// Await the result from the channel.
// First, explicitly handle the channel's `Canceled` error.
// Then, use `?` to propagate the `AuthenticateError` from the command execution.
let token_output = rx.await
.map_err(|_cancelled| AuthenticateError::Other(anyhow!("Authentication task was cancelled")))?
?;
// Retrieve the access token from the gcloud command output.
// Ensure UTF-8 decoding and trim whitespace.
let access_token = String::from_utf8(token_output.stdout)
.map_err(|e| AuthenticateError::Other(anyhow!("Invalid UTF-8 in gcloud output: {}", e)))?
.trim()
.to_string();
// Check the exit status of the gcloud command.
if !token_output.status.success() {
let stderr = String::from_utf8_lossy(&token_output.stderr).into_owned();
return Err(AuthenticateError::Other(anyhow!("gcloud command failed: {}", stderr)));
}
let api_key = access_token; // Use the retrieved token as the API key.
let from_env = false; // This token is dynamically fetched, not from env or keychain.
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
}
}
impl GoogleVertexLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
});
Self { http_client, state }
}
fn create_language_model(&self, model: google_vertex_ai::Model) -> Arc<dyn LanguageModel> {
Arc::new(GoogleVertexLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
})
}
}
impl LanguageModelProviderState for GoogleVertexLanguageModelProvider {
type ObservableEntity = State;
fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
Some(self.state.clone())
}
}
impl LanguageModelProvider for GoogleVertexLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn icon(&self) -> IconName {
IconName::AiGoogle
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(self.create_language_model(google_vertex_ai::Model::default()))
}
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(self.create_language_model(google_vertex_ai::Model::default_fast()))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();
// Add base models from google_vertex_ai::Model::iter()
for model in google_vertex_ai::Model::iter() {
if !matches!(model, google_vertex_ai::Model::Custom { .. }) {
models.insert(model.id().to_string(), model);
}
}
// Override with available models from settings
for model in &AllLanguageModelSettings::get_global(cx)
.google_vertex
.available_models
{
models.insert(
model.name.clone(),
google_vertex_ai::Model::Custom {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
mode: model.mode.unwrap_or_default().into(),
},
);
}
models
.into_values()
.map(|model| {
Arc::new(GoogleVertexLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
}
fn is_authenticated(&self, cx: &App) -> bool {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.reset_api_key(cx))
}
}
pub struct GoogleVertexLanguageModel {
id: LanguageModelId,
model: google_vertex_ai::Model,
state: gpui::Entity<State>,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
impl GoogleVertexLanguageModel {
fn stream_completion(
&self,
request: google_vertex_ai::GenerateContentRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
> {
let http_client = self.http_client.clone();
let Ok((access_token_option, api_url, project_id, location_id)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).google_vertex;
(
state.api_key.clone(), // This is the access token for Vertex AI
settings.api_url.clone(),
settings.project_id.clone(), // ADDED
settings.location_id.clone(), // ADDED
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let access_token = access_token_option.context("Missing Google API key (access token)")?;
let request = google_vertex_ai::stream_generate_content(
http_client.as_ref(),
&api_url,
&project_id, // ADDED
&location_id, // ADDED
&access_token,
request,
);
request.await.context("failed to stream completion")
}
.boxed()
}
}
impl LanguageModel for GoogleVertexLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn supports_tools(&self) -> bool {
self.model.supports_tools()
}
fn supports_images(&self) -> bool {
self.model.supports_images()
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto
| LanguageModelToolChoice::Any
| LanguageModelToolChoice::None => true,
}
}
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
LanguageModelToolSchemaFormat::JsonSchemaSubset
}
fn telemetry_id(&self) -> String {
format!("google_vertex/{}", self.model.request_id())
}
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
fn max_output_tokens(&self) -> Option<u64> {
self.model.max_output_tokens()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
let model_id = self.model.request_id().to_string();
let request = into_vertex_ai(request, model_id.clone(), self.model.mode());
let http_client = self.http_client.clone();
// Synchronously read the state and settings.
// `read_entity` executes the closure and returns its result directly.
let (access_token_option, api_url, project_id, location_id) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).google_vertex;
(
state.api_key.clone(), // This is the access token for Vertex AI (Option<String>)
settings.api_url.clone(), // String
settings.project_id.clone(), // String
settings.location_id.clone(), // String
)
}); // No .unwrap_or_default() here, as read_entity directly returns the tuple
async move {
// Check if the access token is present. If not, return an error.
let access_token = access_token_option
.context("Missing Google API key (access token). Please authenticate.")?;
let response = google_vertex_ai::count_tokens(
http_client.as_ref(),
&api_url,
&project_id,
&location_id,
&access_token,
google_vertex_ai::CountTokensRequest {
generate_content_request: request,
},
)
.await?;
Ok(response.total_tokens)
}
.boxed()
}
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<
futures::stream::BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
> {
let request = into_vertex_ai(
request,
self.model.request_id().to_string(),
self.model.mode(),
);
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request
.await
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
Ok(GoogleVertexEventMapper::new().map_stream(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
pub fn into_vertex_ai(
mut request: LanguageModelRequest,
model_id: String,
mode: GoogleModelMode,
) -> google_vertex_ai::GenerateContentRequest {
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
content
.into_iter()
.flat_map(|content| match content {
language_model::MessageContent::Text(text) => {
if !text.is_empty() {
vec![Part::TextPart(google_vertex_ai::TextPart { text })]
} else {
vec![]
}
}
language_model::MessageContent::Thinking {
text: _,
signature: Some(signature),
} => {
if !signature.is_empty() {
vec![Part::ThoughtPart(google_vertex_ai::ThoughtPart {
thought: true,
thought_signature: signature,
})]
} else {
vec![]
}
}
language_model::MessageContent::Thinking { .. } => {
vec![]
}
language_model::MessageContent::RedactedThinking(_) => vec![],
language_model::MessageContent::Image(image) => {
vec![Part::InlineDataPart(google_vertex_ai::InlineDataPart {
inline_data: google_vertex_ai::GenerativeContentBlob {
mime_type: "image/png".to_string(), // Assuming PNG for simplicity, could derive from format
data: image.source.to_string(), // Assuming base64 encoded for simplicity
},
})]
}
language_model::MessageContent::ToolUse(tool_use) => {
vec![Part::FunctionCallPart(google_vertex_ai::FunctionCallPart {
function_call: google_vertex_ai::FunctionCall {
name: tool_use.name.to_string(),
args: tool_use.input,
},
})]
}
language_model::MessageContent::ToolResult(tool_result) => {
match tool_result.content {
language_model::LanguageModelToolResultContent::Text(text) => {
vec![Part::FunctionResponsePart(
google_vertex_ai::FunctionResponsePart {
function_response: google_vertex_ai::FunctionResponse {
name: tool_result.tool_name.to_string(),
// The API expects a valid JSON object
response: serde_json::json!({
"output": text
}),
},
},
)]
}
language_model::LanguageModelToolResultContent::Image(image) => {
vec![
Part::FunctionResponsePart(google_vertex_ai::FunctionResponsePart {
function_response: google_vertex_ai::FunctionResponse {
name: tool_result.tool_name.to_string(),
// The API expects a valid JSON object
response: serde_json::json!({
"output": "Tool responded with an image"
}),
},
}),
Part::InlineDataPart(google_vertex_ai::InlineDataPart {
inline_data: google_vertex_ai::GenerativeContentBlob {
mime_type: "image/png".to_string(),
data: image.source.to_string(),
},
}),
]
}
}
}
})
.collect()
}
let system_instructions = if request
.messages
.first()
.map_or(false, |msg| matches!(msg.role, Role::System))
{
let message = request.messages.remove(0);
Some(SystemInstruction {
parts: map_content(message.content),
})
} else {
None
};
google_vertex_ai::GenerateContentRequest {
model: google_vertex_ai::ModelName { model_id },
system_instruction: system_instructions,
contents: request
.messages
.into_iter()
.filter_map(|message| {
let parts = map_content(message.content);
if parts.is_empty() {
None
} else {
Some(google_vertex_ai::Content {
parts,
role: match message.role {
Role::User => google_vertex_ai::Role::User,
Role::Assistant => google_vertex_ai::Role::Model,
Role::System => google_vertex_ai::Role::User, // Google AI doesn't have a distinct system role; often maps to user for initial context
},
})
}
})
.collect(),
generation_config: Some(google_vertex_ai::GenerationConfig {
candidate_count: Some(1),
stop_sequences: Some(request.stop),
max_output_tokens: None,
temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
thinking_config: match mode {
GoogleModelMode::Thinking { budget_tokens } => {
budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
}
GoogleModelMode::Default => None,
},
top_p: None,
top_k: None,
}),
safety_settings: None, // Safety settings are handled at a different layer or can be configured.
tools: (request.tools.len() > 0).then(|| {
vec![google_vertex_ai::Tool {
function_declarations: request
.tools
.into_iter()
.map(|tool| FunctionDeclaration {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
})
.collect(),
}]
}),
tool_config: request.tool_choice.map(|choice| google_vertex_ai::ToolConfig {
function_calling_config: google_vertex_ai::FunctionCallingConfig {
mode: match choice {
LanguageModelToolChoice::Auto => google_vertex_ai::FunctionCallingMode::Auto,
LanguageModelToolChoice::Any => google_vertex_ai::FunctionCallingMode::Any,
LanguageModelToolChoice::None => google_vertex_ai::FunctionCallingMode::None,
},
allowed_function_names: None,
},
}),
}
}
pub struct GoogleVertexEventMapper {
usage: UsageMetadata,
stop_reason: StopReason,
}
impl GoogleVertexEventMapper {
pub fn new() -> Self {
Self {
usage: UsageMetadata::default(),
stop_reason: StopReason::EndTurn,
}
}
pub fn map_stream(
mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events
.map(Some)
.chain(futures::stream::once(async { None }))
.flat_map(move |event| {
futures::stream::iter(match event {
Some(Ok(event)) => self.map_event(event),
Some(Err(error)) => {
vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))]
}
None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
})
})
}
pub fn map_event(
&mut self,
event: GenerateContentResponse,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
let mut events: Vec<_> = Vec::new();
let mut wants_to_use_tool = false;
if let Some(usage_metadata) = event.usage_metadata {
update_usage(&mut self.usage, &usage_metadata);
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&self.usage),
)))
}
if let Some(candidates) = event.candidates {
for candidate in candidates {
if let Some(finish_reason) = candidate.finish_reason.as_deref() {
self.stop_reason = match finish_reason {
"STOP" => StopReason::EndTurn,
"MAX_TOKENS" => StopReason::MaxTokens,
_ => {
log::error!("Unexpected google_vertex finish_reason: {finish_reason}");
StopReason::EndTurn
}
};
}
candidate
.content
.parts
.into_iter()
.for_each(|part| match part {
Part::TextPart(text_part) => {
events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
}
Part::InlineDataPart(_) => {}
Part::FunctionCallPart(function_call_part) => {
wants_to_use_tool = true;
let name: Arc<str> = function_call_part.function_call.name.into();
let next_tool_id =
TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
let id: LanguageModelToolUseId =
format!("{}-{}", name, next_tool_id).into();
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id,
name,
is_input_complete: true,
raw_input: function_call_part.function_call.args.to_string(),
input: function_call_part.function_call.args,
},
)));
}
Part::FunctionResponsePart(_) => {}
Part::ThoughtPart(part) => {
events.push(Ok(LanguageModelCompletionEvent::Thinking {
text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
signature: Some(part.thought_signature),
}));
}
});
}
}
// Even when Gemini wants to use a Tool, the API
// responds with `finish_reason: STOP`
if wants_to_use_tool {
self.stop_reason = StopReason::ToolUse;
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
}
events
}
}
pub fn count_google_tokens(
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
// We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
// So we have to use tokenizer from tiktoken_rs to count tokens.
cx.background_spawn(async move {
let messages = request
.messages
.into_iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(message.string_contents()),
name: None,
function_call: None,
})
.collect::<Vec<_>>();
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
})
.boxed()
}
fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
if let Some(prompt_token_count) = new.prompt_token_count {
usage.prompt_token_count = Some(prompt_token_count);
}
if let Some(cached_content_token_count) = new.cached_content_token_count {
usage.cached_content_token_count = Some(cached_content_token_count);
}
if let Some(candidates_token_count) = new.candidates_token_count {
usage.candidates_token_count = Some(candidates_token_count);
}
if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
}
if let Some(thoughts_token_count) = new.thoughts_token_count {
usage.thoughts_token_count = Some(thoughts_token_count);
}
if let Some(total_token_count) = new.total_token_count {
usage.total_token_count = Some(total_token_count);
}
}
fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
let input_tokens = prompt_tokens - cached_tokens;
let output_tokens = usage.candidates_token_count.unwrap_or(0);
language_model::TokenUsage {
input_tokens,
output_tokens,
cache_read_input_tokens: cached_tokens,
cache_creation_input_tokens: 0,
}
}
struct ConfigurationView {
state: gpui::Entity<State>,
load_credentials_task: Option<Task<()>>,
}
impl ConfigurationView {
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
cx.observe(&state, |_, _, cx| {
cx.notify();
})
.detach();
let load_credentials_task = Some(cx.spawn_in(window, {
let state = state.clone();
async move |this, cx| {
if let Some(task) = state
.update(cx, |state, cx| state.authenticate(cx))
.log_err()
{
// We don't log an error, because "not signed in" is also an error.
let _ = task.await;
}
this.update(cx, |this, cx| {
this.load_credentials_task = None;
cx.notify();
})
.log_err();
}
}));
Self {
state,
load_credentials_task,
}
}
fn authenticate_gcloud(&mut self, window: &mut Window, cx: &mut Context<Self>) {
println!("Authenticating with gcloud...");
let state = self.state.clone();
self.load_credentials_task = Some(cx.spawn_in(window, {
async move |this, cx| {
if let Some(task) = state
.update(cx, |state, cx| state.authenticate(cx))
.log_err()
{
let _ = task.await;
}
this.update(cx, |this, cx| {
this.load_credentials_task = None;
cx.notify();
})
.log_err();
}
}));
cx.notify();
}
fn reset_gcloud_auth(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
})
.detach_and_log_err(cx);
cx.notify();
}
}
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let is_authenticated = self.state.read(cx).is_authenticated();
if self.load_credentials_task.is_some() {
div().child(Label::new("Attempting to authenticate with gcloud...")).into_any()
} else if !is_authenticated {
v_flex()
.size_full()
.child(Label::new("Please authenticate with Google Cloud to use this provider."))
.child(
List::new()
.child(InstructionListItem::text_only(
"1. Ensure Google Cloud SDK is installed and configured.",
))
.child(InstructionListItem::text_only(
"2. Run 'gcloud auth application-default login' in your terminal.",
))
.child(InstructionListItem::text_only(
"3. Configure your desired Google Cloud Project ID and Location ID in Zed's settings.json file under 'language_models.google_vertex'.",
))
)
.child(
h_flex()
.w_full()
.my_2()
.child(
Button::new("authenticate-gcloud", "Authenticate with gcloud")
.label_size(LabelSize::Small)
.icon_size(IconSize::Small)
.on_click(cx.listener(|this, _, window, cx| this.authenticate_gcloud(window, cx))),
),
)
.child(
Label::new(
"This will attempt to acquire an access token using your
gcloud application-default credentials. You might need to run
'gcloud auth application-default login' manually first."
)
.size(LabelSize::Small).color(Color::Muted),
)
.into_any()
} else {
h_flex()
.mt_1()
.p_1()
// .justify_between() // Removed, button is handled separately
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().background)
.child(
h_flex()
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new("Authenticated with gcloud.")),
)
.child(
Button::new("reset-gcloud-auth", "Clear Token")
.label_size(LabelSize::Small)
.icon(Some(IconName::Trash))
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.tooltip(Tooltip::text("Clear the in-memory access token. You will need to re-authenticate to use the provider."))
.on_click(cx.listener(|this, _, window, cx| this.reset_gcloud_auth(window, cx))),
)
.into_any()
}
}
}
codeuse std::sync::Arc;
use client::{Client, UserStore};
use fs::Fs;
use gpui::{App, Context, Entity};
use language_model::LanguageModelRegistry;
use provider::deepseek::DeepSeekLanguageModelProvider;
pub mod provider;
mod settings;
pub mod ui;
use crate::provider::anthropic::AnthropicLanguageModelProvider;
use crate::provider::bedrock::BedrockLanguageModelProvider;
use crate::provider::cloud::CloudLanguageModelProvider;
use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
use crate::provider::google::GoogleLanguageModelProvider;
use crate::provider::google_vertex::GoogleVertexLanguageModelProvider;
use crate::provider::lmstudio::LmStudioLanguageModelProvider;
use crate::provider::mistral::MistralLanguageModelProvider;
use crate::provider::ollama::OllamaLanguageModelProvider;
use crate::provider::open_ai::OpenAiLanguageModelProvider;
use crate::provider::open_router::OpenRouterLanguageModelProvider;
pub use crate::settings::*;
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut App) {
crate::settings::init(fs, cx);
let registry = LanguageModelRegistry::global(cx);
registry.update(cx, |registry, cx| {
register_language_model_providers(registry, user_store, client, cx);
});
}
fn register_language_model_providers(
registry: &mut LanguageModelRegistry,
user_store: Entity<UserStore>,
client: Arc<Client>,
cx: &mut Context<LanguageModelRegistry>,
) {
registry.register_provider(
CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
cx,
);
registry.register_provider(
AnthropicLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
OpenAiLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
OllamaLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
LmStudioLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
DeepSeekLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
GoogleLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider( // NEW REGISTRATION BY DIAB
GoogleVertexLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
MistralLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
BedrockLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
OpenRouterLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
}
code // Google Vertex AI has api_url, project_id and location_id
merge(
&mut settings.google_vertex.api_url,
value.google_vertex.as_ref().and_then(|s| s.api_url.clone()),
);
merge(
&mut settings.google_vertex.project_id,
value.google_vertex.as_ref().and_then(|s| s.project_id.clone()),
);
merge(
&mut settings.google_vertex.location_id,
value.google_vertex.as_ref().and_then(|s| s.location_id.clone()),
);
I have tested for the last couple of days and all seems well with inline edit, token counting and agent mode. The only thing I have not confirmed is that ALL the additional models work as I just wanted to get the Google models running for my purposes. Best Wishes, |
Beta Was this translation helpful? Give feedback.
-
Beta Was this translation helpful? Give feedback.
-
|
@rupesh1 also updated the repo again: joeyboey@2b5891a |
Beta Was this translation helpful? Give feedback.
-
|
I also have access to some models (Anthropic's Claude, for example) via our corporate Google Vertex contract. Not being able to use this directly via zed (if I set the environment variables, Claude Code or opencode will use it) is a major loss of function. Given how supportive zed is of AI features, I'm surprised this isn't merged. |
Beta Was this translation helpful? Give feedback.
-
|
#40023 |
Beta Was this translation helpful? Give feedback.
-
|
I am on the same page, I would love to be able to use Kimi K2 Thinking and Claude on our corporate Google Vertex AI from Zed. Other agents already support this |
Beta Was this translation helpful? Give feedback.
-
|
I ended up having Claude do the changes and compile it for me.. 😅 |
Beta Was this translation helpful? Give feedback.
-
|
I might to create a fork and keep it updated meanwhile till others create a polished version: https://github.com/NewtonChutney/zed Compile on your own |
Beta Was this translation helpful? Give feedback.
-
Had you looked at contributing that @NewtonChutney -- I hope this gets added/supported directly by the project. Do we have a sense of how/whether/process for whether this can/will be supported!? I assume this is question for zed org/maintainers. |
Beta Was this translation helpful? Give feedback.
-
I'm not against contributing, but the code was completely generated by Claude, and I think there are gaps to be patched.. Claude for one, didn't properly handle the auth token renewal, it said, and I quite, it's cheap to renew a token once expired! 😂 |
Beta Was this translation helpful? Give feedback.
-
Fyi, token renewal is working now.. Added docs:{
"language_models": {
"vertex_ai": {
"project_id": "your-google-cloud-project-id",
"location_id": "us-central1"
}
}
}Required Settings:
Optional Settings:
|
Beta Was this translation helpful? Give feedback.
-
|
Anybody kind enough to get this merged? |
Beta Was this translation helpful? Give feedback.
-
|
I really love to use Zed with Vertex AI claude models offerings. |
Beta Was this translation helpful? Give feedback.
-
|
In.the meantime, that can work. For solo mode, or remote/team. Collab welcome on that project! Even tell me what doesn't work for you, and ill improve it [ or PRs welcomed ]. A workaround to use about any IDE and vertex/etc. |
Beta Was this translation helpful? Give feedback.
-
|
As would I. |
Beta Was this translation helpful? Give feedback.
-
|
This is a feature request rather than a bug, so we're moving it to Discussions where provider requests and roadmap/product prioritization are easier to track. |
Beta Was this translation helpful? Give feedback.





Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Description
Platforms like Cursor.ai and Continue.dev allow the use of Vertex AI as a provider, which offers models like Gemini Flash and Gemini Pro. A key benefit of Vertex AI is region selection, enabling users to pick servers closer to them for lower latency.
Additionally, using Vertex AI provides other advantages, such as unified billing with GCP. Many companies already use Google Cloud, so activating Vertex AI is as simple as enabling a service they’re already paying for.
Beta Was this translation helpful? Give feedback.
All reactions