Skip to content

Commit

Permalink
Doの実装
Browse files Browse the repository at this point in the history
  • Loading branch information
ikura-hamu committed Jul 13, 2023
1 parent bfeb3a4 commit 098d7be
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 37 deletions.
18 changes: 13 additions & 5 deletions service/search/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,24 @@ type Sort struct {
}

// GetSortKey ソートに使うキーの情報を抽出します
func (q Query) GetSortKey() Sort {
func (q Query) GetSortKey() string {
if !q.Sort.Valid {
return Sort{Key: createdAtSortKey, Desc: true}
return createdAtSortKey + ":" + descSortKey
}
match := allowedSortKeysRegExp.FindStringSubmatch(q.Sort.ValueOrZero())
if match[2] == "" {
return Sort{Key: createdAtSortKey, Desc: true}
return createdAtSortKey + ":" + descSortKey
}
if match[1] == "-" {
return Sort{Key: match[2], Desc: !shouldUseDescendingAsDefault(match[2])}
if shouldUseDescendingAsDefault(match[2]) {
return match[2] + ":" + ascSortKey
}
return match[2] + ":" + descSortKey
}
return Sort{Key: match[2], Desc: shouldUseDescendingAsDefault(match[2])}
if shouldUseDescendingAsDefault(match[2]) {
return match[2] + ":" + descSortKey
}
return match[2] + ":" + ascSortKey
}

// Result 検索結果インターフェイス
Expand All @@ -87,6 +93,8 @@ type Result interface {

const createdAtSortKey = "createdAt" // 作成日時の新しい順
const updatedAtSortKey = "updatedAt" // 更新日時の新しい順
const ascSortKey = "asc" //昇順
const descSortKey = "desc" //降順

var allowedSortKeys = []string{createdAtSortKey, updatedAtSortKey}
var allowedSortKeysRegExp = regexp.MustCompile("([+-]?)(" + strings.Join(allowedSortKeys, "|") + ")")
Expand Down
125 changes: 93 additions & 32 deletions service/search/es.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

"github.com/elastic/go-elasticsearch/v7"
"github.com/gofrs/uuid"
"github.com/olivere/elastic/v7"
"go.uber.org/zap"

"github.com/traPtitech/traQ/repository"
Expand Down Expand Up @@ -247,70 +246,112 @@ func NewESEngine(mm message.Manager, cm channel.Manager, repo repository.Reposit
return engine, nil
}

type searchQuery m

type searchBody struct {
Query *struct {
Bool *struct {
Musts []searchQuery `json:"musts,omitempty"`
} `json:"bool,omitempty"`
} `json:"query,omitempty"`
}

func NewSearchBody(sq []searchQuery) searchBody {
sb := searchBody{}
sb.Query.Bool.Musts = sq
return sb
}

type simpleQueryString struct {
Query string `json:"query"`
Fields []string `json:"fields"`
DefaultOperator string `json:"default_operator"`
}

type rangeQuery map[string]rangeParameters

type rangeParameters struct {
Lt string `json:"lt,omitempty"`
Gt string `json:"gt,omitempty"`
}

type termQuery map[string]termQueryParameter

type termQueryParameter struct {
Value any `json:"value,omitempty"`
}

func (e *esEngine) Do(q *Query) (Result, error) {
e.l.Debug("do search", zap.Reflect("q", q))

var musts []elastic.Query
var musts []searchQuery

if q.Word.Valid {
musts = append(musts, elastic.NewSimpleQueryStringQuery(q.Word.V).
Field("text").
DefaultOperator("AND"))
body := simpleQueryString{
Query: q.Word.V,
Fields: []string{"text"},
DefaultOperator: "AND",
}

musts = append(musts, searchQuery{"simple_query_string": body})
}

switch {
case q.After.Valid && q.Before.Valid:
musts = append(musts, elastic.NewRangeQuery("createdAt").
Gt(q.After.ValueOrZero().Format(esDateFormat)).
Lt(q.Before.ValueOrZero().Format(esDateFormat)))
musts = append(musts, searchQuery{"range": rangeQuery{"createdAt": rangeParameters{
Gt: q.After.ValueOrZero().Format(esDateFormat),
Lt: q.Before.ValueOrZero().Format(esDateFormat),
}}})
case q.After.Valid && !q.Before.Valid:
musts = append(musts, elastic.NewRangeQuery("createdAt").
Gt(q.After.ValueOrZero().Format(esDateFormat)))
musts = append(musts, searchQuery{"range": rangeQuery{"createdAt": rangeParameters{
Gt: q.After.ValueOrZero().Format(esDateFormat),
}}})
case !q.After.Valid && q.Before.Valid:
musts = append(musts, elastic.NewRangeQuery("createdAt").
Lt(q.Before.ValueOrZero().Format(esDateFormat)))
musts = append(musts, searchQuery{"rage": rangeQuery{"createdAt": rangeParameters{
Lt: q.Before.ValueOrZero().Format(esDateFormat),
}}})
}

// チャンネル指定があるときはそのチャンネルを検索
// そうでないときはPublicチャンネルを検索
if q.In.Valid {
musts = append(musts, elastic.NewTermQuery("channelId", q.In))
musts = append(musts, searchQuery{"term": termQuery{"channelId": termQueryParameter{Value: q.In}}})
} else {
musts = append(musts, elastic.NewTermQuery("isPublic", true))
musts = append(musts, searchQuery{"term": termQuery{"isPublic": termQueryParameter{Value: true}}})
}

if q.To.Valid {
musts = append(musts, elastic.NewTermQuery("to", q.To))
musts = append(musts, searchQuery{"term": termQuery{"to": termQueryParameter{Value: q.To}}})
}

if q.From.Valid {
musts = append(musts, elastic.NewTermQuery("userId", q.From))
musts = append(musts, searchQuery{"term": termQuery{"userId": termQueryParameter{Value: q.From}}})
}

if q.Citation.Valid {
musts = append(musts, elastic.NewTermQuery("citation", q.Citation))
musts = append(musts, searchQuery{"term": termQuery{"citation": termQueryParameter{Value: q.Citation}}})
}

if q.Bot.Valid {
musts = append(musts, elastic.NewTermQuery("bot", q.Bot))
musts = append(musts, searchQuery{"term": termQuery{"bot": termQueryParameter{Value: q.Bot}}})
}

if q.HasURL.Valid {
musts = append(musts, elastic.NewTermQuery("hasURL", q.HasURL))
musts = append(musts, searchQuery{"term": termQuery{"hasURL": termQueryParameter{Value: q.HasURL}}})
}

if q.HasAttachments.Valid {
musts = append(musts, elastic.NewTermQuery("hasAttachments", q.HasAttachments))
musts = append(musts, searchQuery{"term": termQuery{"hasAttachments": termQueryParameter{Value: q.HasAttachments}}})
}

if q.HasImage.Valid {
musts = append(musts, elastic.NewTermQuery("hasImage", q.HasImage))
musts = append(musts, searchQuery{"term": termQuery{"hasImage": termQueryParameter{Value: q.HasImage}}})
}
if q.HasVideo.Valid {
musts = append(musts, elastic.NewTermQuery("hasVideo", q.HasVideo))
musts = append(musts, searchQuery{"term": termQuery{"hasVideo": termQueryParameter{Value: q.HasVideo}}})
}
if q.HasAudio.Valid {
musts = append(musts, elastic.NewTermQuery("hasAudio", q.HasAudio))
musts = append(musts, searchQuery{"term": termQuery{"hasAudio": termQueryParameter{Value: q.HasAudio}}})
}

limit, offset := 20, 0
Expand All @@ -324,21 +365,41 @@ func (e *esEngine) Do(q *Query) (Result, error) {
// NOTE: 現状`sort.Key`はそのままesのソートキーとして使える前提
sort := q.GetSortKey()

b, err := json.Marshal(NewSearchBody(musts))
if err != nil {
return nil, fmt.Errorf("failed to marshal search query: %w", err)
}

sr, err := e.client.Search(
e.client.Search.WithIndex(getIndexName(esMessageIndex)),
e.client.Search.WithBody(bytes.NewBuffer(b)),
e.client.Search.WithSort(sort),
e.client.Search.WithSize(limit),
e.client.Search.WithFrom(offset),
e.client.Search.WithContext(context.Background()),
)

sr, err := e.client.Search().
Index(getIndexName(esMessageIndex)).
Query(elastic.NewBoolQuery().Must(musts...)).
Sort(sort.Key, !sort.Desc).
Size(limit).
From(offset).
Do(context.Background())
if err != nil {
return nil, err
}
if sr.IsError() {
return nil, fmt.Errorf("failed to get search result")
}
var searchResultBody []byte
_, err = sr.Body.Read(searchResultBody)
if err != nil {
return nil, fmt.Errorf("failed to get search result body")
}
defer sr.Body.Close()

var res m
err = json.Unmarshal(searchResultBody, &res)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal response body")
}

e.l.Debug("search result", zap.Reflect("hits", sr.Hits))
return e.bindESResult(sr)
e.l.Debug("search result", zap.Reflect("hits", res["hits"]))
return e.parseResBody(res)
}

func (e *esEngine) Available() bool {
Expand Down
33 changes: 33 additions & 0 deletions service/search/es_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,39 @@ type esResult struct {
messages []message.Message
}

func (e *esEngine) parseResBody(resBody m) (Result, error) {
totalHits := resBody["hits"].(m)["total"].(m)["value"].(int64)
hits := resBody["hits"].(m)["hits"].([]map[string]any)

r := &esResult{
totalHits: totalHits,
messages: make([]message.Message, 0, len(hits)),
}

messageIDs := utils.Map(hits, func(hit map[string]any) uuid.UUID {
return uuid.Must(uuid.FromString(hit["_id"].(string)))
})

messages, err := e.mm.GetIn(messageIDs)
if err != nil {
return nil, err
}

messagesMap := lo.SliceToMap(messages, func(m message.Message) (uuid.UUID, message.Message) {
return m.GetID(), m
})
// sort result
for _, id := range messageIDs {
msg, ok := messagesMap[id]
if !ok {
return nil, fmt.Errorf("message %v not found", id)
}
r.messages = append(r.messages, msg)
}

return r, nil
}

func (e *esEngine) bindESResult(sr *elastic.SearchResult) (Result, error) {
r := &esResult{
totalHits: sr.TotalHits(),
Expand Down

0 comments on commit 098d7be

Please sign in to comment.