In [None]:
import os
os.environ["RERUN_NOTEBOOK_ASSET"] = "inline"

In [None]:
from __future__ import annotations

import rerun as rr
import sys
import torch
import uuid
import pyarrow as pa

from transformers import CLIPTokenizer, CLIPModel

In [None]:
# Load CLIP tokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")

In [None]:
dataset = {
    "ep1": [
        "The swift fox leaps over the calm river.",
        "Rainbows dance after the gentle rain.",
        "Midnight jazz echoes through the quiet alley.",
        "An old lamp flickers in the dusty attic.",
        "A silver plane soars above cotton-candy clouds.",
        "Neon lights dazzle in the bustling cityscape.",
        "A lost kitten meows outside the warm bakery.",
        "Pastel petals flutter in the morning breeze.",
        "Golden wheat sways under a cobalt sky.",
        "A chessboard sits abandoned on the old table.",
        "The aroma of coffee fills the tiny café.",
        "Two robots gently shake hands under the moonlight.",
        "A lonely violinist plays in the park at dusk.",
        "A stack of vintage postcards lies in a drawer.",
        "The library clock chimes exactly at noon.",
    ],
    "ep2": [
        "A curious owl hoots near a glowing streetlamp.",
        "Cracked pavement lines the abandoned roller rink.",
        "Bright lanterns illuminate the narrow garden path.",
        "Waves crash gently on the sandy shoreline at dusk.",
        "An antique clock ticks quietly in the corner shop.",
        "Butterflies dance above the fragrant lavender fields.",
        "Rainbow confetti rains down during the lively parade.",
        "A solitary scarecrow watches the endless horizon.",
        "Glassy skyscrapers mirror the afternoon sunlight.",
        "A worn-out wooden sign points to a hidden trail.",
        "Dust motes swirl in the old bookshop’s golden light.",
        "Children laugh as they race paper boats in puddles.",
        "Steam rises from a fresh loaf in the rustic bakery.",
        "A jazz quartet rehearses in a forgotten theater.",
        "Fluffy sheep graze along the misty rolling hills.",
    ]
}

In [None]:
conn = rr.remote.connect("http://0.0.0.0:51234")

for ep, phrases in dataset.items():
    storage_path = f"/tmp/{ep}.rrd"
    rec = rr.new_recording("rerun_example_test_recording", recording_id=uuid.uuid4())

    inputs = tokenizer(
        phrases,
        padding=True,
        truncation=True,
        return_tensors="pt",
    ).to("cuda")

    with torch.no_grad():
        outputs = model.get_text_features(**inputs).cpu().numpy()

    for i, phrase in enumerate(phrases):
        rec.set_time_sequence("index", i)
        rec.log("words", rr.TextLog(phrase))
        rec.log("words", rr.AnyValues(embeddings=outputs[i]))

    rec.save(storage_path)
    conn.register(f"file://{storage_path}", metadata=pa.RecordBatch.from_arrays([pa.array([ep])], names=["episode"]))

In [None]:
conn.create_fts_index(
    collection="default",
    column=rr.dataframe.ComponentColumnSelector("/words", rr.components.Text),
    time_index=rr.dataframe.IndexColumnSelector("index"),
    store_position=False,
    base_tokenizer="simple",
)

In [None]:
# TODO(jleibs): Wrapping query params in a record-batch is annoying. Python APIs should do this for us.
search_result = conn.query_fts_index(
    collection="default",
    column=rr.dataframe.ComponentColumnSelector("/words", rr.components.Text),
    query=pa.RecordBatch.from_arrays([pa.array(["coffee"])], names=["item"]),
).read_all()

In [None]:
search_result

In [None]:
# Load the recording
rec = conn.open_recording(search_result[0][0].as_py())
view = rec.view(index="index", contents="/words", include_indicator_columns=True).select()

# Display it in the notebook
rr.init("episode")
rr.dataframe.send_dataframe(view)
viewer = rr.notebook.Viewer(width=1024)
viewer.display()

# Jump to the timepoint from the search results
viewer.set_time_ctrl(timeline="index", sequence=search_result[2][0].as_py())

In [None]:
def search_for(text):
    search_result = conn.query_fts_index(
        collection="default",
        column=rr.dataframe.ComponentColumnSelector("/words", rr.components.Text),
        query=pa.RecordBatch.from_arrays([pa.array([text])], names=["item"]),
    ).read_all()

    if len(search_result) > 0:
        viewer.set_time_ctrl(timeline="index", sequence=search_result[2][0].as_py())
    else:
        print("No match found")

In [None]:
search_for("neon")

In [None]:
search_for("petals")

In [None]:
search_for("library")

In [None]:
# TODO(jleibs): Something went wrong with indexing the first time-point -- off-by-1 somewhere?
search_for("fox")

In [None]:
# Need to debug -- looks like something incorrect with list-array-unwrapping
conn.create_vector_index(
    collection="default",
    column=rr.dataframe.ComponentColumnSelector("/words", "embeddings"),
    time_index=rr.dataframe.IndexColumnSelector("index"),
    num_partitions=5,
    num_sub_vectors=16,
    distance_metric="L2",
)