## Bayesian Expansion Algorithm

In [1]:
import functools
import logging
import json
import multiprocessing as mp
import os
from typing import Dict, List

import numpy as np
from tqdm.notebook import tqdm

import helpers
from mssqldb import MSSQLDatabase

This runs on a SqlServer database on SciServer.

If you want use something different. Roll your own database class and have 
it implement the same methods as the MSSQLDatabase class.

In [2]:
db = MSSQLDatabase.from_file(
    "/home/idies/workspace/Storage/ryanhausen/persistent/tip/tip.json"
)

### Database setup

First build the database tables that will be used to store the results and intermediate results

Setting `danger=True` will drop the results tables, so only do that if you want to start from scratch.

In [3]:
db.execute_update(helpers.get_sql("sql/DDL.pysql", danger=False))

### Run the algorithm

1. Run the sql to fill the tables with the authors to run the algorithm on
2. Build the sparse matrix
3. Run the algorithm
4. Save the results back to the database

In [3]:
#https://stackoverflow.com/a/11233293/2691018
def setup_logger(name:str, log_file:str, level=logging.INFO):
    """To setup as many loggers as you want"""

    handler = logging.FileHandler(log_file)
    handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))

    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.addHandler(handler)

    return logger

def update_year(
    year:int,
    db:MSSQLDatabase,
    scopus_version:int,
    run_id:int,
    threshold:float,
    prior_alpha:float,
    prior_beta:float,
    n_workers:int=os.cpu_count(),
) -> None:

    logger = setup_logger(
        f"bayesian_expansion_{year}",
        f"./logs/{run_id}_{year}.log",
    )

    logger.info(f"run_id={run_id}")

    if n_workers < 1 or type(n_workers)!=int:
        logger.fatal(f"n_workers needs to be an integer >=1, actual {n_workers}")
        raise ValueError("n_workers needs to be an integer >=1")

    with helpers.LogTime(logger, "-----Start Run-----"):
        with helpers.LogTime(logger, f"building eid/auid tables for {year}"):
            db.execute_update(helpers.get_sql(
                "sql/update_year.pysql",
                year=year,
                threshold=threshold,
                run_id=run_id,
                scopus_version=scopus_version
            ))

        # Has columns:
        # auid:         int
        # is_seed:      int {0, 1}
        # intial_score: float
        with helpers.LogTime(logger, "getting authors"):
            input_auids = db.execute_query(helpers.get_sql(
                "sql/get_authors.pysql",
                year=year,
                run_id=run_id,
            )).astype(dict(
                auid = np.int32,
                is_seed = np.uint8,
                initial_score = np.float32,
            ))

        # We want to sort by:
        # is_seed: so that there are continguous areas of labeled and unlabeled data
        input_auids.sort_values(
            by=["is_seed","auid"],
            ascending=[False, True],
            ignore_index=True,
            inplace=True
        )

        # We zip the auids with the index values to match every auid with an integer
        # that will be it's row in the afjacency matrix
        auids_idxs = zip(
            input_auids["auid"].values,
            input_auids.index.values.astype(np.int32),
        )


        auid_idx_map = {auid:idx for auid, idx in auids_idxs}
        transform_f = functools.partial(helpers.transform_row_to_idx_pairs, auid_idx_map)


        # [n,]
        score_vector = input_auids["initial_score"].values

        # this is the index that separates labeled and unlabeled data
        labeled_idx = np.argmin(score_vector)
        logger.info(f"the split index is {labeled_idx}")


        # Has columns
        # eid:   int
        # auids: str (comma separated auids per eid)
        with helpers.LogTime(logger, "getting eids with auids"):
            input_eid_auids = db.execute_query(helpers.get_sql(
                "sql/get_authors_per_eid.pysql",
                year=year,
                run_id=run_id,
                scopus_version=scopus_version,
            ))

        # array of strings containing comma separated auids per eid
        csv_auids_per_eid = input_eid_auids["auids"].values.flatten()

        del input_eid_auids

        with helpers.LogTime(logger, "building adjacency matrix"):
            if n_workers==1:
                logger.info("running in serial")
                adjacency_matrix = helpers.arr_to_matrix(
                    transform_f,
                    len(score_vector),
                    labeled_idx,
                    csv_auids_per_eid,
                )
            else:
                logger.info(f"running in parallel with {n_workers} workers")
                map_f = functools.partial(
                    helpers.arr_to_matrix,
                    transform_f,
                    len(score_vector),
                    labeled_idx,
                )
                with mp.Pool(n_workers) as p:
                    matricies = p.map(
                        map_f,
                        np.array_split(csv_auids_per_eid, n_workers)
                    )

                adjacency_matrix = sum(matricies)


        with helpers.LogTime(logger, "Computing new labels"):
            score_vector = helpers.compute(
                score_vector,
                adjacency_matrix,
                labeled_idx,
                prior_alpha,
                prior_beta,
                max_iter=100,
                logger=logger,
            )

        input_auids["final_score"] = score_vector

        with helpers.LogTime(logger, "Deleting adj matrix/score vector"):
            del adjacency_matrix
            del score_vector

        # push the new scores back to the db
        with helpers.LogTime(logger, "Converting to records"):
            # the sql dtypes are:
            # auid          INT
            # is_seed       TINYINT
            # initial_score FLOAT(24)
            # final_score   FLOAT(24)
            dtypes = dict(
                auid="<i4",
                is_seed="<u1",
                initial_score="<f4",
                final_score="<f4",
            )


            input_auids.astype(dtypes).to_records(
                index=False,
                column_dtypes=dtypes
            ).tofile(
                f"/home/idies/workspace/showusthedata/tip/ryan/bulk_tmp_bayes_{run_id}_{year}.bin"
            )

        with helpers.LogTime(logger, "Bulk inserting records"):

            db.execute_update(helpers.get_sql("sql/tmp_update_DDL.pysql"))
            db.execute_update(helpers.get_sql(
                "sql/update_bulk_insert.pysql",
                year=year,
                run_id=run_id,
            ))

        with helpers.LogTime(logger, "Merging updates in database"):
            db.execute_update(helpers.get_sql(
                "sql/merge_results.pysql",
                year=year,
                run_id=run_id,
            ))


