In [1]:
import altair as alt
import panel as pn
import polars as pl

pn.extension("vega")
alt.data_transformers.disable_max_rows()
# alt.data_transformers.enable("vegafusion")
# alt.renderers.enable("jupyter")

DataTransformerRegistry.enable('default')

In [2]:
def parse_list_strs(col: str) -> pl.Expr:
    parts = pl.col(col).str.strip_chars("\"[]").str.split(" ")
    return parts
    return parts.explode()


def parse_list_ints(col: str) -> pl.Expr:
    parts = pl.col(col).str.strip_chars("\"[]").str.split(" ")
    return parts.list.eval(pl.element().cast(pl.UInt64))
    return parts.explode().cast(pl.UInt64)


def parse_list_floats(col: str) -> pl.Expr:
    parts = pl.col(col).str.strip_chars("\"[]").str.split(" ")
    return parts.list.eval(pl.element().cast(pl.Float64))
    return parts.explode().cast(pl.Float64)

In [4]:
df_raw = pl.read_csv(
    source="swap_impact_stats_by_protocol.csv",
)

df_pools = pl.read_csv(
    source="pools.csv",
).with_columns(
    pl.col("block_ts").str.to_datetime(format="%Y-%m-%d %H:%M:%S.%3f %Z", time_unit="ms"),
    pl.col("block_n").cast(pl.UInt64),
    pl.col("log_index").cast(pl.UInt64),
    pl.col("tx_from").str.to_lowercase(),
    pl.col("factory_address").str.to_lowercase(),
    pl.col("pool_address").str.to_lowercase(),
    pl.col("token0_address").str.to_lowercase(),
    pl.col("token1_address").str.to_lowercase(),
    pl.col("token0_decimals").cast(pl.UInt64),
    pl.col("token1_decimals").cast(pl.UInt64),
    pl.col("pool_id").cast(pl.UInt64),
    pl.col("bin_step").cast(pl.UInt64),
    pl.col("stable").cast(pl.String).replace_strict({"false":False, "true":True}),
    pl.col("fee").cast(pl.UInt64),
    pl.col("tick_spacing").cast(pl.UInt64),
)

df_tokens = pl.read_csv(
    source="tokens.csv",
).with_columns(
    pl.col("token_address").str.to_lowercase(),
    pl.col("token_decimals").cast(pl.UInt64),
)

"""
df_tokens = pl.concat(
    [
        df_pools.select(
            pl.col("token0_address").alias("token_address"),
            pl.col("token0_symbol").alias("token_symbol"),    
        ),
        df_pools.select(
            pl.col("token1_address").alias("token_address"),
            pl.col("token1_symbol").alias("token_symbol"),    
        )
    ]
).unique()

"""

'\ndf_tokens = pl.concat(\n    [\n        df_pools.select(\n            pl.col("token0_address").alias("token_address"),\n            pl.col("token0_symbol").alias("token_symbol"),    \n        ),\n        df_pools.select(\n            pl.col("token1_address").alias("token_address"),\n            pl.col("token1_symbol").alias("token_symbol"),    \n        )\n    ]\n).unique()\n\n'

In [7]:
df = df_raw.select(
    pl.col("date").str.to_datetime(format="%Y-%m-%d %H:%M:%S.%3f %Z", time_unit="ms").cast(pl.Date).alias("date"),
    parse_list_strs("protocols").alias("protocol"),
    parse_list_strs("token_sold_addresses").alias("token_sold_address"),
    parse_list_strs("token_bought_addresses").alias("token_bought_address"),
    parse_list_ints("swap_counts").alias("swap_count"),
    parse_list_floats("token_sold_amounts").alias("token_sold_amount"),
    parse_list_floats("token_bought_amounts").alias("token_bought_amount"),
    parse_list_floats("token_sold_amount_weighted_mean_slippages").alias("token_sold_amount_weighted_mean_slippage"),
    parse_list_floats("median_slippages").alias("median_slippage"),
    parse_list_floats("token_sold_amount_weighted_mean_market_impacts").alias("token_sold_amount_weighted_mean_market_impact"),
    parse_list_floats("median_market_impacts").alias("median_market_impact"),
).explode(columns=pl.exclude("date"))

