Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: 增加分片并发下载和断续下载 #78

Merged
merged 2 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ jobs:
- name: Test
run: |
go build .
go test -v .
go test -v ./...
20 changes: 16 additions & 4 deletions commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,19 @@ func NewGetCommand() cli.Command {
PrintErrorAndExit("get %s: parse mtime: %v", upPath, err)
}
}
if c.Int("w") > 10 || c.Int("w") < 1 {
PrintErrorAndExit("max concurrent threads must between (1 - 10)")
}
if mc.Start != "" || mc.End != "" {
session.GetStartBetweenEndFiles(upPath, localPath, mc, c.Int("w"))
} else {
session.Get(upPath, localPath, mc, c.Int("w"))
session.Get(upPath, localPath, mc, c.Int("w"), c.Bool("c"))
}
return nil
},
Flags: []cli.Flag{
cli.IntFlag{Name: "w", Usage: "max concurrent threads", Value: 5},
cli.IntFlag{Name: "w", Usage: "max concurrent threads (1-10)", Value: 5},
cli.BoolFlag{Name: "c", Usage: "continue download, Resume Broken Download"},
cli.StringFlag{Name: "mtime", Usage: "file's data was last modified n*24 hours ago, same as linux find command."},
cli.StringFlag{Name: "start", Usage: "file download range starting location"},
cli.StringFlag{Name: "end", Usage: "file download range ending location"},
Expand All @@ -315,7 +319,9 @@ func NewPutCommand() cli.Command {
if c.NArg() > 1 {
upPath = c.Args().Get(1)
}

if c.Int("w") > 10 || c.Int("w") < 1 {
PrintErrorAndExit("max concurrent threads must between (1 - 10)")
}
session.Put(
localPath,
upPath,
Expand All @@ -332,9 +338,12 @@ func NewPutCommand() cli.Command {
func NewUploadCommand() cli.Command {
return cli.Command{
Name: "upload",
Usage: "upload multiple directory or file",
Usage: "upload multiple directory or file or http url",
Action: func(c *cli.Context) error {
InitAndCheck(LOGIN, CHECK, c)
if c.Int("w") > 10 || c.Int("w") < 1 {
PrintErrorAndExit("max concurrent threads must between (1 - 10)")
}
session.Upload(
c.Args(),
c.String("remote"),
Expand Down Expand Up @@ -422,6 +431,9 @@ func NewSyncCommand() cli.Command {
if c.NArg() > 1 {
upPath = c.Args().Get(1)
}
if c.Int("w") > 10 || c.Int("w") < 1 {
PrintErrorAndExit("max concurrent threads must between (1 - 10)")
}
session.Sync(localPath, upPath, c.Int("w"), c.Bool("delete"), c.Bool("strong"))
return nil
},
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/jehiah/go-strftime v0.0.0-20171201141054-1d33003b3869
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0
github.com/syndtr/goleveldb v1.0.0
github.com/upyun/go-sdk/v3 v3.0.3
github.com/upyun/go-sdk/v3 v3.0.4
github.com/urfave/cli v1.22.4
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFd
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
github.com/upyun/go-sdk/v3 v3.0.3 h1:2wUkNk2fyJReMYHMvJyav050D83rYwSjN7mEPR0Pp8Q=
github.com/upyun/go-sdk/v3 v3.0.3/go.mod h1:P/SnuuwhrIgAVRd/ZpzDWqCsBAf/oHg7UggbAxyZa0E=
github.com/upyun/go-sdk/v3 v3.0.4 h1:2DCJa/Yi7/3ZybT9UCPATSzvU3wpPPxhXinNlb1Hi8Q=
github.com/upyun/go-sdk/v3 v3.0.4/go.mod h1:P/SnuuwhrIgAVRd/ZpzDWqCsBAf/oHg7UggbAxyZa0E=
github.com/urfave/cli v1.22.4 h1:u7tSpNPPswAFymm8IehJhy4uJMlUuU/GmqSkvJ1InXA=
github.com/urfave/cli v1.22.4/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
Expand Down
18 changes: 15 additions & 3 deletions io.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,27 @@ func (w *WrappedWriter) Close() error {
return w.w.Close()
}

func NewFileWrappedWriter(localPath string, bar *uiprogress.Bar) (*WrappedWriter, error) {
fd, err := os.Create(localPath)
func NewFileWrappedWriter(localPath string, bar *uiprogress.Bar, resume bool) (*WrappedWriter, error) {
var fd *os.File
var err error

if resume {
fd, err = os.OpenFile(localPath, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0755)
} else {
fd, err = os.Create(localPath)
}
if err != nil {
return nil, err
}

fileinfo, err := fd.Stat()
if err != nil {
return nil, err
}

return &WrappedWriter{
w: fd,
Copyed: 0,
Copyed: int(fileinfo.Size()),
bar: bar,
}, nil
}
Expand Down
97 changes: 97 additions & 0 deletions partial/chunk.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package partial

import (
"sync/atomic"
)

type Chunk struct {
// 切片的顺序
index int64

// 切片内容的在源文件的开始地址
start int64

// 切片内容在源文件的结束地址
end int64

// 切片任务的下载错误
err error

// 下载完的切片的具体内容
buffer []byte
}
arrebole marked this conversation as resolved.
Show resolved Hide resolved

func NewChunk(index, start, end int64) *Chunk {
chunk := &Chunk{
start: start,
end: end,
index: index,
}
return chunk
}

func (p *Chunk) SetData(bytes []byte) {
p.buffer = bytes
}

func (p *Chunk) SetError(err error) {
p.err = err
}

func (p *Chunk) Error() error {
return p.err
}

func (p *Chunk) Data() []byte {
return p.buffer
}

// 切片乱序写入后,将切片顺序读取
type ChunksSorter struct {
// 已经读取的切片数量
readCount int64

// 切片的所有总数
chunkCount int64

// 线程数,用于阻塞写入
works int64

// 存储切片的缓存区
chunks []chan *Chunk
}

func NewChunksSorter(chunkCount int64, works int) *ChunksSorter {
chunks := make([]chan *Chunk, works)
for i := 0; i < len(chunks); i++ {
chunks[i] = make(chan *Chunk)
}

return &ChunksSorter{
chunkCount: chunkCount,
works: int64(works),
chunks: chunks,
}
}

// 将数据写入到缓存区,如果该缓存已满,则会被阻塞
func (p *ChunksSorter) Write(chunk *Chunk) {
p.chunks[chunk.index%p.works] <- chunk
}

// 关闭 workId 下的通道
func (p *ChunksSorter) Close(workId int) {
if (len(p.chunks) - 1) >= workId {
close(p.chunks[workId])
}
}

// 顺序读取切片,如果下一个切片没有下载完,则会被阻塞
func (p *ChunksSorter) Read() *Chunk {
if p.chunkCount == 0 {
return nil
}
i := atomic.AddInt64(&p.readCount, 1)
chunk := <-p.chunks[(i-1)%p.works]
return chunk
}
141 changes: 141 additions & 0 deletions partial/downloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package partial

import (
"context"
"errors"
"io"
"os"
"sync"
)

const DefaultChunkSize = 1024 * 1024 * 10

type ChunkDownFunc func(start, end int64) ([]byte, error)

type MultiPartialDownloader struct {
arrebole marked this conversation as resolved.
Show resolved Hide resolved

// 文件路径
filePath string

// 最终文件大小
finalSize int64

// 本地文件大小
localSize int64

//分片大小
chunkSize int64

writer io.Writer
works int
downFunc ChunkDownFunc
}

func NewMultiPartialDownloader(filePath string, finalSize, chunkSize int64, writer io.Writer, works int, fn ChunkDownFunc) *MultiPartialDownloader {
return &MultiPartialDownloader{
filePath: filePath,
finalSize: finalSize,
works: works,
writer: writer,
chunkSize: chunkSize,
downFunc: fn,
}
}

func (p *MultiPartialDownloader) Download() error {
fileinfo, err := os.Stat(p.filePath)

// 如果异常
// - 文件不存在异常: localSize 默认值 0
// - 不是文件不存在异常: 报错
if err != nil && !os.IsNotExist(err) {
return err
}
if err == nil {
arrebole marked this conversation as resolved.
Show resolved Hide resolved
p.localSize = fileinfo.Size()
}

// 计算需要下载的块数
needDownSize := p.finalSize - p.localSize
chunkCount := needDownSize / p.chunkSize
if needDownSize%p.chunkSize != 0 {
chunkCount++
}

chunksSorter := NewChunksSorter(
chunkCount,
p.works,
)

// 下载切片任务
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(context.Background())
defer func() {
// 取消切片下载任务,并等待
cancel()
wg.Wait()
}()

for i := 0; i < p.works; i++ {
wg.Add(1)
go func(ctx context.Context, workId int) {
defer func() {
// 关闭 workId 下的接收通道
chunksSorter.Close(workId)
wg.Done()
}()

// 每个 work 取自己倍数的 chunk
arrebole marked this conversation as resolved.
Show resolved Hide resolved
for j := workId; j < int(chunkCount); j += p.works {
select {
case <-ctx.Done():
return
default:
var (
err error
buffer []byte
)
start := p.localSize + int64(j)*p.chunkSize
end := p.localSize + int64(j+1)*p.chunkSize
if end > p.finalSize {
end = p.finalSize
}
chunk := NewChunk(int64(j), start, end)

// 重试三次
for t := 0; t < 3; t++ {
// ? 由于长度是从1开始,而数据是从0地址开始
// ? 计算字节时容量会多出开头的一位,所以末尾需要减少一位
buffer, err = p.downFunc(chunk.start, chunk.end-1)
if err == nil {
break
}
}
chunk.SetData(buffer)
chunk.SetError(err)
chunksSorter.Write(chunk)

if err != nil {
return
}
}
}
}(ctx, i)
}

// 将分片顺序写入到文件
for {
chunk := chunksSorter.Read()
if chunk == nil {
break
}
if chunk.Error() != nil {
return chunk.Error()
}
if len(chunk.Data()) == 0 {
return errors.New("chunk buffer download but size is 0")
}
p.writer.Write(chunk.Data())
}
return nil
}
32 changes: 32 additions & 0 deletions partial/downloader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package partial

import (
"bytes"
"crypto/md5"
"strings"
"testing"
)

func TestDownload(t *testing.T) {
var buffer bytes.Buffer

filedata := []byte(strings.Repeat("hello world", 1024*100))
download := NewMultiPartialDownloader(
"myTestfile",
int64(len(filedata)),
1024,
&buffer,
3,
func(start, end int64) ([]byte, error) {
return filedata[start : end+1], nil
},
)

err := download.Download()
if err != nil {
t.Fatal(err.Error())
}
if md5.Sum(buffer.Bytes()) != md5.Sum(filedata) {
t.Fatal("download file has diff MD5")
}
}
Loading