diff --git a/commands.go b/commands.go index c9906f0..5eaf622 100644 --- a/commands.go +++ b/commands.go @@ -282,15 +282,19 @@ func NewGetCommand() cli.Command { PrintErrorAndExit("get %s: parse mtime: %v", upPath, err) } } + if c.Int("w") > 10 || c.Int("w") < 1 { + PrintErrorAndExit("max concurrent threads must between (1 - 10)") + } if mc.Start != "" || mc.End != "" { session.GetStartBetweenEndFiles(upPath, localPath, mc, c.Int("w")) } else { - session.Get(upPath, localPath, mc, c.Int("w")) + session.Get(upPath, localPath, mc, c.Int("w"), c.Bool("c")) } return nil }, Flags: []cli.Flag{ - cli.IntFlag{Name: "w", Usage: "max concurrent threads", Value: 5}, + cli.IntFlag{Name: "w", Usage: "max concurrent threads (1-10)", Value: 5}, + cli.BoolFlag{Name: "c", Usage: "continue download, Resume Broken Download"}, cli.StringFlag{Name: "mtime", Usage: "file's data was last modified n*24 hours ago, same as linux find command."}, cli.StringFlag{Name: "start", Usage: "file download range starting location"}, cli.StringFlag{Name: "end", Usage: "file download range ending location"}, @@ -315,7 +319,9 @@ func NewPutCommand() cli.Command { if c.NArg() > 1 { upPath = c.Args().Get(1) } - + if c.Int("w") > 10 || c.Int("w") < 1 { + PrintErrorAndExit("max concurrent threads must between (1 - 10)") + } session.Put( localPath, upPath, @@ -332,9 +338,12 @@ func NewPutCommand() cli.Command { func NewUploadCommand() cli.Command { return cli.Command{ Name: "upload", - Usage: "upload multiple directory or file", + Usage: "upload multiple directory or file or http url", Action: func(c *cli.Context) error { InitAndCheck(LOGIN, CHECK, c) + if c.Int("w") > 10 || c.Int("w") < 1 { + PrintErrorAndExit("max concurrent threads must between (1 - 10)") + } session.Upload( c.Args(), c.String("remote"), @@ -422,6 +431,9 @@ func NewSyncCommand() cli.Command { if c.NArg() > 1 { upPath = c.Args().Get(1) } + if c.Int("w") > 10 || c.Int("w") < 1 { + PrintErrorAndExit("max concurrent threads must between (1 - 10)") + } session.Sync(localPath, upPath, c.Int("w"), c.Bool("delete"), c.Bool("strong")) return nil }, diff --git a/go.mod b/go.mod index ab132ff..358cb44 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/jehiah/go-strftime v0.0.0-20171201141054-1d33003b3869 github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 github.com/syndtr/goleveldb v1.0.0 - github.com/upyun/go-sdk/v3 v3.0.3 + github.com/upyun/go-sdk/v3 v3.0.4 github.com/urfave/cli v1.22.4 ) diff --git a/go.sum b/go.sum index 0d614b5..50460e3 100644 --- a/go.sum +++ b/go.sum @@ -39,6 +39,8 @@ github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFd github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= github.com/upyun/go-sdk/v3 v3.0.3 h1:2wUkNk2fyJReMYHMvJyav050D83rYwSjN7mEPR0Pp8Q= github.com/upyun/go-sdk/v3 v3.0.3/go.mod h1:P/SnuuwhrIgAVRd/ZpzDWqCsBAf/oHg7UggbAxyZa0E= +github.com/upyun/go-sdk/v3 v3.0.4 h1:2DCJa/Yi7/3ZybT9UCPATSzvU3wpPPxhXinNlb1Hi8Q= +github.com/upyun/go-sdk/v3 v3.0.4/go.mod h1:P/SnuuwhrIgAVRd/ZpzDWqCsBAf/oHg7UggbAxyZa0E= github.com/urfave/cli v1.22.4 h1:u7tSpNPPswAFymm8IehJhy4uJMlUuU/GmqSkvJ1InXA= github.com/urfave/cli v1.22.4/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/io.go b/io.go index 40306af..1f2ffd2 100644 --- a/io.go +++ b/io.go @@ -34,15 +34,27 @@ func (w *WrappedWriter) Close() error { return w.w.Close() } -func NewFileWrappedWriter(localPath string, bar *uiprogress.Bar) (*WrappedWriter, error) { - fd, err := os.Create(localPath) +func NewFileWrappedWriter(localPath string, bar *uiprogress.Bar, resume bool) (*WrappedWriter, error) { + var fd *os.File + var err error + + if resume { + fd, err = os.OpenFile(localPath, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0755) + } else { + fd, err = os.Create(localPath) + } + if err != nil { + return nil, err + } + + fileinfo, err := fd.Stat() if err != nil { return nil, err } return &WrappedWriter{ w: fd, - Copyed: 0, + Copyed: int(fileinfo.Size()), bar: bar, }, nil } diff --git a/partial/chunk.go b/partial/chunk.go new file mode 100644 index 0000000..1184a53 --- /dev/null +++ b/partial/chunk.go @@ -0,0 +1,97 @@ +package partial + +import ( + "sync/atomic" +) + +type Chunk struct { + // 切片的顺序 + index int64 + + // 切片内容的在源文件的开始地址 + start int64 + + // 切片内容在源文件的结束地址 + end int64 + + // 切片任务的下载错误 + err error + + // 下载完的切片的具体内容 + buffer []byte +} + +func NewChunk(index, start, end int64) *Chunk { + chunk := &Chunk{ + start: start, + end: end, + index: index, + } + return chunk +} + +func (p *Chunk) SetData(bytes []byte) { + p.buffer = bytes +} + +func (p *Chunk) SetError(err error) { + p.err = err +} + +func (p *Chunk) Error() error { + return p.err +} + +func (p *Chunk) Data() []byte { + return p.buffer +} + +// 切片乱序写入后,将切片顺序读取 +type ChunksSorter struct { + // 已经读取的切片数量 + readCount int64 + + // 切片的所有总数 + chunkCount int64 + + // 线程数,用于阻塞写入 + works int64 + + // 存储切片的缓存区 + chunks []chan *Chunk +} + +func NewChunksSorter(chunkCount int64, works int) *ChunksSorter { + chunks := make([]chan *Chunk, works) + for i := 0; i < len(chunks); i++ { + chunks[i] = make(chan *Chunk) + } + + return &ChunksSorter{ + chunkCount: chunkCount, + works: int64(works), + chunks: chunks, + } +} + +// 将数据写入到缓存区,如果该缓存已满,则会被阻塞 +func (p *ChunksSorter) Write(chunk *Chunk) { + p.chunks[chunk.index%p.works] <- chunk +} + +// 关闭 workId 下的通道 +func (p *ChunksSorter) Close(workId int) { + if (len(p.chunks) - 1) >= workId { + close(p.chunks[workId]) + } +} + +// 顺序读取切片,如果下一个切片没有下载完,则会被阻塞 +func (p *ChunksSorter) Read() *Chunk { + if p.chunkCount == 0 { + return nil + } + i := atomic.AddInt64(&p.readCount, 1) + chunk := <-p.chunks[(i-1)%p.works] + return chunk +} diff --git a/partial/downloader.go b/partial/downloader.go new file mode 100644 index 0000000..5cc43f0 --- /dev/null +++ b/partial/downloader.go @@ -0,0 +1,120 @@ +package partial + +import ( + "context" + "errors" + "io" + "os" +) + +const ChunkSize = 1024 * 1024 * 10 + +type ChunkDownFunc func(start, end int64) ([]byte, error) + +type MultiPartialDownloader struct { + + // 文件路径 + filePath string + + // 最终文件大小 + finalSize int64 + + // 本地文件大小 + localSize int64 + + writer io.Writer + works int + downFunc ChunkDownFunc +} + +func NewMultiPartialDownloader(filePath string, finalSize int64, writer io.Writer, works int, fn ChunkDownFunc) *MultiPartialDownloader { + return &MultiPartialDownloader{ + filePath: filePath, + finalSize: finalSize, + works: works, + writer: writer, + downFunc: fn, + } +} + +func (p *MultiPartialDownloader) Download() error { + fileinfo, err := os.Stat(p.filePath) + if err == nil { + p.localSize = fileinfo.Size() + } + + // 计算需要下载的块数 + needDownSize := p.finalSize - p.localSize + chunkCount := needDownSize / ChunkSize + if needDownSize%ChunkSize != 0 { + chunkCount++ + } + + chunksSorter := NewChunksSorter( + chunkCount, + p.works, + ) + + // 下载切片任务 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < p.works; i++ { + go func(ctx context.Context, workId int) { + // 关闭 workId 下的接收通道 + defer chunksSorter.Close(workId) + + // 每个 work 取自己倍数的 chunk + for j := workId; j < int(chunkCount); j += p.works { + select { + case <-ctx.Done(): + return + default: + var ( + err error + buffer []byte + ) + start := p.localSize + int64(j)*ChunkSize + end := p.localSize + int64(j+1)*ChunkSize + if end > p.finalSize { + end = p.finalSize + } + chunk := NewChunk(int64(j), start, end) + + // 重试三次 + for t := 0; t < 3; t++ { + // ? 由于长度是从1开始,而数据是从0地址开始 + // ? 计算字节时容量会多出开头的一位,所以末尾需要减少一位 + buffer, err = p.downFunc(chunk.start, chunk.end-1) + if err == nil { + break + } + } + chunk.SetData(buffer) + chunk.SetError(err) + chunksSorter.Write(chunk) + + if err != nil { + return + } + } + } + }(ctx, i) + } + + // 将分片顺序写入到文件 + for { + chunk := chunksSorter.Read() + if chunk == nil { + break + } + if chunk.Error() != nil { + return chunk.Error() + } + if len(chunk.Data()) == 0 { + return errors.New("chunk buffer download but size is 0") + } + p.writer.Write(chunk.Data()) + } + return nil +} diff --git a/session.go b/session.go index 7fb16a0..ea26050 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "encoding/json" "fmt" "io/ioutil" @@ -20,6 +21,7 @@ import ( "github.com/gosuri/uiprogress" "github.com/jehiah/go-strftime" "github.com/upyun/go-sdk/v3/upyun" + "github.com/upyun/upx/partial" ) const ( @@ -298,7 +300,7 @@ func (sess *Session) getDir(upPath, localPath string, match *MatchConfig, worker os.MkdirAll(lpath, 0755) } else { for i := 1; i <= MaxRetry; i++ { - id, e = sess.getFileWithProgress(id, fpath, lpath, fInfo) + id, e = sess.getFileWithProgress(id, fpath, lpath, fInfo, 1, false) if e == nil { break } @@ -328,7 +330,7 @@ func (sess *Session) getDir(upPath, localPath string, match *MatchConfig, worker return err } -func (sess *Session) getFileWithProgress(id int, upPath, localPath string, upInfo *upyun.FileInfo) (int, error) { +func (sess *Session) getFileWithProgress(id int, upPath, localPath string, upInfo *upyun.FileInfo, works int, resume bool) (int, error) { var err error var bar *uiprogress.Bar @@ -361,20 +363,35 @@ func (sess *Session) getFileWithProgress(id int, upPath, localPath string, upInf return id, err } - w, err := NewFileWrappedWriter(localPath, bar) + w, err := NewFileWrappedWriter(localPath, bar, resume) if err != nil { return id, err } defer w.Close() - _, err = sess.updriver.Get(&upyun.GetObjectConfig{ - Path: sess.AbsPath(upPath), - Writer: w, - }) + downloader := partial.NewMultiPartialDownloader( + localPath, + upInfo.Size, + w, + works, + func(start, end int64) ([]byte, error) { + var buffer bytes.Buffer + _, err = sess.updriver.Get(&upyun.GetObjectConfig{ + Path: sess.AbsPath(upPath), + Writer: &buffer, + Headers: map[string]string{ + "Range": fmt.Sprintf("bytes=%d-%d", start, end), + }, + }) + return buffer.Bytes(), err + }, + ) + err = downloader.Download() + return idx, err } -func (sess *Session) Get(upPath, localPath string, match *MatchConfig, workers int) { +func (sess *Session) Get(upPath, localPath string, match *MatchConfig, workers int, resume bool) { upPath = sess.AbsPath(upPath) upInfo, err := sess.updriver.GetInfo(upPath) if err != nil { @@ -406,7 +423,12 @@ func (sess *Session) Get(upPath, localPath string, match *MatchConfig, workers i if isDir { localPath = filepath.Join(localPath, path.Base(upPath)) } - sess.getFileWithProgress(-1, upPath, localPath, upInfo) + + // 小于 100M 不开启多线程 + if upInfo.Size < 1024*1024*100 { + workers = 1 + } + sess.getFileWithProgress(-1, upPath, localPath, upInfo, workers, resume) } } @@ -451,7 +473,7 @@ func (sess *Session) GetStartBetweenEndFiles(upPath, localPath string, match *Ma for fInfo := range fInfoChan { fp := filepath.Join(fpath, fInfo.Name) if (fp >= startList || startList == "") && (fp < endList || endList == "") { - sess.Get(fp, localPath, match, workers) + sess.Get(fp, localPath, match, workers, false) } else if strings.HasPrefix(startList, fp) { //前缀相同进入下一级文件夹,继续递归判断 if fInfo.IsDir {