Skip to content

Commit

Permalink
feat: wh Endpoint To Fetch Tables Per Connection (#3279)
Browse files Browse the repository at this point in the history
* feat: wh endpoint to fetch tables per connection
- An new endpoint `/v1/warehouse/fetch-tables`
 It will fetch most recent tables for connection `source - destination`

* chore: fix merging master issue

* chore: fix formatting error

* chore: check for empty connections list

* chore: lint fix

* chore: code review changes
- rename structs for Request and response
- avoid sql injection

* chore: lint fix

* chore: fix json name for `FetchTablesResponse.ConnectionsTables`

* chore: add namespace FetchTableInfo struct
  • Loading branch information
am6010 committed May 17, 2023
1 parent 4a4f293 commit ea7d5ce
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 0 deletions.
53 changes: 53 additions & 0 deletions warehouse/internal/repo/schema.go
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"strings"

"github.com/rudderlabs/rudder-server/utils/timeutil"
"github.com/rudderlabs/rudder-server/warehouse/internal/model"
Expand Down Expand Up @@ -192,3 +193,55 @@ func (repo *WHSchema) GetNamespace(ctx context.Context, sourceID, destID string)

return namespace, nil
}

func (repo *WHSchema) GetTablesForConnection(ctx context.Context, connections []warehouseutils.SourceIDDestinationID) ([]warehouseutils.FetchTableInfo, error) {
if len(connections) == 0 {
return nil, fmt.Errorf("no source id and destination id pairs provided")
}

var parameters []interface{}
sourceIDDestinationIDPairs := make([]string, len(connections))
for idx, connection := range connections {
sourceIDDestinationIDPairs[idx] = fmt.Sprintf("($%d,$%d)", 2*idx+1, 2*idx+2)
parameters = append(parameters, connection.SourceID, connection.DestinationID)
}

// select all rows with max id for each source id and destination id pair
query := `SELECT` + whSchemaTableColumns + `FROM ` + whSchemaTableName + `
WHERE id IN (
SELECT max(id) FROM ` + whSchemaTableName + `
WHERE
(source_id, destination_id) IN (` + strings.Join(sourceIDDestinationIDPairs, ", ") + `)
AND
schema::text <> '{}'::text
GROUP BY id
)`
rows, err := repo.db.QueryContext(
ctx,
query,
parameters...)
if err != nil {
return nil, fmt.Errorf("querying schema: %w", err)
}

entries, err := repo.parseRows(rows)
if err != nil {
return nil, fmt.Errorf("parsing rows: %w", err)
}

var tables []warehouseutils.FetchTableInfo
for _, entry := range entries {
var allTablesOfConnections []string
for tableName := range entry.Schema {
allTablesOfConnections = append(allTablesOfConnections, tableName)
}
tables = append(tables, warehouseutils.FetchTableInfo{
SourceID: entry.SourceID,
DestinationID: entry.DestinationID,
Namespace: entry.Namespace,
Tables: allTablesOfConnections,
})
}

return tables, nil
}
31 changes: 31 additions & 0 deletions warehouse/internal/repo/schema_test.go
Expand Up @@ -6,6 +6,10 @@ import (
"testing"
"time"

"golang.org/x/exp/slices"

warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils"

"github.com/rudderlabs/rudder-server/warehouse/internal/model"
"github.com/rudderlabs/rudder-server/warehouse/internal/repo"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -112,4 +116,31 @@ func TestWHSchemasRepo(t *testing.T) {
require.NoError(t, err)
require.Empty(t, expectedNamespace, expectedNamespace)
})

