Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions internal/channel/anthropic_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,16 @@ func (ch *AnthropicChannel) ExtractModel(c *gin.Context, bodyBytes []byte) strin
}

// ValidateKey checks if the given API key is valid by making a messages request.
func (ch *AnthropicChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey, group *models.Group) (bool, error) {
func (ch *AnthropicChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey, group *models.Group) (bool, int, error) {
upstreamURL := ch.getUpstreamURL()
if upstreamURL == nil {
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
return false, 0, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
}

// Parse validation endpoint to extract path and query parameters
endpointURL, err := url.Parse(ch.ValidationEndpoint)
if err != nil {
return false, fmt.Errorf("failed to parse validation endpoint: %w", err)
return false, 0, fmt.Errorf("failed to join upstream URL and validation endpoint: %w", err)
}

// Build final URL with path and query parameters
Expand All @@ -102,12 +102,12 @@ func (ch *AnthropicChannel) ValidateKey(ctx context.Context, apiKey *models.APIK
}
body, err := json.Marshal(payload)
if err != nil {
return false, fmt.Errorf("failed to marshal validation payload: %w", err)
return false, 0, fmt.Errorf("failed to marshal validation payload: %w", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", reqURL, bytes.NewBuffer(body))
if err != nil {
return false, fmt.Errorf("failed to create validation request: %w", err)
return false, 0, fmt.Errorf("failed to create validation request: %w", err)
}
req.Header.Set("x-api-key", apiKey.KeyValue)
req.Header.Set("anthropic-version", "2023-06-01")
Expand All @@ -121,23 +121,25 @@ func (ch *AnthropicChannel) ValidateKey(ctx context.Context, apiKey *models.APIK

resp, err := ch.HTTPClient.Do(req)
if err != nil {
return false, fmt.Errorf("failed to send validation request: %w", err)
return false, 0, fmt.Errorf("failed to send validation request: %w", err)
}
defer resp.Body.Close()

statusCode := resp.StatusCode

// Any 2xx status code indicates the key is valid.
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return true, nil
return true, statusCode, nil
}

// For non-200 responses, parse the body to provide a more specific error reason.
errorBody, err := io.ReadAll(resp.Body)
if err != nil {
return false, fmt.Errorf("key is invalid (status %d), but failed to read error body: %w", resp.StatusCode, err)
return false, statusCode, fmt.Errorf("key is invalid (status %d), but failed to read error body: %w", resp.StatusCode, err)
}

// Use the new parser to extract a clean error message.
parsedError := app_errors.ParseUpstreamError(errorBody)

return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError)
return false, statusCode, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError)
}
2 changes: 1 addition & 1 deletion internal/channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ type ChannelProxy interface {
ExtractModel(c *gin.Context, bodyBytes []byte) string

// ValidateKey checks if the given API key is valid.
ValidateKey(ctx context.Context, apiKey *models.APIKey, group *models.Group) (bool, error)
ValidateKey(ctx context.Context, apiKey *models.APIKey, group *models.Group) (bool, int, error)
}
20 changes: 11 additions & 9 deletions internal/channel/gemini_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,16 @@ func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
}

