In [1]:
!pip install --quiet altair


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [1]:
import altair as alt
import numpy as np
import pandas as pd

from tqdm import tqdm

In [2]:
alt.data_transformers.enable("vegafusion")

DataTransformerRegistry.enable('vegafusion')

In [3]:
click_df = pd.read_parquet("output/clicks.parquet", columns=["query", "title", "abstract", "url_md5", "position", "click"])
annotation_df = pd.read_parquet("output/annotations.parquet")

### Utils

In [4]:
def create_session_id(df):
    query_change = click_df["query"] != click_df["query"].shift()
    position_change = click_df["position"] < click_df["position"].shift()
    return (query_change | position_change).cumsum()

def match(df1, df2, on, left_columns=[], right_columns=[], unique=False):
    
    if unique:
        df1 = df1[on + left_columns].drop_duplicates(on)
        df2 = df2[on + right_columns].drop_duplicates(on)

    return df1[on + left_columns].merge(df2[on + right_columns], on=on)

def count_stats(df, groupby, unique = False):
    groupby = groupby if type(groupby) is list else [groupby]
    
    if unique:
        df = df.drop_duplicates(["query", "title"] + groupby)
    
    source = df.groupby(groupby).agg(
        documents=("title", "count"),
    ).reset_index()

    source["perc_documents"] = (source.documents / source.documents.sum()).round(4)
    return source

# Matching clicks and annotations:

In [5]:
# Create surrogate id for individual sessions (either on query change or position change):
click_df["session_id"] = create_session_id(click_df)
# At least one document in user sessions
match_df = match(annotation_df, click_df, on=['query', 'title'], left_columns=["frequency_bucket", "label"], right_columns=["session_id", "position", "click"])
# At least three documents in the same user session
query_df = match_df.groupby(["query", "session_id"]).agg(documents=("title", "nunique")).reset_index()
session_df = query_df[query_df["documents"] >= 3]
session_df = match_df.merge(session_df, on=["query", "session_id"])

match_df.head()

Unnamed: 0,query,title,frequency_bucket,label,session_id,position,click
0,21391 21391 21446 21368 21368 21368 76 14451 1...,9102 10144 13027 16653 5978 14383 14170 20627 ...,6,2,3026196,5,0
1,21391 21391 21446 21368 21368 21368 76 14451 1...,9102 10144 13027 16653 5978 14383 14170 20627 ...,6,2,3197388,5,0
2,21391 21391 21446 21368 21368 21368 76 14451 1...,9102 10144 13027 16653 5978 14383 14170 20627 ...,6,2,3026196,5,0
3,21391 21391 21446 21368 21368 21368 76 14451 1...,9102 10144 13027 16653 5978 14383 14170 20627 ...,6,2,3197388,5,0
4,21391 21391 21446 21368 21368 21368 76 14451 1...,445 445 76 14451 5924 21429 16951 8479 8331,6,3,3026196,1,0


In [6]:
print(f"Queries with one matching doc: {query_df['query'].nunique()}")
print(f"Queries with three or more matching docs: {query_df[query_df['documents'] >= 3]['query'].nunique()}")
print(f"Queries with five or more matching docs: {query_df[query_df['documents'] >= 5]['query'].nunique()}")

Queries with one matching doc: 2744
Queries with three or more matching docs: 1205
Queries with five or more matching docs: 243


In [7]:
source = session_df.groupby("query").agg(
    unique_documents=("title", "nunique"),
    sessions=("session_id", "nunique"),
).reset_index()

base = alt.Chart(source, width=400, height=300).transform_calculate(
    log_sessions = 'log(datum.sessions)/log(10)'
)

base.mark_bar().encode(
    x=alt.X("unique_documents:O", title="# of unique documents").axis(labelAngle=0),
    y=alt.Y("count(query):Q", title="# of queries"),
) | base.mark_bar().encode(
    x=alt.X("log_sessions:Q", title="# of user sessions per query (log scale)").bin().axis(labelExpr='format(pow(10, datum.value), ",")', labelAngle=-45, labelOverlap=False),
    y=alt.Y("count(query):Q", title="# of queries"),
    tooltip=["count(query)"]
)

