In [None]:
import sqlite3

import pandas as pd

# Connect to SQLite database
sqlite_conn = sqlite3.connect("default.sqlite")
cursor = sqlite_conn.cursor()

# Get the list of all tables
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()

In [None]:
tables

In [None]:
import os

import snowflake.connector

snowflake_conn = snowflake.connector.connect(
    account=os.environ.get("SNOWFLAKE_ACCOUNT"),
    user=os.environ.get("SNOWFLAKE_USER"),
    password=os.environ.get("SNOWFLAKE_USER_PASSWORD"),
    database=os.environ.get("SNOWFLAKE_DATABASE"),
    schema="CONTEXT_RELEVANCE_TREC_COMBINED_NO_RUBRIC",
    warehouse=os.environ.get("SNOWFLAKE_WAREHOUSE"),
    role=os.environ.get("SNOWFLAKE_ROLE"),
)

In [None]:
for table_name in tables:
    # Load data from SQLite into a DataFrame
    table_name = table_name[0]  # Extract table name from tuple
    df = pd.read_sql_query(f"SELECT * FROM {table_name}", sqlite_conn)

    # Generate CREATE TABLE statement
    create_table_sql = f"CREATE OR REPLACE TABLE {table_name} ("
    columns = []
    for col_name, col_type in df.dtypes.items():
        if col_type == "int64":
            col_type_snowflake = "INTEGER"
        elif col_type == "float64":
            col_type_snowflake = "FLOAT"
        elif col_type == "bool":
            col_type_snowflake = "BOOLEAN"
        else:
            col_type_snowflake = "TEXT"
        columns.append(f"{col_name} {col_type_snowflake}")
    create_table_sql += ", ".join(columns) + ");"

    # Create table in Snowflake
    cursor = snowflake_conn.cursor()
    cursor.execute(create_table_sql)

In [None]:
import os
import re

for table_name in tables:
    table_name = table_name[0]
    df = pd.read_sql_query(f"SELECT * FROM {table_name}", sqlite_conn)

    # Save DataFrame as a CSV
    csv_file = f"{table_name}.csv"
    df.to_csv(csv_file, index=False)

    # Generate a Snowflake-compatible stage name by replacing special characters with underscores
    sanitized_stage_name = re.sub(r"\W+", "_", f"temp_stage_{table_name}")

    # Create a temporary stage in Snowflake
    snowflake_cursor = snowflake_conn.cursor()
    snowflake_cursor.execute(
        f"CREATE OR REPLACE TEMPORARY STAGE {sanitized_stage_name}"
    )

    # PUT the CSV file to the stage
    with open(csv_file, "rb") as file_data:
        snowflake_cursor.execute(
            f"PUT file://{csv_file} @{sanitized_stage_name}"
        )

    # Copy data from the stage to the table in Snowflake
    snowflake_cursor.execute(f"""
        COPY INTO {table_name}
        FROM @{sanitized_stage_name}
        FILE_FORMAT = (TYPE = 'CSV' FIELD_OPTIONALLY_ENCLOSED_BY = '"' SKIP_HEADER = 1)
    """)

    # Clean up temporary stage and local CSV file
    snowflake_cursor.execute(f"DROP STAGE IF EXISTS {sanitized_stage_name}")
    os.remove(csv_file)

In [None]:
sqlite_conn.close()
snowflake_conn.close()

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns


def plot_confusion_matrices(csv_file_path: str, title: str):
    # Step 1: Load the CSV file
    data = pd.read_csv(csv_file_path)

    # Step 2: Inspect the data
    print(data.head())

    # Ensure your CSV has columns: 'APP_VERSION', 'RAW_GT_SCORE', 'RAW_FEEDBACK_SCORE', and 'COUNT'

    # Step 3: Group data by 'APP_VERSION' and create a confusion matrix for each version
    app_versions = data["APP_VERSION"].unique()  # Get unique app versions

    for app_version in app_versions:
        # Filter data for the current app version
        app_data = data[data["APP_VERSION"] == app_version]

        # Pivot the data to create a confusion matrix
        confusion_matrix = app_data.pivot(
            index="RAW_GT_SCORE", columns="RAW_FEEDBACK_SCORE", values="COUNT"
        ).fillna(0)

        # Normalize the confusion matrix (optional)
        confusion_matrix_normalized = confusion_matrix.div(
            confusion_matrix.sum(axis=1), axis=0
        )

        # Step 4: Plot the confusion matrix for the current app version
        plt.figure(figsize=(8, 6))
        sns.heatmap(confusion_matrix, annot=True, fmt=".0f", cmap="Blues")
        plt.title(f"{title}: {app_version}")
        plt.xlabel("Feedback Score")
        plt.ylabel("Ground Truth")
        plt.show()

        # Step 5: Plot the normalized confusion matrix for the current app version
        plt.figure(figsize=(8, 6))
        sns.heatmap(
            confusion_matrix_normalized, annot=True, fmt=".2f", cmap="Blues"
        )
        plt.title(f"Normalized {title}: {app_version}")
        plt.xlabel("Feedback Score")
        plt.ylabel("Ground Truth")
        plt.show()

In [None]:
csv_file_with_rubric = "/Users/dhuang/Documents/git/trulens/src/benchmark/trulens/benchmark/benchmark_frameworks/experiments/data/TREC_rubric.csv"
csv_file_no_rubric = "/Users/dhuang/Documents/git/trulens/src/benchmark/trulens/benchmark/benchmark_frameworks/experiments/data/TREC_no_rubric.csv"

In [None]:
plot_confusion_matrices(csv_file_no_rubric, "Original prompt for")