t.Run("GetTablesForConnection", func(t *testing.T) {
t.Log("existing")
connection := warehouseutils.SourceIDDestinationID{SourceID: sourceID, DestinationID: destinationID}
expectedTableNames, err := r.GetTablesForConnection(ctx, []warehouseutils.SourceIDDestinationID{connection})
require.NoError(t, err)
require.Equal(t, len(expectedTableNames), 1)
require.Equal(t, expectedTableNames[0].SourceID, sourceID)
require.Equal(t, expectedTableNames[0].DestinationID, destinationID)
require.Equal(t, expectedTableNames[0].Namespace, namespace)
require.True(t, slices.Contains(expectedTableNames[0].Tables, "table_name_1"))
require.True(t, slices.Contains(expectedTableNames[0].Tables, "table_name_2"))

t.Log("cancelled context")
_, err = r.GetTablesForConnection(cancelledCtx, []warehouseutils.SourceIDDestinationID{connection})
require.EqualError(t, err, errors.New("querying schema: context canceled").Error())

t.Log("not found")
expectedTableNames, err = r.GetTablesForConnection(ctx,
[]warehouseutils.SourceIDDestinationID{{SourceID: notFound, DestinationID: notFound}})
require.NoError(t, err)
require.Empty(t, expectedTableNames, expectedTableNames)

t.Log("empty")
_, err = r.GetTablesForConnection(ctx, []warehouseutils.SourceIDDestinationID{})
require.EqualError(t, err, errors.New("no source id and destination id pairs provided").Error())
})
}
20 changes: 20 additions & 0 deletions warehouse/utils/utils.go
Expand Up @@ -297,6 +297,26 @@ type TriggerUploadRequest struct {
DestinationID string `json:"destination_id"`
}

type SourceIDDestinationID struct {
SourceID string `json:"source_id"`
DestinationID string `json:"destination_id"`
}

type FetchTablesRequest struct {
Connections []SourceIDDestinationID `json:"connections"`
}

type FetchTableInfo struct {
SourceID string `json:"source_id"`
DestinationID string `json:"destination_id"`
Namespace string `json:"namespace"`
Tables []string `json:"tables"`
}

type FetchTablesResponse struct {
ConnectionsTables []FetchTableInfo `json:"connections_tables"`
}

func TimingFromJSONString(str sql.NullString) (status string, recordedTime time.Time) {
timingsMap := gjson.Parse(str.String).Map()
for s, t := range timingsMap {
Expand Down
42 changes: 42 additions & 0 deletions warehouse/warehouse.go
Expand Up @@ -1462,6 +1462,45 @@ func databricksVersionHandler(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte(deltalake.GetDatabricksVersion()))
}

func fetchTablesHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
defer func() { _ = r.Body.Close() }()

body, err := io.ReadAll(r.Body)
if err != nil {
pkgLogger.Errorf("[WH]: Error reading body: %v", err)
http.Error(w, "can't read body", http.StatusBadRequest)
return
}

var connectionsTableRequest warehouseutils.FetchTablesRequest
err = json.Unmarshal(body, &connectionsTableRequest)
if err != nil {
pkgLogger.Errorf("[WH]: Error unmarshalling body: %v", err)
http.Error(w, "can't unmarshall body", http.StatusBadRequest)
return
}

schemaRepo := repo.NewWHSchemas(dbHandle)
tables, err := schemaRepo.GetTablesForConnection(ctx, connectionsTableRequest.Connections)
if err != nil {
pkgLogger.Errorf("[WH]: Error fetching tables: %v", err)
http.Error(w, "can't fetch tables from schemas repo", http.StatusInternalServerError)
return
}
resBody, err := json.Marshal(warehouseutils.FetchTablesResponse{
ConnectionsTables: tables,
})
if err != nil {
err := fmt.Errorf("failed to marshall tables to response : %v", err)
pkgLogger.Errorf("[WH]: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

_, _ = w.Write(resBody)
}

func isUploadTriggered(wh model.Warehouse) bool {
triggerUploadsMapLock.RLock()
defer triggerUploadsMapLock.RUnlock()
Expand Down Expand Up @@ -1546,6 +1585,9 @@ func startWebHandler(ctx context.Context) error {
srvMux.HandleFunc("/v1/warehouse/jobs", asyncWh.AddWarehouseJobHandler).Methods("POST") // FIXME: add degraded mode
srvMux.HandleFunc("/v1/warehouse/jobs/status", asyncWh.StatusWarehouseJobHandler).Methods("GET") // FIXME: add degraded mode

// fetch schema info
srvMux.HandleFunc("/v1/warehouse/fetch-tables", fetchTablesHandler).Methods("GET")

pkgLogger.Infof("WH: Starting warehouse master service in %d", webPort)
} else {
pkgLogger.Infof("WH: Starting warehouse slave service in %d", webPort)
Expand Down

0 comments on commit ea7d5ce

Please sign in to comment.