/
main.go
137 lines (114 loc) · 3.07 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package main
import (
"context"
"fmt"
"os"
"os/signal"
"github.com/jessevdk/go-flags"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"golang.org/x/sync/errgroup"
"net/http"
_ "net/http/pprof"
)
// Version of the binary, assigned during build.
var Version = "dev"
var logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
// Options contains the flag options
type Options struct {
Args struct {
Addr string `positional-arg-name:"Addr" description:"SSH host:port to connect with and relay"`
} `positional-args:"yes"`
Websocket string `long:"websocket" description:"Websocket host:port to bind to and supply a relay"`
Username string `long:"name" description:"Username to connect with" default:"ssh-chat-relay"`
Verbose []bool `long:"verbose" short:"v" description:"Show verbose logging."`
Version bool `long:"version" description:"Print version and exit."`
Pprof string `long:"pprof" description:"Bind pprof on http server on this addr. (Example: \"localhost:6060\")"`
}
func exit(code int, format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, format, args...)
os.Exit(code)
}
func main() {
options := Options{}
p, err := flags.NewParser(&options, flags.Default).ParseArgs(os.Args[1:])
if err != nil {
if p == nil {
fmt.Println(err)
}
return
}
if options.Version {
fmt.Println(Version)
os.Exit(0)
}
// Logging
switch len(options.Verbose) {
case 0:
logger = logger.Level(zerolog.WarnLevel)
case 1:
logger = logger.Level(zerolog.InfoLevel)
default:
logger = logger.Level(zerolog.DebugLevel)
}
if options.Pprof != "" {
go func() {
logger.Debug().Str("bind", options.Pprof).Msg("serving pprof http server")
fmt.Println(http.ListenAndServe(options.Pprof, nil))
}()
}
// Signals
ctx, abort := context.WithCancel(context.Background())
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt)
go func(abort context.CancelFunc) {
<-sigCh
logger.Warn().Msg("interrupt received, shutting down")
abort()
<-sigCh
logger.Error().Msg("second interrupt received, panicking")
panic("aborted")
}(abort)
if err := run(ctx, options); err != nil {
exit(1, "failed: %s\n", err)
}
}
func run(ctx context.Context, options Options) error {
conn := sshConnection{
Addr: options.Args.Addr,
Name: options.Username,
Term: "bot",
}
logger.Info().Str("addr", conn.Addr).Str("name", conn.Name).Msg("connecting")
if err := conn.Connect(ctx); err != nil {
return err
}
defer conn.Close()
DebugOnMessage := func(msg string) {
logger.Debug().Str("received", msg).Msg("msg")
}
src := ioSource{
RelayHandlers: RelayHandlers{
OnMessage: DebugOnMessage,
},
}
g, ctx := errgroup.WithContext(ctx)
if options.Websocket != "" {
ws := wsRelay{
Bind: options.Websocket,
Send: src.Send,
}
src.RelayHandlers.OnMessage = func(msg string) {
DebugOnMessage(msg)
ws.OnMessage(msg)
}
logger.Info().Str("addr", ws.Bind).Msg("serving websocket relay")
g.Go(func() error {
return ws.Serve(ctx)
})
}
g.Go(func() error {
return src.Serve(ctx, conn.Reader, conn.Writer)
})
return g.Wait()
}