-
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
RemoteSocket.scala
264 lines (235 loc) · 10.4 KB
/
RemoteSocket.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
package lila.socket
import akka.actor.{ CoordinatedShutdown, Scheduler }
import chess.{ Centis, Color }
import io.lettuce.core.*
import io.lettuce.core.pubsub.StatefulRedisPubSubConnection as PubSub
import play.api.libs.json.*
import java.util.concurrent.atomic.AtomicReference
import lila.common.{ Bus, Lilakka }
import lila.core.misc.streamer.StreamersOnline
import lila.core.relation.{ Follow, UnFollow }
import lila.core.round.Mlat
import lila.core.security.CloseAccount
import lila.core.socket.remote.*
import lila.core.socket.{ SocketRequester as _, * }
final class RemoteSocket(
redisClient: RedisClient,
shutdown: CoordinatedShutdown,
requester: SocketRequester,
userLag: UserLagCache
)(using Executor, Scheduler):
import RemoteSocket.*, Protocol.*
private var stopping = false
private type UserIds = Set[UserId]
val onlineUserIds: AtomicReference[Set[UserId]] = AtomicReference(initialUserIds)
def kit = SocketKit(subscribe, channel => makeSender(channel, 1).send, baseHandler)
def parallelKit = ParallelSocketKit(subscribeRoundRobin, makeSender, baseHandler)
val baseHandler: SocketHandler =
case In.ConnectUser(userId) =>
onlineUserIds.getAndUpdate(_ + userId)
case In.ConnectUsers(userIds) =>
onlineUserIds.getAndUpdate(_ ++ userIds)
case In.DisconnectUsers(userIds) =>
onlineUserIds.getAndUpdate(_ -- userIds)
case In.NotifiedBatch(userIds) =>
Bus.publish(lila.core.notify.NotifiedBatch(userIds), "notify")
case In.Lags(lags) =>
lags.foreach: (userId, centis) =>
userLag.put(userId, centis)
// this shouldn't be necessary... ensure that users are known to be online
onlineUserIds.getAndUpdate((x: UserIds) => x ++ lags.keys)
case In.TellSri(sri, userId, typ, msg) =>
Bus.publish(TellSriIn(sri.value, userId, msg), s"remoteSocketIn:$typ")
case In.TellUser(userId, typ, msg) =>
Bus.publish(TellUserIn(userId, msg), s"remoteSocketIn:$typ")
case In.ReqResponse(reqId, response) => requester.onResponse(reqId, response)
case In.Ping(id) => send(Out.pong(id))
case In.WsBoot =>
logger.warn("Remote socket boot")
onlineUserIds.set(initialUserIds)
Bus.subscribeFun(
"socketUsers",
"announce",
"mlat",
"sendToFlag",
"remoteSocketOut",
"accountClose",
"shadowban",
"impersonate",
"relation",
"onlineApiUsers"
) {
case SendTos(userIds, payload) =>
val connectedUsers = userIds.intersect(onlineUserIds.get)
if connectedUsers.nonEmpty then send(Out.tellUsers(connectedUsers, payload))
case SendTo(userId, payload) =>
if onlineUserIds.get.contains(userId) then send(Out.tellUser(userId, payload))
case SendToOnlineUser(userId, makePayload) =>
if onlineUserIds.get.contains(userId) then
makePayload.value.foreach: payload =>
send(Out.tellUser(userId, payload))
case Announce(_, _, json) =>
send(Out.tellAll(Json.obj("t" -> "announce", "d" -> json)))
case Mlat(millis) =>
send(Out.mlat(millis))
case SendToFlag(flag, payload) =>
send(Out.tellFlag(flag, payload))
case TellSriOut(sri, payload) =>
send(Out.tellSri(Sri(sri), payload))
case TellSrisOut(sris, payload) =>
send(Out.tellSris(Sri.from(sris), payload))
case CloseAccount(userId) =>
send(Out.disconnectUser(userId))
case lila.core.mod.Shadowban(userId, v) =>
send(Out.setTroll(userId, v))
case lila.core.mod.Impersonate(userId, modId) =>
send(Out.impersonate(userId, modId))
case ApiUserIsOnline(userId, value) =>
send(Out.apiUserOnline(userId, value))
if value then onlineUserIds.getAndUpdate(_ + userId)
case Follow(u1, u2) => send(Out.follow(u1, u2))
case UnFollow(u1, u2) => send(Out.unfollow(u1, u2))
}
Bus.sub[StreamersOnline]:
case StreamersOnline(streamers) =>
send(Out.streamersOnline(streamers))
final class StoppableSender(val conn: PubSub[String, String], channel: Channel) extends Sender:
def apply(msg: String) = if !stopping then super.sendTo(channel, msg)
def sticky(_id: String, msg: String) = apply(msg)
final class RoundRobinSender(val conn: PubSub[String, String], channel: Channel, parallelism: Int)
extends Sender:
def apply(msg: String): Unit = publish(msg.hashCode.abs % parallelism, msg)
// use the ID to select the channel, not the entire message
def sticky(id: String, msg: String): Unit = publish(id.hashCode.abs % parallelism, msg)
private def publish(subChannel: Int, msg: String) =
if !stopping then conn.async.publish(s"$channel:$subChannel", msg)
def makeSender(channel: Channel, parallelism: Int = 1): Sender =
if parallelism > 1 then RoundRobinSender(redisClient.connectPubSub(), channel, parallelism)
else StoppableSender(redisClient.connectPubSub(), channel)
private val send: SocketSend = makeSender("site-out").send
def subscribe(channel: Channel, reader: In.Reader)(handler: SocketHandler): Funit =
val fullReader = reader.orElse(Protocol.In.baseReader)
connectAndSubscribe(channel): str =>
val parts = str.split(" ", 2)
parts.headOption
.map:
new lila.core.socket.protocol.RawMsg(_, ~parts.lift(1))
.match
case None => logger.error(s"Invalid $channel $str")
case Some(raw) =>
fullReader
.applyOrElse(
raw,
raw =>
logger.info(s"Unread $channel $raw")
none
)
.collect(handler) match
case Some(_) => // processed
case None => logger.info(s"Unhandled $channel $str")
def subscribeRoundRobin(channel: Channel, reader: In.Reader, parallelism: Int)(
handler: SocketHandler
): Funit =
// subscribe to main channel
subscribe(channel, reader)(handler) >> {
// and subscribe to subchannels
(0 to parallelism)
.parallelVoid(index => subscribe(s"$channel:$index", reader)(handler))
}
private def connectAndSubscribe(channel: Channel)(f: String => Unit): Funit =
val conn = redisClient.connectPubSub()
conn.addListener(
new pubsub.RedisPubSubAdapter[String, String]:
override def message(_channel: String, message: String): Unit = f(message)
)
val subPromise = Promise[Unit]()
conn.async
.subscribe(channel)
.thenRun: () =>
subPromise.success(())
subPromise.future
Lilakka.shutdown(shutdown, _.PhaseBeforeServiceUnbind, "Telling lila-ws we're stopping"): () =>
requester[Unit](
id => send(Protocol.Out.stop(id)),
res => logger.info(s"lila-ws says: $res")
).withTimeout(1 second, "Lilakka.shutdown")
.addFailureEffect(e => logger.error("lila-ws stop", e))
.recoverDefault
Lilakka.shutdown(shutdown, _.PhaseServiceUnbind, "Stopping the socket redis pool"): () =>
Future:
stopping = true
redisClient.shutdown()
object RemoteSocket:
trait Sender extends ParallelSocketSend:
protected val conn: PubSub[String, String]
protected def sendTo(channel: Channel, msg: String) = conn.async.publish(channel, msg)
object Protocol:
trait In
object In:
export lila.core.socket.protocol.In.*
import lila.core.socket.protocol.RawMsg
val baseReader: Reader =
case RawMsg("connect/user", raw) => ConnectUser(UserId(raw.args)).some
case RawMsg("connect/users", raw) => ConnectUsers(UserId.from(commas(raw.args))).some
case RawMsg("disconnect/users", raw) => DisconnectUsers(UserId.from(commas(raw.args))).some
case RawMsg("connect/sris", raw) =>
ConnectSris {
commas(raw.args).map(_.split(' ')).map { s =>
(Sri(s(0)), UserId.from(s.lift(1)))
}
}.some
case RawMsg("disconnect/sris", raw) => DisconnectSris(commas(raw.args).map { Sri(_) }).some
case RawMsg("notified/batch", raw) => NotifiedBatch(UserId.from(commas(raw.args))).some
case RawMsg("lag", raw) =>
raw.all.pipe { s =>
Centis.from(s.lift(1).flatMap(_.toIntOption)).map { Lag(UserId(s(0)), _) }
}
case RawMsg("lags", raw) =>
Lags(commas(raw.args).flatMap {
_.split(':') match
case Array(user, l) =>
l.toIntOption.map { lag =>
UserId(user) -> Centis(lag)
}
case _ => None
}.toMap).some
case RawMsg("tell/sri", raw) => raw.get(3)(lila.core.socket.protocol.In.tellSriMapper)
case RawMsg("tell/user", raw) =>
raw.get(2) { case Array(user, payload) =>
for
obj <- Json.parse(payload).asOpt[JsObject]
typ <- obj.str("t")
yield TellUser(UserId(user), typ, obj)
}
case RawMsg("req/response", raw) =>
raw.get(2) { case Array(reqId, response) =>
reqId.toIntOption.map { ReqResponse(_, response) }
}
case RawMsg("ping", raw) => Ping(raw.args).some
case RawMsg("boot", raw) => WsBoot.some
object Out:
export lila.core.socket.protocol.Out.*
def tellUser(userId: UserId, payload: JsObject) =
s"tell/users $userId ${Json.stringify(payload)}"
def tellUsers(userIds: Set[UserId], payload: JsObject) =
s"tell/users ${commas(userIds)} ${Json.stringify(payload)}"
def tellAll(payload: JsObject) =
s"tell/all ${Json.stringify(payload)}"
def tellFlag(flag: String, payload: JsObject) =
s"tell/flag $flag ${Json.stringify(payload)}"
def mlat(millis: Int) =
s"mlat ${millis}"
def disconnectUser(userId: UserId) =
s"disconnect/user $userId"
def setTroll(userId: UserId, v: Boolean) =
s"mod/troll/set $userId ${boolean(v)}"
def impersonate(userId: UserId, by: Option[UserId]) =
s"mod/impersonate $userId ${optional(by.map(_.value))}"
def follow(u1: UserId, u2: UserId) = s"rel/follow $u1 $u2"
def unfollow(u1: UserId, u2: UserId) = s"rel/unfollow $u1 $u2"
def apiUserOnline(u: UserId, v: Boolean) = s"api/online $u ${boolean(v)}"
def streamersOnline(streamers: Iterable[(UserId, String)]) =
s"streamers/online ${commas(streamers.map { (u, s) => s"$u:$s" })}"
def respond(reqId: Int, payload: JsObject) = s"req/response $reqId ${Json.stringify(payload)}"
def stop(reqId: Int) = s"lila/stop $reqId"
val initialUserIds = Set(UserId("lichess"))