diff --git a/runtime/drivers/duckdb/model_executor_self_file.go b/runtime/drivers/duckdb/model_executor_self_file.go index 69ea3982fa1..78fd2886459 100644 --- a/runtime/drivers/duckdb/model_executor_self_file.go +++ b/runtime/drivers/duckdb/model_executor_self_file.go @@ -72,22 +72,22 @@ func (e *selfToFileExecutor) Execute(ctx context.Context) (*drivers.ModelResult, }, nil } -func exportSQL(qry, path, format string) (string, error) { +func exportSQL(qry, path string, format drivers.FileFormat) (string, error) { switch format { - case "parquet": + case drivers.FileFormatParquet: return fmt.Sprintf("COPY (%s\n) TO '%s' (FORMAT PARQUET)", qry, path), nil - case "csv": + case drivers.FileFormatCSV: return fmt.Sprintf("COPY (%s\n) TO '%s' (FORMAT CSV, HEADER true)", qry, path), nil - case "json": + case drivers.FileFormatJSON: return fmt.Sprintf("COPY (%s\n) TO '%s' (FORMAT JSON)", qry, path), nil default: return "", fmt.Errorf("duckdb: unsupported export format %q", format) } } -func supportsExportFormat(format string) bool { +func supportsExportFormat(format drivers.FileFormat) bool { switch format { - case "parquet", "csv", "json": + case drivers.FileFormatParquet, drivers.FileFormatCSV, drivers.FileFormatJSON: return true default: return false diff --git a/runtime/drivers/file/model_executor.go b/runtime/drivers/file/model_executor.go index f607b3167e6..63fff050d58 100644 --- a/runtime/drivers/file/model_executor.go +++ b/runtime/drivers/file/model_executor.go @@ -1,10 +1,14 @@ package file -import "fmt" +import ( + "fmt" + + "github.com/rilldata/rill/runtime/drivers" +) type ModelOutputProperties struct { - Path string `mapstructure:"path"` - Format string `mapstructure:"format"` + Path string `mapstructure:"path"` + Format drivers.FileFormat `mapstructure:"format"` } func (p *ModelOutputProperties) Validate() error { @@ -13,11 +17,13 @@ func (p *ModelOutputProperties) Validate() error { } if p.Format == "" { return fmt.Errorf("missing property 'format'") + } else if !p.Format.Valid() { + return fmt.Errorf("invalid property 'format': %q", p.Format) } return nil } type ModelResultProperties struct { - Path string `mapstructure:"path"` - Format string `mapstructure:"format"` + Path string `mapstructure:"path"` + Format drivers.FileFormat `mapstructure:"format"` } diff --git a/runtime/drivers/file/model_executor_olap_self.go b/runtime/drivers/file/model_executor_olap_self.go index f9c551827c3..fe3aeceb0ee 100644 --- a/runtime/drivers/file/model_executor_olap_self.go +++ b/runtime/drivers/file/model_executor_olap_self.go @@ -62,12 +62,14 @@ func (e *olapToSelfExecutor) Execute(ctx context.Context) (*drivers.ModelResult, defer res.Close() switch outputProps.Format { - case "csv": + case drivers.FileFormatParquet: + err = writeParquet(res, outputProps.Path) + case drivers.FileFormatCSV: err = writeCSV(res, outputProps.Path) - case "xlsx": + case drivers.FileFormatJSON: + return nil, errors.New("json file output not currently supported") + case drivers.FileFormatXLSX: err = writeXLSX(res, outputProps.Path) - case "parquet": - err = writeParquet(res, outputProps.Path) default: return nil, fmt.Errorf("unsupported output format %q", outputProps.Format) } diff --git a/runtime/drivers/models.go b/runtime/drivers/models.go index 04432f6fc3e..c76384b2ee3 100644 --- a/runtime/drivers/models.go +++ b/runtime/drivers/models.go @@ -40,3 +40,25 @@ type ModelExecutorOptions struct { IncrementalRun bool PreviousResult *ModelResult } + +type FileFormat string + +const ( + FileFormatUnspecified FileFormat = "" + FileFormatParquet FileFormat = "parquet" + FileFormatCSV FileFormat = "csv" + FileFormatJSON FileFormat = "json" + FileFormatXLSX FileFormat = "xlsx" +) + +func (f FileFormat) Filename(stem string) string { + return stem + "." + string(f) +} + +func (f FileFormat) Valid() bool { + switch f { + case FileFormatParquet, FileFormatCSV, FileFormatJSON, FileFormatXLSX: + return true + } + return false +} diff --git a/runtime/metricsview/executor.go b/runtime/metricsview/executor.go index 0e682b49f8b..fd8d741fe14 100644 --- a/runtime/metricsview/executor.go +++ b/runtime/metricsview/executor.go @@ -261,7 +261,7 @@ func (e *Executor) Query(ctx context.Context, qry *Query, executionTime *time.Ti // Export executes and exports the provided query against the metrics view. // It returns a path to a temporary file containing the export. The caller is responsible for cleaning up the file. -func (e *Executor) Export(ctx context.Context, qry *Query, executionTime *time.Time, format string) (string, error) { +func (e *Executor) Export(ctx context.Context, qry *Query, executionTime *time.Time, format drivers.FileFormat) (string, error) { if e.security != nil && !e.security.Access { return "", runtime.ErrForbidden } diff --git a/runtime/metricsview/executor_export.go b/runtime/metricsview/executor_export.go index 8a9d6711259..5448ba16e68 100644 --- a/runtime/metricsview/executor_export.go +++ b/runtime/metricsview/executor_export.go @@ -15,7 +15,7 @@ import ( // executeExport works by simulating a model that outputs to a file. // This means it creates a ModelExecutor with the provided input connector and props as input, // and with the "file" driver as the output connector targeting a temporary output path. -func (e *Executor) executeExport(ctx context.Context, format, inputConnector string, inputProps map[string]any) (string, error) { +func (e *Executor) executeExport(ctx context.Context, format drivers.FileFormat, inputConnector string, inputProps map[string]any) (string, error) { ctx, cancel := context.WithTimeout(ctx, defaultExportTimeout) defer cancel() @@ -29,7 +29,7 @@ func (e *Executor) executeExport(ctx context.Context, format, inputConnector str if err != nil { return "", err } - name = fmt.Sprintf("%s.%s", name, format) + name = format.Filename(name) path = filepath.Join(path, name) ic, ir, err := e.rt.AcquireHandle(ctx, e.instanceID, inputConnector) diff --git a/runtime/metricsview/executor_pivot.go b/runtime/metricsview/executor_pivot.go index 8bd281dd14c..054c43160e3 100644 --- a/runtime/metricsview/executor_pivot.go +++ b/runtime/metricsview/executor_pivot.go @@ -115,7 +115,7 @@ func (e *Executor) rewriteQueryForPivot(qry *Query) (*pivotAST, bool, error) { } // executePivotExport executes a PIVOT query prepared using rewriteQueryForPivot, and exports the result to a file in the given format. -func (e *Executor) executePivotExport(ctx context.Context, ast *AST, pivot *pivotAST, format string) (string, error) { +func (e *Executor) executePivotExport(ctx context.Context, ast *AST, pivot *pivotAST, format drivers.FileFormat) (string, error) { ctx, cancel := context.WithTimeout(ctx, defaultPivotExportTimeout) defer cancel() diff --git a/runtime/queries/metricsview_aggregation.go b/runtime/queries/metricsview_aggregation.go index 2ef2f529cd5..a6937d703e3 100644 --- a/runtime/queries/metricsview_aggregation.go +++ b/runtime/queries/metricsview_aggregation.go @@ -10,6 +10,7 @@ import ( runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" "github.com/rilldata/rill/runtime" + "github.com/rilldata/rill/runtime/drivers" "github.com/rilldata/rill/runtime/metricsview" ) @@ -33,7 +34,7 @@ type MetricsViewAggregation struct { Exact bool `json:"exact,omitempty"` Result *runtimev1.MetricsViewAggregationResponse `json:"-"` - Exporting bool `json:"-"` + Exporting bool `json:"-"` // Deprecated: Remove when tests call Export directly } var _ runtime.Query = &MetricsViewAggregation{} @@ -127,14 +128,14 @@ func (q *MetricsViewAggregation) Export(ctx context.Context, rt *runtime.Runtime } defer e.Close() - var format string + var format drivers.FileFormat switch opts.Format { case runtimev1.ExportFormat_EXPORT_FORMAT_CSV: - format = "csv" + format = drivers.FileFormatCSV case runtimev1.ExportFormat_EXPORT_FORMAT_XLSX: - format = "xlsx" + format = drivers.FileFormatXLSX case runtimev1.ExportFormat_EXPORT_FORMAT_PARQUET: - format = "parquet" + format = drivers.FileFormatParquet default: return fmt.Errorf("unsupported format: %s", opts.Format.String()) } @@ -263,7 +264,7 @@ func (q *MetricsViewAggregation) rewriteToMetricsViewQuery(export bool) (*metric qry.ComparisonTimeRange = res } - if q.Filter != nil { // backwards backwards compatibility + if q.Filter != nil { // Backwards compatibility if q.Where != nil { return nil, fmt.Errorf("both filter and where is provided") } diff --git a/runtime/queries/metricsview_comparison_toplist.go b/runtime/queries/metricsview_comparison_toplist.go index 9549d888fec..3dbe8aaf47c 100644 --- a/runtime/queries/metricsview_comparison_toplist.go +++ b/runtime/queries/metricsview_comparison_toplist.go @@ -10,6 +10,7 @@ import ( runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" "github.com/rilldata/rill/runtime" + "github.com/rilldata/rill/runtime/drivers" "github.com/rilldata/rill/runtime/metricsview" "github.com/rilldata/rill/runtime/pkg/pbutil" @@ -89,7 +90,6 @@ func (q *MetricsViewComparison) Resolve(ctx context.Context, rt *runtime.Runtime return err } - // Attempt to route to metricsview executor qry, err := q.rewriteToMetricsViewQuery(false) if err != nil { return fmt.Errorf("error rewriting to metrics query: %w", err) @@ -181,7 +181,6 @@ func (q *MetricsViewComparison) Export(ctx context.Context, rt *runtime.Runtime, return err } - // Attempt to route to metricsview executor qry, err := q.rewriteToMetricsViewQuery(true) if err != nil { return fmt.Errorf("error rewriting to metrics query: %w", err) @@ -193,14 +192,14 @@ func (q *MetricsViewComparison) Export(ctx context.Context, rt *runtime.Runtime, } defer e.Close() - var format string + var format drivers.FileFormat switch opts.Format { case runtimev1.ExportFormat_EXPORT_FORMAT_CSV: - format = "csv" + format = drivers.FileFormatCSV case runtimev1.ExportFormat_EXPORT_FORMAT_XLSX: - format = "xlsx" + format = drivers.FileFormatXLSX case runtimev1.ExportFormat_EXPORT_FORMAT_PARQUET: - format = "parquet" + format = drivers.FileFormatParquet default: return fmt.Errorf("unsupported format: %s", opts.Format.String()) } @@ -354,7 +353,7 @@ func (q *MetricsViewComparison) rewriteToMetricsViewQuery(export bool) (*metrics }) } - if q.Filter != nil { // backwards backwards compatibility + if q.Filter != nil { // Backwards compatibility if q.Where != nil { return nil, fmt.Errorf("both filter and where is provided") } diff --git a/runtime/queries/metricsview_timeseries.go b/runtime/queries/metricsview_timeseries.go index 443460ab15f..ee1cef6e9be 100644 --- a/runtime/queries/metricsview_timeseries.go +++ b/runtime/queries/metricsview_timeseries.go @@ -326,7 +326,7 @@ func (q *MetricsViewTimeSeries) rewriteToMetricsViewQuery(timeDimension string) }) } - if q.Filter != nil { // backwards backwards compatibility + if q.Filter != nil { // Backwards compatibility if q.Where != nil { return nil, fmt.Errorf("both filter and where is provided") } diff --git a/runtime/queries/metricsview_toplist.go b/runtime/queries/metricsview_toplist.go index b3649251b06..56e6ac62dfc 100644 --- a/runtime/queries/metricsview_toplist.go +++ b/runtime/queries/metricsview_toplist.go @@ -10,6 +10,7 @@ import ( runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" "github.com/rilldata/rill/runtime" + "github.com/rilldata/rill/runtime/drivers" "github.com/rilldata/rill/runtime/metricsview" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -128,14 +129,14 @@ func (q *MetricsViewToplist) Export(ctx context.Context, rt *runtime.Runtime, in } defer e.Close() - var format string + var format drivers.FileFormat switch opts.Format { case runtimev1.ExportFormat_EXPORT_FORMAT_CSV: - format = "csv" + format = drivers.FileFormatCSV case runtimev1.ExportFormat_EXPORT_FORMAT_XLSX: - format = "xlsx" + format = drivers.FileFormatXLSX case runtimev1.ExportFormat_EXPORT_FORMAT_PARQUET: - format = "parquet" + format = drivers.FileFormatParquet default: return fmt.Errorf("unsupported format: %s", opts.Format.String()) } @@ -201,7 +202,7 @@ func (q *MetricsViewToplist) rewriteToMetricsViewQuery(export bool) (*metricsvie }) } - if q.Filter != nil { // backwards backwards compatibility + if q.Filter != nil { // Backwards compatibility if q.Where != nil { return nil, fmt.Errorf("both filter and where is provided") } diff --git a/runtime/queries/metricsview_totals.go b/runtime/queries/metricsview_totals.go index 8f8536ac8f5..a96fd34bc3f 100644 --- a/runtime/queries/metricsview_totals.go +++ b/runtime/queries/metricsview_totals.go @@ -124,7 +124,7 @@ func (q *MetricsViewTotals) rewriteToMetricsViewQuery(exporting bool) (*metricsv qry.TimeRange = res } - if q.Filter != nil { // backwards backwards compatibility + if q.Filter != nil { // Backwards compatibility if q.Where != nil { return nil, fmt.Errorf("both filter and where is provided") }