Skip to content

Commit

Permalink
fix(server): dataset limitation by policies, asset size calc
Browse files Browse the repository at this point in the history
  • Loading branch information
rot1024 committed Nov 21, 2022
1 parent 2f7fb95 commit e07b785
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 140 deletions.
36 changes: 34 additions & 2 deletions server/internal/adapter/gql/resolver_mutation_property.go
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/reearth/reearth/server/internal/usecase/interfaces"
"github.com/reearth/reearth/server/pkg/id"
"github.com/reearth/reearth/server/pkg/property"
"github.com/reearth/reearthx/rerror"
"github.com/reearth/reearthx/util"
"github.com/samber/lo"
)
Expand Down Expand Up @@ -74,14 +75,45 @@ func (r *mutationResolver) UploadFileToProperty(ctx context.Context, input gqlmo
return nil, err
}

p, pgl, pg, pf, err := usecases(ctx).Property.UploadFile(ctx, interfaces.UploadFileParam{
uc := usecases(ctx)
pr, err := uc.Property.Fetch(ctx, []id.PropertyID{pid}, getOperator(ctx))
if err != nil || len(pr) == 0 {
if err == nil {
err = rerror.ErrNotFound
}
return nil, err
}
ws, err := uc.Scene.Fetch(ctx, []id.SceneID{pr[0].Scene()}, getOperator(ctx))
if err != nil || len(ws) == 0 {
if err == nil {
err = rerror.ErrNotFound
}
return nil, err
}
prj, err := uc.Project.Fetch(ctx, []id.ProjectID{ws[0].Project()}, getOperator(ctx))
if err != nil || len(prj) == 0 {
if err == nil {
err = rerror.ErrNotFound
}
return nil, err
}

a, err := uc.Asset.Create(ctx, interfaces.CreateAssetParam{
WorkspaceID: prj[0].Workspace(),
File: gqlmodel.FromFile(&input.File),
}, getOperator(ctx))
if err != nil {
return nil, err
}

p, pgl, pg, pf, err := uc.Property.UpdateValue(ctx, interfaces.UpdatePropertyValueParam{
PropertyID: pid,
Pointer: gqlmodel.FromPointer(
gqlmodel.ToStringIDRef[id.PropertySchemaGroup](input.SchemaGroupID),
input.ItemID,
gqlmodel.ToStringIDRef[id.PropertyField](&input.FieldID),
),
File: gqlmodel.FromFile(&input.File),
Value: property.ValueTypeURL.ValueFrom(a.URL()),
}, getOperator(ctx))
if err != nil {
return nil, err
Expand Down
30 changes: 17 additions & 13 deletions server/internal/infrastructure/fs/file.go
Expand Up @@ -42,12 +42,13 @@ func (f *fileRepo) ReadAsset(ctx context.Context, filename string) (io.ReadClose
return f.read(ctx, filepath.Join(assetDir, sanitize.Path(filename)))
}

func (f *fileRepo) UploadAsset(ctx context.Context, file *file.File) (*url.URL, error) {
func (f *fileRepo) UploadAsset(ctx context.Context, file *file.File) (*url.URL, int64, error) {
filename := sanitize.Path(newAssetID() + path.Ext(file.Path))
if err := f.upload(ctx, filepath.Join(assetDir, filename), file.Content); err != nil {
return nil, err
size, err := f.upload(ctx, filepath.Join(assetDir, filename), file.Content)
if err != nil {
return nil, 0, err
}
return getAssetFileURL(f.urlBase, filename), nil
return getAssetFileURL(f.urlBase, filename), size, nil
}

func (f *fileRepo) RemoveAsset(ctx context.Context, u *url.URL) error {
Expand All @@ -68,7 +69,8 @@ func (f *fileRepo) ReadPluginFile(ctx context.Context, pid id.PluginID, filename
}

func (f *fileRepo) UploadPluginFile(ctx context.Context, pid id.PluginID, file *file.File) error {
return f.upload(ctx, filepath.Join(pluginDir, pid.String(), sanitize.Path(file.Path)), file.Content)
_, err := f.upload(ctx, filepath.Join(pluginDir, pid.String(), sanitize.Path(file.Path)), file.Content)
return err
}

func (f *fileRepo) RemovePlugin(ctx context.Context, pid id.PluginID) error {
Expand All @@ -82,7 +84,8 @@ func (f *fileRepo) ReadBuiltSceneFile(ctx context.Context, name string) (io.Read
}

func (f *fileRepo) UploadBuiltScene(ctx context.Context, reader io.Reader, name string) error {
return f.upload(ctx, filepath.Join(publishedDir, sanitize.Path(name+".json")), reader)
_, err := f.upload(ctx, filepath.Join(publishedDir, sanitize.Path(name+".json")), reader)
return err
}

func (f *fileRepo) MoveBuiltScene(ctx context.Context, oldName, name string) error {
Expand Down Expand Up @@ -114,30 +117,31 @@ func (f *fileRepo) read(ctx context.Context, filename string) (io.ReadCloser, er
return file, nil
}

func (f *fileRepo) upload(ctx context.Context, filename string, content io.Reader) error {
func (f *fileRepo) upload(ctx context.Context, filename string, content io.Reader) (int64, error) {
if filename == "" {
return gateway.ErrFailedToUploadFile
return 0, gateway.ErrFailedToUploadFile
}

if fnd := path.Dir(filename); fnd != "" {
if err := f.fs.MkdirAll(fnd, 0755); err != nil {
return rerror.ErrInternalBy(err)
return 0, rerror.ErrInternalBy(err)
}
}

dest, err := f.fs.Create(filename)
if err != nil {
return rerror.ErrInternalBy(err)
return 0, rerror.ErrInternalBy(err)
}
defer func() {
_ = dest.Close()
}()

if _, err := io.Copy(dest, content); err != nil {
return gateway.ErrFailedToUploadFile
size, err := io.Copy(dest, content)
if err != nil {
return 0, gateway.ErrFailedToUploadFile
}

return nil
return size, nil
}

func (f *fileRepo) move(ctx context.Context, from, dest string) error {
Expand Down
3 changes: 2 additions & 1 deletion server/internal/infrastructure/fs/file_test.go
Expand Up @@ -47,11 +47,12 @@ func TestFile_UploadAsset(t *testing.T) {
fs := mockFs()
f, _ := NewFile(fs, "https://example.com/assets")

u, err := f.UploadAsset(context.Background(), &file.File{
u, s, err := f.UploadAsset(context.Background(), &file.File{
Path: "aaa.txt",
Content: io.NopCloser(strings.NewReader("aaa")),
})
assert.NoError(t, err)
assert.Equal(t, int64(3), s)
assert.Equal(t, "https", u.Scheme)
assert.Equal(t, "example.com", u.Host)
assert.True(t, strings.HasPrefix(u.Path, "/assets/"))
Expand Down
40 changes: 22 additions & 18 deletions server/internal/infrastructure/gcs/file.go
Expand Up @@ -63,29 +63,30 @@ func (f *fileRepo) ReadAsset(ctx context.Context, name string) (io.ReadCloser, e
return f.read(ctx, path.Join(gcsAssetBasePath, sn))
}

func (f *fileRepo) UploadAsset(ctx context.Context, file *file.File) (*url.URL, error) {
func (f *fileRepo) UploadAsset(ctx context.Context, file *file.File) (*url.URL, int64, error) {
if file == nil {
return nil, gateway.ErrInvalidFile
return nil, 0, gateway.ErrInvalidFile
}
if file.Size >= fileSizeLimit {
return nil, gateway.ErrFileTooLarge
return nil, 0, gateway.ErrFileTooLarge
}

sn := sanitize.Path(newAssetID() + path.Ext(file.Path))
if sn == "" {
return nil, gateway.ErrInvalidFile
return nil, 0, gateway.ErrInvalidFile
}

filename := path.Join(gcsAssetBasePath, sn)
u := getGCSObjectURL(f.base, filename)
if u == nil {
return nil, gateway.ErrInvalidFile
return nil, 0, gateway.ErrInvalidFile
}

if err := f.upload(ctx, filename, file.Content); err != nil {
return nil, err
s, err := f.upload(ctx, filename, file.Content)
if err != nil {
return nil, 0, err
}
return u, nil
return u, s, nil
}

func (f *fileRepo) RemoveAsset(ctx context.Context, u *url.URL) error {
Expand All @@ -111,7 +112,8 @@ func (f *fileRepo) UploadPluginFile(ctx context.Context, pid id.PluginID, file *
if sn == "" {
return gateway.ErrInvalidFile
}
return f.upload(ctx, path.Join(gcsPluginBasePath, pid.String(), sanitize.Path(file.Path)), file.Content)
_, err := f.upload(ctx, path.Join(gcsPluginBasePath, pid.String(), sanitize.Path(file.Path)), file.Content)
return err
}

func (f *fileRepo) RemovePlugin(ctx context.Context, pid id.PluginID) error {
Expand All @@ -132,7 +134,8 @@ func (f *fileRepo) UploadBuiltScene(ctx context.Context, content io.Reader, name
if sn == "" {
return gateway.ErrInvalidFile
}
return f.upload(ctx, path.Join(gcsMapBasePath, sn), content)
_, err := f.upload(ctx, path.Join(gcsMapBasePath, sn), content)
return err
}

func (f *fileRepo) MoveBuiltScene(ctx context.Context, oldName, name string) error {
Expand Down Expand Up @@ -186,37 +189,38 @@ func (f *fileRepo) read(ctx context.Context, filename string) (io.ReadCloser, er
return reader, nil
}

func (f *fileRepo) upload(ctx context.Context, filename string, content io.Reader) error {
func (f *fileRepo) upload(ctx context.Context, filename string, content io.Reader) (int64, error) {
if filename == "" {
return gateway.ErrInvalidFile
return 0, gateway.ErrInvalidFile
}

bucket, err := f.bucket(ctx)
if err != nil {
log.Errorf("gcs: upload bucket err: %+v\n", err)
return rerror.ErrInternalBy(err)
return 0, rerror.ErrInternalBy(err)
}

object := bucket.Object(filename)
if err := object.Delete(ctx); err != nil && !errors.Is(err, storage.ErrObjectNotExist) {
log.Errorf("gcs: upload delete err: %+v\n", err)
return gateway.ErrFailedToUploadFile
return 0, gateway.ErrFailedToUploadFile
}

writer := object.NewWriter(ctx)
writer.ObjectAttrs.CacheControl = f.cacheControl

if _, err := io.Copy(writer, content); err != nil {
size, err := io.Copy(writer, content)
if err != nil {
log.Errorf("gcs: upload err: %+v\n", err)
return gateway.ErrFailedToUploadFile
return 0, gateway.ErrFailedToUploadFile
}

if err := writer.Close(); err != nil {
log.Errorf("gcs: upload close err: %+v\n", err)
return gateway.ErrFailedToUploadFile
return 0, gateway.ErrFailedToUploadFile
}

return nil
return size, nil
}

func (f *fileRepo) move(ctx context.Context, from, dest string) error {
Expand Down
2 changes: 1 addition & 1 deletion server/internal/usecase/gateway/file.go
Expand Up @@ -19,7 +19,7 @@ var (

type File interface {
ReadAsset(context.Context, string) (io.ReadCloser, error)
UploadAsset(context.Context, *file.File) (*url.URL, error)
UploadAsset(context.Context, *file.File) (*url.URL, int64, error)
RemoveAsset(context.Context, *url.URL) error
ReadPluginFile(context.Context, id.PluginID, string) (io.ReadCloser, error)
UploadPluginFile(context.Context, id.PluginID, *file.File) error
Expand Down
15 changes: 8 additions & 7 deletions server/internal/usecase/interactor/asset.go
Expand Up @@ -59,6 +59,11 @@ func (i *Asset) Create(ctx context.Context, inp interfaces.CreateAssetParam, ope
return nil, err
}

url, size, err := i.gateways.File.UploadAsset(ctx, inp.File)
if err != nil {
return nil, err
}

// enforce policy
if policyID := operator.Policy(ws.Policy()); policyID != nil {
p, err := i.repos.Policy.FindByID(ctx, *policyID)
Expand All @@ -69,21 +74,17 @@ func (i *Asset) Create(ctx context.Context, inp interfaces.CreateAssetParam, ope
if err != nil {
return nil, err
}
if err := p.EnforceAssetStorageSize(s + inp.File.Size); err != nil {
if err := p.EnforceAssetStorageSize(s + size); err != nil {
_ = i.gateways.File.RemoveAsset(ctx, url)
return nil, err
}
}

url, err := i.gateways.File.UploadAsset(ctx, inp.File)
if err != nil {
return nil, err
}

a, err := asset.New().
NewID().
Workspace(inp.WorkspaceID).
Name(path.Base(inp.File.Path)).
Size(inp.File.Size).
Size(size).
URL(url.String()).
Build()
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions server/internal/usecase/interactor/dataset.go
Expand Up @@ -55,6 +55,7 @@ func NewDataset(r *repo.Container, gr *gateway.Container) interfaces.Dataset {
layerRepo: r.Layer,
pluginRepo: r.Plugin,
transaction: r.Transaction,
policyRepo: r.Policy,
datasource: gr.DataSource,
file: gr.File,
google: gr.Google,
Expand Down Expand Up @@ -219,12 +220,11 @@ func (i *Dataset) importDataset(ctx context.Context, content io.Reader, name str
return nil, err
}

err = i.datasetSchemaRepo.Save(ctx, schema)
if err != nil {
if err := i.datasetSchemaRepo.Save(ctx, schema); err != nil {
return nil, err
}
err = i.datasetRepo.SaveAll(ctx, datasets)
if err != nil {

if err := i.datasetRepo.SaveAll(ctx, datasets); err != nil {
return nil, err
}

Expand Down
26 changes: 14 additions & 12 deletions server/internal/usecase/interactor/project.go
Expand Up @@ -263,18 +263,20 @@ func (i *Project) Publish(ctx context.Context, params interfaces.PublishProjectP
return nil, err
}

// enforce policy
if policyID := operator.Policy(ws.Policy()); policyID != nil {
p, err := i.policyRepo.FindByID(ctx, *policyID)
if err != nil {
return nil, err
}
s, err := i.projectRepo.CountPublicByWorkspace(ctx, ws.ID())
if err != nil {
return nil, err
}
if err := p.EnforcePublishedProjectCount(s); err != nil {
return nil, err
if prj.PublishmentStatus() == project.PublishmentStatusPrivate {
// enforce policy
if policyID := operator.Policy(ws.Policy()); policyID != nil {
p, err := i.policyRepo.FindByID(ctx, *policyID)
if err != nil {
return nil, err
}
s, err := i.projectRepo.CountPublicByWorkspace(ctx, ws.ID())
if err != nil {
return nil, err
}
if err := p.EnforcePublishedProjectCount(s); err != nil {
return nil, err
}
}
}

Expand Down

0 comments on commit e07b785

Please sign in to comment.