# Filtering JetStream Events

In this notebook, we'll filter events from the Redis Stream that we created in the previous notebook. We'll use a combination of techniques to filter the events:

1. Deduplication using Redis Bloom Filter to avoid processing the same event multiple times
2. Content-based filtering using a machine learning model to identify software-related posts
3. Storing filtered events in Redis for further processing

Redis Bloom Filter is a probabilistic data structure that allows us to check if an element is in a set. It's very memory efficient and has a constant time complexity for both insertion and lookup operations. The trade-off is that it can have false positives, but the probability of false positives can be controlled by the size of the filter.

Machine learning models can be used to classify text into different categories. In this notebook, we'll use a pre-trained zero-shot classification model to classify posts as software-related or not.

## Consuming from Redis Streams

### Model Redis Streams Event
In this section, we'll define a data class to represent the events stored in the Redis Stream. This model will be used to deserialize the events from the stream.

In [1]:
@file:DependsOn("redis.clients:jedis:6.0.0")

In [2]:
import redis.clients.jedis.resps.StreamEntry

data class Event(
    val did: String,
    val rkey: String,
    val text: String,
    val timeUs: String,
    val operation: String,
    val uri: String,
    val parentUri: String,
    val rootUri: String,
    val langs: List<String>,
) {
    companion object {
        fun fromMap(entry: StreamEntry): Event {
            val fields = entry.fields
            return Event(
                did = fields["did"] ?: "",
                rkey = fields["rkey"] ?: "",
                text = fields["text"] ?: "",
                timeUs = fields["timeUs"] ?: "",
                operation = fields["operation"] ?: "",
                uri = fields["uri"] ?: "",
                parentUri = fields["parentUri"] ?: "",
                rootUri = fields["rootUri"] ?: "",
                langs = fields["langs"]?.replace("[", "")?.replace("]", "")?.split(", ") ?: emptyList()
            )
        }
    }
}

### Creating a Redis Client
Create a Jedis client to connect to Redis. This is a reusable client that can be used to interact with Redis Streams.

In [3]:
import redis.clients.jedis.JedisPooled

val jedis = JedisPooled()

### Creating a Consumer Group
Create a consumer group to read from the Redis Stream. A consumer group allows multiple consumers to read from the same stream without duplicating the work. Each consumer in the group will receive a different subset of the messages.

A consumer group can be created in Redis with the XGROUP CREATE command:

`XGROUP CREATE streamName groupName id [MKSTREAM]`

To create a consumer group in this notebook, we will encapsulate the command in a function. The function will take the stream name and the group name as parameters.

In [4]:
import redis.clients.jedis.StreamEntryID

fun createConsumerGroup(streamName: String, consumerGroupName: String) {
    try {
        jedis.xgroupCreate(streamName, consumerGroupName, StreamEntryID("0-0"), true)
    } catch (e: Exception) {
        println("Group already exists")
    }
}

In [5]:
createConsumerGroup("jetstream", "printer-example")

### Reading from the Stream
Create a reusable function to read from the stream. This function will read from the stream and return a list of entries. It uses the XREADGROUP command to read from the stream as part of a consumer group:

`XREADGROUP GROUP groupName consumerName COUNT count BLOCK blockTime streamName id`

The command will be encapsulated in a function that takes the stream name, consumer group name, consumer name, and count as parameters. The function will return a list of entries.

In [6]:
import redis.clients.jedis.params.XReadGroupParams

fun readFromStream(streamName: String, consumerGroup: String, consumer: String, count: Int): List<Map.Entry<String, List<StreamEntry>>> {
    return jedis.xreadGroup(
        consumerGroup,
        consumer,
        XReadGroupParams().count(count),
        mapOf(
            streamName to StreamEntryID.XREADGROUP_UNDELIVERED_ENTRY
        )
    ) ?: emptyList()
}

### Acknowledging Messages
Create a function to acknowledge the message. This is important to let Redis know that the message has been processed successfully, so it won't be delivered to other consumers in the group.

This is done by using the XACK command:

`XACK streamName groupName id`

