In [None]:
# res = compare_experiments_barplot(
#     experiment_paths=[experiment_output_dir],
#     title="TARS eval.",
# )


## WANDB dev

In [None]:
from pathlib import Path

import pandas as pd
import yaml

from data_util import create_multi_label_train_test_splits

CONFIG = yaml.safe_load(
    Path(
        "/Users/samhardyhey/Desktop/blog/blog-multi-label/train/train_config.yaml"
    ).read_bytes()
)

# 1.1 create splits
df = pd.read_csv(CONFIG["dataset"])
train_split, test_split = create_multi_label_train_test_splits(
    df, label_col=CONFIG["label_col"], test_size=CONFIG["test_size"]
)
test_split, dev_split = create_multi_label_train_test_splits(
    test_split, label_col=CONFIG["label_col"], test_size=CONFIG["test_size"]
)

# # 1.2 log splits
# with wandb.init(
#     project=CONFIG["wandb_project"],
#     name="reddit_aus_finance",
#     group=CONFIG["wandb_group"],
#     entity="cool_stonebreaker",
# ) as run:
#     log_dataframe(run, train, "train_split", "Train split")
#     log_dataframe(run, dev, "dev_split", "Dev split")
#     log_dataframe(run, test, "test_split", "Test split")


## Dictionary classifier

In [None]:
from model_util import fit_and_log_dictionary_classifier, fit_and_log_linear_svc

for model in CONFIG["models"]:
    model["model"]
    # if model['name'] == 'dictionary_classifier':
    #     fit_and_log_dictionary_classifier(train, dev, test, model)

    # elif model['name'] == 'sklearn_linear_svc':
    #     fit_and_log_linear_svc(train, dev, test, model)

    # else:
    #     print(f"Unsupported model: {model['name']} found")


## Flair

In [None]:
from flair.data import Corpus, Sentence, Token
from flair.models import SequenceTagger, TARSClassifier, TARSTagger, TextClassifier
from flair.tokenization import SegtokTokenizer

sent = Sentence("hello world", use_tokenizer=SegtokTokenizer())


In [None]:
import pandas as pd
from clear_bow.classifier import DictionaryClassifier
from flair.data import Corpus, Sentence
from flair.models import TARSClassifier
from flair.tokenization import SegtokTokenizer
from flair.trainers import ModelTrainer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_selection import VarianceThreshold
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from skmultilearn.problem_transform import BinaryRelevance

import wandb
from data_util import log_dataframe
from eval_util import (
    create_classification_report,
    create_slim_classification_report,
    label_dictionary_to_label_mat,
)


In [None]:
def create_flair_classification_sentence(text, label_object, label_type="class"):
    sentence = Sentence(text, use_tokenizer=SegtokTokenizer())
    for label in [k for k, v in label_object.items() if v > 0]:
        sentence.add_label(label_type, label, 1.0)
    return sentence


def predict_flair_tars(text, flair_tars_model):
    sentence = Sentence(text)
    labels = flair_tars_model.get_current_label_dictionary().get_items()
    flair_tars_model.predict(sentence)
    pred_dict = {label: 0.0 for label in labels}
    for e in sentence.labels:
        label = e.to_dict()["value"]
        confidence = round(float(e.to_dict()["confidence"]), 2)
        pred_dict[label] = confidence
    return pred_dict


In [None]:
from model_util import fit_and_log_flair_tars

model_config = CONFIG["models"][-1]

with wandb.init(
    project=CONFIG["wandb_project"],
    name=model_config["type"],
    group=CONFIG["wandb_group"],
    entity=CONFIG["wandb_entity"],
) as run:
    wandb.config.type = model_config["type"]
    wandb.config.group = CONFIG["wandb_group"]
    label_type = model_config.get("label_type", "multi_label_class")

    train_dev = pd.concat([train_split, dev_split], sort=True)
    train_sents = (
        pd.concat([train_split, dev_split], sort=True)
        .apply(
            lambda x: create_flair_classification_sentence(
                x[CONFIG["text_col"]], x[CONFIG["label_col"]], label_type
            ),
            axis=1,
        )
        .tolist()
    )
    test_sents = test_split.apply(
        lambda x: create_flair_classification_sentence(
            x[CONFIG["text_col"]], x[CONFIG["label_col"]], label_type
        ),
        axis=1,
    ).tolist()

    # make a corpus with train and test split
    corpus = Corpus(train=train_sents, test=test_sents)

    # train a tiny model, with tiny parameters
    tars = TARSClassifier.load("tars-base")

    # 2. make the model aware of the desired set of labels from the new corpus
    tars.add_and_switch_to_new_task(
        task_name=label_type,
        label_dictionary=corpus.make_label_dictionary(label_type),
        label_type=label_type,
        multi_label=True,
    )

    # 3. initialize the text classifier trainer with your corpus
    trainer = ModelTrainer(tars, corpus)

    # 4. train model
    with tempfile.TemporaryDirectory() as artefact_dir:
        trainer.train(
            base_path=artefact_dir,  # path to store the model artifacts
            learning_rate=model_config.get(
                "learning_rate", 0.02
            ),  # use very small learning rate
            mini_batch_size=model_config.get(
                "mini_batch_size", 1
            ),  # small mini-batch size since corpus is tiny
            max_epochs=model_config.get("max_epochs", 10),
            save_final_model=model_config.get("max_epochs", False),
        )
        trainer.model.save(Path(artefact_dir) / "final-model.pt", checkpoint=False)


In [None]:
test_preds = test_split.assign(
    pred=test_split[CONFIG["text_col"]].apply(lambda y: predict_flair_tars(y, tars))
)


In [None]:
create_classification_report(test_split, test_preds, CONFIG)


In [None]:
sentence = corpus.test[1]
tars.predict(sentence)


In [None]:
predict_flair_tars("hello world", tars)


In [None]:
sentence.labels


## WANDB misc

In [None]:
# clear out for dev purposes
import wandb

api = wandb.Api()

for run in api.runs(path="cool_stonebreaker/tyre_kick"):
    run = api.run(f"cool_stonebreaker/tyre_kick/{run.id}")
    run.delete()


In [None]:
!pip install plotly

In [None]:
import matplotlib.pyplot as plt

import wandb

fibonacci = [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
plt.plot(fibonacci)
plt.ylabel("some interesting numbers")

# Initialize run
with wandb.init(
    project=CONFIG["wandb_project"],
    name="flair_tars",
    group=CONFIG["wandb_group"],
    entity="cool_stonebreaker",
) as run:

    # Log plot object
    wandb.log({"plot": plt})
