# Message Filtering

In [None]:
from at_nlp.filters.string_filter import StringFilter
from pathlib import Path
from functools import partial
from snorkel.labeling import labeling_function
import pandas as pd
import numpy as np
from dask.diagnostics import Profiler, ResourceProfiler, CacheProfiler
from dask.diagnostics import ProgressBar
from dask.diagnostics import visualize
from dask.distributed import LocalCluster
from dask.distributed import Client

In [None]:
num_processors = 10
cluster = LocalCluster(n_workers=num_processors)          # Fully-featured local Dask cluster
client = cluster.get_client()
client

In [None]:
%load_ext rich

In [None]:
data_path = Path("/home/dwalker/SIGIL/natural_language_processing/nitmre/data/(CUI) alexa_816th_file_1a1.csv")
assert data_path.exists(), f"Data path {data_path} does not exist"
data = pd.read_csv(data_path)

In [None]:
sf.register_preprocessor([(0, fn)])
sf.register_csv_preprocessor(data_path, idx)
sf.register_csv_preprocessor(data_path2, idx)

sf.csv_name_data = dict()
sf.csv_name_data2 = dict()
sf.data.extend(csv_name, csv_name2)

In [None]:
sf = StringFilter()
sf.reset()

In [None]:
def lower_case(ds: pd.Series, col_idx: int):
    _s: str = ds.iat[col_idx]
    _s = _s.lower()
    ds.iat[col_idx] = _s
    return ds

def upper_case(ds: pd.Series, col_idx: int):
    _s: str = ds.iat[col_idx]
    _s = _s.upper()
    ds.iat[col_idx] = _s
    return ds


pre_processors = [
    (1, lower_case),
    (0, upper_case)
]

sf.register_preprocessor(pre_processors)

In [None]:
sf.print_preprocessor_stack()

In [None]:
test_csv_path = Path("../../tests/test.csv")

In [None]:
sf.register_csv_preprocessor(test_csv_path)
sf.print_preprocessor_stack()

In [None]:
test_arr = [[idx, f"test{idx}"] for idx in range(10_000)]

csv_indices = np.random.randint(low=1, high=1000, size=(200,))

for idx in csv_indices:
    test_arr[idx] = [idx, "APL"]

test_df = pd.DataFrame(
    test_arr,
    columns=["id", "text"],
)

In [None]:
num_divisions = 20

In [None]:
sf.preprocess(test_df, 1, True, num_divisions, False).visualize()

In [None]:
out_df = sf.preprocess(test_df, 1, True, num_divisions, False).compute()

In [None]:
out_df.head(100)

In [None]:
%timeit sf.preprocess(test_df, 1, False, num_divisions, False)

## Train

In [None]:
msg_filter.train(
    data, 
    {
        "stage-one": {
            "split": 0.9,
            "amt": 2000
        },
        "stage-two": {
            "split": 0.9,
            "amt": 1700
        }
    }, 
    serialize=False
)

In [None]:
row_apply = partial(msg_filter.template_miner_transform, tm=msg_filter.template_miner)

In [None]:
test_set = msg_filter.stage_one_test_data
test_labels = test_set["labels"]

In [None]:
msg_filter.latency_trace(test_set)

### Test Stage One

In [None]:
msg_filter.evaluate(test_set, test_labels, "rf")
msg_filter.evaluate(test_set, test_labels, "mlp")

## Label Ensemble

### Data Preparation

In [None]:
test_set = msg_filter.stage_two_test_data
test_labels = test_set["labels"]
test_set = msg_filter.applier.apply(test_set)

### Test Ensemble

In [None]:
msg_filter.evaluate(test_set, test_labels, "label_model")

In [None]:
ds = data[:5]
msg_filter.predict(ds)

In [None]:
@labeling_function()
def lf_confirmation(in_ds: pd.Series) -> int:
    msg = in_ds["Message"]
    msg = msg.lower()
    if "wilco" in msg or "affirm" in msg or "ok" in msg:
        return 2
    return 0

In [None]:
msg_filter.register_new_labeling_fn([lf_confirmation])

In [None]:
msg_filter.labeling_functions

In [None]:
msg_filter.train(
    data, 
    {
        "stage-one": {
            "split": 0.9,
            "amt": 2000
        },
        "stage-two": {
            "split": 0.9,
            "amt": 1700
        }
    }, 
    serialize=False
)

In [None]:
test_set = msg_filter.stage_two_test_data
test_labels = test_set["labels"]
test_set = msg_filter.applier.apply(test_set)
msg_filter.evaluate(test_set, test_labels, "label_model")

In [51]:
def test_function(ds: pd.Series, int) -> pd.Series:
    print(locals()['ds'])

In [52]:
test_function(pd.Series([0, 0,]), None)

0    0
1    0
dtype: int64


In [54]:
locals()

TypeError: locals() takes no arguments (1 given)