In [1]:
%load_ext autoreload
%autoreload 2
from IPython.core.interactiveshell import InteractiveShell
from IPython.display import IFrame, display
InteractiveShell.ast_node_interactivity = "all"

# Probabilities to cluster algorithm

A notebook to hash out this algorithm and check it works.

Will hopefully turn into a unit test too, hence CVSs into version control.

In [72]:
from src import locations as loc
from src.data import utils as du

import pandas as pd
import duckdb
from pathlib import Path

Tests:

* unambig_t2_e4
* unambig_t3_e2
* masked_t3_e3
* val_masked_t3_e2
* val_unambig_t3_e2

## Helper functions

In [3]:
def validate_against_answer(my_cluster, validated_cluster, n_type = 'par'):
    clus_check_l = duckdb.sql("""
        select
            cluster,
            id,
            source,
            n::int as n
        from
            my_cluster
        order by
            cluster,
            source,
            id,
            n
    """)
    clus_check_r = duckdb.sql(f"""
        select
            cluster,
            id,
            source,
            n_{n_type}::int as n
        from
            validated_cluster
        order by
            cluster,
            source,
            id,
            n_{n_type}
    """)
    return clus_check_l.df().equals(clus_check_r.df())

## Formalise algorithm

### DuckDB version

In [11]:
def resolve_clusters(prob, val, clus, n):
    # The clusters are initialised outside the function, as in the
    # real repo
    # The "where" in validation is to prevent data leaking
    # when we do this in steps. We only resolve against the 
    # sources in prob
    clus_init = duckdb.sql(f"""
        select
            uuid,
            cluster,
            id,
            source,
            n,
        from
            clus
        union
        select
            nextval('uuid') as uuid,
            cluster,
            id,
            source,
            {n} as n,
        from
            val
        where 
            source in (
                select
                    source
                from
                    prob
            )
    """)
    # Create a temporary probabilities table so we 
    # can delete stuff
    # Create a temporary clusters table so duckDB can
    # insert stuff. Wouldn't be needed in a database
    duckdb.sql("""
        drop table if exists probabilities_temp;
        drop table if exists clusters_temp;
        
        create temp table probabilities_temp as
            select
                uuid,
                link_type,
                cluster,
                id,
                source,
                probability
            from
                prob prob
            where 
                prob.probability >= 0.7
                and cluster != 0
            order by
                probability desc;
        
        create temp table clusters_temp as
            select
                uuid,
                cluster,
                id,
                source,
                n,
            from
                clus_init;
    """)
    # Find what we need to insert by comparing clusters_temp and
    # probabilities_temp
    # Insert it into clusters_temp
    # Delete it from probabilities_temp
    # Keep going until there's nothing to find
    data_to_insert = True
    while data_to_insert:
        to_insert = duckdb.sql(f"""
            select
                distinct on (agg.id, agg.source)
                nextval('uuid') as uuid,
                cluster,
                id,
                source,
                {n} as n,
            from (
                select
                    distinct on (prob.cluster, prob.source)
                    prob.*
                from
                    probabilities_temp prob
                where 
                    not exists (
                        select
                            id,
                            source
                        from
                            clusters_temp clus
                        where
                            clus.id = prob.id
                            and clus.source = prob.source
                    )
                    or not exists (
                        select
                            cluster,
                            source
                        from
                            clusters_temp clus
                        where
                            clus.cluster = prob.cluster
                            and clus.source = prob.source
                    )
                order by
                    probability desc
            ) agg;
        """)
        
        if len(to_insert.df().index) == 0:
            data_to_insert = False
            break
        
        duckdb.sql("""
            insert into clusters_temp 
            select
                uuid,
                cluster,
                id,
                source,
                n,
            from
                to_insert;
        """)

        duckdb.sql("""
            delete from probabilities_temp prob_temp
            where exists (
                select 
                    cl.cluster,
                    cl.id,
                    cl.source
                from 
                    to_insert cl
                where
                    cl.id = prob_temp.id
                    and cl.cluster = prob_temp.cluster
                    and cl.source = prob_temp.source
            );
        """)

    result = duckdb.sql("""
        select
            uuid,
            cluster,
            id,
            source,
            n,
        from
            clusters_temp;
    """)

    return result.df()

