Skip to content

Commit

Permalink
Merge pull request #938 from treeverse/fix/local-adapter-namespace
Browse files Browse the repository at this point in the history
local adapter - respect namespace in objects full path
  • Loading branch information
guy-har committed Nov 19, 2020
2 parents 95e2125 + 6ad53c1 commit 2c137c6
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 24 deletions.
93 changes: 71 additions & 22 deletions block/local/adapter.go
Expand Up @@ -70,9 +70,23 @@ func NewAdapter(path string, opts ...func(a *Adapter)) (*Adapter, error) {
}
return adapter, nil
}
func resolveNamespace(obj block.ObjectPointer) (block.QualifiedKey, error) {
qualifiedKey, err := block.ResolveNamespace(obj.StorageNamespace, obj.Identifier)
if err != nil {
return qualifiedKey, err
}
if qualifiedKey.StorageType != block.StorageTypeLocal {
return qualifiedKey, block.ErrInvalidNamespace
}
return qualifiedKey, nil
}

func (l *Adapter) getPath(identifier string) string {
return path.Join(l.path, identifier)
func (l *Adapter) getPath(identifier block.ObjectPointer) (string, error) {
obj, err := resolveNamespace(identifier)
if err != nil {
return "", err
}
return path.Join(l.path, obj.StorageNamespace, obj.Key), nil
}

// maybeMkdir runs f(path), but if f fails due to file-not-found MkdirAll's its dir and then
Expand All @@ -89,7 +103,10 @@ func maybeMkdir(path string, f func(p string) (*os.File, error)) (*os.File, erro
}

func (l *Adapter) Put(obj block.ObjectPointer, _ int64, reader io.Reader, _ block.PutOpts) error {
p := l.getPath(obj.Identifier)
p, err := l.getPath(obj)
if err != nil {
return err
}
f, err := maybeMkdir(p, os.Create)
if err != nil {
return err
Expand All @@ -102,20 +119,29 @@ func (l *Adapter) Put(obj block.ObjectPointer, _ int64, reader io.Reader, _ bloc
}

func (l *Adapter) Remove(obj block.ObjectPointer) error {
p := l.getPath(obj.Identifier)
p, err := l.getPath(obj)
if err != nil {
return err
}
return os.Remove(p)
}

func (l *Adapter) Copy(sourceObj, destinationObj block.ObjectPointer) error {
source := l.getPath(sourceObj.Identifier)
source, err := l.getPath(sourceObj)
if err != nil {
return err
}
sourceFile, err := os.Open(source)
defer func() {
_ = sourceFile.Close()
}()
if err != nil {
return err
}
dest := l.getPath(destinationObj.Identifier)
dest, err := l.getPath(destinationObj)
if err != nil {
return err
}
destinationFile, err := maybeMkdir(dest, os.Create)
if err != nil {
return err
Expand All @@ -128,7 +154,10 @@ func (l *Adapter) Copy(sourceObj, destinationObj block.ObjectPointer) error {
}

func (l *Adapter) Get(obj block.ObjectPointer, _ int64) (reader io.ReadCloser, err error) {
p := l.getPath(obj.Identifier)
p, err := l.getPath(obj)
if err != nil {
return nil, err
}
f, err := os.OpenFile(p, os.O_RDONLY, 0755)
if err != nil {
return nil, err
Expand All @@ -137,7 +166,10 @@ func (l *Adapter) Get(obj block.ObjectPointer, _ int64) (reader io.ReadCloser, e
}

func (l *Adapter) GetRange(obj block.ObjectPointer, start int64, end int64) (io.ReadCloser, error) {
p := l.getPath(obj.Identifier)
p, err := l.getPath(obj)
if err != nil {
return nil, err
}
f, err := os.Open(p)
if err != nil {
return nil, err
Expand All @@ -152,8 +184,11 @@ func (l *Adapter) GetRange(obj block.ObjectPointer, start int64, end int64) (io.
}

func (l *Adapter) GetProperties(obj block.ObjectPointer) (block.Properties, error) {
p := l.getPath(obj.Identifier)
_, err := os.Stat(p)
p, err := l.getPath(obj)
if err != nil {
return block.Properties{}, err
}
_, err = os.Stat(p)
if err != nil {
return block.Properties{}, err
}
Expand All @@ -175,11 +210,14 @@ func isDirectoryWritable(pth string) bool {
return true
}

func (l *Adapter) CreateMultiPartUpload(obj block.ObjectPointer, r *http.Request, opts block.CreateMultiPartUploadOpts) (string, error) {
func (l *Adapter) CreateMultiPartUpload(obj block.ObjectPointer, _ *http.Request, _ block.CreateMultiPartUploadOpts) (string, error) {
if strings.Contains(obj.Identifier, "/") {
fullPath := l.getPath(obj.Identifier)
fullPath, err := l.getPath(obj)
if err != nil {
return "", err
}
fullDir := path.Dir(fullPath)
err := os.MkdirAll(fullDir, 0755)
err = os.MkdirAll(fullDir, 0755)
if err != nil {
return "", err
}
Expand All @@ -190,16 +228,16 @@ func (l *Adapter) CreateMultiPartUpload(obj block.ObjectPointer, r *http.Request
return uploadID, nil
}

func (l *Adapter) UploadPart(obj block.ObjectPointer, sizeBytes int64, reader io.Reader, uploadID string, partNumber int64) (string, error) {
func (l *Adapter) UploadPart(obj block.ObjectPointer, _ int64, reader io.Reader, uploadID string, partNumber int64) (string, error) {
md5Read := block.NewHashingReader(reader, block.HashFunctionMD5)
fName := uploadID + fmt.Sprintf("-%05d", (partNumber))
err := l.Put(block.ObjectPointer{StorageNamespace: "", Identifier: fName}, -1, md5Read, block.PutOpts{})
err := l.Put(block.ObjectPointer{StorageNamespace: obj.StorageNamespace, Identifier: fName}, -1, md5Read, block.PutOpts{})
etag := "\"" + hex.EncodeToString(md5Read.Md5.Sum(nil)) + "\""
return etag, err
}

func (l *Adapter) AbortMultiPartUpload(obj block.ObjectPointer, uploadID string) error {
files, err := l.getPartFiles(uploadID)
files, err := l.getPartFiles(uploadID, obj)
if err != nil {
return err
}
Expand All @@ -209,11 +247,11 @@ func (l *Adapter) AbortMultiPartUpload(obj block.ObjectPointer, uploadID string)

func (l *Adapter) CompleteMultiPartUpload(obj block.ObjectPointer, uploadID string, multipartList *block.MultipartUploadCompletion) (*string, int64, error) {
etag := computeETag(multipartList.Part) + "-" + strconv.Itoa(len(multipartList.Part))
partFiles, err := l.getPartFiles(uploadID)
partFiles, err := l.getPartFiles(uploadID, obj)
if err != nil {
return nil, -1, fmt.Errorf("part files not found for %s: %w", uploadID, err)
}
size, err := l.unitePartFiles(obj.Identifier, partFiles)
size, err := l.unitePartFiles(obj, partFiles)
if err != nil {
return nil, -1, fmt.Errorf("multipart upload unite for %s: %w", uploadID, err)
}
Expand All @@ -237,8 +275,11 @@ func computeETag(parts []*s3.CompletedPart) string {
return csm
}

func (l *Adapter) unitePartFiles(identifier string, files []string) (int64, error) {
p := l.getPath(identifier)
func (l *Adapter) unitePartFiles(identifier block.ObjectPointer, files []string) (int64, error) {
p, err := l.getPath(identifier)
if err != nil {
return 0, err
}
unitedFile, err := os.Create(p)
if err != nil {
return 0, fmt.Errorf("create path %s: %w", p, err)
Expand Down Expand Up @@ -268,8 +309,16 @@ func (l *Adapter) removePartFiles(files []string) {
}
}

func (l *Adapter) getPartFiles(uploadID string) ([]string, error) {
globPathPattern := l.getPath(uploadID) + "-*"
func (l *Adapter) getPartFiles(uploadID string, obj block.ObjectPointer) ([]string, error) {
newObj := block.ObjectPointer{
StorageNamespace: obj.StorageNamespace,
Identifier: uploadID,
}
globPathPattern, err := l.getPath(newObj)
if err != nil {
return nil, err
}
globPathPattern += "*"
names, err := filepath.Glob(globPathPattern)
if err != nil {
return nil, err
Expand Down
49 changes: 47 additions & 2 deletions block/local/adapter_test.go
@@ -1,6 +1,8 @@
package local_test

import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"io/ioutil"
"os"
"strings"
Expand All @@ -15,7 +17,7 @@ func makeAdapter(t *testing.T) (*local.Adapter, func()) {
t.Helper()
dir, err := ioutil.TempDir("", "testing-local-adapter-*")
testutil.MustDo(t, "TempDir", err)
os.MkdirAll(dir, 0700)
testutil.MustDo(t, "NewAdapter", os.MkdirAll(dir, 0700))
a, err := local.NewAdapter(dir)
testutil.MustDo(t, "NewAdapter", err)

Expand All @@ -27,7 +29,7 @@ func makeAdapter(t *testing.T) (*local.Adapter, func()) {
}

func makePointer(path string) block.ObjectPointer {
return block.ObjectPointer{Identifier: path}
return block.ObjectPointer{Identifier: path, StorageNamespace: "local://test/"}
}

func TestLocalPutGet(t *testing.T) {
Expand Down Expand Up @@ -58,6 +60,49 @@ func TestLocalPutGet(t *testing.T) {
}
}

func TestLocalMultipartUpload(t *testing.T) {
a, cleanup := makeAdapter(t)
defer cleanup()

cases := []struct {
name string
path string
partData []string
}{
{"simple", "abc", []string{"one ", "two ", "three"}},
{"nested", "foo/bar", []string{"one ", "two ", "three"}},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
pointer := makePointer(c.path)
uploadID, err := a.CreateMultiPartUpload(pointer, nil, block.CreateMultiPartUploadOpts{})
testutil.MustDo(t, "CreateMultiPartUpload", err)
parts := make([]*s3.CompletedPart, 0)
for partNumber, content := range c.partData {
cs, err := a.UploadPart(pointer, 0, strings.NewReader(content), uploadID, int64(partNumber))
testutil.MustDo(t, "UploadPart", err)
parts = append(parts, &s3.CompletedPart{
ETag: aws.String(cs),
PartNumber: aws.Int64(int64(partNumber)),
})
}
_, _, err = a.CompleteMultiPartUpload(pointer, uploadID, &block.MultipartUploadCompletion{
Part: parts,
})
testutil.MustDo(t, "CompleteMultiPartUpload", err)
reader, err := a.Get(pointer, 0)
testutil.MustDo(t, "Get", err)
got, err := ioutil.ReadAll(reader)
testutil.MustDo(t, "ReadAll", err)
expected := strings.Join(c.partData, "")
if string(got) != expected {
t.Errorf("expected to read \"%s\" as written, got \"%s\"", expected, string(got))
}
})
}
}

func TestLocalCopy(t *testing.T) {
a, cleanup := makeAdapter(t)
defer cleanup()
Expand Down

0 comments on commit 2c137c6

Please sign in to comment.