The command will be encapsulated in a lambda function that takes the stream name, consumer group name, and entry as parameters. The function will acknowledge the message by calling the XACK command.

In [7]:
val ackFn: (String, String, StreamEntry) -> Unit = { streamName, consumerGroup, entry ->
    jedis.xack(
        streamName,
        consumerGroup,
        entry.id
    )
}

### Consuming the Stream
Create a reusable function to consume the stream.

This function implements a pipeline pattern where each event is processed sequentially by a series of handlers. If any handler returns false, the processing stops for that event.

After processing the event, the function acknowledges the message using the ack function.

In [8]:
%use coroutines

In [44]:
import kotlinx.coroutines.*

fun consumeStream(
    streamName: String,
    consumerGroup: String,
    consumer: String,
    handlers: List<(Event) -> Pair<Boolean, String>>,
    ackFunction: ((String, String, StreamEntry) -> Unit),
    count: Int = 5,
    limit: Int = 5
) {
    var lastMessageTime = System.currentTimeMillis()
    var consumed = 0

    while (consumed < limit) {
        val entries = readFromStream(streamName, consumerGroup, consumer, count)
        val allEntries = entries.flatMap { it.value }
        allEntries.map { entry ->
            consumed++
            val event = Event.fromMap(entry)

            for (handler in handlers) {
                val (shouldContinue, message) = handler(event)
                ackFunction(streamName, consumerGroup, entry)

                if (!shouldContinue) {
                    println("$consumer: Handler stopped processing: $message")
                    break
                }
            }
        }

        if (allEntries.isEmpty()) {
            val now = System.currentTimeMillis()
            if (now - lastMessageTime >= 2_000) {
                println("$consumer: No new messages for 2 seconds. Stopping.")
                break
            }
        }
    }

}

To test the consumeStream function, we'll create a simple handler that prints the event's URI.

In [10]:
val printUri: (Event) -> Pair<Boolean, String> = {
    println("Got event from ${it.uri}")
    Pair(true, "OK")
}

In [46]:
runBlocking {
    consumeStream(
        streamName = "jetstream",
        consumerGroup = "printer-example",
        consumer ="printer-1",
        handlers = listOf(printUri),
        ackFunction = ackFn,
        count = 100,
        limit = 100
    )
}

Got event from at://did:plc:3ymwxtmeesirvucb2degvdsl/app.bsky.feed.post/3lpcmgglvfc2p
Got event from at://did:plc:wff7uo734vlav2j646kogbrj/app.bsky.feed.post/3lpcmghsguy2s
Got event from at://did:plc:4kqyouk3lcaodrkdw5uphgty/app.bsky.feed.post/3lpcmghvyyv25
Got event from at://did:plc:j4fa3lo7pu5ckeh2fuc3fdgl/app.bsky.feed.post/3lpcmgi2jlc2x
Got event from at://did:plc:rjn7np2j6ugqpvptlij2dbwl/app.bsky.feed.post/3lpcmgicu6k2t
Got event from at://did:plc:ueg2lbj2tvpc2hs7xytw3tic/app.bsky.feed.post/3lpcmghsfss2j
Got event from at://did:plc:5lya2736k5olysj2gkarysao/app.bsky.feed.post/3loyxxl2ccc25
Got event from at://did:plc:w4hpxq5o4dapo6ozvmhjqlav/app.bsky.feed.post/3lpcmgigxzc2w
Got event from at://did:plc:pjfggsk6zp25bgoy5ich2xio/app.bsky.feed.post/3lpcmfgtrrc2t
Got event from at://did:plc:n47xarga2ijdibgqa6swzply/app.bsky.feed.post/3lpcmgjfbqk2u
Got event from at://did:plc:judc7jgvsqmnoe7ljfbfi2oq/app.bsky.feed.post/3lpcmghi6o22f
Got event from at://did:plc:42kvq7or34iwpzkz2s4dczs6/a

## Deduplication with Bloom Filter
Redis Bloom Filter is a probabilistic data structure that allows us to check if an element is in a set. It's very memory efficient and has a constant time complexity for both insertion and lookup operations.