### Postgres version

In [233]:
def resolve_clusters_pg(prob, val, clus, n, threshold: float = 0.7):
    # This time we're reading and writing stuff from the DB
    # Assume prob, val and clus are all table names
    # (or possibly objects we get those names from)
    clusters_temp = "clusters_temp"
    probabilities_temp = "probabilities_temp"
    to_insert_temp = "to_insert_temp"
    
    # Create a temporary clusters table to work with 
    # until the algorithm has finished, for safety
    du.query_nonreturn(f"""
        drop table if exists {clusters_temp};
        create temporary table {clusters_temp} as
            select
                uuid,
                cluster,
                id,
                source,
                n
            from
                {clus}
            union
            select
                gen_random_uuid() as uuid,
                cluster,
                id,
                source,
                {n} as n
            from
                {val}
            where 
                source in (
                    select
                        source
                    from
                        {prob}
                );
    """)
    # Create a temporary probabilities table so we 
    # can delete stuff
    du.query_nonreturn(f"""
        drop table if exists {probabilities_temp};
        create temporary table {probabilities_temp} as
            select
                uuid,
                link_type,
                cluster,
                id,
                source,
                probability
            from
                {prob} prob
            where 
                prob.probability >= {threshold}
            order by
                probability desc;
    """)
    # Find what we need to insert by comparing clusters_temp and
    # probabilities_temp
    # Insert it into clusters_temp
    # Delete it from probabilities_temp
    # Keep going until there's nothing to find
    data_to_insert = True
    while data_to_insert:
        du.query_nonreturn(f"""
            drop table if exists {to_insert_temp};
            create temporary table {to_insert_temp} as
                select
                	distinct on (id_rank.id, id_rank.source)
                	gen_random_uuid() as uuid,
                	id_rank.cluster,
                	id_rank.id,
                	id_rank.source,
                	{n} as n
                from (
                	select
                		distinct on (clus_rank.cluster, clus_rank.source)
                		clus_rank.*,
                		rank() over (
                			partition by
                				clus_rank.id,
                				clus_rank.source
                			order by 
                				clus_rank.probability desc
                		) as id_rank
                	from (
                		select
                			prob.*,
                			rank() over(
                				partition by 
                					prob.cluster, 
                					prob.source
                				order by 
                					prob.probability desc
                			) as clus_rank
                		from
                			{probabilities_temp} prob
                	) clus_rank
                	where 
                		clus_rank.clus_rank = 1
                		and (
                			not exists (
                				select
                					id,
                					source
                				from
                					{clusters_temp} clus
                				where
                					clus.id = clus_rank.id
                					and clus.source = clus_rank.source
                			)
                			or not exists (
                				select
                					cluster,
                					source
                				from
                					{clusters_temp} clus
                				where
                					clus.cluster = clus_rank.cluster
                					and clus.source = clus_rank.source
                			)
                		)
                	order by
                		clus_rank.cluster, 
                		clus_rank.source
                ) id_rank
                where
                	id_rank.id_rank = 1
                order by
                	id_rank.id, 
                	id_rank.source;
        """)
        
        if du.check_table_empty(f"{to_insert_temp}"):
            data_to_insert = False
            break

        du.query_nonreturn(f"""
            insert into {clusters_temp}
            select
                uuid,
                cluster,
                id,
                source,
                n
            from
                {to_insert_temp};
        """)

        du.query_nonreturn(f"""
            delete from {probabilities_temp} prob_temp
            where exists (
                select 
                    cl.cluster,
                    cl.id,
                    cl.source
                from 
                    {to_insert_temp} cl
                where
                    (
                        cl.id = prob_temp.id
                        and cl.source = prob_temp.source
                    )
                    or (
                        cl.cluster = prob_temp.cluster
                        and cl.source = prob_temp.source
                    )
            );
        """)

    # New in this version -- add new items to clusters from temp
    # where the cluster match UUID is new

    du.query_nonreturn(f"""
        insert into {clus}
        select
            uuid,
            cluster,
            id,
            source,
            n
        from
            {clusters_temp} ct
        where not exists (
            select
                uuid,
                cluster,
                id,
                source,
                n
            from
                {clus} c
            where
                c.uuid = ct.uuid
        );
    """)

    # tidy up
    
    du.query_nonreturn(f"""
        drop table if exists {clusters_temp};
        drop table if exists {probabilities_temp};
        drop table if exists {to_insert_temp};
    """)
    

