/
encryptmegolm.go
279 lines (257 loc) · 11.1 KB
/
encryptmegolm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
// Copyright (c) 2020 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package crypto
import (
"encoding/json"
"errors"
"fmt"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
var (
AlreadyShared = errors.New("group session already shared")
NoGroupSession = errors.New("no group session created")
)
func getRelatesTo(content interface{}) *event.RelatesTo {
contentStruct, ok := content.(*event.Content)
if ok {
content = contentStruct.Parsed
}
relatable, ok := content.(event.Relatable)
if ok {
return relatable.OptionalGetRelatesTo()
}
return nil
}
type rawMegolmEvent struct {
RoomID id.RoomID `json:"room_id"`
Type event.Type `json:"type"`
Content interface{} `json:"content"`
}
// IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession
func IsShareError(err error) bool {
return err == SessionExpired || err == SessionNotShared || err == NoGroupSession
}
// EncryptMegolmEvent encrypts data with the m.megolm.v1.aes-sha2 algorithm.
//
// If you use the event.Content struct, make sure you pass a pointer to the struct,
// as JSON serialization will not work correctly otherwise.
func (mach *OlmMachine) EncryptMegolmEvent(roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
mach.Log.Trace("Encrypting event of type %s for %s", evtType.Type, roomID)
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
if err != nil {
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
} else if session == nil {
return nil, NoGroupSession
}
plaintext, err := json.Marshal(&rawMegolmEvent{
RoomID: roomID,
Type: evtType,
Content: content,
})
if err != nil {
return nil, err
}
ciphertext, err := session.Encrypt(plaintext)
if err != nil {
return nil, err
}
err = mach.CryptoStore.UpdateOutboundGroupSession(session)
if err != nil {
mach.Log.Warn("Failed to update megolm session in crypto store after encrypting: %v", err)
}
return &event.EncryptedEventContent{
Algorithm: id.AlgorithmMegolmV1,
SenderKey: mach.account.IdentityKey(),
DeviceID: mach.Client.DeviceID,
SessionID: session.ID(),
MegolmCiphertext: ciphertext,
RelatesTo: getRelatesTo(content),
}, nil
}
func (mach *OlmMachine) newOutboundGroupSession(roomID id.RoomID) *OutboundGroupSession {
session := NewOutboundGroupSession(roomID, mach.StateStore.GetEncryptionEvent(roomID))
signingKey, idKey := mach.account.Keys()
mach.createGroupSession(idKey, signingKey, roomID, session.ID(), session.Internal.Key())
return session
}
type deviceSessionWrapper struct {
session *OlmSession
identity *DeviceIdentity
}
// ShareGroupSession shares a group session for a specific room with all the devices of the given user list.
//
// For devices with TrustStateBlacklisted, a m.room_key.withheld event with code=m.blacklisted is sent.
// If AllowUnverifiedDevices is false, a similar event with code=m.unverified is sent to devices with TrustStateUnset
func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) error {
mach.Log.Debug("Sharing group session for room %s to %v", roomID, users)
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
if err != nil {
return fmt.Errorf("failed to get previous outbound group session: %w", err)
} else if session != nil && session.Shared && !session.Expired() {
return AlreadyShared
}
if session == nil || session.Expired() {
session = mach.newOutboundGroupSession(roomID)
}
withheldCount := 0
toDeviceWithheld := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
olmSessions := make(map[id.UserID]map[id.DeviceID]deviceSessionWrapper)
missingSessions := make(map[id.UserID]map[id.DeviceID]*DeviceIdentity)
missingUserSessions := make(map[id.DeviceID]*DeviceIdentity)
var fetchKeys []id.UserID
for _, userID := range users {
devices, err := mach.CryptoStore.GetDevices(userID)
if err != nil {
mach.Log.Error("Failed to get devices of %s", userID)
} else if devices == nil {
mach.Log.Trace("GetDevices returned nil for %s, will fetch keys and retry", userID)
fetchKeys = append(fetchKeys, userID)
} else if len(devices) == 0 {
mach.Log.Trace("%s has no devices, skipping", userID)
} else {
mach.Log.Trace("Trying to find olm sessions to encrypt %s for %s", session.ID(), userID)
toDeviceWithheld.Messages[userID] = make(map[id.DeviceID]*event.Content)
olmSessions[userID] = make(map[id.DeviceID]deviceSessionWrapper)
mach.findOlmSessionsForUser(session, userID, devices, olmSessions[userID], toDeviceWithheld.Messages[userID], missingUserSessions)
mach.Log.Trace("Found %d sessions, withholding from %d sessions and missing %d sessions to encrypt %s for for %s", len(olmSessions[userID]), len(toDeviceWithheld.Messages[userID]), len(missingUserSessions), session.ID(), userID)
withheldCount += len(toDeviceWithheld.Messages[userID])
if len(missingUserSessions) > 0 {
missingSessions[userID] = missingUserSessions
missingUserSessions = make(map[id.DeviceID]*DeviceIdentity)
}
if len(toDeviceWithheld.Messages[userID]) == 0 {
delete(toDeviceWithheld.Messages, userID)
}
}
}
if len(fetchKeys) > 0 {
mach.Log.Trace("Fetching missing keys for %v", fetchKeys)
for userID, devices := range mach.fetchKeys(fetchKeys, "", true) {
mach.Log.Trace("Got %d device keys for %s", len(devices), userID)
missingSessions[userID] = devices
}
}
if len(missingSessions) > 0 {
mach.Log.Trace("Creating missing outbound sessions")
err = mach.createOutboundSessions(missingSessions)
if err != nil {
mach.Log.Error("Failed to create missing outbound sessions: %v", err)
}
}
for userID, devices := range missingSessions {
if len(devices) == 0 {
// No missing sessions
continue
}
output, ok := olmSessions[userID]
if !ok {
output = make(map[id.DeviceID]deviceSessionWrapper)
olmSessions[userID] = output
}
withheld, ok := toDeviceWithheld.Messages[userID]
if !ok {
withheld = make(map[id.DeviceID]*event.Content)
toDeviceWithheld.Messages[userID] = withheld
}
mach.Log.Trace("Trying to find olm sessions to encrypt %s for %s (post-fetch retry)", session.ID(), userID)
mach.findOlmSessionsForUser(session, userID, devices, output, withheld, nil)
mach.Log.Trace("Found %d sessions and withholding from %d sessions to encrypt %s for for %s (post-fetch retry)", len(output), len(withheld), session.ID(), userID)
withheldCount += len(toDeviceWithheld.Messages[userID])
if len(toDeviceWithheld.Messages[userID]) == 0 {
delete(toDeviceWithheld.Messages, userID)
}
}
err = mach.encryptAndSendGroupSession(session, olmSessions)
if err != nil {
return fmt.Errorf("failed to share group session: %w", err)
}
if len(toDeviceWithheld.Messages) > 0 {
mach.Log.Trace("Sending to-device messages to %d devices of %d users to report withheld keys in %s", withheldCount, len(toDeviceWithheld.Messages), roomID)
// TODO remove the next 4 lines once clients support m.room_key.withheld
_, err = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld)
if err != nil {
mach.Log.Warn("Failed to report withheld keys in %s (legacy event type): %v", roomID, err)
}
_, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld)
if err != nil {
mach.Log.Warn("Failed to report withheld keys in %s: %v", roomID, err)
}
}
mach.Log.Debug("Group session %s for %s successfully shared", session.ID(), roomID)
session.Shared = true
return mach.CryptoStore.AddOutboundGroupSession(session)
}
func (mach *OlmMachine) encryptAndSendGroupSession(session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
mach.olmLock.Lock()
defer mach.olmLock.Unlock()
mach.Log.Trace("Encrypting group session %s for all found devices", session.ID())
deviceCount := 0
toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
for userID, sessions := range olmSessions {
if len(sessions) == 0 {
continue
}
output := make(map[id.DeviceID]*event.Content)
toDevice.Messages[userID] = output
for deviceID, device := range sessions {
mach.Log.Trace("Encrypting group session %s for %s of %s", session.ID(), deviceID, userID)
content := mach.encryptOlmEvent(device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
output[deviceID] = &event.Content{Parsed: content}
deviceCount++
mach.Log.Trace("Encrypted group session %s for %s of %s", session.ID(), deviceID, userID)
}
}
mach.Log.Trace("Sending to-device to %d devices of %d users to share group session %s", deviceCount, len(toDevice.Messages), session.ID())
_, err := mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice)
return err
}
func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*DeviceIdentity, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*DeviceIdentity) {
for deviceID, device := range devices {
userKey := UserDevice{UserID: userID, DeviceID: deviceID}
if state := session.Users[userKey]; state != OGSNotShared {
continue
} else if userID == mach.Client.UserID && deviceID == mach.Client.DeviceID {
session.Users[userKey] = OGSIgnored
} else if device.Trust == TrustStateBlacklisted {
mach.Log.Debug("Not encrypting group session %s for %s of %s: device is blacklisted", session.ID(), deviceID, userID)
withheld[deviceID] = &event.Content{Parsed: &event.RoomKeyWithheldEventContent{
RoomID: session.RoomID,
Algorithm: id.AlgorithmMegolmV1,
SessionID: session.ID(),
SenderKey: mach.account.IdentityKey(),
Code: event.RoomKeyWithheldBlacklisted,
Reason: "Device is blacklisted",
}}
session.Users[userKey] = OGSIgnored
} else if !mach.AllowUnverifiedDevices && !mach.IsDeviceTrusted(device) {
mach.Log.Debug("Not encrypting group session %s for %s of %s: device is not verified", session.ID(), deviceID, userID)
withheld[deviceID] = &event.Content{Parsed: &event.RoomKeyWithheldEventContent{
RoomID: session.RoomID,
Algorithm: id.AlgorithmMegolmV1,
SessionID: session.ID(),
SenderKey: mach.account.IdentityKey(),
Code: event.RoomKeyWithheldUnverified,
Reason: "This device does not encrypt messages for unverified devices",
}}
session.Users[userKey] = OGSIgnored
} else if deviceSession, err := mach.CryptoStore.GetLatestSession(device.IdentityKey); err != nil {
mach.Log.Error("Failed to get session for %s of %s: %v", deviceID, userID, err)
} else if deviceSession == nil {
mach.Log.Warn("Didn't find a session for %s of %s", deviceID, userID)
if missingOutput != nil {
missingOutput[deviceID] = device
}
} else {
output[deviceID] = deviceSessionWrapper{
session: deviceSession,
identity: device,
}
session.Users[userKey] = OGSAlreadyShared
}
}
}