-
Notifications
You must be signed in to change notification settings - Fork 1
/
hijackedio.go
115 lines (94 loc) · 2.49 KB
/
hijackedio.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package execcmd
import (
"io"
"context"
"github.com/docker/docker/api/types"
"github.com/docker/docker/pkg/stdcopy"
"github.com/moby/term"
"github.com/sirupsen/logrus"
)
// HijackedIOStreamer handles copying input to and output from streams to the
// connection.
type HijackedIOStreamer struct {
InputStream io.ReadCloser
OutputStream io.Writer
ErrorStream io.Writer
Resp types.HijackedResponse
Tty bool
}
// Stream handles setting up the IO and then begins streaming stdin/stdout
// to/from the hijacked connection, blocking until it is either done reading
// output, the user inputs the detach key sequence when in TTY mode, or when
// the given context is cancelled.
func (h *HijackedIOStreamer) Stream(ctx context.Context) error {
outputDone := h.beginOutputStream()
inputDone, detached := h.beginInputStream()
select {
case err := <-outputDone:
return err
case <-inputDone:
// Input stream has closed.
if h.OutputStream != nil || h.ErrorStream != nil {
// Wait for output to complete streaming.
select {
case err := <-outputDone:
return err
case <-ctx.Done():
return ctx.Err()
}
}
return nil
case err := <-detached:
// Got a detach key sequence.
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (h *HijackedIOStreamer) beginOutputStream() <-chan error {
if h.OutputStream == nil && h.ErrorStream == nil {
// There is no need to copy output.
return nil
}
outputDone := make(chan error)
go func() {
var err error
// When TTY is ON, use regular copy
if h.OutputStream != nil {
if h.Tty {
_, err = io.Copy(h.OutputStream, h.Resp.Reader)
} else {
_, err = stdcopy.StdCopy(h.OutputStream, h.ErrorStream, h.Resp.Reader)
}
}
if err != nil {
logrus.Debugf("Error receiveStdout: %s", err)
}
outputDone <- err
}()
return outputDone
}
func (h *HijackedIOStreamer) beginInputStream() (doneC <-chan struct{}, detachedC <-chan error) {
inputDone := make(chan struct{})
detached := make(chan error)
go func() {
if h.InputStream != nil {
_, err := io.Copy(h.Resp.Conn, h.InputStream)
if _, ok := err.(term.EscapeError); ok {
detached <- err
return
}
if err != nil {
// This error will also occur on the receive
// side (from stdout) where it will be
// propogated back to the caller.
logrus.Debugf("Error sendStdin: %s", err)
}
}
if err := h.Resp.CloseWrite(); err != nil {
logrus.Debugf("Couldn't send EOF: %s", err)
}
close(inputDone)
}()
return inputDone, detached
}