### Creating a Bloom Filter
This function creates a Bloom Filter with the given name. The filter is configured with an error rate of 0.01 and an initial capacity of 1,000,000 elements.

In [12]:
import redis.clients.jedis.bloom.BFReserveParams
import redis.clients.jedis.exceptions.JedisDataException
fun createBloomFilter(name: String) {
    runCatching {
        jedis.bfReserve(name, 0.01, 1_000_000L, BFReserveParams().expansion(2))
    }.onFailure {
        println("Bloom filter already exists")
    }
}

### Deduplication Handler
This function creates a handler that checks if an event has already been processed by checking if its URI is in the Bloom Filter. If the URI is in the filter, the handler returns false, which stops the processing of the event.


In [13]:
fun deduplicate(bloomFilter: String): (Event) -> Pair<Boolean, String> {
    return { event ->
        if (jedis.bfExists(bloomFilter, event.uri)) {
            Pair(false, "${event.uri} already processed")
        } else {
            Pair(true, "OK")
        }
    }
}

### Atomic Acknowledgment and Bloom Filter Update
This function creates a handler that acknowledges the message and adds the URI to the Bloom Filter in a single atomic transaction. This ensures that if the acknowledgment succeeds, the URI is also added to the filter, and vice versa.


In [14]:
import redis.clients.jedis.Connection
import redis.clients.jedis.JedisPool
import redis.clients.jedis.Transaction

val jedisPool = JedisPool()

fun ackAndBfFn(bloomFilter: String):  (String, String, StreamEntry) -> Unit = { streamName, consumerGroup, entry ->
    jedisPool.resource.use { jedis ->
        // Create a transaction
        val multi = jedis.multi()

        // Acknowledge the message
        multi.xack(
            streamName,
            consumerGroup,
            entry.id
        )

        // Add the URI to the bloom filter
        multi.bfAdd(bloomFilter, Event.fromMap(entry).uri)

        // Execute the transaction
        multi.exec()
    }
}

In [15]:
createConsumerGroup("jetstream", "deduplicate-example")

In [16]:
val bloomFilterName = "processed-uris"
createBloomFilter("processed-uris")

In [47]:
runBlocking {
    consumeStream(
        streamName = "jetstream",
        consumerGroup = "deduplicate-example",
        consumer = "deduplicate-1",
        handlers = listOf(deduplicate(bloomFilterName), printUri),
        ackFunction = ackAndBfFn(bloomFilterName),
        count = 100,
        limit = 100
    )
}

deduplicate-1: Handler stopped processing: at://did:plc:3ymwxtmeesirvucb2degvdsl/app.bsky.feed.post/3lpcmgglvfc2p already processed
deduplicate-1: Handler stopped processing: at://did:plc:wff7uo734vlav2j646kogbrj/app.bsky.feed.post/3lpcmghsguy2s already processed
deduplicate-1: Handler stopped processing: at://did:plc:4kqyouk3lcaodrkdw5uphgty/app.bsky.feed.post/3lpcmghvyyv25 already processed
deduplicate-1: Handler stopped processing: at://did:plc:j4fa3lo7pu5ckeh2fuc3fdgl/app.bsky.feed.post/3lpcmgi2jlc2x already processed
deduplicate-1: Handler stopped processing: at://did:plc:rjn7np2j6ugqpvptlij2dbwl/app.bsky.feed.post/3lpcmgicu6k2t already processed
deduplicate-1: Handler stopped processing: at://did:plc:ueg2lbj2tvpc2hs7xytw3tic/app.bsky.feed.post/3lpcmghsfss2j already processed
deduplicate-1: Handler stopped processing: at://did:plc:5lya2736k5olysj2gkarysao/app.bsky.feed.post/3loyxxl2ccc25 already processed
deduplicate-1: Handler stopped processing: at://did:plc:w4hpxq5o4dapo6ozvmhj

## Content-Based Filtering with Machine Learning
In this section, we'll use a machine learning model to filter posts based on their content. We'll use a pre-trained zero-shot classification model to classify posts as software-related or not.

