/
input_files_helper.go
103 lines (94 loc) · 2.96 KB
/
input_files_helper.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
package playwright
import (
"encoding/base64"
"errors"
"fmt"
"os"
"path/filepath"
)
const fileSizeLimitInBytes = 50 * 1024 * 1024
var ErrInputFilesSizeExceeded = errors.New("Cannot set buffer larger than 50Mb, please write it to a file and pass its path instead.")
type inputFiles struct {
Selector *string `json:"selector,omitempty"`
Streams []*channel `json:"streams,omitempty"` // writableStream
LocalPaths []string `json:"localPaths,omitempty"`
Payloads []map[string]string `json:"payloads,omitempty"`
}
// convertInputFiles converts files to proper format for Playwright
//
// - files should be one of: string, []string, InputFile, []InputFile,
// string: local file path
func convertInputFiles(files interface{}, context *browserContextImpl) (*inputFiles, error) {
converted := &inputFiles{}
switch items := files.(type) {
case InputFile:
if sizeOfInputFiles([]InputFile{items}) > fileSizeLimitInBytes {
return nil, ErrInputFilesSizeExceeded
}
converted.Payloads = normalizeFilePayloads([]InputFile{items})
case []InputFile:
if sizeOfInputFiles(items) > fileSizeLimitInBytes {
return nil, ErrInputFilesSizeExceeded
}
converted.Payloads = normalizeFilePayloads(items)
case string: // local file path
converted.LocalPaths = []string{items}
case []string:
converted.LocalPaths = items
default:
return nil, errors.New("files should be one of: string, []string, InputFile, []InputFile")
}
if len(converted.LocalPaths) > 0 && context.connection.isRemote {
converted.Streams = make([]*channel, 0)
for _, file := range converted.LocalPaths {
lastModifiedMs, err := getFileLastModifiedMs(file)
if err != nil {
return nil, fmt.Errorf("failed to get last modified time of %s: %w", file, err)
}
result, err := context.connection.WrapAPICall(func() (interface{}, error) {
return context.channel.Send("createTempFile", map[string]interface{}{
"name": filepath.Base(file),
"lastModifiedMs": lastModifiedMs,
})
}, true)
if err != nil {
return nil, err
}
stream := fromChannel(result).(*writableStream)
if err := stream.Copy(file); err != nil {
return nil, err
}
converted.Streams = append(converted.Streams, stream.channel)
}
converted.LocalPaths = nil
}
return converted, nil
}
func getFileLastModifiedMs(path string) (int64, error) {
info, err := os.Stat(path)
if err != nil {
return 0, err
}
if info.IsDir() {
return 0, fmt.Errorf("%s is a directory", path)
}
return info.ModTime().UnixMilli(), nil
}
func sizeOfInputFiles(files []InputFile) int {
size := 0
for _, file := range files {
size += len(file.Buffer)
}
return size
}
func normalizeFilePayloads(files []InputFile) []map[string]string {
out := make([]map[string]string, 0)
for _, file := range files {
out = append(out, map[string]string{
"name": file.Name,
"mimeType": file.MimeType,
"buffer": base64.StdEncoding.EncodeToString(file.Buffer),
})
}
return out
}