# Matching Funnel
## Query funnel

In [9]:
query_funnel_df = pd.DataFrame([
    {"stage": "1. Test queries", "queries": annotation_df["query"].nunique()},
    {"stage": "2. Appear in train set", "queries": match(annotation_df, click_df, on=["query"], unique=True)["query"].nunique()},
    {"stage": "3. At least one matching doc", "queries": match_df["query"].nunique()},
    {"stage": "4. At least three docs in session", "queries": session_df["query"].nunique()},
])

query_funnel_df["perc_queries"] = query_funnel_df["queries"] / query_funnel_df["queries"].max()

In [10]:
base = alt.Chart(
    query_funnel_df,
    title="Funnel of Expert-Annotated Queries",
    width=800
).encode(
    x=alt.X("stage:N", title="").axis(labelAngle=0),
    y=alt.Y("queries:Q", title="% of annotated queries"),
    color=alt.Color("stage:N", legend=None).scale(scheme="darkblue"),
)
bars = base.mark_bar()
absolute = base.mark_text(align="center", dy=-8, dx=-60, fontSize=12).encode(text=alt.Text("queries:Q", format=","))
percentage = base.mark_text(align="center", dy=-8, dx=60, fontSize=12).encode(text=alt.Text("perc_queries:Q", format=".2%"))
chart = (bars + absolute + percentage)
chart.configure_axis(
    labelFontSize=12,
    titleFontSize=14,
).configure_title(
    fontSize=16,
)

## Document funnel

In [8]:
query_doc_funnel_df = pd.DataFrame([
    {"stage": "1. Test query-doc pairs", "documents": len(annotation_df)},
    {"stage": "2. Unique query-doc pairs", "documents": len(annotation_df.drop_duplicates(["query", "title"]))},
    {"stage": "3. Title in train set", "documents": len(match_df.drop_duplicates(["query", "title"]))},
    {"stage": "4. At least three titles in session", "documents": len(session_df.drop_duplicates(["query", "title"]))},
    {"stage": "5. Title/abstract in train set", "documents": len(match(annotation_df, click_df, on=["query", "title", "abstract"], unique=True))},
])

query_doc_funnel_df["perc_documents"] = query_doc_funnel_df["documents"] / query_doc_funnel_df["documents"].max()

In [9]:
base = alt.Chart(
    query_doc_funnel_df,
    title="Funnel of (Unique) Expert-Annotated Query-Document pairs",
    width=800
).encode(
    x=alt.X("stage:N", title="").axis(labelAngle=0),
    y=alt.Y("documents:Q", title="% of annotated query-document pairs"),
    color=alt.Color("stage:N", legend=None).scale(scheme="darkblue"),
)
bars = base.mark_bar()
absolute = base.mark_text(align="center", dy=-8, dx=-40, fontSize=12).encode(text=alt.Text("documents:Q", format=","))
percentage = base.mark_text(align="center", dy=-8, dx=40, fontSize=12).encode(text=alt.Text("perc_documents:Q", format=".2%"))
chart = (bars + absolute + percentage)
chart.configure_axis(
    labelFontSize=12,
    titleFontSize=14,
).configure_title(
    fontSize=16,
)

## Distribution Shifts after Matching
### Expert Annotations

In [10]:
alt.Chart(
    count_stats(annotation_df, groupby="label", unique=True),
    title="Expert Annotations",
    width=600,
    height=150,
).mark_bar().encode(
    x=alt.X("label:O", title="Relevance label").axis(labelAngle=0),
    y=alt.Y("perc_documents", title="% of (unique) query/title pairs").axis(format="%"),
    color=alt.Color("label:N", legend=None)
) & alt.Chart(
    count_stats(session_df, groupby="label", unique=True),
    title="Expert Annotations with User Sessions (sessions with >= 3 matching docs)",
    width=600,
    height=150,
).mark_bar().encode(
    x=alt.X("label:O", title="Relevance label").axis(labelAngle=0),
    y=alt.Y("perc_documents", title="% of (unique) query/title pairs").axis(format="%"),
    color=alt.Color("label:N", legend=None)
)

### Popularity shift