### Setting Up the Machine Learning Model
To load the model, we'll use the DJL (Deep Java Library) library. DJL is a high-level framework for deep learning in Java that provides a simple and consistent API for loading and using models.

In [18]:
@file:DependsOn("ai.djl.huggingface:tokenizers:0.33.0")
@file:DependsOn("ai.djl.pytorch:pytorch-engine:0.33.0")
@file:DependsOn("ai.djl:api:0.33.0")
@file:DependsOn("ai.djl:model-zoo:0.33.0")

#### Creating the model translator
A translator is responsible for converting the input text into a format that the model can understand, and converting the model's output back into a format that we can use.

In this case, we'll create a custom translator for the zero-shot classification model.

In [19]:
import ai.djl.Model
import ai.djl.huggingface.tokenizers.Encoding
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer
import ai.djl.inference.Predictor
import ai.djl.modality.nlp.translator.ZeroShotClassificationInput
import ai.djl.modality.nlp.translator.ZeroShotClassificationOutput
import ai.djl.ndarray.NDArray
import ai.djl.ndarray.NDArrays
import ai.djl.ndarray.NDList
import ai.djl.ndarray.NDManager
import ai.djl.translate.*
import java.util.*

class CustomZeroShotClassificationTranslator private constructor(
    private val tokenizer: HuggingFaceTokenizer,
    private val int32: Boolean
) : NoBatchifyTranslator<ZeroShotClassificationInput, ZeroShotClassificationOutput> {

    private lateinit var predictor: Predictor<NDList, NDList>

    override fun prepare(ctx: TranslatorContext) {
        val model: Model = ctx.model
        predictor = model.newPredictor(NoopTranslator(null))
        ctx.predictorManager.attachInternal(UUID.randomUUID().toString(), predictor)
    }

    override fun processInput(ctx: TranslatorContext, input: ZeroShotClassificationInput): NDList {
        ctx.setAttachment("input", input)
        return NDList()
    }

    override fun processOutput(ctx: TranslatorContext, list: NDList): ZeroShotClassificationOutput {
        val input = ctx.getAttachment("input") as ZeroShotClassificationInput

        val template = input.hypothesisTemplate
        val candidates = input.candidates ?: throw TranslateException("Missing candidates in input")

        val manager: NDManager = ctx.ndManager
        val output = NDList(candidates.size)

        for (candidate in candidates) {
            val hypothesis = applyTemplate(template, candidate)
            val encoding: Encoding = tokenizer.encode(input.text, hypothesis)
            val encoded = encoding.toNDList(manager, true, true)
            val batch = Batchifier.STACK.batchify(arrayOf(encoded))
            output.add(predictor.predict(batch)[0])
        }

        var logits: NDArray = NDArrays.concat(output)
        logits = if (input.isMultiLabel) {
            val entailmentId = 0
            val contradictionId = 2
            val scores = NDList()
            for (i in 0 until output.size) {
                val logits2 = output[i]
                val pair = logits2.get(":, {}", manager.create(intArrayOf(contradictionId, entailmentId)))
                val probs = pair.softmax(1)
                val entailmentScore = probs.get(":, 1") // shape: [1]
                scores.add(entailmentScore)
            }
             NDArrays.stack(scores).squeeze()
        } else {
            val entailmentId = 0
            val entailLogits = logits.get(":, $entailmentId")
            val exp = entailLogits.exp()
            val sum = exp.sum()
            exp.div(sum)
        }

        val indices = logits.argSort(-1, false).toLongArray()
        val probabilities = logits.toFloatArray()

        val labels = Array(candidates.size) { "" }
        val scores = DoubleArray(candidates.size)

        for (i in labels.indices) {
            val index = indices[i].toInt()
            labels[i] = candidates[index]
            scores[i] = probabilities[index].toDouble()
        }

        return ZeroShotClassificationOutput(input.text, labels, scores)
    }

    private fun applyTemplate(template: String, arg: String): String {
        val pos = template.indexOf("{}")
        return if (pos == -1) template + arg else template.substring(0, pos) + arg + template.substring(pos + 2)
    }

    companion object {
        fun builder(tokenizer: HuggingFaceTokenizer): Builder = Builder(tokenizer)

        fun builder(tokenizer: HuggingFaceTokenizer, arguments: MutableMap<String, *>): Builder =
            builder(tokenizer).apply { configure(arguments) }
    }

    class Builder(private val tokenizer: HuggingFaceTokenizer) {
        private var int32: Boolean = false

        fun optInt32(int32: Boolean) = apply { this.int32 = int32 }

        fun configure(arguments: MutableMap<String, *>) {
            optInt32(ArgumentsUtil.booleanValue(arguments, "int32"))
        }

        fun build(): CustomZeroShotClassificationTranslator = CustomZeroShotClassificationTranslator(tokenizer, int32)
    }
}

