diff --git a/server/internal/app/config.go b/server/internal/app/config.go index 9cf5588fa..911ecb6bc 100644 --- a/server/internal/app/config.go +++ b/server/internal/app/config.go @@ -321,6 +321,7 @@ func addHTTPScheme(host string) string { type MarketplaceConfig struct { Endpoint string + Secret string OAuth *OAuthClientCredentialsConfig } diff --git a/server/internal/app/repo.go b/server/internal/app/repo.go index 6070cb02c..e928a73bd 100644 --- a/server/internal/app/repo.go +++ b/server/internal/app/repo.go @@ -72,7 +72,7 @@ func initReposAndGateways(ctx context.Context, conf *Config, debug bool) (*repo. // Marketplace if conf.Marketplace.Endpoint != "" { - gateways.PluginRegistry = marketplace.New(conf.Marketplace.Endpoint, conf.Marketplace.OAuth.Config()) + gateways.PluginRegistry = marketplace.New(conf.Marketplace.Endpoint, conf.Marketplace.Secret, conf.Marketplace.OAuth.Config()) } // release lock of all scenes diff --git a/server/internal/infrastructure/marketplace/marketplace.go b/server/internal/infrastructure/marketplace/marketplace.go index 9577cd26b..3278264ae 100644 --- a/server/internal/infrastructure/marketplace/marketplace.go +++ b/server/internal/infrastructure/marketplace/marketplace.go @@ -13,110 +13,103 @@ import ( "golang.org/x/oauth2/clientcredentials" ) -var pluginPackageSizeLimit int64 = 10 * 1024 * 1024 // 10MB +const ( + secretHeader string = "X-Reearth-Secret" + pluginPackageSizeLimit int64 = 10 * 1024 * 1024 // 10MB +) type Marketplace struct { endpoint string + secret string conf *clientcredentials.Config } -func New(endpoint string, conf *clientcredentials.Config) *Marketplace { +func New(endpoint, secret string, conf *clientcredentials.Config) *Marketplace { return &Marketplace{ endpoint: strings.TrimSuffix(endpoint, "/"), + secret: secret, conf: conf, } } func (m *Marketplace) FetchPluginPackage(ctx context.Context, pid id.PluginID) (*pluginpack.Package, error) { - purl, err := m.getPluginURL(pid) + url, err := m.getPluginURL(pid) if err != nil { return nil, err } - return m.downloadPluginPackage(ctx, purl) -} -func (m *Marketplace) getPluginURL(pid id.PluginID) (string, error) { - return strings.TrimSpace(fmt.Sprintf("%s/api/plugins/%s/%s.zip", m.endpoint, pid.Name(), pid.Version().String())), nil -} + log.Infof("marketplace: downloading plugin package from \"%s\"", url) -/* -func (m *Marketplace) getPluginURL(ctx context.Context, pid id.PluginID) (string, error) { - body := strings.NewReader(fmt.Sprintf( - `{"query":"query { node(id:"%s" type:PLUGIN) { ...Plugin { url } } }"}`, - pid.Name(), - )) - req, err := http.NewRequestWithContext(ctx, "POST", m.endpoint+"/graphql", body) + req, err := http.NewRequest("GET", url, nil) if err != nil { - return "", rerror.ErrInternalBy(err) + return nil, rerror.ErrInternalBy(err) } - req.Header.Set("Content-Type", "application/json") - - res, err := m.client.Do(req) - if err != nil { - return "", rerror.ErrInternalBy(err) + if m.secret != "" { + req.Header.Set(secretHeader, m.secret) } - if res.StatusCode != http.StatusOK { - return "", rerror.ErrNotFound + res, err := m.client(ctx).Do(req) + if err != nil { + return nil, rerror.ErrInternalBy(err) } + defer func() { _ = res.Body.Close() }() - var pluginRes response - if err := json.NewDecoder(res.Body).Decode(&pluginRes); err != nil { - return "", rerror.ErrInternalBy(err) - } - if pluginRes.Errors != nil { - return "", rerror.ErrInternalBy(fmt.Errorf("gql returns errors: %v", pluginRes.Errors)) - } - purl := pluginRes.PluginURL() - if purl == "" { - return "", rerror.ErrNotFound + if res.StatusCode == http.StatusNotFound { + return nil, rerror.ErrNotFound } - return purl, nil -} -type response struct { - Data pluginNodeQueryData `json:"data"` - Errors any `json:"errors"` -} + if res.StatusCode != http.StatusOK { + return nil, rerror.ErrInternalBy(fmt.Errorf("status code is %d", res.StatusCode)) + } -func (r response) PluginURL() string { - return r.Data.Node.URL + return pluginpack.PackageFromZip(res.Body, nil, pluginPackageSizeLimit) } -type pluginNodeQueryData struct { - Node plugin -} +func (m *Marketplace) NotifyDownload(ctx context.Context, pid id.PluginID) error { + url, err := m.getPluginURL(pid) + if err != nil { + return err + } + url = url + "/download" -type plugin struct { - URL string `json:"url"` -} -*/ + log.Infof("marketplace: notify donwload to \"%s\"", url) -func (m *Marketplace) downloadPluginPackage(ctx context.Context, url string) (*pluginpack.Package, error) { - var client *http.Client - if m.conf != nil && m.conf.ClientID != "" && m.conf.ClientSecret != "" && m.conf.TokenURL != "" { - client = m.conf.Client(ctx) + req, err := http.NewRequest("POST", url, nil) + if err != nil { + return rerror.ErrInternalBy(err) } - if client == nil { - client = http.DefaultClient + if m.secret != "" { + req.Header.Set(secretHeader, m.secret) } - log.Infof("marketplace: downloading plugin package from \"%s\"", url) - - res, err := client.Get(url) + res, err := m.client(ctx).Do(req) if err != nil { - return nil, rerror.ErrInternalBy(err) + return rerror.ErrInternalBy(err) } + defer func() { _ = res.Body.Close() }() - if res.StatusCode == http.StatusNotFound { - return nil, rerror.ErrNotFound + + if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusNotFound { + return rerror.ErrInternalBy(fmt.Errorf("status code is %d", res.StatusCode)) } - if res.StatusCode != http.StatusOK { - return nil, rerror.ErrInternalBy(fmt.Errorf("status code is %d", res.StatusCode)) + return nil +} + +func (m *Marketplace) getPluginURL(pid id.PluginID) (string, error) { + return strings.TrimSpace(fmt.Sprintf("%s/api/plugins/%s/%s", m.endpoint, pid.Name(), pid.Version().String())), nil +} + +func (m *Marketplace) client(ctx context.Context) (client *http.Client) { + if m.conf != nil && m.conf.ClientID != "" && m.conf.ClientSecret != "" && m.conf.TokenURL != "" { + client = m.conf.Client(ctx) } - return pluginpack.PackageFromZip(res.Body, nil, pluginPackageSizeLimit) + if client == nil { + client = http.DefaultClient + } + + return } diff --git a/server/internal/infrastructure/marketplace/marketplace_test.go b/server/internal/infrastructure/marketplace/marketplace_test.go index 21564ad8d..699fb84d7 100644 --- a/server/internal/infrastructure/marketplace/marketplace_test.go +++ b/server/internal/infrastructure/marketplace/marketplace_test.go @@ -9,13 +9,17 @@ import ( "testing" "github.com/jarcoal/httpmock" + "github.com/reearth/reearth/server/internal/usecase/gateway" "github.com/reearth/reearth/server/pkg/id" "github.com/stretchr/testify/assert" "golang.org/x/oauth2/clientcredentials" ) +var _ gateway.PluginRegistry = (*Marketplace)(nil) + func TestMarketplace_FetchPluginPackage(t *testing.T) { ac := "xxxxx" + secret := "secret" pid := id.MustPluginID("testplugin~1.0.1") f, err := os.Open("testdata/test.zip") @@ -85,16 +89,19 @@ func TestMarketplace_FetchPluginPackage(t *testing.T) { httpmock.RegisterResponder( "GET", - "https://marketplace.example.com/api/plugins/testplugin/1.0.1.zip", + "https://marketplace.example.com/api/plugins/testplugin/1.0.1", func(req *http.Request) (*http.Response, error) { if req.Header.Get("Authorization") != "Bearer "+ac { return httpmock.NewStringResponse(http.StatusUnauthorized, ""), nil } + if req.Header.Get(secretHeader) != secret { + return httpmock.NewStringResponse(http.StatusUnauthorized, ""), nil + } return httpmock.NewBytesResponse(http.StatusOK, z), nil }, ) - m := New("https://marketplace.example.com/", &clientcredentials.Config{ + m := New("https://marketplace.example.com/", secret, &clientcredentials.Config{ ClientID: "x", ClientSecret: "y", TokenURL: "https://marketplace.example.com/oauth/token", @@ -123,26 +130,87 @@ func TestMarketplace_FetchPluginPackage_NoAuth(t *testing.T) { defer httpmock.Deactivate() httpmock.RegisterResponder( - "GET", "https://marketplace.example.com/api/plugins/testplugin/1.0.1.zip", + "GET", "https://marketplace.example.com/api/plugins/testplugin/1.0.1", func(req *http.Request) (*http.Response, error) { return httpmock.NewBytesResponse(http.StatusOK, z), nil }, ) - m := New("https://marketplace.example.com/", nil) + m := New("https://marketplace.example.com/", "", nil) got, err := m.FetchPluginPackage(context.Background(), pid) assert.NoError(t, err) // no need to test pluginpack in detail here assert.Equal(t, id.MustPluginID("testplugin~1.0.1"), got.Manifest.Plugin.ID()) } +func TestMarketplace_NotifyDownload(t *testing.T) { + ac := "xxxxx" + pid := id.MustPluginID("testplugin~1.0.1") + + httpmock.Activate() + defer httpmock.Deactivate() + + httpmock.RegisterResponder( + "POST", "https://marketplace.example.com/oauth/token", + func(req *http.Request) (*http.Response, error) { + _ = req.ParseForm() + if req.Form.Get("grant_type") != "client_credentials" { + return httpmock.NewStringResponse(http.StatusUnauthorized, ""), nil + } + if req.Form.Get("audience") != "d" { + return httpmock.NewStringResponse(http.StatusUnauthorized, ""), nil + } + if req.Form.Get("client_id") != "x" { + return httpmock.NewStringResponse(http.StatusUnauthorized, ""), nil + } + if req.Form.Get("client_secret") != "y" { + return httpmock.NewStringResponse(http.StatusUnauthorized, ""), nil + } + + resp, err := httpmock.NewJsonResponse(200, map[string]any{ + "access_token": ac, + "token_type": "Bearer", + "expires_in": 86400, + }) + if err != nil { + return httpmock.NewStringResponse(http.StatusInternalServerError, ""), nil + } + return resp, nil + }, + ) + + called := false + httpmock.RegisterResponder( + "POST", + "https://marketplace.example.com/api/plugins/testplugin/1.0.1/download", + func(req *http.Request) (*http.Response, error) { + if req.Header.Get("Authorization") != "Bearer "+ac { + return httpmock.NewStringResponse(http.StatusUnauthorized, ""), nil + } + called = true + return httpmock.NewBytesResponse(http.StatusOK, nil), nil + }, + ) + + m := New("https://marketplace.example.com/", "", &clientcredentials.Config{ + ClientID: "x", + ClientSecret: "y", + TokenURL: "https://marketplace.example.com/oauth/token", + EndpointParams: url.Values{ + "audience": []string{"d"}, + }, + }) + assert.NoError(t, m.NotifyDownload(context.Background(), pid)) + assert.True(t, called) +} + func TestMarketplace_GetPluginURL(t *testing.T) { pid := id.MustPluginID("aaaabbbxxxbb~1.0.0") u, err := (&Marketplace{ endpoint: "https://xxxxx", }).getPluginURL(pid) assert.NoError(t, err) - assert.Equal(t, "https://xxxxx/api/plugins/aaaabbbxxxbb/1.0.0.zip", u) + assert.Equal(t, "https://xxxxx/api/plugins/aaaabbbxxxbb/1.0.0", u) _, err = url.Parse(u) assert.NoError(t, err) } diff --git a/server/internal/usecase/gateway/plugin_registry.go b/server/internal/usecase/gateway/plugin_registry.go index e9f9aed28..976575331 100644 --- a/server/internal/usecase/gateway/plugin_registry.go +++ b/server/internal/usecase/gateway/plugin_registry.go @@ -12,4 +12,5 @@ var ErrFailedToFetchDataFromPluginRegistry = errors.New("failed to fetch data fr type PluginRegistry interface { FetchPluginPackage(context.Context, id.PluginID) (*pluginpack.Package, error) + NotifyDownload(context.Context, id.PluginID) error } diff --git a/server/internal/usecase/interactor/plugin_common.go b/server/internal/usecase/interactor/plugin_common.go index b3f3e2065..154b645a7 100644 --- a/server/internal/usecase/interactor/plugin_common.go +++ b/server/internal/usecase/interactor/plugin_common.go @@ -56,6 +56,12 @@ func (i *pluginCommon) GetOrDownloadPlugin(ctx context.Context, pid id.PluginID) if plugin, err := i.pluginRepo.FindByID(ctx, pid); err != nil && !errors.Is(err, rerror.ErrNotFound) { return nil, err } else if plugin != nil { + if plugin.ID().Scene() == nil { + if err := i.pluginRegistry.NotifyDownload(ctx, plugin.ID()); err != nil { + return nil, err + } + } + return plugin, nil }