-
Notifications
You must be signed in to change notification settings - Fork 110
/
copy_local.go
456 lines (419 loc) · 12.9 KB
/
copy_local.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
package shell
import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
"time"
"go.uber.org/multierr"
"go.viam.com/utils"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// copy_local supports local filesystem copy operations and is agnostic to which
// side of the RPC copying is happening from.
// NewLocalFileCopyFactory returns a FileCopyFactory that is responsible for making
// FileCopiers that copy from the local filesystem. The destination is used later on
// in tandem with a the CopyFilesSourceType passed into MakeFileCopier.
func NewLocalFileCopyFactory(
destination string,
preserve bool,
relativeToHome bool,
) (FileCopyFactory, error) {
// fixup destination to something we can work with
destination, err := fixPeerPath(destination, true, relativeToHome)
if err != nil {
return nil, err
}
return &localFileCopyFactory{destination: destination, preserve: preserve}, nil
}
type localFileCopyFactory struct {
destination string
preserve bool
}
// MakeFileCopier makes a new FileCopier that is ready to copy files into the factory's
// file destination.
func (f *localFileCopyFactory) MakeFileCopier(ctx context.Context, sourceType CopyFilesSourceType) (FileCopier, error) {
finalDestination := f.destination
var overrideName string
switch sourceType {
case CopyFilesSourceTypeMultipleFiles:
// for multiple files (a b c machine:~/some/dir), ~/some/dir needs to already exist
// as a directory
dstInfo, err := os.Stat(f.destination)
if err != nil || dstInfo == nil || !dstInfo.IsDir() {
return nil, fmt.Errorf("%q does not exist or is not a directory", f.destination)
}
if err := os.MkdirAll(filepath.Dir(f.destination), 0o750); err != nil {
return nil, err
}
case CopyFilesSourceTypeSingleFile, CopyFilesSourceTypeSingleDirectory:
// for single files (a machine:~/some/dir_or_file):
// if destination exists and
// it is a directory, then put the source file/directory in it.
// it is a file and the source is a file, overwrite.
// it is a file and the source is a directory, error.
// or if destination does not exist and
// if the parent exists and
// it is a directory, then put the source/file directory in it.
// it is a file, then error.
// the parent does not exist, then error.
var rename bool
dstInfo, err := os.Stat(f.destination)
// if destination exists and
if err == nil {
if dstInfo == nil {
return nil, fmt.Errorf("expected file info for %q", f.destination)
}
switch {
case dstInfo.IsDir():
// it is a directory, then put the source file/directory in it
// destination stays the same
case sourceType == CopyFilesSourceTypeSingleFile:
// it is a file and the source is a file, overwrite
// destination becomes parent
rename = true
default:
// it is a file and the source is a directory, error
return nil, fmt.Errorf("destination %q is an existing file", f.destination)
}
} else { // or if destination does not exist and
parent := filepath.Dir(f.destination)
parentInfo, err := os.Stat(parent)
if err != nil {
// the parent does not exist, then error
return nil, err
}
if parentInfo == nil {
return nil, fmt.Errorf("expected file info for %q", parent)
}
// if the parent exists and
if parentInfo.IsDir() {
// it is a directory, then put the source/file directory in it
// destination becomes parent
rename = true
} else {
// it is a file, then error
return nil, fmt.Errorf("parent of destination %q is an existing file, not a directory", f.destination)
}
}
if rename {
overrideName = filepath.Base(f.destination)
finalDestination = filepath.Dir(f.destination)
}
case CopyFilesSourceTypeMultipleUnknown:
fallthrough
default:
return nil, fmt.Errorf("do not know how to process source copy type %q", sourceType)
}
return &localFileCopier{
sourceType: sourceType,
dst: finalDestination,
overrideName: overrideName,
preserve: f.preserve,
}, nil
}
func (f *localFileCopyFactory) Close(ctx context.Context) error {
return nil
}
// A localFileCopier takes in files and copies them to a set destination. It should be created
// with a localFileCopyFactory.
type localFileCopier struct {
sourceType CopyFilesSourceType
dst string
overrideName string
preserve bool
}
func (copier *localFileCopier) Copy(ctx context.Context, file File) error {
defer func() {
utils.UncheckedError(file.Data.Close())
}()
fileName := file.RelativeName
if copier.overrideName != "" {
// only change the first part of the file name for directory
// renaming purposes. This is based on SCP logic.
fileSplit := splitPath(file.RelativeName)
if len(fileSplit) == 0 {
fileSplit = []string{copier.overrideName}
} else {
fileSplit[0] = copier.overrideName
}
fileName = filepath.Join(fileSplit...)
}
fullPath := filepath.Join(copier.dst, fileName)
fileInfo, err := file.Data.Stat()
if err != nil {
return err
}
if fileInfo == nil {
return fmt.Errorf("expected file info for %q to be non-nil", fileName)
}
parentPath := filepath.Dir(fullPath)
if parentPath == "" {
return fmt.Errorf("expected non-empty parent path to destination %q", copier.dst)
}
if fileInfo, err := os.Stat(parentPath); err != nil {
if !errors.Is(err, fs.ErrNotExist) {
return err
}
// this will later be updated with chmod for a specific directory. It's safe to make
// directories here because we assume whoever constructed us has validated the top-level
// directory as existing or created.
//nolint:gosec // this is from an authenticated/authorized connection
if err := os.MkdirAll(parentPath, 0o755); err != nil {
return err
}
} else if fileInfo == nil {
return fmt.Errorf("expected file info for %q to be non-nil", parentPath)
} else if !fileInfo.IsDir() {
return fmt.Errorf("invariant: parent path %q should have been a directory", parentPath)
}
var fileMode fs.FileMode
modTime := time.Now()
switch {
case copier.preserve:
modTime = fileInfo.ModTime()
fileMode = fileInfo.Mode()
case fileInfo.IsDir():
fileMode = 0o750
default:
fileMode = 0o640
}
if fileInfo.IsDir() {
if err := os.Mkdir(fullPath, fileMode); err != nil {
if !errors.Is(err, fs.ErrExist) {
return err
}
if copier.preserve {
// Update the mode since it maye have been created via mkdirall above
if err := os.Chmod(fullPath, fileMode); err != nil {
return err
}
}
}
} else {
//nolint:gosec // this is from an authenticated/authorized connection
localFile, err := os.OpenFile(fullPath, os.O_CREATE|os.O_WRONLY, fileMode)
if err != nil {
return err
}
if _, err := io.Copy(localFile, file.Data); err != nil {
return multierr.Combine(err, localFile.Close())
}
}
if copier.preserve {
// Update the mode since it maye have been created via mkdirall above
// or modified by umask
if err := os.Chmod(fullPath, fileMode); err != nil {
return err
}
// Note(erd): maybe support access time in the future if needed
if err := os.Chtimes(fullPath, time.Now(), modTime); err != nil {
return err
}
}
return nil
}
// Close does nothing.
func (copier *localFileCopier) Close(ctx context.Context) error {
return nil
}
type localFileReadCopier struct {
filesToCopy []*os.File
copyFactory FileCopyFactory
}
// NewLocalFileReadCopier returns a FileReadCopier that will have its ReadAll
// method iteratively copy each file found indicated by paths into a FileCopier
// created by the FileCopyFactory. The Factory is used since we don't yet know
// what type of files we are going to copy until ReadAll is called.
func NewLocalFileReadCopier(
paths []string,
allowRecursive bool,
relativeToHome bool,
copyFactory FileCopyFactory,
) (FileReadCopier, error) {
var filesToCopy []*os.File
for _, p := range paths {
p, err := fixPeerPath(p, false, relativeToHome)
if err != nil {
return nil, err
}
//nolint:gosec // this is from an authenticated/authorized connection
fileToCopy, err := os.Open(p)
if err != nil {
return nil, err
}
if !allowRecursive {
fileInfo, err := fileToCopy.Stat()
if err != nil {
return nil, err
}
if fileInfo.IsDir() {
details := &errdetails.BadRequest_FieldViolation{
Field: "paths",
Description: fmt.Sprintf("local %q is a directory but copy recursion not used", p),
}
s, err := status.New(codes.InvalidArgument, ErrMsgDirectoryCopyRequestNoRecursion).WithDetails(details)
if err != nil {
return nil, err
}
return nil, s.Err()
}
}
filesToCopy = append(filesToCopy, fileToCopy)
}
if len(filesToCopy) == 0 {
return nil, errors.New("no files provided to copy")
}
return &localFileReadCopier{filesToCopy: filesToCopy, copyFactory: copyFactory}, nil
}
// ErrMsgDirectoryCopyRequestNoRecursion should be returned when a file is included in a path for a copy request
// where recursion is not enabled.
var ErrMsgDirectoryCopyRequestNoRecursion = "file is a directory but copy recursion not used"
// ReadAll processes and copies each file one by one into a newly constructed FileCopier until
// complete.
func (reader *localFileReadCopier) ReadAll(ctx context.Context) error {
if len(reader.filesToCopy) == 0 {
return nil
}
var sourceType CopyFilesSourceType
if len(reader.filesToCopy) == 1 {
fileInfo, err := reader.filesToCopy[0].Stat()
if err != nil {
return err
}
if fileInfo.IsDir() {
sourceType = CopyFilesSourceTypeSingleDirectory
} else {
sourceType = CopyFilesSourceTypeSingleFile
}
} else {
sourceType = CopyFilesSourceTypeMultipleFiles
}
copier, err := reader.copyFactory.MakeFileCopier(ctx, sourceType)
if err != nil {
return err
}
defer func() {
utils.UncheckedError(copier.Close(ctx))
}()
// Note: okay with recursion for now. may want to check depth later...
makeRelName := func(relDir string, file *os.File) string {
return filepath.Join(relDir, filepath.Base(file.Name()))
}
var copyFiles func(relDir string, files []*os.File) error
copyFiles = func(relDir string, files []*os.File) error {
for _, f := range files {
if ctx.Err() != nil {
return ctx.Err()
}
fileInfo, err := f.Stat()
if err != nil {
return err
}
if fileInfo.IsDir() {
filesEntriesInDir, err := f.ReadDir(0)
if err != nil {
return err
}
filesInDir := make([]*os.File, 0, len(filesEntriesInDir))
for _, dirEntry := range filesEntriesInDir {
entryPath := filepath.Join(f.Name(), dirEntry.Name())
//nolint:gosec // this is from an authenticated/authorized connection
entryFile, err := os.Open(entryPath)
if err != nil {
for _, f := range filesInDir {
utils.UncheckedError(f.Close())
}
return err
}
filesInDir = append(filesInDir, entryFile)
}
if err := copyFiles(makeRelName(relDir, f), filesInDir); err != nil {
return err
}
}
if err := copier.Copy(ctx, File{
RelativeName: makeRelName(relDir, f),
Data: f,
}); err != nil {
return err
}
}
return nil
}
if err := copyFiles("", reader.filesToCopy); err != nil {
return err
}
return nil
}
// Close closes all files that were used for copying at the top-level. They have
// likely already been closed deeper down the stack.
func (reader *localFileReadCopier) Close(ctx context.Context) error {
var errs error
for _, f := range reader.filesToCopy {
if err := f.Close(); err != nil && !errors.Is(err, fs.ErrClosed) {
errs = multierr.Combine(errs, err)
}
}
return errs
}
var errUnexpectedEmptyPath = errors.New("unexpected empty path")
// fixPeerPath works with the usage of ~ or empty paths and turns
// them into the proper HOME pathings.
// Security Note: this is the only time we end up interpreting a user's path
// string before it's passed to a file related syscall.
func fixPeerPath(path string, allowEmpty, relativeToHome bool) (string, error) {
if !filepath.IsAbs(path) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}
switch {
case strings.HasPrefix(path, "~/"):
path = strings.Replace(path, "~", homeDir, 1)
case path == "":
if !allowEmpty {
return "", errUnexpectedEmptyPath
}
if relativeToHome {
path = homeDir
} else {
path, err = filepath.Abs("")
if err != nil {
return "", err
}
}
case relativeToHome:
// From path has us use HOME paths
path = filepath.Join(homeDir, path)
default:
// To path has us use CWD paths
path, err = filepath.Abs(path)
if err != nil {
return "", err
}
}
}
return path, nil
}
// this seems to mostly work as a cross-platform splitter.
func splitPath(path string) []string {
var split []string
vol := filepath.VolumeName(path)
start := len(vol)
for i := start; i < len(path); i++ {
if os.IsPathSeparator(path[i]) {
split = append(split, path[start:i])
start = i + 1
} else if i+1 == len(path) {
split = append(split, path[start:])
}
}
return split
}