## Testing

In [234]:
tests = [
    "unambig_t2_e4",
    "unambig_t3_e2",
    "masked_t3_e3",
    "val_masked_t3_e2",
    "val_unambig_t3_e2",
]

### DuckDB version

#### Parallel tests

In [161]:
for test in tests:
    prob, clus, val = du.load_test_data(Path(loc.PROJECT_DIR, "test", test))
    clus_init = duckdb.sql("""
        drop sequence if exists uuid;
        drop sequence if exists cluster;
        create sequence uuid start 1;
        create sequence cluster start 1;
        select
            nextval('uuid') as uuid,
            nextval('cluster') as cluster,
            id,
            source,
            0 as n,
        from
            prob
        where
            cluster = 0
    """)
    my_answer = resolve_clusters(prob, val, clus_init, 1)
    passed = validate_against_answer(my_answer, clus, n_type = 'par')
    print(f"{test} passed: {passed}")

unambig_t2_e4 passed: True
unambig_t3_e2 passed: True
masked_t3_e3 passed: True
val_masked_t3_e2 passed: True
val_unambig_t3_e2 passed: True


#### Sequential tests

In [160]:
for test in tests:
    prob, clus, val = du.load_test_data(Path(loc.PROJECT_DIR, "test", test))
    clus_init = duckdb.sql("""
        drop sequence if exists uuid;
        drop sequence if exists cluster;
        create sequence uuid start 1;
        create sequence cluster start 1;
        select
            nextval('uuid') as uuid,
            nextval('cluster') as cluster,
            id,
            source,
            0 as n,
        from
            prob
        where
            cluster = 0
    """)
    prob_sequence_dict = {i - 1: g for i, g in prob.groupby('source')}
    val_sequence_dict = {i - 1: g for i, g in val.groupby('source')}
    for i in range(len(prob_sequence_dict)):
        prob_n = prob_sequence_dict[i]
        try:
            val_n = val_sequence_dict[i]
        except KeyError:
            val_n = val.iloc[0:0]
        clus_init = resolve_clusters(prob_n, val_n, clus_init, i)
    my_answer = clus_init
    passed = validate_against_answer(my_answer, clus, n_type = 'seq')
    print(f"{test} passed: {passed}")

unambig_t2_e4 passed: True
unambig_t3_e2 passed: True
masked_t3_e3 passed: True
val_masked_t3_e2 passed: True
val_unambig_t3_e2 passed: True


### Postgres version

#### Parallel tests

