In [1]:
%load_ext autoreload
%autoreload 2
from IPython.core.interactiveshell import InteractiveShell
from IPython.display import IFrame, display
InteractiveShell.ast_node_interactivity = "all"

# Probabilities to cluster algorithm

A notebook to hash out this algorithm and check it works.

Will hopefully turn into a unit test too, hence CVSs into version control.

In [2]:
from src import locations as loc
from src.data import utils as du

import pandas as pd
import duckdb
from pathlib import Path

Tests:

* unambig_t2_e4
* unambig_t3_e2
* masked_t3_e3
* val_masked_t3_e2
* val_unambig_t3_e2

## Helper functions

In [27]:
def validate_against_answer(my_cluster, validated_cluster, n_type = 'par'):
    clus_check_l = duckdb.sql("""
        select
            cluster,
            id,
            source,
            n::int as n
        from
            my_cluster
        order by
            cluster,
            source,
            id,
            n
    """)
    clus_check_r = duckdb.sql(f"""
        select
            cluster,
            id,
            source,
            n_{n_type}::int as n
        from
            validated_cluster
        order by
            cluster,
            source,
            id,
            n_{n_type}
    """)
    return clus_check_l.df().equals(clus_check_r.df())

## Formalise algo

In [16]:
def resolve_clusters(prob, val):
    # Initialise clusters and insert validated links
    # The "where" in validation is to prevent data leaking
    # when we do this in steps. We only resolve against the 
    # sources in prob
    clus_init = duckdb.sql("""
        drop sequence if exists uuid;
        drop sequence if exists cluster;
        create sequence uuid start 1;
        create sequence cluster start 1;
        select
            nextval('uuid') as uuid,
            nextval('cluster') as cluster,
            id,
            source,
            0 as n,
        from
            prob
        where
            cluster = 0
        union
        select
            nextval('uuid') as uuid,
            cluster,
            id,
            source,
            1 as n,
        from
            val
        where 
            source in (
                select
                    source
                from
                    prob
            )
    """)
    # Create a temporary probabilities table so we 
    # can delete stuff
    # Create a temporary clusters table so duckDB can
    # insert stuff. Wouldn't be needed in a database
    duckdb.sql("""
        drop table if exists probabilities_temp;
        drop table if exists clusters_temp;
        
        create temp table probabilities_temp as
            select
                uuid,
                link_type,
                cluster,
                id,
                source,
                probability
            from
                prob prob
            where 
                prob.probability >= 0.7
                and cluster != 0
            order by
                probability desc;
        
        create temp table clusters_temp as
            select
                uuid,
                cluster,
                id,
                source,
                n,
            from
                clus_init;
    """)
    # Find what we need to insert by comparing clusters_temp and
    # probabilities_temp
    # Insert it into clusters_temp
    # Delete it from probabilities_temp
    # Keep going until there's nothing to find
    data_to_insert = True
    while data_to_insert:
        to_insert = duckdb.sql("""
            select
                distinct on (agg.id, agg.source)
                nextval('uuid') as uuid,
                cluster,
                id,
                source,
                1 as n,
            from (
                select
                    distinct on (prob.cluster, prob.source)
                    prob.*
                from
                    probabilities_temp prob
                where 
                    not exists (
                        select
                            id,
                            source
                        from
                            clusters_temp clus
                        where
                            clus.id = prob.id
                            and clus.source = prob.source
                    )
                    or not exists (
                        select
                            cluster,
                            source
                        from
                            clusters_temp clus
                        where
                            clus.cluster = prob.cluster
                            and clus.source = prob.source
                    )
                order by
                    probability desc
            ) agg;
        """)
        
        if len(to_insert.df().index) == 0:
            data_to_insert = False
            break
        
        duckdb.sql("""
            insert into clusters_temp 
            select
                uuid,
                cluster,
                id,
                source,
                n,
            from
                to_insert;
        """)

        duckdb.sql("""
            delete from probabilities_temp prob_temp
            where exists (
                select 
                    cl.cluster,
                    cl.id,
                    cl.source
                from 
                    to_insert cl
                where
                    cl.id = prob_temp.id
                    and cl.cluster = prob_temp.cluster
                    and cl.source = prob_temp.source
            );
        """)

    result = duckdb.sql("""
        select
            uuid,
            cluster,
            id,
            source,
            n,
        from
            clusters_temp;
    """)

    return result.df()

## Testing

In [30]:
tests = [
    "unambig_t2_e4",
    "unambig_t3_e2",
    "masked_t3_e3",
    "val_masked_t3_e2",
    "val_unambig_t3_e2",
]

In [31]:
for test in tests:
    prob, clus, val = du.load_test_data(Path(loc.PROJECT_DIR, "test", test))
    my_answer = resolve_clusters(prob, val)
    passed = validate_against_answer(my_answer, clus, n_type = 'par')
    print(f"{test} passed: {passed}")

unambig_t2_e4 passed: True
unambig_t3_e2 passed: True
masked_t3_e3 passed: True
val_masked_t3_e2 passed: True
val_unambig_t3_e2 passed: True