df = df.join(
    other=df_tokens.select(pl.all().name.map(lambda x: x + "_sold")),
    left_on="token_sold_address",
    right_on="token_address_sold",
    how="left",
).join(
    other=df_tokens.select(pl.all().name.map(lambda x: x + "_bought")),
    left_on="token_bought_address",
    right_on="token_address_bought",
    how="left",
).select(
    *df.columns,
    pl.col("token_symbol_sold").alias("token_sold_symbol"),
    pl.col("token_symbol_bought").alias("token_bought_symbol"),
)

df_dates = pl.date_range(
    start=df.get_column("date").min(),
    end=df.get_column("date").max(),
    interval="1d",
    closed="both",
    eager=True,
).alias("date").to_frame()

df_unique = df.select(
    "protocol",
    "token_sold_address",
    "token_bought_address",
).unique()

df_points = df_dates.join(
    other=df_unique,
    how="cross",
)

df_extended = df_points.join(
    other=df,
    on=[
        "date",
        "protocol",
        "token_sold_address",
        "token_bought_address",
    ],
    how="left",
).select(
    pl.col("swap_count").fill_null(0),
    pl.col("token_sold_amount").fill_null(0),
    pl.col("token_bought_amount").fill_null(0),
)

In [9]:
x_column = "token_sold_amount"
y_columns = [
    "token_sold_amount_weighted_mean_slippage",
    "median_slippage",
    "token_sold_amount_weighted_mean_market_impact",
    "median_market_impact",
]
group_column = "protocol"
filter1_column = "token_sold_symbol"
filter2_column = "token_bought_symbol"
date_column = "date"

column_mapping = {
    "token_sold_amount": "sold amount",
    "token_sold_amount_weighted_mean_slippage": "mean slippage",
    "median_slippage": "median slippage",
    "token_sold_amount_weighted_mean_market_impact": "mean market impact",
    "median_market_impact": "median market impact",
    "protocol": "protocol",
    "version": "version",
    "token_sold_symbol": "token sold",
    "token_bought_symbol": "token bought",
    "token_bought_amount": "bought amount",
    "token_sold_amount": "sold amount",
    "date": "date",
    "mark": "cutoff",
    "swap_count": "swap_count",
}

def create_scatter_plot(
    y_column_selection=y_columns[0],
    x_axis_log_range_selection=None,
    y_axis_linear_range_selection=None,
    token_sold_symbol_selection=None,
    token_bought_symbol_selection=None,
    start_date_selection=None,
    end_date_selection=None,
    cutoff_date_selection=None,
):
    filters = [pl.lit(True)]

    if token_sold_symbol_selection:
        filters.append(pl.col("token_sold_symbol").eq(token_sold_symbol_selection))

    if token_bought_symbol_selection:
        filters.append(pl.col("token_bought_symbol").eq(token_bought_symbol_selection))

    if start_date_selection:
        filters.append(pl.col(date_column).ge(start_date_selection))

    if end_date_selection:
        filters.append(pl.col(date_column).le(end_date_selection))
    
    filtered_df = df.filter(filters).with_columns((pl.when(pl.col("date") < cutoff_date_selection).then(pl.lit("pre")).otherwise(pl.lit("post")).alias("mark"))).to_pandas()

    legend_selection = alt.selection_point(fields=[group_column], bind="legend")
    
    scatter = alt.Chart(filtered_df).mark_point(
        size=60,
        filled=True,
    ).encode(
        x=alt.X(
            x_column,
            scale=alt.Scale(type="log", domain=[pow(10, v) for v in x_axis_log_range_selection], clamp=True),
            title=column_mapping[x_column],
        ),
        y=alt.Y(
            y_column_selection,
            scale=alt.Scale(domain=y_axis_linear_range_selection, type="linear", clamp=True),
            title=column_mapping[y_column_selection],
        ),
        color=alt.Color(
            group_column,
            scale=alt.Scale(scheme="category20"),
            title=column_mapping[group_column],
        ),
        opacity=alt.condition(legend_selection, alt.value(1), alt.value(0.001)),
        shape=alt.Shape("mark:N", title="cutoff", scale=alt.Scale(range=["triangle-right", "triangle-left"],
                                                                  #range=["cross", "circle"]
                                                                 )
        ),
        tooltip=[
            alt.Tooltip("date:T", title=column_mapping["date"], format="%y-%m-%d"),
            alt.Tooltip("mark:N", title=column_mapping["mark"]),
            alt.Tooltip("protocol:N", title=column_mapping["protocol"]),
            alt.Tooltip("token_sold_symbol:N", title=column_mapping["token_sold_symbol"]),
            alt.Tooltip("token_bought_symbol:N", title=column_mapping["token_bought_symbol"]),
            alt.Tooltip("token_sold_amount:Q", title=column_mapping["token_sold_amount"], format=".3f"),
            alt.Tooltip("token_bought_amount:Q", title=column_mapping["token_bought_amount"], format=".3f"),
            alt.Tooltip(y_column_selection + ":Q", title=column_mapping[y_column_selection], format=".3f"),
            alt.Tooltip("swap_count:N", title=column_mapping["swap_count"]),
        ],
    ).add_params(
        legend_selection,
    ).properties(
        #width=1400,
        #height=700,
        width="container",
        height="container",
        title=f"{token_sold_symbol_selection.lower() if token_sold_symbol_selection else "any"} -> {token_bought_symbol_selection.lower() if token_bought_symbol_selection else "any"}: {column_mapping[x_column]} vs {column_mapping[y_column_selection]} grouped by {column_mapping[group_column]}"
    )

    return scatter

