/
tunnel.go
131 lines (115 loc) · 3.32 KB
/
tunnel.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// Copyright 2020 SEQSENSE, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tunnel
import (
"context"
"encoding/json"
"errors"
"sync"
"github.com/at-wat/mqtt-go"
"github.com/seqsense/aws-iot-device-sdk-go/v4"
"github.com/seqsense/aws-iot-device-sdk-go/v4/internal/ioterr"
)
// Tunnel is an interface of secure tunneling.
type Tunnel interface {
mqtt.Handler
OnError(func(error))
}
type tunnel struct {
mqtt.ServeMux
thingName string
mu sync.Mutex
onError func(err error)
dialerMap map[string]Dialer
opts *Options
}
// Options stores options of the tunnel.
type Options struct {
// EndpointHostFunc is a function returns secure proxy endpoint.
EndpointHostFunc func(region string) string
// TopicFunc is a function returns MQTT topic for the operation.
TopicFunc func(operation string) string
}
// Option is a type of functional options.
type Option func(*Options) error
// ErrInvalidClientMode indicate that the requested client mode is not valid for the tunnel.
var ErrInvalidClientMode = errors.New("invalid client mode")
func (t *tunnel) topic(operation string) string {
return "$aws/things/" + t.thingName + "/tunnels/" + operation
}
// New creates new secure tunneling proxy.
func New(ctx context.Context, cli awsiotdev.Device, dialer map[string]Dialer, opts ...Option) (Tunnel, error) {
t := &tunnel{
thingName: cli.ThingName(),
dialerMap: dialer,
}
t.opts = &Options{
TopicFunc: t.topic,
EndpointHostFunc: endpointHost,
}
for _, o := range opts {
if err := o(t.opts); err != nil {
return nil, ioterr.New(err, "applying options")
}
}
if err := t.ServeMux.Handle(t.opts.TopicFunc("notify"), mqtt.HandlerFunc(t.notify)); err != nil {
return nil, ioterr.New(err, "registering message handler")
}
err := cli.Subscribe(ctx,
mqtt.Subscription{Topic: t.opts.TopicFunc("notify"), QoS: mqtt.QoS1},
)
if err != nil {
return nil, ioterr.New(err, "subscribing tunnel topic")
}
return t, nil
}
func (t *tunnel) notify(msg *mqtt.Message) {
n := &Notification{}
if err := json.Unmarshal(msg.Payload, n); err != nil {
t.handleError(ioterr.New(err, "unmarshaling notification"))
return
}
if n.ClientMode != Destination {
t.handleError(ioterr.Newf(ErrInvalidClientMode, "requested %s", n.ClientMode))
return
}
for _, srv := range n.Services {
if d, ok := t.dialerMap[srv]; ok {
go func() {
err := ProxyDestination(
d,
t.opts.EndpointHostFunc(n.Region),
n.ClientAccessToken,
WithErrorHandler(ErrorHandlerFunc(t.handleError)),
)
if err != nil {
t.handleError(ioterr.New(err, "creating proxy destination"))
}
}()
}
}
}
func (t *tunnel) OnError(cb func(err error)) {
t.mu.Lock()
t.onError = cb
t.mu.Unlock()
}
func (t *tunnel) handleError(err error) {
t.mu.Lock()
cb := t.onError
t.mu.Unlock()
if cb != nil {
cb(err)
}
}