diff --git a/src/pubsub.c b/src/pubsub.c index d1fffa20a72..bbcfc1f43b5 100644 --- a/src/pubsub.c +++ b/src/pubsub.c @@ -125,6 +125,8 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify) { /* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the client was already subscribed to that pattern. */ int pubsubSubscribePattern(client *c, robj *pattern) { + dictEntry *de; + list *clients; int retval = 0; if (listSearchKey(c->pubsub_patterns,pattern) == NULL) { @@ -136,6 +138,16 @@ int pubsubSubscribePattern(client *c, robj *pattern) { pat->pattern = getDecodedObject(pattern); pat->client = c; listAddNodeTail(server.pubsub_patterns,pat); + /* Add the client to the pattern -> list of clients hash table */ + de = dictFind(server.pubsub_patterns_dict,pattern); + if (de == NULL) { + clients = listCreate(); + dictAdd(server.pubsub_patterns_dict,pattern,clients); + incrRefCount(pattern); + } else { + clients = dictGetVal(de); + } + listAddNodeTail(clients,c); } /* Notify the client */ addReply(c,shared.mbulkhdr[3]); @@ -148,6 +160,8 @@ int pubsubSubscribePattern(client *c, robj *pattern) { /* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or * 0 if the client was not subscribed to the specified channel. */ int pubsubUnsubscribePattern(client *c, robj *pattern, int notify) { + dictEntry *de; + list *clients; listNode *ln; pubsubPattern pat; int retval = 0; @@ -160,6 +174,18 @@ int pubsubUnsubscribePattern(client *c, robj *pattern, int notify) { pat.pattern = pattern; ln = listSearchKey(server.pubsub_patterns,&pat); listDelNode(server.pubsub_patterns,ln); + /* Remove the client from the pattern -> clients list hash table */ + de = dictFind(server.pubsub_patterns_dict,pattern); + serverAssertWithInfo(c,NULL,de != NULL); + clients = dictGetVal(de); + ln = listSearchKey(clients,c); + serverAssertWithInfo(c,NULL,ln != NULL); + listDelNode(clients,ln); + if (listLength(clients) == 0) { + /* Free the list and associated hash entry at all if this was + * the latest client. */ + dictDelete(server.pubsub_patterns_dict,pattern); + } } /* Notify the client */ if (notify) { @@ -225,6 +251,7 @@ int pubsubUnsubscribeAllPatterns(client *c, int notify) { int pubsubPublishMessage(robj *channel, robj *message) { int receivers = 0; dictEntry *de; + dictIterator *di; listNode *ln; listIter li; @@ -247,25 +274,32 @@ int pubsubPublishMessage(robj *channel, robj *message) { } } /* Send to clients listening to matching channels */ - if (listLength(server.pubsub_patterns)) { - listRewind(server.pubsub_patterns,&li); + di = dictGetIterator(server.pubsub_patterns_dict); + if (di) { channel = getDecodedObject(channel); - while ((ln = listNext(&li)) != NULL) { - pubsubPattern *pat = ln->value; - - if (stringmatchlen((char*)pat->pattern->ptr, - sdslen(pat->pattern->ptr), + while((de = dictNext(di)) != NULL) { + robj *pattern = dictGetKey(de); + list *clients = dictGetVal(de); + if (!stringmatchlen((char*)pattern->ptr, + sdslen(pattern->ptr), (char*)channel->ptr, sdslen(channel->ptr),0)) { - addReply(pat->client,shared.mbulkhdr[4]); - addReply(pat->client,shared.pmessagebulk); - addReplyBulk(pat->client,pat->pattern); - addReplyBulk(pat->client,channel); - addReplyBulk(pat->client,message); + continue; + } + listRewind(clients,&li); + while ((ln = listNext(&li)) != NULL) { + client *c = listNodeValue(ln); + + addReply(c,shared.mbulkhdr[4]); + addReply(c,shared.pmessagebulk); + addReplyBulk(c,pattern); + addReplyBulk(c,channel); + addReplyBulk(c,message); receivers++; } } decrRefCount(channel); + dictReleaseIterator(di); } return receivers; } diff --git a/src/server.c b/src/server.c index 1a6f3038111..a071fceedc1 100644 --- a/src/server.c +++ b/src/server.c @@ -1900,6 +1900,7 @@ void initServer(void) { evictionPoolAlloc(); /* Initialize the LRU keys pool. */ server.pubsub_channels = dictCreate(&keylistDictType,NULL); server.pubsub_patterns = listCreate(); + server.pubsub_patterns_dict = dictCreate(&keylistDictType,NULL); listSetFreeMethod(server.pubsub_patterns,freePubsubPattern); listSetMatchMethod(server.pubsub_patterns,listMatchPubsubPattern); server.cronloops = 0; diff --git a/src/server.h b/src/server.h index 29919f5eedf..a56663495a3 100644 --- a/src/server.h +++ b/src/server.h @@ -1163,6 +1163,7 @@ struct redisServer { /* Pubsub */ dict *pubsub_channels; /* Map channels to list of subscribed clients */ list *pubsub_patterns; /* A list of pubsub_patterns */ + dict *pubsub_patterns_dict; /* A dict of pubsub_patterns */ int notify_keyspace_events; /* Events to propagate via Pub/Sub. This is an xor of NOTIFY_... flags. */ /* Cluster */