In [None]:
import dask.array as da
import dask.dataframe as dd
import dask
import pandas as pd
import numpy as np
from constants import DEBUG_POSTGRESQL_PARQUET_FOLDER
from pathlib import Path
from scipy.sparse import dok_array
import networkx as nx
from tqdm.contrib.concurrent import process_map
from multiprocessing import cpu_count, Pool
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import defaultdict

In [None]:
class QueryTemplateEncoder:
    """
    Why not sklearn.preprocessing.LabelEncoder()?

    - Not all labels (query templates) are known ahead of time.
    - Not that many query templates, so hopefully this isn't a bottleneck.
    """

    def __init__(self):
        self._encodings = {}
        self._inverse = {}
        self._next_label = 1

    def fit(self, labels):
        for label in labels:
            if label not in self._encodings:
                self._encodings[label] = self._next_label
                self._inverse[self._next_label] = label
                self._next_label += 1
        return self

    def transform(self, labels):
        return [self._encodings[label] for label in labels]

    def fit_transform(self, labels):
        return self.fit(labels).transform(labels)

    def inverse_transform(self, encodings):
        return [self._inverse[encoding] for encoding in encodings]


class DfMeta:
    SESSION_BEGIN = "SESSION_BEGIN"
    SESSION_END = "SESSION_END"

    def __init__(self):
        self.qt_enc = QueryTemplateEncoder()
        # Dummy tokens for session begin and session end.
        self.qt_enc.fit([self.SESSION_BEGIN, self.SESSION_END, pd.NA])

        # networkx dict_of_dicts format.
        self.transition_sessions = {}

    def augment(self, df):
        # Augment the dataframe while updating internal state.

        # Encode the query templates.
        df["query_template_enc"] = self.qt_enc.fit_transform(df["query_template"])

        self._update_transition_dict(self.transition_sessions, self._compute_transition_dict("session_id"))
        # In a world of autocommit, this doesn't matter as much.
        # self._update_transition_dict(self.transition_sessions, self._compute_transition_dict("virtual_transaction_id"))

    @staticmethod
    def _update_transition_dict(current, other):
        for src in other:
            current[src] = current.get(src, {})
            for dst in other[src]:
                current[src][dst] = current[src].get(dst, {"weight": 0})
                current[src][dst]["weight"] += other[src][dst]["weight"]

    def _compute_transition_dict(self, group_key):
        assert group_key in ["session_id", "virtual_transaction_id"], f"Unknown group key: {group_key}"

        group_fn = None
        if group_key == "session_id":
            group_fn = self._group_session
        elif group_key == "virtual_transaction_id":
            group_fn = self._group_txn
        assert group_fn is not None, "Forgot to add a case?"

        transitions = {}
        # grouped = [group_fn(item) for item in df.groupby(group_key)]
        groups = df.groupby(group_key)
        chunksize = len(groups) // cpu_count()
        grouped = process_map(group_fn, groups, chunksize=chunksize, desc=f"Grouping on {group_key}.", disable=True)
        # TODO(WAN): Parallelize.
        for group_id, group_qt_encs in tqdm(grouped, desc=f"Computing transition matrix for {group_key}.",
                                            disable=True):
            for transition in zip(group_qt_encs, group_qt_encs[1:]):
                src, dst = transition
                transitions[src] = transitions.get(src, {})
                transitions[src][dst] = transitions[src].get(dst, {"weight": 0})
                transitions[src][dst]["weight"] += 1
        return transitions

    def _group_txn(self, item):
        group_id, df = item
        df = df.sort_values(["log_time", "session_line_num"])
        group_vals = df["query_template_enc"].values
        return group_id, group_vals

    def _group_session(self, item):
        group_id, df = item
        df = df.sort_values(["log_time", "session_line_num"])
        group_vals = df["query_template_enc"].values
        group_vals = np.concatenate([
            self.qt_enc.transform([self.SESSION_BEGIN]),
            group_vals,
            self.qt_enc.transform([self.SESSION_END]),
        ])
        return group_id, group_vals

    # @staticmethod
    # def _compute_graph(transition_dict):
    #     dd = {}
    #     for src in transition_dict:
    #         for dst in transition_dict[src]:
    #             if dst == "count":
    #                 continue
    #             dd[src] = dd.get(src, {})
    #             dd[src][dst] = transition_dict[src][dst] / transition_dict[src]["count"]
    #         print(dd)
    #         raise Exception
    #     return nx.from_dict_of_dicts(dd)

    # def _build_markov_chain(self, group):
    #     assert group in ["session_id", "virtual_transaction_id"], "What are you grouping by?"
    #     grouped = self.df.groupby(group)
    #     group_vals = np.concatenate(
    #         self.qt_enc.inverse_transform([self.SESSION_BEGIN])[0],
    #         grouped.values,
    #         self.qt_enc.inverse_transform([self.SESSION_END])[0])


dfm = DfMeta()
for pq_file in tqdm(sorted(list(Path(DEBUG_POSTGRESQL_PARQUET_FOLDER).glob("*.parquet"))),
                    desc="Reading Parquet files.",
                    disable=True):
    df = pd.read_parquet(pq_file)
    df["query_template"] = df["query_template"].replace("", np.nan)
    dropna_before = df.shape[0]
    df = df.dropna(subset=["query_template"])
    dropna_after = df.shape[0]
    print(f"Dropped {dropna_before - dropna_after} empty query templates in {pq_file}.")
    dfm.augment(df)

In [None]:
dfm.qt_enc.inverse_transform([257])

In [None]:
%matplotlib notebook

G = nx.DiGraph(dfm.transition_sessions)
fig = plt.figure(figsize=(24, 36))
pos = nx.nx_agraph.graphviz_layout(G)
nx.draw(G, pos, with_labels=True)
labels = nx.get_edge_attributes(G, "weight")
nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)
plt.savefig("sessions.pdf")

In [None]:
G_old = G

In [None]:
# num_qt = df["query_template"].nunique()
# S = dok_array((num_qt, num_qt), dtype=np.int64)
# S[1, 2] = 3
# S[2, 3] = 3
# S[1, 5] = 4
# print(num_qt)
# print(S)
# G = nx.from_scipy_sparse_array(S, parallel_edges=False, edge_attribute="counts")
# print(G)

In [None]:
G = G_old

while True:
    deg_out_one = set([node for node, degree in G.out_degree if degree == 1])
    deg_in_one = set([node for node, degree in G.in_degree if degree == 1])
    contraction_candidates = list(deg_in_one & deg_out_one)

    for node in contraction_candidates:
        pred = list(G.predecessors(node))[0]
        succ = list(G.successors(node))[0]
        predw = G.in_edges[pred, node]["weight"]
        succw = G.out_edges[node, succ]["weight"]

        # Try to merge the node with its successor.
        if succ in deg_out_one and predw == succw:
            nx.contracted_nodes(G, node, succ, self_loops=False, copy=False)
            nx.relabel_nodes(G, {node: f"{node},{succ}"}, copy=False)
            break
    else:
        break
print(G)

In [None]:
fig = plt.figure(figsize=(24, 36))
# pos = nx.nx_agraph.graphviz_layout(G)
pos = nx.spring_layout(G, weight=None)
nx.draw(G, pos, with_labels=True)
labels = nx.get_edge_attributes(G, "weight")
nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)
plt.savefig("sessions_contracted.pdf")