Skip to content

Commit

Permalink
Merge pull request #69 from pulumi/eronwright/issue-65
Browse files Browse the repository at this point in the history
Launch providers into workspace directory
  • Loading branch information
EronWright committed Feb 26, 2024
2 parents 06d8947 + dd332f3 commit a0511ba
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 16 deletions.
4 changes: 2 additions & 2 deletions providers/attach.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

// StartProviders starts each of the given providers and returns a map of provider names to the ports they are listening on.
// The context should be cancelled when the test is complete to shut down the providers.
func StartProviders(ctx context.Context, factories map[ProviderName]ProviderFactory) (map[ProviderName]Port, error) {
func StartProviders(ctx context.Context, factories map[ProviderName]ProviderFactory, pt PulumiTest) (map[ProviderName]Port, error) {
if len(factories) == 0 {
return nil, nil
}
Expand All @@ -24,7 +24,7 @@ func StartProviders(ctx context.Context, factories map[ProviderName]ProviderFact
portMappings := map[ProviderName]Port{}
for _, providerName := range providerNames {
factory := factories[providerName]
port, err := factory(ctx)
port, err := factory(ctx, pt)
if err != nil {
return nil, fmt.Errorf("failed to start provider %s: %v", providerName, err)
}
Expand Down
4 changes: 2 additions & 2 deletions providers/downloadPluginBinary.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ import (
)

func DownloadPluginBinaryFactory(name, version string) ProviderFactory {
factory := func(ctx context.Context) (Port, error) {
factory := func(ctx context.Context, pt PulumiTest) (Port, error) {
binaryPath, err := DownloadPluginBinary(name, version)
if err != nil {
return 0, err
}
return startLocalBinary(ctx, binaryPath, name)
return startLocalBinary(ctx, binaryPath, name, pt.Source())
}
return factory
}
Expand Down
8 changes: 7 additions & 1 deletion providers/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@ type ProviderName string
// Port is the port that a provider is listening on.
type Port int

// PulumiTest provides context about the program under test.
type PulumiTest interface {
// Source returns the current source directory.
Source() string
}

// ProviderFactory is a function that starts a provider and returns the port it is listening on.
// The function should return an error if the provider fails to start.
// When the test is complete, the context will be cancelled and the provider should exit.
type ProviderFactory func(ctx context.Context) (Port, error)
type ProviderFactory func(ctx context.Context, pt PulumiTest) (Port, error)
7 changes: 4 additions & 3 deletions providers/localBinary.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ import (
)

func LocalBinary(name, path string) ProviderFactory {
factory := func(ctx context.Context) (Port, error) {
return startLocalBinary(ctx, path, name)
factory := func(ctx context.Context, pt PulumiTest) (Port, error) {
return startLocalBinary(ctx, path, name, pt.Source())
}
return factory
}

func startLocalBinary(ctx context.Context, path string, name string) (Port, error) {
func startLocalBinary(ctx context.Context, path, name, cwd string) (Port, error) {
stat, err := os.Stat(path)
if err != nil {
return 0, err
Expand All @@ -28,6 +28,7 @@ func startLocalBinary(ctx context.Context, path string, name string) (Port, erro
path = filepath.Join(path, binaryName)
}
cmd := exec.CommandContext(ctx, path)
cmd.Dir = cwd
reader, err := cmd.StdoutPipe()
cmd.Stderr = os.Stderr
if err != nil {
Expand Down
11 changes: 10 additions & 1 deletion providers/localBinary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,20 @@ import (
"github.com/stretchr/testify/assert"
)

type mockPulumiTest struct {
source string
}

func (m *mockPulumiTest) Source() string {
return m.source
}

func TestLocalBinaryAttach(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pt := &mockPulumiTest{source: t.TempDir()}
factory := providers.DownloadPluginBinaryFactory("azure-native", "2.25.0")
port, err := factory(ctx)
port, err := factory(ctx, pt)
assert.NoError(t, err)
assert.NotZero(t, port)
}
4 changes: 2 additions & 2 deletions providers/providerInterceptProxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ type ProviderInterceptors struct {

// ProviderInterceptFactory creates a new provider factory that can be used to intercept calls to a downstream provider.
func ProviderInterceptFactory(ctx context.Context, factory ProviderFactory, interceptors ProviderInterceptors) ProviderFactory {
return ResourceProviderFactory(func() (rpc.ResourceProviderServer, error) {
port, err := factory(ctx)
return ResourceProviderFactory(func(pt PulumiTest) (rpc.ResourceProviderServer, error) {
port, err := factory(ctx, pt)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion providers/providerMock.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type ProviderMocks struct {

// ProviderInterceptFactory creates a new provider factory that can be used to intercept calls to a downstream provider.
func ProviderMockFactory(mocks ProviderMocks) ProviderFactory {
return ResourceProviderFactory(func() (rpc.ResourceProviderServer, error) {
return ResourceProviderFactory(func(_ PulumiTest) (rpc.ResourceProviderServer, error) {
return NewProviderMock(mocks)
})
}
Expand Down
6 changes: 3 additions & 3 deletions providers/resourceProvider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ import (
"google.golang.org/grpc"
)

type ResourceProviderServerFactory func() (pulumirpc.ResourceProviderServer, error)
type ResourceProviderServerFactory func(PulumiTest) (pulumirpc.ResourceProviderServer, error)

// startProvider starts the provider in a goProc and returns the port it's listening on.
// To shut down the provider, cancel the context.
func ResourceProviderFactory(makeResourceProviderServer ResourceProviderServerFactory) ProviderFactory {
return func(ctx context.Context) (Port, error) {
return func(ctx context.Context, pt PulumiTest) (Port, error) {
cancelChannel := make(chan bool)
go func() {
<-ctx.Done()
Expand All @@ -24,7 +24,7 @@ func ResourceProviderFactory(makeResourceProviderServer ResourceProviderServerFa
handle, err := rpcutil.ServeWithOptions(rpcutil.ServeOptions{
Cancel: cancelChannel,
Init: func(srv *grpc.Server) error {
prov, proverr := makeResourceProviderServer()
prov, proverr := makeResourceProviderServer(pt)
if proverr != nil {
return fmt.Errorf("failed to create resource provider server: %v", proverr)
}
Expand Down
2 changes: 1 addition & 1 deletion pulumitest/newStack.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (pt *PulumiTest) NewStack(stackName string, opts ...optnewstack.NewStackOpt
if len(providerFactories) > 0 {
pt.t.Log("starting providers")
providerContext, cancelProviders := context.WithCancel(pt.ctx)
providerPorts, err := providers.StartProviders(providerContext, providerFactories)
providerPorts, err := providers.StartProviders(providerContext, providerFactories, pt)
if err != nil {
cancelProviders()
pt.t.Fatalf("failed to start providers: %v", err)
Expand Down

0 comments on commit a0511ba

Please sign in to comment.