In [None]:
# res = compare_experiments_barplot(
#     experiment_paths=[experiment_output_dir],
#     title="TARS eval.",
# )


## WANDB dev

In [None]:
from pathlib import Path

import pandas as pd
import yaml

from data_util import create_multi_label_train_test_splits

CONFIG = yaml.safe_load(
    Path(
        "/Users/samhardyhey/Desktop/blog/blog-multi-label/train/train_config.yaml"
    ).read_bytes()
)

# 1.1 create splits
df = pd.read_csv(CONFIG["dataset"])
train_split, test_split = create_multi_label_train_test_splits(
    df, label_col=CONFIG["label_col"], test_size=CONFIG["test_size"]
)
test_split, dev_split = create_multi_label_train_test_splits(
    test_split, label_col=CONFIG["label_col"], test_size=CONFIG["test_size"]
)

# # 1.2 log splits
# with wandb.init(
#     project=CONFIG["wandb_project"],
#     name="reddit_aus_finance",
#     group=CONFIG["wandb_group"],
#     entity="cool_stonebreaker",
# ) as run:
#     log_dataframe(run, train, "train_split", "Train split")
#     log_dataframe(run, dev, "dev_split", "Dev split")
#     log_dataframe(run, test, "test_split", "Test split")


## Dictionary classifier

In [None]:
from model_util import fit_and_log_dictionary_classifier, fit_and_log_linear_svc

for model in CONFIG["models"]:
    model["model"]
    # if model['name'] == 'dictionary_classifier':
    #     fit_and_log_dictionary_classifier(train, dev, test, model)

    # elif model['name'] == 'sklearn_linear_svc':
    #     fit_and_log_linear_svc(train, dev, test, model)

    # else:
    #     print(f"Unsupported model: {model['name']} found")


## Flair

In [None]:
from flair.data import Corpus, Sentence, Token
from flair.models import SequenceTagger, TARSClassifier, TARSTagger, TextClassifier
from flair.tokenization import SegtokTokenizer

sent = Sentence("hello world", use_tokenizer=SegtokTokenizer())


In [None]:
from eval_util import create_classification_report
from model.flair_tars import predict_flair_tars

test_preds = test_split.assign(
    pred=test_split[CONFIG["text_col"]].apply(lambda y: predict_flair_tars(y, tars))
)

classification_report = create_classification_report(test_split, test_preds, CONFIG)


In [None]:
from data_util import label_dictionary_to_label_mat

label_dictionary_to_label_mat(test_preds.label)

label_dictionary_to_label_mat(test_preds.pred)


In [None]:
# with wandb.init(
#         project=CONFIG["wandb_project"],
#         name=model_config["type"],
#         group=CONFIG["wandb_group"],
#         entity=CONFIG["wandb_entity"],
#     ) as run:
#     run.dir


In [None]:
import json
import tempfile

with tempfile.TemporaryDirectory() as artefact_dir:
    (Path(artefact_dir) / "label_dictionary.json").write_text(json.dumps({"a": 10}))
    (Path(artefact_dir) / "label_dictionary.json").read_text()
    # run.save(str(Path(artefact_dir) / 'label_dictionary.json'))


In [None]:
from model.flair_tars import fit_and_log_flair_tars_classifier

tars = fit_and_log_flair_tars_classifier(
    train_split, dev_split, test_split, CONFIG, CONFIG["models"][-1]
)


## WANDB misc

In [None]:
import wandb

api = wandb.Api()  # refresh state of project?
_ = [
    run.delete()
    for run in api.runs(path="cool_stonebreaker/tyre_kick")
    if run.name == "inter_group_model_comparison"
]


In [None]:
# clear out for dev purposes
import wandb

api = wandb.Api()

# _ = [run.delete() for run in api.runs(path="cool_stonebreaker/tyre_kick")]


In [None]:
[run.name for run in api.runs(path="cool_stonebreaker/tyre_kick")]
# log_inter_group_model_comparisons(project_artifacts, CONFIG)


In [None]:
proj = api.project("blog-multi-label-train")


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("darkgrid")
sns.set_palette("pastel", 12)

