Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customer function apply gpts #940

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ build
logs
data
/web/node_modules
cmd.md
/.history

cmd.md
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM node:16 as builder
FROM node:20 as builder

WORKDIR /web
COPY ./VERSION .
Expand Down
3 changes: 2 additions & 1 deletion common/gin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package common
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"strings"

"github.com/gin-gonic/gin"
)

const KeyRequestBody = "key_request_body"
Expand Down
31 changes: 31 additions & 0 deletions controller/log.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package controller

import (

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"

)

func GetAllLogs(c *gin.Context) {
Expand Down Expand Up @@ -36,7 +38,36 @@ func GetAllLogs(c *gin.Context) {
})
return
}
func GetLogsByKey(c *gin.Context) {
startidx, _ := strconv.Atoi(c.Query("startIdx"))
num, _ := strconv.Atoi(c.Query("num"))

if startidx <= 0 {
startidx = 0
}
if num <= 0 {
num = 10
}

logType, _ := strconv.Atoi(c.Query("type"))
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
key := c.Query("key")
logs, err := model.GetLogsByKey(logType, startTimestamp, endTimestamp, key, startidx, num)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": logs,
})
return
}
func GetUserLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
Expand Down
5 changes: 4 additions & 1 deletion controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"bytes"
"context"
"fmt"
"net/http"

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
Expand All @@ -17,7 +19,6 @@ import (
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
)

// https://platform.openai.com/docs/api-reference/chat
Expand Down Expand Up @@ -69,6 +70,7 @@ func Relay(c *gin.Context) {
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %+v", err)
break
}

logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i)
if channel.Id == lastFailedChannelId {
continue
Expand Down Expand Up @@ -116,6 +118,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
}

func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) {

logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
Expand Down
18 changes: 17 additions & 1 deletion controller/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,23 @@ func GetAllTokens(c *gin.Context) {
})
return
}

func GetNameByToken(c *gin.Context) {
token := c.Query("key")
name, err := model.GetNameByToken(token)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": name,
})
return
}
func SearchTokens(c *gin.Context) {
userId := c.GetInt(ctxkey.Id)
keyword := c.Query("keyword")
Expand Down
14 changes: 11 additions & 3 deletions middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package middleware

import (
"fmt"
"net/http"
"strconv"
"strings"

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channeltype"
"net/http"
"strconv"
)

type ModelRequest struct {
Expand Down Expand Up @@ -39,9 +41,15 @@ func Distribute() func(c *gin.Context) {
return
}
} else {

requestModel = c.GetString(ctxkey.RequestModel)
var err error
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
if strings.HasPrefix(requestModel, "gpt-4-gizmo") {
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, "gpt-4-gizmo", false)
} else {
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
}

if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel)
if channel != nil {
Expand Down
22 changes: 21 additions & 1 deletion model/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,27 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
logger.Error(ctx, "failed to record log: "+err.Error())
}
}

func GetLogsByKey(logType int, startTimestamp int64, endTimestamp int64, key string, startIdx int, num int) (logs []*Log, err error) {
var tx *gorm.DB
fmt.Println(num)
token, err := GetNameByToken(key)
if logType == LogTypeUnknown {
tx = DB.Debug()
} else {
tx = DB.Debug().Where("type = ?", logType)
}
if token != nil {
tx = tx.Where("token_name = ?", token.Name)
}
if startTimestamp != 0 {
tx = tx.Where("created_at >= ?", startTimestamp)
}
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
return logs, err
}
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
Expand Down
12 changes: 11 additions & 1 deletion model/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"

"github.com/songquanpeng/one-api/common/message"

"gorm.io/gorm"
)

Expand Down Expand Up @@ -56,7 +58,15 @@ func SearchUserTokens(userId int, keyword string) (tokens []*Token, err error) {
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", keyword+"%").Find(&tokens).Error
return tokens, err
}

func GetNameByToken(token string) (*Token, error) {
if token == "" {
return nil, errors.New("token为空")
}
token_name := Token{Key: token}
var err error = nil
err = DB.First(&token_name, "`key` = ?", token).Error
return &token_name, err
}
func ValidateUserToken(key string) (token *Token, err error) {
if key == "" {
return nil, errors.New("未提供令牌")
Expand Down
13 changes: 7 additions & 6 deletions relay/adaptor/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package adaptor
import (
"errors"
"fmt"
"io"
"net/http"

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/relay/meta"
"io"
"net/http"
)

func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
Expand All @@ -21,19 +22,19 @@ func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta
func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(meta)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
return nil, fmt.Errorf("get request url failed: ")
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
return nil, fmt.Errorf("new request failed: ")
}
err = a.SetupRequestHeader(c, req, meta)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
return nil, fmt.Errorf("setup request header failed: ")
}
resp, err := DoRequest(c, req)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
return nil, fmt.Errorf("do request failed: ")
}
return resp, nil
}
Expand Down
5 changes: 5 additions & 0 deletions relay/billing/ratio/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ const (
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{

// https://openai.com/pricing
"gpt-4": 15,
"gpt-4-0314": 15,
"gpt-4-0613": 15,
"gpt-4-gizmo": 15,
"gpt-4-32k": 30,
"gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30,
Expand Down Expand Up @@ -238,6 +240,9 @@ func GetModelRatio(name string) float64 {
name = strings.TrimSuffix(name, "-internet")
}
ratio, ok := ModelRatio[name]
if strings.Index(name, "gpt-4-gizmo") != -1 {
return ModelRatio["gpt-4-gizmo"]
}
if !ok {
ratio, ok = DefaultModelRatio[name]
}
Expand Down
3 changes: 2 additions & 1 deletion router/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func SetApiRouter(router *gin.Engine) {
tokenRoute.Use(middleware.UserAuth())
{
tokenRoute.GET("/", controller.GetAllTokens)
tokenRoute.GET("/search", controller.SearchTokens)
tokenRoute.GET("/search", middleware.CORS(), controller.SearchTokens)
tokenRoute.GET("/:id", controller.GetToken)
tokenRoute.POST("/", controller.AddToken)
tokenRoute.PUT("/", controller.UpdateToken)
Expand All @@ -106,6 +106,7 @@ func SetApiRouter(router *gin.Engine) {
logRoute := apiRouter.Group("/log")
logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs)
logRoute.GET("/key", middleware.CORS(), controller.GetLogsByKey)
logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
Expand Down
5 changes: 5 additions & 0 deletions router/dashboard.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
package router

import (


"github.com/gin-contrib/gzip"
"github.com/gin-gonic/gin"

"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/middleware"

)

func SetDashboardRouter(router *gin.Engine) {

apiRouter := router.Group("/")
apiRouter.Use(middleware.CORS())
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
Expand Down
2 changes: 1 addition & 1 deletion router/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS) {
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
if config.IsMasterNode && frontendBaseUrl != "" {
frontendBaseUrl = ""
logger.SysLog("FRONTEND_BASE_URL is ignored on master node")
logger.SysLog("FRONTEND_——BASE_URL is ignored on master node")
}
if frontendBaseUrl == "" {
SetWebRouter(router, buildFS)
Expand Down