alt.Chart(
    source,
    title="Expert Annotations",
    width=600,
    height=150,
).mark_bar().encode(
    x=alt.X("frequency_bucket:O", title="Query frequency (high to low)").axis(labelAngle=0),
    y=alt.Y("queries", title="# of queries").scale(domain=(0, 1000)),
    color=alt.Color("frequency_bucket:O", legend=None).scale(scheme="darkblue"),
) & alt.Chart(
    session_source,
    title="Expert Annotations with User Sessions (sessions with >= 3 matching docs)",
    width=600,
    height=150,
).mark_bar().encode(
    x=alt.X("frequency_bucket:O", title="Query frequency (high to low)").axis(labelAngle=0),
    y=alt.Y("queries", title="# of queries").scale(domain=(0, 1000)),
    color=alt.Color("frequency_bucket:O", legend=None).scale(scheme="darkblue"),
)

# Analysis
## CTR per Position and Relevance Label

In [11]:
matching_docs = 3
filter_df = query_df[query_df["documents"] >= matching_docs]
filter_df = match_df.merge(filter_df, on=["query", "session_id"])

source = filter_df.groupby(["query", "title", "position"]).agg(
    ctr=("click", "mean"),
    label=("label", "max"),
).reset_index()

session_source = filter_df.groupby(["position", "label"]).agg(
    queries=("query", "nunique"),
    sessions=("session_id", "nunique"),
).reset_index()

In [12]:
selection = alt.selection_point(fields=["label"], bind="legend")

base = alt.Chart(source, title=f"Avg. CTR per position and expert label (sessions with >= {matching_docs} matching docs)", width=800).encode(
    x=alt.X("position:O", title="Position"),
    color=alt.Color("label:N", title="Relevance"),
).add_params(selection)

mean = base.mark_line(point=True).encode(
    y=alt.Y("mean(ctr)", title="CTR"),
    opacity=alt.condition(selection, alt.value(1), alt.value(0.2)),
    strokeWidth=alt.value(2),
    tooltip=["position", "label", "mean(ctr)"],
)

ci = base.mark_errorband(extent="ci").encode(
    y=alt.Y("ctr", title="CTR"),
    opacity=alt.condition(selection, alt.value(0.2), alt.value(0.05)),
)

queries = alt.Chart(session_source, title="Queries", width=800, height=200).mark_bar().encode(
    x=alt.X("position:O", title="Position"),
    xOffset="label:O",
    y=alt.Y("queries", title="# of queries (log scale)").scale(type="log"),
    color=alt.Color("label:N", title="Relevance"),
    opacity=alt.condition(selection, alt.value(1), alt.value(0.2)),
    tooltip=["position", "label", "queries"]
)

chart = (ci + mean) & queries
chart

In [103]:
source = session_df.groupby(["query", "title", "position"]).agg(
    impressions=("click", "count"),
    ctr=("click", "mean"),
    label=("label", "max"),
).reset_index()

In [133]:
def plot_mean_relevance(source, min_label=0, min_impressions=10):
    source = source[(source.label >= min_label) & (source.impressions >= min_impressions)].copy()
    
    base = alt.Chart(
        source,
        width=600,
        height=200,
        title=f"Avg. expert label over CTRs (impressions >= {min_impressions}, relevance >= {min_label})"
    ).mark_line(point=True).encode(
        x=alt.X("ctr:Q", title="Mean CTR (binned)").bin(step=0.1),
    )
    
    mean = base.mark_line(point=True).encode(
        y=alt.Y("mean(label):Q", title="Mean relevance").scale(domain=(1, 4.0)),
    )
    
    ci = base.mark_errorband(extent="ci", opacity=0.2).encode(
        y=alt.Y("label:Q", title="Mean relevance").scale(domain=(1, 4.0)),
    )
    
    impressions = alt.Chart(
        source,
        width=600,
        height=200,
        title="Queries",
    ).mark_bar().encode(
        x=alt.X("ctr:Q", title="Mean CTR (binned)").bin(step=0.1),
        y=alt.Y("distinct(query):Q", title="# of queries").scale(domain=(0, 800)),
    ).properties(height=150)
    
    return (ci + mean) & impressions

In [139]:
plot_mean_relevance(source, min_label=1, min_impressions=10)