Skip to content

Commit

Permalink
feat: Allow to skip local TLS when token is set
Browse files Browse the repository at this point in the history
By default we enforce TLS if token or secret is set.
This feature allow to disable TLS when connecting to local
instance with authenticatiaon enabled.
TLS can only be disabled for HTTP.
GRPC always enforce TLS when credentials non empty, it only
allows to disble it on Unix sockets.

* Added unix socket support.
* Fixed a bug when user was automatically logged in to local instance
  with authentication enabled.
* This also updated the client to the latest, which requires inference
code updates.
  • Loading branch information
efirs committed Jun 1, 2023
1 parent 8ee9c21 commit 9156f93
Show file tree
Hide file tree
Showing 11 changed files with 193 additions and 114 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go-releaser.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
node-version: 18
registry-url: 'https://registry.npmjs.org'

- name: Genearate shell completions
- name: Generate shell completions
run: |
make
mkdir -p .tmp
Expand Down
5 changes: 3 additions & 2 deletions client/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ func initConfig(inCfg *config.Config) {
Token: inCfg.Token,
Protocol: inCfg.Protocol,
Branch: inCfg.Branch,
SkipLocalTLS: inCfg.SkipLocalTLS,
}

if inCfg.UseTLS || (cfg.URL == "" && cfg.Protocol == "") ||
strings.HasSuffix(cfg.URL, config.Domain) {
if !cfg.SkipLocalTLS && (inCfg.UseTLS || (cfg.URL == "" && cfg.Protocol == "") ||
strings.HasSuffix(cfg.URL, config.Domain)) {
cfg.TLS = &tls.Config{MinVersion: tls.VersionTLS12}
}
}
Expand Down
5 changes: 3 additions & 2 deletions cmd/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ func waitServerUp(port string) {
err := tclient.Init(&cfg)
util.Fatal(err, "init tigris client")

if err := pingLow(context.Background(), waitUpTimeout, pingSleepTimeout, true); err != nil {
if err = pingLow(context.Background(), waitUpTimeout, pingSleepTimeout, true, true,
util.IsTTY(os.Stdout) && !util.Quiet); err != nil {
util.Fatal(err, "tigris initialization failed")
}

Expand Down Expand Up @@ -236,7 +237,7 @@ var serverUpCmd = &cobra.Command{
}

if loginParam {
login.LocalLogin(net.JoinHostPort("localhost", port))
login.LocalLogin(net.JoinHostPort("localhost", port), "")
} else if port != "8081" {
util.Stdoutf("run 'export TIGRIS_URL=localhost:%s' for tigris cli to connect to the local instance\n", port)
}
Expand Down
79 changes: 71 additions & 8 deletions cmd/ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,78 @@ package cmd
import (
"context"
"fmt"
"math/rand"
"os"
"strings"
"time"

"github.com/schollz/progressbar/v3"
"github.com/spf13/cobra"
"github.com/tigrisdata/tigris-cli/client"
"github.com/tigrisdata/tigris-cli/config"
"github.com/tigrisdata/tigris-cli/util"
)

var pingTimeout time.Duration

func pingLow(cmdCtx context.Context, timeout time.Duration, initSleep time.Duration, linear bool) error {
func pingCall(ctx context.Context, waitAuth bool) error {
var err error

if waitAuth {
_, err = client.D.ListProjects(ctx)
} else {
_, err = client.D.Health(ctx)
}

return err
}

func localURL(url string) bool {
return strings.HasPrefix(url, "localhost:") ||
strings.HasPrefix(url, "127.0.0.1:") ||
strings.HasPrefix(url, "http://localhost:") ||
strings.HasPrefix(url, "http://127.0.0.1:") ||
strings.HasPrefix(url, "[::1]") ||
strings.HasPrefix(url, "http://[::1]:")
}

func initPingProgressBar(init bool) *progressbar.ProgressBar {
if !init {
return nil
}

return progressbar.NewOptions64(
-1,
progressbar.OptionSetDescription("Waiting for OK response"),
progressbar.OptionSetWriter(os.Stderr),
progressbar.OptionSetWidth(10),
progressbar.OptionThrottle(65*time.Millisecond),
progressbar.OptionOnCompletion(func() {
_, _ = fmt.Fprint(os.Stderr, "\n")
}),
progressbar.OptionSpinnerType(int(rand.Int63()%76)), //nolint:gosec
progressbar.OptionSetRenderBlankState(true),
progressbar.OptionFullWidth(),
)
}

func pingLow(cmdCtx context.Context, timeout time.Duration, sleep time.Duration, linear bool, waitAuth bool,
pgBar bool,
) error {
ctx, cancel := util.GetContext(cmdCtx)

err := client.InitLow()
if err == nil {
_, err = client.D.Health(ctx)
err = pingCall(ctx, waitAuth)
}

_ = util.Error(err, "ping")

cancel()

pb := initPingProgressBar(pgBar)

end := time.Now().Add(timeout)
sleep := initSleep

for err != nil && timeout > 0 && time.Now().Add(sleep).Before(end) {
_ = util.Error(err, "ping sleep %v", sleep)
Expand All @@ -55,7 +103,9 @@ func pingLow(cmdCtx context.Context, timeout time.Duration, initSleep time.Durat
ctx, cancel = util.GetContext(cmdCtx)

if err = client.InitLow(); err == nil {
_, err = client.D.Health(ctx)
if err = pingCall(ctx, waitAuth); err == nil {
break
}
}

cancel()
Expand All @@ -68,6 +118,10 @@ func pingLow(cmdCtx context.Context, timeout time.Duration, initSleep time.Durat
sleep = rem
_ = util.Error(err, "ping sleep1 %v", sleep)
}

if pb != nil {
_ = pb.Add(int(sleep / time.Millisecond))
}
}

return err
Expand All @@ -77,12 +131,21 @@ var pingCmd = &cobra.Command{
Use: "ping",
Short: "Checks connection to Tigris",
Run: func(cmd *cobra.Command, args []string) {
if err := pingLow(cmd.Context(), pingTimeout, 32*time.Millisecond, false); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "FAILED\n")
os.Exit(1) //nolint:revive
var err error

_ = client.Init(&config.DefaultConfig)

waitForAuth := localURL(config.DefaultConfig.URL) && (config.DefaultConfig.Token != "" ||
config.DefaultConfig.ClientSecret != "")

if err = pingLow(cmd.Context(), pingTimeout, 32*time.Millisecond, localURL(config.DefaultConfig.URL),
waitForAuth, util.IsTTY(os.Stdout) && !util.Quiet); err == nil {
_, _ = fmt.Fprintf(os.Stderr, "OK\n")
return
}

_, _ = fmt.Fprintf(os.Stderr, "OK\n")
_, _ = fmt.Fprintf(os.Stderr, "FAILED\n")
os.Exit(1) //nolint:revive
},
}

Expand Down
17 changes: 10 additions & 7 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,19 @@ type Log struct {
}

type Config struct {
ClientID string `json:"client_id" yaml:"client_id,omitempty" mapstructure:"client_id"`
ClientSecret string `json:"client_secret" yaml:"client_secret,omitempty" mapstructure:"client_secret"`
Token string `json:"token" yaml:"token,omitempty"`
URL string `json:"url" yaml:"url,omitempty"`
Protocol string `json:"protocol" yaml:"protocol,omitempty"`
Project string `json:"project" yaml:"project,omitempty"`
Branch string `json:"branch" yaml:"branch,omitempty"`
ClientID string `json:"client_id" yaml:"client_id,omitempty" mapstructure:"client_id"`
ClientSecret string `json:"client_secret" yaml:"client_secret,omitempty" mapstructure:"client_secret"`
Token string `json:"token" yaml:"token,omitempty"`
URL string `json:"url" yaml:"url,omitempty"`
Protocol string `json:"protocol" yaml:"protocol,omitempty"`
Project string `json:"project" yaml:"project,omitempty"`
Branch string `json:"branch" yaml:"branch,omitempty"`
DataDir string `json:"data_dir" yaml:"data_dir,omitempty"`

Log Log `json:"log" yaml:"log,omitempty"`
Timeout time.Duration `json:"timeout" yaml:"timeout,omitempty"`
UseTLS bool `json:"use_tls" yaml:"use_tls,omitempty" mapstructure:"use_tls"`
SkipLocalTLS bool `json:"skip_local_tls" yaml:"skip_local_tls,omitempty" mapstructure:"skip_local_tls"`
}

var DefaultName = "tigris-cli"
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/tigrisdata/tigris-cli

go 1.18
go 1.20

require (
github.com/coreos/go-oidc/v3 v3.5.0
Expand All @@ -19,7 +19,7 @@ require (
github.com/spf13/cobra v1.7.0
github.com/spf13/viper v1.15.0
github.com/stretchr/testify v1.8.2
github.com/tigrisdata/tigris-client-go v1.0.0
github.com/tigrisdata/tigris-client-go v1.1.0-next.6
golang.org/x/net v0.10.0
golang.org/x/oauth2 v0.8.0
gopkg.in/yaml.v2 v2.4.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8=
github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0=
github.com/tigrisdata/tigris-client-go v1.0.0 h1:07Qw8Tm0qL15WiadP0hp4iBiRzfNSJ+GH4/ozO0nNs0=
github.com/tigrisdata/tigris-client-go v1.0.0/go.mod h1:2n6TQUdoTbzuTtakHT/ZNuK5X+I/i57BqqCcYAzG7y4=
github.com/tigrisdata/tigris-client-go v1.1.0-next.6 h1:Bkr74x8uXeArEbTI5osyLsPyRwK07TzwtXqvydmW/fY=
github.com/tigrisdata/tigris-client-go v1.1.0-next.6/go.mod h1:2n6TQUdoTbzuTtakHT/ZNuK5X+I/i57BqqCcYAzG7y4=
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
Expand Down
20 changes: 15 additions & 5 deletions login/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ func Ensure(cctx context.Context, fn func(ctx context.Context) error) {
var ep *driver.Error
if !errors.As(err, &ep) || ep.Code != ecode.Unauthenticated ||
os.Getenv(driver.EnvClientID) != "" || os.Getenv(driver.EnvClientSecret) != "" ||
config.DefaultConfig.ClientID != "" || config.DefaultConfig.ClientSecret != "" || !util.IsTTY(os.Stdin) {
config.DefaultConfig.ClientID != "" || config.DefaultConfig.ClientSecret != "" || !util.IsTTY(os.Stdin) ||
isLocalConn(GetHost("")) {
util.PrintError(err)
os.Exit(1) //nolint:revive
}
Expand Down Expand Up @@ -269,14 +270,23 @@ func waitCallbackServerUp() {
}
}

func isUnixSock(url string) bool {
return len(url) > 0 && (url[0] == '/' || url[0] == '.')
}

func isLocalConn(host string) bool {
return host == "local" || host == "dev" || strings.HasPrefix(host, "localhost") ||
isUnixSock(host)
}

func localLogin(host string) bool {
if host == "local" || host == "dev" || strings.HasPrefix(host, "localhost") {
if isLocalConn(host) {
// handle the cases without a port
if host == "local" || host == "dev" || host == "localhost" {
host = "localhost:8081"
}

LocalLogin(host)
LocalLogin(host, "")

return true
}
Expand Down Expand Up @@ -368,10 +378,10 @@ func CmdLow(_ context.Context, host string) error {
return util.Error(callbackErr, "callback error")
}

func LocalLogin(host string) {
func LocalLogin(host string, token string) {
config.DefaultConfig.ClientSecret = ""
config.DefaultConfig.ClientID = ""
config.DefaultConfig.Token = ""
config.DefaultConfig.Token = token
config.DefaultConfig.URL = host

err := config.Save(config.DefaultName, config.DefaultConfig)
Expand Down
32 changes: 16 additions & 16 deletions schema/inference.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func parseNumber(v any, existing *schema.Field) (string, string, error) {
return "", "", ErrExpectedNumber
}

if _, err := n.Int64(); err != nil || (!DetectIntegers && (existing == nil || existing.Type != typeInteger)) {
if _, err := n.Int64(); err != nil || (!DetectIntegers && (existing == nil || existing.Type.First() != typeInteger)) {
if _, err = n.Float64(); err != nil {
return "", "", err
}
Expand Down Expand Up @@ -208,17 +208,17 @@ func traverseObject(name string, existingField *schema.Field, newField *schema.F
switch {
case existingField == nil:
newField.Fields = make(map[string]*schema.Field)
case existingField.Type == typeObject:
case existingField.Type.First() == typeObject:
if existingField.Fields == nil {
newField.Fields = make(map[string]*schema.Field)
} else {
newField.Fields = existingField.Fields
}
default:
log.Debug().Str("oldType", existingField.Type).Str("newType", newField.Type).Interface("values", values).
Msg("object converted to primitive")
log.Debug().Str("oldType", existingField.Type.First()).Str("newType", newField.Type.First()).
Interface("values", values).Msg("object converted to primitive")

return newInompatibleSchemaError(name, existingField.Type, "", newField.Type, "")
return newInompatibleSchemaError(name, existingField.Type.First(), "", newField.Type.First(), "")
}

return traverseFields(newField.Fields, values, nil)
Expand All @@ -234,23 +234,23 @@ func traverseArray(name string, existingField *schema.Field, newField *schema.Fi
if i == 0 {
switch {
case existingField == nil:
newField.Items = &schema.Field{Type: t, Format: format}
case existingField.Type == typeArray:
newField.Items = &schema.Field{Type: schema.NewMultiType(t), Format: format}
case existingField.Type.First() == typeArray:
newField.Items = existingField.Items
default:
log.Debug().Str("oldType", existingField.Type).Str("newType", newField.Type).Interface("values", v).
log.Debug().Str("oldType", existingField.Type.First()).Str("newType", newField.Type.First()).Interface("values", v).
Msg("object converted to primitive")

return newInompatibleSchemaError(name, existingField.Type, "", newField.Type, "")
return newInompatibleSchemaError(name, existingField.Type.First(), "", newField.Type.First(), "")
}
}

nt, nf, err := extendedType(name, newField.Items.Type, newField.Items.Format, t, format)
nt, nf, err := extendedType(name, newField.Items.Type.First(), newField.Items.Format, t, format)
if err != nil {
return err
}

newField.Items.Type = nt
newField.Items.Type.Set(nt)
newField.Items.Format = nf

if t == typeObject {
Expand Down Expand Up @@ -309,12 +309,12 @@ func traverseFieldsLow(t string, format string, name string, f *schema.Field, v
return true, nil // empty object
}
case sch[name] != nil:
nt, nf, err := extendedType(name, sch[name].Type, sch[name].Format, t, format)
nt, nf, err := extendedType(name, sch[name].Type.First(), sch[name].Format, t, format)
if err != nil {
return false, err
}

f.Type = nt
f.Type.Set(nt)
f.Format = nf
}

Expand All @@ -333,7 +333,7 @@ func traverseFields(sch map[string]*schema.Field, fields map[string]any, autoGen
return err
}

f := &schema.Field{Type: t, Format: format}
f := &schema.Field{Type: schema.NewMultiType(t), Format: format}

skip, err := traverseFieldsLow(t, format, name, f, val, sch)
if err != nil {
Expand Down Expand Up @@ -423,7 +423,7 @@ func GenerateInitDoc(sch *schema.Schema, doc json.RawMessage) ([]byte, error) {
}

func initDocTraverseFields(field *schema.Field, doc map[string]any, fieldName string) error {
switch field.Type {
switch field.Type.First() {
case typeNumber:
doc[fieldName] = 0.0000001
case typeObject:
Expand All @@ -436,7 +436,7 @@ func initDocTraverseFields(field *schema.Field, doc map[string]any, fieldName st

doc[fieldName] = vo
case typeArray:
if field.Items.Type == typeObject {
if field.Items.Type.First() == typeObject {
vo := map[string]any{}
for name := range field.Items.Fields {
if err := initDocTraverseFields(field.Items.Fields[name], vo, name); err != nil {
Expand Down
Loading

0 comments on commit 9156f93

Please sign in to comment.