# plot results
g = sns.catplot(
    x="label",
    y="f1-score",
    hue="type",
    data=(
        group_model_classification_reports.pipe(
            lambda x: x[~x["label"].str.contains("accuracy|samples|macro|micro")]
        )
    ),
    height=10,
    kind="bar",
    ci=None,
)
g.despine(left=True)
g.set_xticklabels(rotation=45)


## Label formatting

In [1]:
x = """workplace
boss, co-workers, WFH, life balance, office, culture, hybrid

property
refinance, real estate, property, landlord, loan, buy, house, rate, rent, resident, afford, mortgage, bedroom, townhouse, auction, agent, defect, layout, floor plan, builder, boom, salary

tax
tax, land tax, gst, salary sacrifice

insurance
insurance, indemnity, income protection

super
super, contribution, fund, balance, self-funded, retire, pension

public institution
watch dog, rba, central bank, mint, fair work, bond

inflation
inflation, interest rates, reserve bank, phillip lowe, rba, petrol

exchange
exchange, rate, dollar

stocks
stock, shares, invest, indexed, van guard, wealth, assets, asx, commsec, etf, return, vdhg, high growth, selfwealth, dividends, securities, buy, dip, 200

toxic
butt, salty, fuck, laughable, fool, tard, lol, bro, shit"""

In [11]:
label_dicts = {}
for e in x.split('\n\n'):
    label_dicts[e.split('\n')[0]] = sorted(e.split('\n')[1].split(', '))

In [15]:
{'workplace': ['WFH', 'boss', 'co-workers', 'culture', 'hybrid', 'life balance', 'office'],
'property': ['afford', 'agent', 'auction', 'bedroom', 'boom', 'builder', 'buy', 'defect', 'floor plan', 'house', 'landlord', 'layout', 'loan', 'mortgage', 'property', 'rate', 'real estate', 'refinance', 'rent', 'resident', 'salary', 'townhouse'],
'tax': ['gst', 'land tax', 'salary sacrifice', 'tax'],
'insurance': ['income protection', 'indemnity', 'insurance'],
'super': ['balance', 'contribution', 'fund', 'pension', 'retire', 'self-funded', 'super'],
'public institution': ['bond', 'central bank', 'fair work', 'mint', 'rba', 'watch dog'],
'inflation': ['inflation', 'interest rates', 'petrol', 'phillip lowe', 'rba', 'reserve bank'],
'exchange': ['dollar', 'exchange', 'rate'],
'stocks': ['200', 'assets', 'asx', 'buy', 'commsec', 'dip', 'dividends', 'etf', 'high growth', 'indexed', 'invest', 'return', 'securities', 'selfwealth', 'shares', 'stock', 'van guard', 'vdhg', 'wealth'],
'toxic': ['bro', 'butt', 'fool', 'fuck', 'laughable', 'lol', 'salty', 'shit', 'tard']}

{'workplace': ['WFH',
  'boss',
  'co-workers',
  'culture',
  'hybrid',
  'life balance',
  'office'],
 'property': ['afford',
  'agent',
  'auction',
  'bedroom',
  'boom',
  'builder',
  'buy',
  'defect',
  'floor plan',
  'house',
  'landlord',
  'layout',
  'loan',
  'mortgage',
  'property',
  'rate',
  'real estate',
  'refinance',
  'rent',
  'resident',
  'salary',
  'townhouse'],
 'tax': ['gst', 'land tax', 'salary sacrifice', 'tax'],
 'insurance': ['income protection', 'indemnity', 'insurance'],
 'super': ['balance',
  'contribution',
  'fund',
  'pension',
  'retire',
  'self-funded',
  'super'],
 'public institution': ['bond',
  'central bank',
  'fair work',
  'mint',
  'rba',
  'watch dog'],
 'inflation': ['inflation',
  'interest rates',
  'petrol',
  'phillip lowe',
  'rba',
  'reserve bank'],
 'exchange': ['dollar', 'exchange', 'rate'],
 'stocks': ['200',
  'assets',
  'asx',
  'buy',
  'commsec',
  'dip',
  'dividends',
  'etf',
  'high growth',
  'indexed',
  'inve

## Save novel plot