Skip to content

Commit

Permalink
Copy data from one destination to another (#146)
Browse files Browse the repository at this point in the history
* start of copy implementation

* worker that deadlocks

* producer and consumer

* update

* working copy

* hack

* update to copier

* remove local data

* progress on copy

* working copy command

---------

Co-authored-by: poundifdef <jay@scratchdata.com>
  • Loading branch information
poundifdef and poundifdef committed Apr 1, 2024
1 parent 4956089 commit 3bf8a3c
Show file tree
Hide file tree
Showing 20 changed files with 422 additions and 79 deletions.
2 changes: 2 additions & 0 deletions config.yaml
Expand Up @@ -18,6 +18,8 @@ workers:
enabled: true
count: 1
data_directory: ./data/worker
max_bulk_query_size_bytes: 500000000
bulk_chunk_size_bytes: 50000000

blob_store:
type: memory
Expand Down
6 changes: 6 additions & 0 deletions pkg/api/auth.go
Expand Up @@ -47,6 +47,7 @@ func (a *ScratchDataAPIStruct) AuthMiddleware(next http.Handler) http.Handler {
}

ctx := context.WithValue(r.Context(), "databaseId", keyDetails.DestinationID)
ctx = context.WithValue(ctx, "teamId", keyDetails.Destination.TeamID)
next.ServeHTTP(w, r.WithContext(ctx))
}
})
Expand All @@ -57,6 +58,11 @@ func (a *ScratchDataAPIStruct) AuthGetDatabaseID(ctx context.Context) int64 {
return int64(dbId)
}

func (a *ScratchDataAPIStruct) AuthGetTeamID(ctx context.Context) uint {
dbId := ctx.Value("teamId").(uint)
return dbId
}