In [236]:
for test in tests:
    prob, clus, val = du.load_test_data(Path(loc.PROJECT_DIR, "test", test))
    du.query_nonreturn(f"""
        drop table if exists _user_eaf4fd9a.temp_prob;
        create table _user_eaf4fd9a.temp_prob (
            uuid bigint,
            link_type text,
            cluster bigint,
            id text,
            source bigint,
            probability double precision
        )
    """)
    du.data_workspace_write("_user_eaf4fd9a", "temp_prob", prob, if_exists="append")
    du.query_nonreturn(f"""
        drop table if exists _user_eaf4fd9a.temp_val;
        create table _user_eaf4fd9a.temp_val (
            uuid bigint,
            id text,
            cluster bigint,
            source bigint,
            "user" text,
            match bool 
        )
    """)
    du.data_workspace_write("_user_eaf4fd9a", "temp_val", val, if_exists="append")
    du.query_nonreturn(f"""
        drop table if exists _user_eaf4fd9a.temp_clus;
        create table _user_eaf4fd9a.temp_clus as
            select
                gen_random_uuid() as uuid,
                row_number() over () as cluster,
                init.id,
                init.source,
                0 as n
            from (
                select 
                    * 
                from 
                    _user_eaf4fd9a.temp_prob
                where
                    source = 1
            ) init
    """)
    resolve_clusters_pg(
        "_user_eaf4fd9a.temp_prob", 
        "_user_eaf4fd9a.temp_val", 
        "_user_eaf4fd9a.temp_clus",
        1,
        0.7
    )
    passed = validate_against_answer(
        du.query("select * from _user_eaf4fd9a.temp_clus"), 
        clus, 
        n_type = 'par'
    )
    du.query_nonreturn("""
        drop table if exists _user_eaf4fd9a.temp_prob;
        drop table if exists _user_eaf4fd9a.temp_clus;
        drop table if exists _user_eaf4fd9a.temp_val;
    """)
    print(f"{test} passed: {passed}")

unambig_t2_e4 passed: True
unambig_t3_e2 passed: True
masked_t3_e3 passed: True
val_masked_t3_e2 passed: True
val_unambig_t3_e2 passed: True


#### Sequential tests

In [231]:
for test in tests:
    prob, clus, val = du.load_test_data(Path(loc.PROJECT_DIR, "test", test))
    prob_sequence_dict = {i - 1: g for i, g in prob.groupby('source')}
    val_sequence_dict = {i - 1: g for i, g in val.groupby('source')}

    # Initialise clusters -- involves some messy work with the prob table but nvm
    du.query_nonreturn("drop table if exists _user_eaf4fd9a.temp_prob;")
    du.data_workspace_write("_user_eaf4fd9a", "temp_prob", prob, if_exists="append")
    du.query_nonreturn("""
        drop table if exists _user_eaf4fd9a.temp_clus;
        create table _user_eaf4fd9a.temp_clus as
            select
                gen_random_uuid() as uuid,
                row_number() over () as cluster,
                init.id,
                init.source,
                0 as n
            from (
                select 
                    * 
                from 
                    _user_eaf4fd9a.temp_prob
                where
                    source = 1
            ) init
    """)
    
    for i in range(len(prob_sequence_dict)):
        # Create probability table at step n
        prob_n = prob_sequence_dict[i]
        du.query_nonreturn("""
            drop table if exists _user_eaf4fd9a.temp_prob;
            create table _user_eaf4fd9a.temp_prob (
                uuid bigint,
                link_type text,
                cluster bigint,
                id text,
                source bigint,
                probability double precision
            )
        """)
        du.data_workspace_write("_user_eaf4fd9a", "temp_prob", prob_n, if_exists="append")

        # Create validation table at step n
        try:
            val_n = val_sequence_dict[i]
        except KeyError:
            val_n = val.iloc[0:0]
        du.query_nonreturn("""
            drop table if exists _user_eaf4fd9a.temp_val;
            create table _user_eaf4fd9a.temp_val (
                uuid bigint,
                id text,
                cluster bigint,
                source bigint,
                "user" text,
                match bool 
            )
        """)
        du.data_workspace_write("_user_eaf4fd9a", "temp_val", val_n, if_exists="append")

        # Resolve clusters
        resolve_clusters_pg(
            "_user_eaf4fd9a.temp_prob", 
            "_user_eaf4fd9a.temp_val", 
            "_user_eaf4fd9a.temp_clus",
            i,
            0.7
        )
    
    my_answer = clus_init
    passed = validate_against_answer(
        du.query("select * from _user_eaf4fd9a.temp_clus"), 
        clus, 
        n_type = 'seq'
    )
    print(f"{test} passed: {passed}")

unambig_t2_e4 passed: True
unambig_t3_e2 passed: True
masked_t3_e3 passed: True
val_masked_t3_e2 passed: True
val_unambig_t3_e2 passed: True
