-
Notifications
You must be signed in to change notification settings - Fork 0
/
ws.go
153 lines (134 loc) · 4.08 KB
/
ws.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
// SPDX-FileCopyrightText: 2023 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0
package main
import (
"context"
"errors"
"net/url"
"time"
"github.com/xmidt-org/wrp-go/v3"
"github.com/xmidt-org/xmidt-agent/internal/credentials"
"github.com/xmidt-org/xmidt-agent/internal/jwtxt"
"github.com/xmidt-org/xmidt-agent/internal/metadata"
"github.com/xmidt-org/xmidt-agent/internal/websocket"
"github.com/xmidt-org/xmidt-agent/internal/websocket/event"
"github.com/xmidt-org/xmidt-agent/internal/wrpkit"
"go.uber.org/fx"
"go.uber.org/zap"
)
var (
ErrWebsocketConfig = errors.New("websocket configuration error")
)
type wsIn struct {
fx.In
Identity Identity
Logger *zap.Logger
CLI *CLI
JWTXT *jwtxt.Instructions
Cred *credentials.Credentials
Metadata *metadata.MetadataProvider
InterfaceUsed *metadata.InterfaceUsedProvider
Websocket Websocket
}
type wsOut struct {
fx.Out
WSHandler wrpkit.Handler
WS *websocket.Websocket
Egress websocket.Egress
// cancels
Cancels []func() `group:"cancels,flatten"`
}
func provideWS(in wsIn) (wsOut, error) {
if in.Websocket.Disable {
return wsOut{}, nil
}
var fetchURLFunc func(context.Context) (string, error)
// JWTXT is not required
// fetchURL() will use in.Websocket.BackUpURL if in.JWTXT is nil
if in.JWTXT != nil {
fetchURLFunc = in.JWTXT.Endpoint
}
var opts []websocket.Option
// Allow operations where no credentials are desired (in.Cred will be nil).
if in.Cred != nil {
opts = append(opts, websocket.CredentialsDecorator(in.Cred.Decorate))
}
// Configuration options
opts = append(opts,
websocket.DeviceID(in.Identity.DeviceID),
websocket.FetchURLTimeout(in.Websocket.FetchURLTimeout),
websocket.FetchURL(
fetchURL(in.Websocket.URLPath, in.Websocket.BackUpURL,
fetchURLFunc)),
websocket.InactivityTimeout(in.Websocket.InactivityTimeout),
websocket.PingWriteTimeout(in.Websocket.PingWriteTimeout),
websocket.SendTimeout(in.Websocket.SendTimeout),
websocket.KeepAliveInterval(in.Websocket.KeepAliveInterval),
websocket.HTTPClientWithForceSets(in.Websocket.HTTPClient),
websocket.MaxMessageBytes(in.Websocket.MaxMessageBytes),
websocket.ConveyDecorator(in.Metadata.Decorate),
websocket.AdditionalHeaders(in.Websocket.AdditionalHeaders),
websocket.NowFunc(time.Now),
websocket.WithIPv6(!in.Websocket.DisableV6),
websocket.WithIPv4(!in.Websocket.DisableV4),
websocket.Once(in.Websocket.Once),
websocket.RetryPolicy(in.Websocket.RetryPolicy),
websocket.InterfaceUsedProvider(in.InterfaceUsed),
)
// Listener options
var (
msg, con, discon, heartbeat event.CancelFunc
cancels []func()
)
if in.CLI.Dev {
logger := in.Logger.Named("websocket")
opts = append(opts,
websocket.AddMessageListener(
event.MsgListenerFunc(
func(m wrp.Message) {
logger.Info("message listener", zap.Any("msg", m))
}), &msg),
websocket.AddConnectListener(
event.ConnectListenerFunc(
func(e event.Connect) {
logger.Info("connect listener", zap.Any("event", e))
}), &con),
websocket.AddDisconnectListener(
event.DisconnectListenerFunc(
func(e event.Disconnect) {
logger.Info("disconnect listener", zap.Any("event", e))
}), &discon),
websocket.AddHeartbeatListener(
event.HeartbeatListenerFunc(func(e event.Heartbeat) {
logger.Info("heartbeat listener", zap.Any("event", e))
}), &heartbeat),
)
}
ws, err := websocket.New(opts...)
if err != nil {
err = errors.Join(ErrWebsocketConfig, err)
}
if in.CLI.Dev {
cancels = append(cancels, msg, con, discon, heartbeat)
}
return wsOut{
WS: ws,
Egress: ws,
Cancels: cancels,
}, err
}
func fetchURL(path, backUpURL string, f func(context.Context) (string, error)) func(context.Context) (string, error) {
return func(ctx context.Context) (string, error) {
if f == nil {
return url.JoinPath(backUpURL, path)
}
baseURL, err := f(ctx)
if err != nil {
if backUpURL != "" {
return url.JoinPath(backUpURL, path)
}
return "", err
}
return url.JoinPath(baseURL, path)
}
}