Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't fail when decoding empty strings as floats in server responses #117

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 43 additions & 21 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,27 +713,27 @@ type stmtResponse struct {
}

type stmtStats struct {
State string `json:"state"`
Scheduled bool `json:"scheduled"`
Nodes int `json:"nodes"`
TotalSplits int `json:"totalSplits"`
QueuesSplits int `json:"queuedSplits"`
RunningSplits int `json:"runningSplits"`
CompletedSplits int `json:"completedSplits"`
UserTimeMillis int `json:"userTimeMillis"`
CPUTimeMillis int64 `json:"cpuTimeMillis"`
WallTimeMillis int64 `json:"wallTimeMillis"`
QueuedTimeMillis int64 `json:"queuedTimeMillis"`
ElapsedTimeMillis int64 `json:"elapsedTimeMillis"`
ProcessedRows int64 `json:"processedRows"`
ProcessedBytes int64 `json:"processedBytes"`
PhysicalInputBytes int64 `json:"physicalInputBytes"`
PhysicalWrittenBytes int64 `json:"physicalWrittenBytes"`
PeakMemoryBytes int64 `json:"peakMemoryBytes"`
SpilledBytes int64 `json:"spilledBytes"`
RootStage stmtStage `json:"rootStage"`
ProgressPercentage float32 `json:"progressPercentage"`
RunningPercentage float32 `json:"runningPercentage"`
State string `json:"state"`
Scheduled bool `json:"scheduled"`
Nodes int `json:"nodes"`
TotalSplits int `json:"totalSplits"`
QueuesSplits int `json:"queuedSplits"`
RunningSplits int `json:"runningSplits"`
CompletedSplits int `json:"completedSplits"`
UserTimeMillis int `json:"userTimeMillis"`
CPUTimeMillis int64 `json:"cpuTimeMillis"`
WallTimeMillis int64 `json:"wallTimeMillis"`
QueuedTimeMillis int64 `json:"queuedTimeMillis"`
ElapsedTimeMillis int64 `json:"elapsedTimeMillis"`
ProcessedRows int64 `json:"processedRows"`
ProcessedBytes int64 `json:"processedBytes"`
PhysicalInputBytes int64 `json:"physicalInputBytes"`
PhysicalWrittenBytes int64 `json:"physicalWrittenBytes"`
PeakMemoryBytes int64 `json:"peakMemoryBytes"`
SpilledBytes int64 `json:"spilledBytes"`
RootStage stmtStage `json:"rootStage"`
ProgressPercentage jsonFloat64 `json:"progressPercentage"`
RunningPercentage jsonFloat64 `json:"runningPercentage"`
}

type ErrTrino struct {
Expand Down Expand Up @@ -792,6 +792,28 @@ type stmtStage struct {
SubStages []stmtStage `json:"subStages"`
}

type jsonFloat64 float64

func (f *jsonFloat64) UnmarshalJSON(data []byte) error {
if string(data) == `""` {
if f != nil {
*f = 0
}
return nil
}

var v float64
err := json.Unmarshal(data, &v)
if err != nil {
return err
}
p := (*float64)(f)
*p = v
return nil
}

var _ json.Unmarshaler = new(jsonFloat64)

func (st *driverStmt) Query(args []driver.Value) (driver.Rows, error) {
return nil, driver.ErrSkip
}
Expand Down
39 changes: 37 additions & 2 deletions trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,34 @@ func TestRoundTripRetryQueryError(t *testing.T) {
assert.IsTypef(t, new(ErrQueryFailed), err, "unexpected error: %w", err)
}

func TestRoundTripBogusData(t *testing.T) {
count := 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if count == 0 {
count++
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
// some invalid JSON
w.Write([]byte(`{"stats": {"progressPercentage": ""}}`))
}))

t.Cleanup(ts.Close)

db, err := sql.Open("trino", ts.URL)
require.NoError(t, err)

t.Cleanup(func() {
assert.NoError(t, db.Close())
})

rows, err := db.Query("SELECT 1")
require.NoError(t, err)
assert.False(t, rows.Next())
require.NoError(t, rows.Err())
}

func TestRoundTripCancellation(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
Expand Down Expand Up @@ -336,10 +364,12 @@ func TestQueryForUsername(t *testing.T) {
}

type TestQueryProgressCallback struct {
statusMap map[time.Time]string
progressMap map[time.Time]float64
statusMap map[time.Time]string
}

func (qpc *TestQueryProgressCallback) Update(qpi QueryProgressInfo) {
qpc.progressMap[time.Now()] = float64(qpi.QueryStats.ProgressPercentage)
qpc.statusMap[time.Now()] = qpi.QueryStats.State
}

Expand Down Expand Up @@ -387,9 +417,11 @@ func TestQueryProgressWithCallbackPeriod(t *testing.T) {
assert.NoError(t, db.Close())
})

progressMap := make(map[time.Time]float64)
statusMap := make(map[time.Time]string)
progressUpdater := &TestQueryProgressCallback{
statusMap: statusMap,
progressMap: progressMap,
statusMap: statusMap,
}
progressUpdaterPeriod, err := time.ParseDuration("1ms")
require.NoError(t, err)
Expand All @@ -416,6 +448,8 @@ func TestQueryProgressWithCallbackPeriod(t *testing.T) {
}

// sort time in order to calculate interval
assert.NotEmpty(t, progressMap)
assert.NotEmpty(t, statusMap)
var keys []time.Time
for k := range statusMap {
keys = append(keys, k)
Expand All @@ -428,6 +462,7 @@ func TestQueryProgressWithCallbackPeriod(t *testing.T) {
if i > 0 {
assert.GreaterOrEqual(t, k.Sub(keys[i-1]), progressUpdaterPeriod)
}
assert.GreaterOrEqual(t, progressMap[k], 0.0)
}
}

Expand Down