forked from hyperledger-archives/aries-framework-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
client.go
347 lines (278 loc) · 9.59 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
/*
Copyright SecureKey Technologies Inc. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
*/
package messaging
import (
"context"
"encoding/json"
"fmt"
"github.com/google/uuid"
"github.com/hyperledger/aries-framework-go/pkg/common/log"
"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service"
"github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher"
"github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdr"
"github.com/hyperledger/aries-framework-go/pkg/kms"
"github.com/hyperledger/aries-framework-go/pkg/storage"
"github.com/hyperledger/aries-framework-go/pkg/store/connection"
"github.com/hyperledger/aries-framework-go/pkg/vdr/fingerprint"
)
const (
// states.
stateNameCompleted = "completed"
// errors.
errMsgDestinationMissing = "missing message destination"
)
var logger = log.New("aries-framework/client/messaging")
// provider contains dependencies for the message client and is typically created by using aries.Context().
type provider interface {
VDRegistry() vdr.Registry
Messenger() service.Messenger
ProtocolStateStorageProvider() storage.Provider
StorageProvider() storage.Provider
KMS() kms.KeyManager
}
// MessageHandler maintains registered message services
// and it allows dynamic registration of message services.
type MessageHandler interface {
// Services returns list of available message services in this message handler
Services() []dispatcher.MessageService
// Register registers given message services to this message handler
Register(msgSvcs ...dispatcher.MessageService) error
// Unregister unregisters message service with given name from this message handler
Unregister(name string) error
}
// Notifier represents a notification dispatcher.
type Notifier interface {
Notify(topic string, message []byte) error
}
type sendMsgOpts struct {
// Connection ID of the message destination
// This parameter takes precedence over all the other destination parameters.
connectionID string
// DID of the destination.
// This parameter takes precedence over `ServiceEndpoint` destination parameter.
theirDID string
// Destination is service endpoint destination.
// This param can be used to send messages outside connection.
destination *service.Destination
// Message type of the response for the message sent.
// If provided then messenger will wait for the response of this type after sending message.
responseMsgType string
// context for await reply operation.
waitForResponseCtx context.Context
}
// SendMessageOpions is the options for choosing message destinations.
type SendMessageOpions func(opts *sendMsgOpts)
// SendByConnectionID option to choose message destination by connection ID.
func SendByConnectionID(connectionID string) SendMessageOpions {
return func(opts *sendMsgOpts) {
opts.connectionID = connectionID
}
}
// SendByTheirDID option to choose message destination by connection ID.
func SendByTheirDID(theirDID string) SendMessageOpions {
return func(opts *sendMsgOpts) {
opts.theirDID = theirDID
}
}
// SendByDestination option to set message destination.
func SendByDestination(destination *service.Destination) SendMessageOpions {
return func(opts *sendMsgOpts) {
opts.destination = destination
}
}
// WaitForResponse option to set message response type.
// Message reply will wait for the response of this message type and matching thread ID.
func WaitForResponse(ctx context.Context, responseType string) SendMessageOpions {
return func(opts *sendMsgOpts) {
opts.waitForResponseCtx = ctx
opts.responseMsgType = responseType
}
}
// messageDispatcher is message dispatch action which returns id of the message sent or error if it fails.
type messageDispatcher func() error
// Client enable access to messaging features.
type Client struct {
ctx provider
msgRegistrar MessageHandler
notifier Notifier
connectionLookup *connection.Lookup
}
// New return new instance of message client.
func New(ctx provider, registrar MessageHandler, notifier Notifier) (*Client, error) {
connectionLookup, err := connection.NewLookup(ctx)
if err != nil {
return nil, fmt.Errorf("failed to initialize connection lookup : %w", err)
}
c := &Client{
ctx: ctx,
msgRegistrar: registrar,
connectionLookup: connectionLookup,
notifier: notifier,
}
return c, nil
}
// RegisterService registers new message service to message handler registrar.
func (c *Client) RegisterService(name, msgType string, purpose ...string) error {
return c.msgRegistrar.Register(newMessageService(name, msgType, purpose, c.notifier))
}
// UnregisterService unregisters given message service handler registrar.
func (c *Client) UnregisterService(name string) error {
return c.msgRegistrar.Unregister(name)
}
// Services returns list of registered service names.
func (c *Client) Services() []string {
names := []string{}
for _, svc := range c.msgRegistrar.Services() {
names = append(names, svc.Name())
}
return names
}
// Send sends new message based on destination options provided.
func (c *Client) Send(msg json.RawMessage, opts ...SendMessageOpions) (json.RawMessage, error) {
sendOpts := &sendMsgOpts{}
for _, opt := range opts {
opt(sendOpts)
}
var action messageDispatcher
didCommMsg, err := prepareMessage(msg)
if err != nil {
return nil, err
}
switch {
case sendOpts.connectionID != "":
action, err = c.sendToConnection(didCommMsg, sendOpts.connectionID)
case sendOpts.theirDID != "":
action, err = c.sendToTheirDID(didCommMsg, sendOpts.theirDID)
case sendOpts.destination != nil:
action, err = c.sendToDestination(didCommMsg, sendOpts.destination)
default:
return nil, fmt.Errorf(errMsgDestinationMissing)
}
if err != nil {
return nil, err
}
return c.sendAndWaitForReply(sendOpts.waitForResponseCtx, action, didCommMsg.ID(), sendOpts.responseMsgType)
}
// Reply sends reply to existing message.
func (c *Client) Reply(ctx context.Context, msg json.RawMessage, msgID string, startNewThread bool,
waitForResponse string) (json.RawMessage, error) {
var action messageDispatcher
didCommMsg, err := prepareMessage(msg)
if err != nil {
return nil, err
}
if startNewThread {
action = func() error {
return c.ctx.Messenger().ReplyToNested(didCommMsg, &service.NestedReplyOpts{MsgID: msgID})
}
return c.sendAndWaitForReply(ctx, action, didCommMsg.ID(), waitForResponse)
}
action = func() error {
return c.ctx.Messenger().ReplyTo(msgID, didCommMsg) // nolint: staticcheck
}
return c.sendAndWaitForReply(ctx, action, "", waitForResponse)
}
func (c *Client) sendToConnection(msg service.DIDCommMsgMap, connectionID string) (messageDispatcher, error) {
conn, err := c.connectionLookup.GetConnectionRecord(connectionID)
if err != nil {
return nil, err
}
return func() error {
return c.ctx.Messenger().Send(msg, conn.MyDID, conn.TheirDID)
}, nil
}
func (c *Client) sendToTheirDID(msg service.DIDCommMsgMap, theirDID string) (messageDispatcher, error) {
records, err := c.connectionLookup.QueryConnectionRecords()
if err != nil {
return nil, err
}
var conn *connection.Record
for _, record := range records {
if record.State == stateNameCompleted && record.TheirDID == theirDID {
conn = record
break
}
}
if conn != nil {
return func() error {
return c.ctx.Messenger().Send(msg, conn.MyDID, conn.TheirDID)
}, nil
}
dest, err := service.GetDestination(theirDID, c.ctx.VDRegistry())
if err != nil {
return nil, err
}
return c.sendToDestination(msg, dest)
}
func (c *Client) sendToDestination(msg service.DIDCommMsgMap, dest *service.Destination) (messageDispatcher, error) {
_, sigPubKey, err := c.ctx.KMS().CreateAndExportPubKeyBytes(kms.ED25519Type)
if err != nil {
return nil, err
}
didKey, _ := fingerprint.CreateDIDKey(sigPubKey)
return func() error {
return c.ctx.Messenger().SendToDestination(msg, didKey, dest)
}, nil
}
func (c *Client) sendAndWaitForReply(ctx context.Context, action messageDispatcher, thID string,
replyType string) (json.RawMessage, error) {
var notificationCh chan NotificationPayload
if replyType != "" {
topic := uuid.New().String()
notificationCh = make(chan NotificationPayload)
err := c.msgRegistrar.Register(newMessageService(topic, replyType, nil,
NewNotifier(notificationCh, func(topic string, msgBytes []byte) bool {
var message struct {
Message service.DIDCommMsgMap `json:"message"`
}
err := json.Unmarshal(msgBytes, &message)
if err != nil {
logger.Debugf("failed to unmarshal incoming message reply: %s", err)
return false
}
msgThID, err := message.Message.ThreadID()
if err != nil {
logger.Debugf("failed to read incoming message reply thread ID: %s", err)
return false
}
return thID == "" || thID == msgThID
})))
if err != nil {
return nil, err
}
defer func() {
e := c.msgRegistrar.Unregister(topic)
if e != nil {
logger.Warnf("Failed to unregister wait for reply notifier: %w", e)
}
}()
}
err := action()
if err != nil {
return nil, err
}
if notificationCh != nil {
return waitForResponse(ctx, notificationCh)
}
return json.RawMessage{}, nil
}
func waitForResponse(ctx context.Context, notificationCh chan NotificationPayload) (json.RawMessage, error) {
select {
case payload := <-notificationCh:
return json.RawMessage(payload.Raw), nil
case <-ctx.Done():
return nil, fmt.Errorf("failed to get reply, context deadline exceeded")
}
}
func prepareMessage(msg json.RawMessage) (service.DIDCommMsgMap, error) {
didCommMsg, err := service.ParseDIDCommMsgMap(msg)
if err != nil {
return nil, err
}
if didCommMsg.ID() == "" {
err = didCommMsg.SetID(uuid.New().String())
}
return didCommMsg, err
}