### Creating the Model Criteria
The criteria is used to load the model and create a predictor. The criteria specifies the model path, the engine to use (in this case, PyTorch), and the translator to use.


In [20]:
import ai.djl.huggingface.translator.ZeroShotClassificationTranslator
import ai.djl.huggingface.translator.ZeroShotClassificationTranslatorFactory
import ai.djl.modality.nlp.translator.ZeroShotClassificationInput
import ai.djl.modality.nlp.translator.ZeroShotClassificationOutput
import ai.djl.repository.zoo.Criteria
import java.nio.file.Paths

val tokenizer = HuggingFaceTokenizer.newInstance(Paths.get("/Users/raphaeldelio/Documents/GitHub/redis/kotlinconf-bluesky-bot/model/DeBERTa-v3-large-mnli-fever-anli-ling-wanli/tokenizer.json"))

val translator = CustomZeroShotClassificationTranslator.builder(tokenizer).build()

val criteria: Criteria<ZeroShotClassificationInput, ZeroShotClassificationOutput> = Criteria.builder()
    .setTypes(
        ZeroShotClassificationInput::class.java,
        ZeroShotClassificationOutput::class.java
    )
    .optModelPath(Paths.get("/Users/raphaeldelio/Documents/GitHub/redis/kotlinconf-bluesky-bot/model/DeBERTa-v3-large-mnli-fever-anli-ling-wanli"))
    .optEngine("PyTorch")
    .optTranslator(translator)
    .build()


### Loading the Model
Now we'll load the model and create a predictor. The predictor is used to make predictions with the model.


In [21]:
import ai.djl.repository.zoo.ModelZoo

val model = ModelZoo.loadModel(criteria)
val predictor = model.newPredictor()

### Creating a Classification Function
Now we'll create a function to classify text using the model.

The function takes a text as input and returns a classification output. The classification output contains the probabilities for each candidate label.


In [22]:
import ai.djl.modality.nlp.translator.ZeroShotClassificationOutput

fun classify(premise: String): ZeroShotClassificationOutput {
    val candidateLabels = listOf("Software Engineering", "Software Programming")
    val input = ZeroShotClassificationInput(premise, candidateLabels.toTypedArray(), true, "{}")
    return predictor.predict(input)
}

### Creating a Filter Handler
Now we'll create a handler that filters events based on their content.

The handler uses the classification function to determine if a post is software-related.

If the post is not software-related, the handler returns false, which stops the processing of the event.


In [23]:
val filter: (Event) -> Pair<Boolean, String> = { event ->
    if (event.text.isNotBlank() && event.operation != "delete") {
        val classification = classify(event.text)
        if (classification.scores.any { it > 0.90 }) {
            Pair(true, "OK")
        } else {
            Pair(false, "Not a post related to software")
        }
    } else {
        Pair(false, "Text is null or empty")
    }
}

## Storing Filtered Events
In this section, we'll store the filtered events in Redis for further processing.


### Converting Events to Maps
First, we need a function to convert an Event object to a Map that can be stored in Redis as a Hash.

In [24]:
fun Event.toMap() = mapOf(
    "did" to this.did,
    "timeUs" to this.timeUs,
    "text" to this.text,
    "langs" to this.langs.joinToString("|"),
    "operation" to this.operation,
    "rkey" to this.rkey,
    "parentUri" to this.parentUri,
    "rootUri" to this.rootUri,
    "uri" to this.uri
)