y_column_widget = pn.widgets.Select(name='metric', options=y_columns, value=y_columns[0])
token_sold_symbol_widget = pn.widgets.Select(name='token sold', options=[""] + sorted(df[filter1_column].unique().to_list(), key=lambda x: (x.lower(), x)), value="USDC")
token_bought_symbol_widget = pn.widgets.Select(name='token bought', options=[""] + sorted(df[filter2_column].unique().to_list(), key=lambda x: (x.lower(), x)), value="WAVAX")
start_date_widget = pn.widgets.DatePicker(name='start date', value=df[date_column].min())
end_date_widget = pn.widgets.DatePicker(name='end date', value=df[date_column].max())
cutoff_date_widget = pn.widgets.DatePicker(name="cutoff date", value=df[date_column].min())

x_axis_log_range_widget = pn.widgets.RangeSlider(
    name="x-axis (log) range",
    start=-3,
    end=9,
    value=(-3, 9),
    step=0.1,
    format="0.1f",
)

y_axis_linear_range_widget = pn.widgets.RangeSlider(
    name="y-axis (linear) range",
    start=0,
    end=1000,
    value=(0, 1000),
    step=10,
)

def update_plot(
    y_column_selection,
    token_sold_symbol_selection,
    token_bought_symbol_selection,
    start_date_selection,
    end_date_selection,
    cutoff_date_selection,
    x_axis_log_range_selection,
    y_axis_linear_range_selection,
):
    return create_scatter_plot(
        y_column_selection=y_column_selection,
        x_axis_log_range_selection=x_axis_log_range_selection,
        y_axis_linear_range_selection=y_axis_linear_range_selection,
        token_sold_symbol_selection=token_sold_symbol_selection,
        token_bought_symbol_selection=token_bought_symbol_selection,
        start_date_selection=start_date_selection,
        end_date_selection=end_date_selection,
        cutoff_date_selection=cutoff_date_selection,
    )

interactive_plot = pn.bind(
    update_plot,
    y_column_selection=y_column_widget,
    x_axis_log_range_selection=x_axis_log_range_widget,
    y_axis_linear_range_selection=y_axis_linear_range_widget,
    token_sold_symbol_selection=token_sold_symbol_widget,
    token_bought_symbol_selection=token_bought_symbol_widget,
    start_date_selection=start_date_widget,
    end_date_selection=end_date_widget,
    cutoff_date_selection=cutoff_date_widget,
)

layout = pn.Column(
    pn.Row(y_column_widget, x_axis_log_range_widget, y_axis_linear_range_widget),
    pn.Row(token_sold_symbol_widget, token_bought_symbol_widget),
    pn.Row(start_date_widget, end_date_widget, cutoff_date_widget),
    interactive_plot
).servable()