// ValidateKey checks if the given API key is valid by making a generateContent request.
func (ch *GeminiChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey, group *models.Group) (bool, error) {
func (ch *GeminiChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey, group *models.Group) (bool, int, error) {
upstreamURL := ch.getUpstreamURL()
if upstreamURL == nil {
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
return false, 0, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
}

// Safely join the path segments
reqURL, err := url.JoinPath(upstreamURL.String(), "v1beta", "models", ch.TestModel+":generateContent")
if err != nil {
return false, fmt.Errorf("failed to create gemini validation path: %w", err)
return false, 0, fmt.Errorf("failed to create gemini validation path: %w", err)
}
reqURL += "?key=" + apiKey.KeyValue

Expand All @@ -121,12 +121,12 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey,
}
body, err := json.Marshal(payload)
if err != nil {
return false, fmt.Errorf("failed to marshal validation payload: %w", err)
return false, 0, fmt.Errorf("failed to marshal validation payload: %w", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", reqURL, bytes.NewBuffer(body))
if err != nil {
return false, fmt.Errorf("failed to create validation request: %w", err)
return false, 0, fmt.Errorf("failed to create validation request: %w", err)
}
req.Header.Set("Content-Type", "application/json")

Expand All @@ -138,23 +138,25 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey,

resp, err := ch.HTTPClient.Do(req)
if err != nil {
return false, fmt.Errorf("failed to send validation request: %w", err)
return false, 0, fmt.Errorf("failed to send validation request: %w", err)
}
defer resp.Body.Close()

statusCode := resp.StatusCode

// Any 2xx status code indicates the key is valid.
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return true, nil
return true, statusCode, nil
}

// For non-200 responses, parse the body to provide a more specific error reason.
errorBody, err := io.ReadAll(resp.Body)
if err != nil {
return false, fmt.Errorf("key is invalid (status %d), but failed to read error body: %w", resp.StatusCode, err)
return false, statusCode, fmt.Errorf("key is invalid (status %d), but failed to read error body: %w", resp.StatusCode, err)
}

// Use the new parser to extract a clean error message.
parsedError := app_errors.ParseUpstreamError(errorBody)

return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError)
return false, statusCode, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError)
}
20 changes: 11 additions & 9 deletions internal/channel/openai_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,16 @@ func (ch *OpenAIChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
}