### Storing Events in Redis
Now we'll create a handler that stores events in Redis. The handler stores the event as a hash in Redis, with the key being the event's URI.


In [25]:
val storeEvent: (Event) -> Pair<Boolean, String> = { event ->
    jedis.hset("post:" + event.uri, event.toMap())
    Pair(true, "OK")
}

### Adding Filtered Events to a New Stream
Finally, we'll create a handler that adds filtered events to a new stream. This allows other consumers to process only the filtered events, rather than having to filter the events themselves.


In [26]:
import redis.clients.jedis.params.XAddParams

val addFilteredEventToStream: (Event) -> Pair<Boolean, String> = { event ->
    jedis.xadd(
        "filtered-events",
        XAddParams.xAddParams()
            .id(StreamEntryID.NEW_ENTRY)
            .maxLen(1_000_000)
            .exactTrimming(),
        event.toMap()
    )
    Pair(true, "OK")
}

In [27]:
createConsumerGroup("jetstream", "store-example")

In [28]:
val bloomFilterName = "store-bf"
createBloomFilter(bloomFilterName)

## Putting It All Together
Now we'll put all the pieces together to create a complete pipeline for filtering events from the Redis Stream.

In this example we create two consumers that will process the same stream.
- By doing that, we can scale the processing of the events by adding more consumers to the group.
- Redis will make sure that each consumer will receive different messages.


In [48]:
runBlocking {
        listOf(
            async { // Use Dispatchers.IO outside Kotlin Notebooks - Kotlin Notebooks won't print whatever was printed to the console
                consumeStream(
                    streamName = "jetstream",
                    consumerGroup = "store-example",
                    consumer = "store-1",
                    handlers = listOf(
                        deduplicate(bloomFilterName),
                        filter,
                        printUri,
                        storeEvent,
                        addFilteredEventToStream
                    ),
                    ackFunction = ackAndBfFn(bloomFilterName),
                    count = 50,
                    limit = 50
                )
            },
            async {
                consumeStream(
                    streamName = "jetstream",
                    consumerGroup = "store-example",
                    consumer = "store-2", // Different consumer
                    handlers = listOf(
                        deduplicate(bloomFilterName),
                        filter,
                        printUri,
                        storeEvent,
                        addFilteredEventToStream
                    ),
                    ackFunction = ackAndBfFn(bloomFilterName),
                    count = 50,
                    limit = 50
                )
            }
        ).awaitAll()
}

store-1: Handler stopped processing: at://did:plc:rshibtcbcif6qukicpuktnm6/app.bsky.feed.post/3lpcmjjbbue2r already processed
store-1: Handler stopped processing: at://did:plc:pipprryk6mati5657cv5y63v/app.bsky.feed.post/3lpcmjjxsa226 already processed
store-1: Handler stopped processing: at://did:plc:dblfhwwk4uqa7h7wnabt6rvz/app.bsky.feed.post/3lpcmjj3ylc2q already processed
store-1: Handler stopped processing: at://did:plc:ormz52bv7npfs5t6y7j6zofg/app.bsky.feed.post/3lpcmjlbekq2i already processed
store-1: Handler stopped processing: at://did:plc:ijgsaagcz2bkl7r7ti6vbqhs/app.bsky.feed.post/3lpcmjl2z6k2h already processed
store-1: Handler stopped processing: at://did:plc:vmd4dpqyrdnelwztp4ia6rsj/app.bsky.feed.post/3lpcmjl2ogs2t already processed
store-1: Handler stopped processing: at://did:plc:srjxh4v7qmms37dpmta7qpsb/app.bsky.feed.post/3lpcmjkvghs2t already processed
store-1: Handler stopped processing: at://did:plc:dgxuyys5hitd4ylqtfv6ovkz/app.bsky.feed.post/3lpcmjj6nk227 already pr

[kotlin.Unit, kotlin.Unit]

## Next Steps
In the next notebook, we'll enrich the filtered events with additional information, such as topic modeling and embeddings for semantic search.