def get_new_run_id(metadata:dict) -> int:
    """
    This function retrieves the first value of the first row from the result of a SQL query.

    The SQL query is defined in the file "sql/get_authors.pysql", and it uses the provided metadata
    as parameters. The metadata is converted to a JSON string before being passed to the SQL query.

    The result of the query is expected to be a pandas DataFrame, and this function returns the first
    value of the first row from this DataFrame.

    Args:
        metadata (dict): A dictionary containing metadata that will be passed as parameters to the SQL query.

    Returns:
        int: The first value of the first row from the result of the SQL query.

    Raises:
        ValueError: If the result of the SQL query is not a pandas DataFrame or if it's empty.
    """
    db.execute_update(helpers.get_sql(
        "sql/generate_new_run_id.pysql",
        metadata=json.dumps(metadata),
    ))

    new_id = db.execute_query(helpers.get_sql(
        "sql/get_new_run_id.pysql",
    ))

    return new_id.iloc[0, 0]



In [None]:
threshold = 0.75
prior_alpha=1
prior_beta=1
n_workers = 4
scopus_version = 4
run_id = get_new_run_id(dict(
    threshold=threshold,
    prior_alpha=prior_alpha,
    prior_beta=prior_beta,
    n_workers=n_workers,
    scopus_version=scopus_version
))

for year in tqdm(range(2010, 2023)):
    update_year(year, db, scopus_version, run_id, threshold, prior_alpha, prior_beta, n_workers)

  0%|          | 0/13 [00:00<?, ?it/s]