Skip to content

Commit

Permalink
Wrap gocql.Iter with custom wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
alexshtin committed Nov 10, 2022
1 parent bbd386e commit 64536e8
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 40 deletions.
15 changes: 4 additions & 11 deletions common/persistence/cassandra/schema_version_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
package cassandra

import (
"errors"
"fmt"

"go.temporal.io/server/common/persistence/nosql/nosqlplugin/cassandra/gocql"
)
Expand All @@ -40,10 +40,6 @@ type (
}
)

var (
ErrGetSchemaVersion = errors.New("failed to get current schema version from cassandra")
)

func NewSchemaVersionReader(session gocql.Session) *SchemaVersionReader {
return &SchemaVersionReader{
session: session,
Expand All @@ -56,12 +52,9 @@ func (svr *SchemaVersionReader) ReadSchemaVersion(keyspace string) (string, erro

iter := query.Iter()
var version string
if !iter.Scan(&version) {
_ = iter.Close()
return "", ErrGetSchemaVersion
}
if err := iter.Close(); err != nil {
return "", err
success := iter.Scan(&version)
if err := iter.Close(); err != nil || !success {
return "", fmt.Errorf("unable to get current schema version from Cassandra: %w", err)
}
return version, nil
}
59 changes: 59 additions & 0 deletions common/persistence/nosql/nosqlplugin/cassandra/gocql/iter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// The MIT License
//
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
//
// Copyright (c) 2020 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package gocql

import (
"github.com/gocql/gocql"
)

type iter struct {
session *session
gocqlIter *gocql.Iter
}

func newIter(session *session, gocqlIter *gocql.Iter) Iter {
return &iter{
session: session,
gocqlIter: gocqlIter,
}
}

func (it *iter) Scan(dest ...interface{}) bool {
return it.gocqlIter.Scan(dest...)
}

func (it *iter) MapScan(m map[string]interface{}) bool {
return it.gocqlIter.MapScan(m)
}

func (it *iter) PageState() []byte {
return it.gocqlIter.PageState()
}

func (it *iter) Close() (retError error) {
defer func() { it.session.handleError(retError) }()

return it.gocqlIter.Close()
}
23 changes: 6 additions & 17 deletions common/persistence/nosql/nosqlplugin/cassandra/gocql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,46 +50,46 @@ func newQuery(
}

func (q *query) Exec() (retError error) {
defer func() { q.handleError(retError) }()
defer func() { q.session.handleError(retError) }()

return q.gocqlQuery.Exec()
}

func (q *query) Scan(
dest ...interface{},
) (retError error) {
defer func() { q.handleError(retError) }()
defer func() { q.session.handleError(retError) }()

return q.gocqlQuery.Scan(dest...)
}

func (q *query) ScanCAS(
dest ...interface{},
) (_ bool, retError error) {
defer func() { q.handleError(retError) }()
defer func() { q.session.handleError(retError) }()

return q.gocqlQuery.ScanCAS(dest...)
}

func (q *query) MapScan(
m map[string]interface{},
) (retError error) {
defer func() { q.handleError(retError) }()
defer func() { q.session.handleError(retError) }()

return q.gocqlQuery.MapScan(m)
}

func (q *query) MapScanCAS(
dest map[string]interface{},
) (_ bool, retError error) {
defer func() { q.handleError(retError) }()
defer func() { q.session.handleError(retError) }()

return q.gocqlQuery.MapScanCAS(dest)
}

func (q *query) Iter() Iter {
iter := q.gocqlQuery.Iter()
return iter
return newIter(q.session, iter)
}

func (q *query) PageSize(n int) Query {
Expand Down Expand Up @@ -124,14 +124,3 @@ func (q *query) Bind(v ...interface{}) Query {
q.gocqlQuery.Bind(v...)
return newQuery(q.session, q.gocqlQuery)
}

func (q *query) handleError(
err error,
) {
switch err {
case gocql.ErrNoConnections:
q.session.refresh()
default:
// noop
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,21 @@ func (s *session) refresh() {
defer s.Unlock()

if time.Now().UTC().Sub(s.sessionInitTime) < sessionRefreshMinInternal {
s.logger.Warn("too soon to refresh cql session")
s.logger.Warn("gocql wrapper: too soon to refresh gocql session")
return
}

newSession, err := initSession(s.config, s.resolver)
if err != nil {
s.logger.Error("unable to refresh cql session", tag.Error(err))
s.logger.Error("gocql wrapper: unable to refresh gocql session", tag.Error(err))
return
}

s.sessionInitTime = time.Now().UTC()
oldSession := s.Value.Load().(*gocql.Session)
s.Value.Store(newSession)
go oldSession.Close()
s.logger.Warn("successfully refreshed cql session")
s.logger.Warn("gocql wrapper: successfully refreshed gocql session")
}

func initSession(
Expand Down
12 changes: 3 additions & 9 deletions tools/cassandra/cqlclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ package cassandra

import (
"context"
"errors"
"fmt"
"time"

Expand Down Expand Up @@ -65,8 +64,6 @@ type (
}
)

var errGetSchemaVersion = errors.New("unable to get current schema version from cassandra")

const (
defaultTimeout = 30 // Timeout in seconds
systemKeyspace = "system"
Expand Down Expand Up @@ -190,12 +187,9 @@ func (client *cqlClient) ReadSchemaVersion() (string, error) {

iter := query.Iter()
var version string
if !iter.Scan(&version) {
iter.Close()
return "", errGetSchemaVersion
}
if err := iter.Close(); err != nil {
return "", err
success := iter.Scan(&version)
if err := iter.Close(); err != nil || !success {
return "", fmt.Errorf("unable to get current schema version from Cassandra: %w", err)
}
return version, nil
}
Expand Down

0 comments on commit 64536e8

Please sign in to comment.