/
vfshandler.go
106 lines (91 loc) · 2.21 KB
/
vfshandler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
package sftpd
import (
"io"
"log"
"os"
"github.com/pkg/sftp"
"sync"
)
type vfs struct {
pathMap PathMapper
pathMapLock sync.Mutex
}
func VfsHandler(mapper *PathMapper) sftp.Handlers {
virtualFileSystem := &vfs{
pathMap: *mapper,
}
return sftp.Handlers{
FileGet: virtualFileSystem,
FilePut: virtualFileSystem,
FileCmd: virtualFileSystem,
FileList: virtualFileSystem,
}
}
func dumpSftpRequest(message string, r *sftp.Request) {
log.Println(message, "Filepath: ", r.Filepath, ", Target: ", r.Target, ", Method: ", r.Method)
}
func (fs *vfs) Fileread(r *sftp.Request) (io.ReaderAt, error) {
dumpSftpRequest("Fileread: ", r)
fs.pathMapLock.Lock()
defer fs.pathMapLock.Unlock()
filePath, err := fs.pathMap.PathTo(r.Filepath)
if err == nil {
f, err := os.Open(filePath)
if err != nil {
defer f.Close()
}
return f, err
}
return nil, err
}
func (fs *vfs) Filewrite(r *sftp.Request) (io.WriterAt, error) {
dumpSftpRequest("Filewrite (disabled): ", r)
return nil, os.ErrInvalid
}
func (fs *vfs) Filecmd(r *sftp.Request) error {
dumpSftpRequest("Filecmd (disabled): ", r)
return os.ErrInvalid
}
type listerAt []os.FileInfo
// Modeled after strings.Reader's ReadAt() implementation
func (l listerAt) ListAt(ls []os.FileInfo, offset int64) (int, error) {
var n int
if offset >= int64(len(l)) {
return 0, io.EOF
}
n = copy(ls, l[offset:])
if n < len(ls) {
return n, io.EOF
}
return n, nil
}
func (fs *vfs) Filelist(r *sftp.Request) (sftp.ListerAt, error) {
dumpSftpRequest("Fileinfo: ", r)
fs.pathMapLock.Lock()
defer fs.pathMapLock.Unlock()
switch r.Method {
case "List":
listing, ok := fs.pathMap.List(r.Filepath)
if !ok {
return nil, os.ErrInvalid
}
statList := make([]os.FileInfo, len(listing))
for i, fileName := range listing {
stat, err := fs.pathMap.Stat(fileName)
if err != nil {
log.Println("Could not stat file", fileName, err)
continue
}
statList[i] = stat
log.Println("Stat for file "+fileName+": isDir=>", stat.IsDir(), "size=>", stat.Size())
}
return listerAt(statList), nil
case "Stat":
stat, err := fs.pathMap.Stat(r.Filepath)
if err != nil {
return nil, err
}
return listerAt([]os.FileInfo{stat}), nil
}
return nil, os.ErrInvalid
}