func (a *ScratchDataAPIStruct) Login(w http.ResponseWriter, r *http.Request) {
url := a.googleOauthConfig.AuthCodeURL(uuid.New().String())
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
Expand Down
32 changes: 32 additions & 0 deletions pkg/api/data.go
Expand Up @@ -7,9 +7,12 @@ import (
"strings"

"github.com/go-chi/chi/v5"
"github.com/go-chi/render"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/rs/zerolog/log"
"github.com/scratchdata/scratchdata/pkg/storage/database/models"
queue_models "github.com/scratchdata/scratchdata/pkg/storage/queue/models"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
Expand All @@ -26,6 +29,35 @@ var insertArraySize = promauto.NewHistogram(prometheus.HistogramOpts{
Buckets: prometheus.LinearBuckets(1, 50, 10),
})

func (a *ScratchDataAPIStruct) Copy(w http.ResponseWriter, r *http.Request) {
message := queue_models.CopyDataMessage{}

err := render.DecodeJSON(r.Body, &message)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

message.SourceID = a.AuthGetDatabaseID(r.Context())

// Make sure the destination db is the same team as the source
dest := a.storageServices.Database.GetDestination(r.Context(), message.DestinationID)
if dest.TeamID != a.AuthGetTeamID(r.Context()) {
http.Error(w, "invalid destination", http.StatusBadRequest)
return
}

// enqueue the copy job
msg, err := a.storageServices.Database.Enqueue(models.CopyData, message)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

render.JSON(w, r, render.M{"job_id": msg.ID})

}

func (a *ScratchDataAPIStruct) Select(w http.ResponseWriter, r *http.Request) {
databaseID := a.AuthGetDatabaseID(r.Context())

Expand Down
24 changes: 3 additions & 21 deletions pkg/api/router.go
@@ -1,14 +1,14 @@
package api

import (
"context"
"net/http"

"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"github.com/go-chi/jwtauth/v5"
"github.com/scratchdata/scratchdata/pkg/config"
"github.com/scratchdata/scratchdata/pkg/view"
"github.com/scratchdata/scratchdata/pkg/view/static"
"net/http"

"github.com/go-chi/chi/v5"
"github.com/prometheus/client_golang/prometheus"
Expand All @@ -27,25 +27,6 @@ var responseSize = promauto.NewHistogramVec(prometheus.HistogramOpts{
Buckets: prometheus.ExponentialBucketsRange(1000, 100_000_000, 20),
}, []string{"route"})

type ScratchDataAPI interface {
Healthcheck(w http.ResponseWriter, r *http.Request)

Select(w http.ResponseWriter, r *http.Request)
Insert(w http.ResponseWriter, r *http.Request)
Tables(w http.ResponseWriter, r *http.Request)
Columns(w http.ResponseWriter, r *http.Request)

CreateQuery(w http.ResponseWriter, r *http.Request)
ShareData(w http.ResponseWriter, r *http.Request)

AuthMiddleware(next http.Handler) http.Handler
AuthGetDatabaseID(context.Context) int64

GetDestinations(w http.ResponseWriter, r *http.Request)
CreateDestination(w http.ResponseWriter, r *http.Request)
AddAPIKey(w http.ResponseWriter, r *http.Request)
}

func CreateMux(apiFunctions *ScratchDataAPIStruct, c config.ScratchDataConfig) *chi.Mux {
r := chi.NewRouter()
r.Use(PrometheusMiddleware)
Expand All @@ -57,6 +38,7 @@ func CreateMux(apiFunctions *ScratchDataAPIStruct, c config.ScratchDataConfig) *
api.Post("/data/insert/{table}", apiFunctions.Insert)
api.Get("/data/query", apiFunctions.Select)
api.Post("/data/query", apiFunctions.Select)
api.Post("/data/copy", apiFunctions.Copy)
api.Get("/tables", apiFunctions.Tables)
api.Get("/tables/{table}/columns", apiFunctions.Columns)

Expand Down
5 changes: 4 additions & 1 deletion pkg/config/config.go
Expand Up @@ -19,10 +19,13 @@ type API struct {
}

type Workers struct {
Enabled bool `yaml:"enabled" env:"SCRATCH_WORKERS_ENABLED"`
Enabled bool `yaml:"enabled" env:"SCRATCH_WORKERS_ENABLED"`
Count int `yaml:"count"`
DataDirectory string `yaml:"data_directory"`
FreeSpaceRequiredBytes int64 `yaml:"free_space_required_bytes"`

MaxBulkQuerySizeBytes int `yaml:"max_bulk_query_size_bytes"`
BulkChunkSizeBytes int `yaml:"bulk_chunk_size_bytes"`
}

type Queue struct {
Expand Down
5 changes: 4 additions & 1 deletion pkg/destinations/bigquery/insert.go
Expand Up @@ -3,10 +3,12 @@ package bigquery
import (
"context"
"fmt"
"github.com/scratchdata/scratchdata/pkg/util"
"os"
"path/filepath"
"strings"
"time"

"github.com/scratchdata/scratchdata/pkg/util"

"cloud.google.com/go/bigquery"

Expand Down Expand Up @@ -71,6 +73,7 @@ func (s *BigQueryServer) createColumns(table string, jsonTypes map[string]string
log.Error().Err(err).Str("query", query).Msg("createColumns: cannot run query")
return err
}
time.Sleep(2 * time.Second)
}

return nil
Expand Down
5 changes: 5 additions & 0 deletions pkg/destinations/bigquery/query.go
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/csv"
"encoding/json"
"errors"
"fmt"
"io"

Expand All @@ -12,6 +13,10 @@ import (
"google.golang.org/api/iterator"
)

func (s *BigQueryServer) QueryNDJson(query string, writer io.Writer) error {
return errors.New("not implemented")
}

func (b *BigQueryServer) QueryJSON(query string, writer io.Writer) error {
// NOTE: Query should be with dataset as prefix. Example: SELECT * FROM `dataset.table`

Expand Down
16 changes: 15 additions & 1 deletion pkg/destinations/clickhouse/clickhouse.go
Expand Up @@ -4,12 +4,14 @@ import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/scratchdata/scratchdata/pkg/util"
"io"
"net/http"
"time"

"github.com/scratchdata/scratchdata/pkg/util"

"github.com/ClickHouse/clickhouse-go/v2"
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -103,9 +105,21 @@ func (s *ClickhouseServer) httpQuery(query string) (io.ReadCloser, error) {
resp, err := client.Do(req)
if err != nil {
log.Error().Err(err).Msg("request failed")
resp.Body.Close()
return nil, err
}

if resp.StatusCode != 200 {
errMsg, readErr := io.ReadAll(resp.Body)
resp.Body.Close()

if readErr != nil {
return nil, readErr
}

return nil, errors.New(string(errMsg))
}

return resp.Body, nil
}

Expand Down
17 changes: 16 additions & 1 deletion pkg/destinations/clickhouse/query.go
Expand Up @@ -2,10 +2,25 @@ package clickhouse

import (
"bufio"
"github.com/scratchdata/scratchdata/pkg/util"
"io"

"github.com/scratchdata/scratchdata/pkg/util"
)

func (s *ClickhouseServer) QueryNDJson(query string, writer io.Writer) error {
sanitized := util.TrimQuery(query)
sql := "SELECT * FROM (" + sanitized + ") FORMAT " + "JSONEachRow"

resp, err := s.httpQuery(sql)
if err != nil {
return err
}
defer resp.Close()

_, err = io.Copy(writer, resp)
return err
}

func (s *ClickhouseServer) QueryJSON(query string, writer io.Writer) error {
sanitized := util.TrimQuery(query)
sql := "SELECT * FROM (" + sanitized + ") FORMAT " + "JSONEachRow"
Expand Down
1 change: 1 addition & 0 deletions pkg/destinations/destinations.go
Expand Up @@ -24,6 +24,7 @@ type DestinationManager struct {
}

type Destination interface {
QueryNDJson(query string, writer io.Writer) error
QueryJSON(query string, writer io.Writer) error
QueryCSV(query string, writer io.Writer) error

Expand Down
1 change: 0 additions & 1 deletion pkg/destinations/duckdb/duckdb.go
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/scratchdata/scratchdata/pkg/util"

"github.com/marcboeker/go-duckdb"
_ "github.com/marcboeker/go-duckdb"
)

type DuckDBServer struct {
Expand Down
9 changes: 8 additions & 1 deletion pkg/destinations/duckdb/query.go
@@ -1,13 +1,14 @@
package duckdb

import (
"github.com/scratchdata/scratchdata/pkg/util"
"io"
"os"
"path/filepath"
"syscall"
"time"

"github.com/scratchdata/scratchdata/pkg/util"

"github.com/rs/zerolog/log"
)

Expand Down Expand Up @@ -50,6 +51,8 @@ func (s *DuckDBServer) QueryPipe(query string, format string, writer io.Writer)
switch format {
case "csv":
formatClause = "(FORMAT CSV)"
case "ndjson":
formatClause = "(FORMAT JSON, ARRAY FALSE)"
default:
formatClause = "(FORMAT JSON, ARRAY TRUE)"
}
Expand Down Expand Up @@ -199,6 +202,10 @@ func (s *DuckDBServer) QueryJSONString(query string, writer io.Writer) error {
return nil
}

func (s *DuckDBServer) QueryNDJson(query string, writer io.Writer) error {
return s.QueryPipe(query, "ndjson", writer)
}

func (s *DuckDBServer) QueryJSON(query string, writer io.Writer) error {
return s.QueryPipe(query, "json", writer)
// return s.QueryJSONString(query, writer)
Expand Down
5 changes: 5 additions & 0 deletions pkg/destinations/redshift/query.go
Expand Up @@ -3,12 +3,17 @@ package redshift
import (
"encoding/csv"
"encoding/json"
"errors"
"fmt"
"io"

"github.com/rs/zerolog/log"
)

func (s *RedshiftServer) QueryNDJson(query string, writer io.Writer) error {
return errors.New("not implemented")
}

func (s *RedshiftServer) QueryJSON(query string, writer io.Writer) error {
rows, err := s.conn.Query(query)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions pkg/storage/database/database.go
Expand Up @@ -16,6 +16,7 @@ type Database interface {
VerifyAdminAPIKey(ctx context.Context, hashedAPIKey string) bool

GetDestinations(ctx context.Context, teamId uint) []config.Destination
GetDestination(ctx context.Context, destId uint) models.Destination
CreateDestination(ctx context.Context, teamId uint, destType string, settings map[string]any) (config.Destination, error)
GetDestinationCredentials(ctx context.Context, dbID int64) (config.Destination, error)

Expand Down
9 changes: 9 additions & 0 deletions pkg/storage/database/gorm/gorm.go
Expand Up @@ -199,6 +199,15 @@ func (s *Gorm) GetDestinations(c context.Context, userId uint) []config.Destinat
return rc
}

func (db *Gorm) GetDestination(ctx context.Context, destId uint) models.Destination {
var dest models.Destination
tx := db.db.First(&dest, destId)
if tx.Error != nil {
log.Error().Err(tx.Error).Msg("Unable to get user")
}
return dest
}

func (s *Gorm) Hash(str string) string {
hash := sha256.Sum256([]byte(str))
return hex.EncodeToString(hash[:])
Expand Down
13 changes: 13 additions & 0 deletions pkg/storage/database/static/static.go
Expand Up @@ -52,6 +52,19 @@ func (db *StaticDatabase) GetDestinations(ctx context.Context, teamId uint) []co
return db.destinations
}

func (db *StaticDatabase) GetDestination(ctx context.Context, destId uint) models.Destination {
dest := db.destinations[destId]
settings, _ := json.Marshal(dest.Settings)
rc := models.Destination{
TeamID: 0,
Type: dest.Type,
Name: dest.Name,
Settings: string(settings),
}
rc.TeamID = 0
return rc
}

func (db *StaticDatabase) AddAPIKey(ctx context.Context, destId int64, key string) error {
return StaticDBError
}
Expand Down
7 changes: 7 additions & 0 deletions pkg/storage/queue/models/models.go
Expand Up @@ -5,3 +5,10 @@ type FileUploadMessage struct {
Table string `json:"table"`
Key string `json:"key"`
}

type CopyDataMessage struct {
SourceID int64 `json:"source_id"`
Query string `json:"query"`
DestinationID uint `json:"destination_id"`
DestinationTable string `json:"destination_table"`
}

0 comments on commit 3bf8a3c

Please sign in to comment.