// ValidateKey checks if the given API key is valid by making a chat completion request.
func (ch *OpenAIChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey, group *models.Group) (bool, error) {
func (ch *OpenAIChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey, group *models.Group) (bool, int, error) {
upstreamURL := ch.getUpstreamURL()
if upstreamURL == nil {
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
return false, 0, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
}

// Parse validation endpoint to extract path and query parameters
endpointURL, err := url.Parse(ch.ValidationEndpoint)
if err != nil {
return false, fmt.Errorf("failed to parse validation endpoint: %w", err)
return false, 0, fmt.Errorf("failed to join upstream URL and validation endpoint: %w", err)
}

// Build final URL with path and query parameters
Expand All @@ -100,12 +100,12 @@ func (ch *OpenAIChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey,
}
body, err := json.Marshal(payload)
if err != nil {
return false, fmt.Errorf("failed to marshal validation payload: %w", err)
return false, 0, fmt.Errorf("failed to marshal validation payload: %w", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", reqURL, bytes.NewBuffer(body))
if err != nil {
return false, fmt.Errorf("failed to create validation request: %w", err)
return false, 0, fmt.Errorf("failed to create validation request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+apiKey.KeyValue)
req.Header.Set("Content-Type", "application/json")
Expand All @@ -118,23 +118,25 @@ func (ch *OpenAIChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey,

resp, err := ch.HTTPClient.Do(req)
if err != nil {
return false, fmt.Errorf("failed to send validation request: %w", err)
return false, 0, fmt.Errorf("failed to send validation request: %w", err)
}
defer resp.Body.Close()

statusCode := resp.StatusCode

// Any 2xx status code indicates the key is valid.
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return true, nil
return true, statusCode, nil
}

// For non-200 responses, parse the body to provide a more specific error reason.
errorBody, err := io.ReadAll(resp.Body)
if err != nil {
return false, fmt.Errorf("key is invalid (status %d), but failed to read error body: %w", resp.StatusCode, err)
return false, statusCode, fmt.Errorf("key is invalid (status %d), but failed to read error body: %w", resp.StatusCode, err)
}

// Use the new parser to extract a clean error message.
parsedError := app_errors.ParseUpstreamError(errorBody)

return false, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError)
return false, statusCode, fmt.Errorf("[status %d] %s", resp.StatusCode, parsedError)
}
7 changes: 6 additions & 1 deletion internal/db/migrations/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ func MigrateDatabase(db *gorm.DB) error {
}

// Run v1.1.0 migration
return V1_1_0_AddKeyHashColumn(db)
if err := V1_1_0_AddKeyHashColumn(db); err != nil {
return err
}

// Run v1.2.0 migration
return V1_2_0_AddStatusCodeColumn(db)
}

// HandleLegacyIndexes removes old indexes from previous versions to prevent migration errors
Expand Down
23 changes: 23 additions & 0 deletions internal/db/migrations/v1_2_0_AddStatusCodeColumn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package db

import (
"fmt"
"gpt-load/internal/models"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)

// V1_2_0_AddStatusCodeColumn adds status_code column to api_keys table
func V1_2_0_AddStatusCodeColumn(db *gorm.DB) error {
// Prefer GORM migrator for portability
if db.Migrator().HasColumn(&models.APIKey{}, "status_code") {
logrus.Info("status_code column already exists, skipping migration")
return nil
}
// Will honor gorm:"default:0" tag on models.APIKey.StatusCode
if err := db.Migrator().AddColumn(&models.APIKey{}, "StatusCode"); err != nil {
return fmt.Errorf("failed to add status_code column: %w", err)
}
logrus.Info("Successfully added status_code column to api_keys table")
return nil
}
36 changes: 35 additions & 1 deletion internal/handler/key_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ func (s *Server) ListKeysInGroup(c *gin.Context) {
searchHash = s.EncryptionSvc.Hash(searchKeyword)
}

query := s.KeyService.ListKeysInGroupQuery(groupID, statusFilter, searchHash)
statusCodeFilter := c.Query("status_code")
query := s.KeyService.ListKeysInGroupQuery(groupID, statusFilter, searchHash, statusCodeFilter)

var keys []models.APIKey
paginatedResult, err := response.Paginate(c, query, &keys)
Expand Down Expand Up @@ -408,6 +409,39 @@ func (s *Server) ClearAllKeys(c *gin.Context) {
response.SuccessI18n(c, "success.all_keys_cleared", nil, map[string]any{"count": rowsAffected})
}

// ClearCurrentQueryKeys deletes keys based on current query conditions.
func (s *Server) ClearCurrentQueryKeys(c *gin.Context) {
var req struct {
GroupID uint `json:"group_id" binding:"required"`
KeyValue *string `json:"key_value,omitempty"`
StatusCode *int `json:"status_code,omitempty"`
Status *string `json:"status,omitempty"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
return
}

if _, ok := s.findGroupByID(c, req.GroupID); !ok {
return
}

// 构建查询条件
var keyHash *string
if req.KeyValue != nil && *req.KeyValue != "" {
hash := s.EncryptionSvc.Hash(*req.KeyValue)
keyHash = &hash
}

rowsAffected, err := s.KeyService.ClearCurrentQueryKeys(req.GroupID, keyHash, req.StatusCode, req.Status)
if err != nil {
response.Error(c, app_errors.ParseDBError(err))
return
}

response.SuccessI18n(c, "success.current_query_keys_cleared", nil, map[string]any{"count": rowsAffected})
}
Comment thread
AAEE86 marked this conversation as resolved.

// ExportKeys handles exporting keys to a text file.
func (s *Server) ExportKeys(c *gin.Context) {
groupID, ok := validateGroupIDFromQuery(c)
Expand Down
2 changes: 1 addition & 1 deletion internal/keypool/cron_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (s *CronChecker) validateGroupKeys(group *models.Group) {
keyForValidation := *key
keyForValidation.KeyValue = decryptedKey

isValid, _ := s.Validator.ValidateSingleKey(&keyForValidation, group)
isValid, _, _ := s.Validator.ValidateSingleKey(&keyForValidation, group)
if isValid {
atomic.AddInt32(&becameValidCount, 1)
}
Expand Down
10 changes: 10 additions & 0 deletions internal/keypool/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ func (p *KeyProvider) UpdateStatus(apiKey *models.APIKey, group *models.Group, i
}()
}

// UpdateStatusCode 更新密钥的状态码
func (p *KeyProvider) UpdateStatusCode(apiKey *models.APIKey, statusCode int) error {
// keep in-memory model consistent to avoid accidental overwrite
apiKey.StatusCode = statusCode
if err := p.db.Model(apiKey).Update("status_code", statusCode).Error; err != nil {
return err
}
return nil
}

// executeTransactionWithRetry wraps a database transaction with a retry mechanism.
func (p *KeyProvider) executeTransactionWithRetry(operation func(tx *gorm.DB) error) error {
const maxRetries = 3
Expand Down
Loading