Skip to content

Commit

Permalink
validate flow name and node name
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyang0 committed Apr 15, 2024
1 parent 56a3540 commit 089f9b0
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 7 deletions.
4 changes: 2 additions & 2 deletions flow/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,11 @@ func (e *Executor) getExecResult(execID string) (eRes *ExecResult, err error) {
}

func newExecID(flowName, nodeName, randomID string) string {
return fmt.Sprintf("%s:%s:%s", flowName, nodeName, randomID)
return fmt.Sprintf("%s%s%s%s%s", flowName, nameSep, nodeName, nameSep, randomID)
}

func parseExecID(execID string) (flowName string, nodeName string, sessID string, err error) {
parts := strings.Split(execID, ":")
parts := strings.Split(execID, nameSep)
if len(parts) != 3 {
err = errors.Newf("failed to parse execution id")
return
Expand Down
24 changes: 22 additions & 2 deletions flow/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"log/slog"
"strings"

"github.com/cockroachdb/errors"
"github.com/heimdalr/dag"
Expand All @@ -12,6 +13,10 @@ import (
"github.com/yuyang0/dagflow/types"
)

const (
nameSep = ":"
)

type NodeFunc func([]byte, map[string][]string) ([]byte, error)
type SwitchCondFunc func([]byte) string
type Definitor func(ctx context.Context, f *Flow) error
Expand All @@ -29,7 +34,10 @@ type Flow struct {
func New(
name string, stor store.Store, cli *asynq.Client,
logger *slog.Logger, cfg *types.Config, insp *asynq.Inspector,
) *Flow {
) (*Flow, error) {
if strings.Contains(name, nameSep) {
return nil, errors.Newf("flow name can't contain %s", nameSep)
}
return &Flow{
Name: name,
DAG: dag.NewDAG(),
Expand All @@ -38,7 +46,7 @@ func New(
logger: logger,
cfg: cfg,
insp: insp,
}
}, nil
}

type flowNode struct {
Expand All @@ -52,6 +60,9 @@ type flowNode struct {
}

func (f *Flow) Node(name string, fn NodeFunc, opts ...Option) error {
if strings.Contains(name, nameSep) {
return errors.Newf("dag node name can't contain %s", nameSep)
}
execOpts := &ExecutionOptions{}
for _, opt := range opts {
opt(execOpts)
Expand All @@ -67,6 +78,9 @@ func (f *Flow) SwitchNode(
name string, condFn SwitchCondFunc,
cases map[string]NodeFunc, opts ...Option,
) error {
if strings.Contains(name, nameSep) {
return errors.Newf("dag node name can't contain %s", nameSep)
}
execOpts := &ExecutionOptions{}
for _, opt := range opts {
opt(execOpts)
Expand All @@ -80,6 +94,12 @@ func (f *Flow) SwitchNode(
}

func (f *Flow) Edge(src, dst string) error {
if strings.Contains(src, nameSep) {
return errors.Newf("dag src node name can't contain %s", nameSep)
}
if strings.Contains(dst, nameSep) {
return errors.Newf("dag dst node name can't contain %s", nameSep)
}
return f.DAG.AddEdge(src, dst)
}

Expand Down
4 changes: 3 additions & 1 deletion flow/flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestDAG(t *testing.T) {
// TODO
f := New("flow1", nil, nil, nil, nil, nil)
f, err := New("flow1", nil, nil, nil, nil, nil)
require.NoError(t, err)
f.Node("l1n1", nil)
f.Node("l2n1", nil)
f.Node("l2n2", nil)
Expand Down
3 changes: 1 addition & 2 deletions service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ func New(cfg *types.Config, logger *slog.Logger) (*Service, error) {
}

func (svc *Service) NewFlow(flowName string) (*flow.Flow, error) {
flow := flow.New(flowName, svc.stor, svc.cli, svc.logger, svc.cfg, svc.insp)
return flow, nil
return flow.New(flowName, svc.stor, svc.cli, svc.logger, svc.cfg, svc.insp)
}

// submit a flow task
Expand Down

0 comments on commit 089f9b0

Please sign in to comment.