In [None]:
# res = compare_experiments_barplot(
#     experiment_paths=[experiment_output_dir],
#     title="TARS eval.",
# )


## WANDB dev

In [36]:
from pathlib import Path

import pandas as pd
import yaml

from data_util import create_multi_label_train_test_splits

CONFIG = yaml.safe_load(
    Path(
        "/Users/samhardyhey/Desktop/blog/blog-multi-label/train/train_config.yaml"
    ).read_bytes()
)

# 1.1 create splits
df = pd.read_csv(CONFIG["dataset"])
train_split, test_split = create_multi_label_train_test_splits(
    df, label_col=CONFIG["label_col"], test_size=CONFIG["test_size"]
)
test_split, dev_split = create_multi_label_train_test_splits(
    test_split, label_col=CONFIG["label_col"], test_size=CONFIG["test_size"]
)

# # 1.2 log splits
# with wandb.init(
#     project=CONFIG["wandb_project"],
#     name="reddit_aus_finance",
#     group=CONFIG["wandb_group"],
#     entity="cool_stonebreaker",
# ) as run:
#     log_dataframe(run, train, "train_split", "Train split")
#     log_dataframe(run, dev, "dev_split", "Dev split")
#     log_dataframe(run, test, "test_split", "Test split")


## Dictionary classifier

In [None]:
from model_util import fit_and_log_dictionary_classifier, fit_and_log_linear_svc

for model in CONFIG["models"]:
    model["model"]
    # if model['name'] == 'dictionary_classifier':
    #     fit_and_log_dictionary_classifier(train, dev, test, model)

    # elif model['name'] == 'sklearn_linear_svc':
    #     fit_and_log_linear_svc(train, dev, test, model)

    # else:
    #     print(f"Unsupported model: {model['name']} found")


## Flair

In [14]:
from flair.data import Corpus, Sentence, Token
from flair.models import (SequenceTagger, TARSClassifier, TARSTagger,
                          TextClassifier)
from flair.tokenization import SegtokTokenizer


sent = Sentence('hello world', use_tokenizer=SegtokTokenizer())

In [37]:
import tempfile
from pathlib import Path

import pandas as pd
from clear_bow.classifier import DictionaryClassifier
from flair.data import Corpus, Sentence
from flair.models import TARSClassifier
from flair.tokenization import SegtokTokenizer
from flair.trainers import ModelTrainer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_selection import VarianceThreshold
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from skmultilearn.problem_transform import BinaryRelevance

import wandb
from data_util import log_dataframe
from eval_util import (create_classification_report,
                       create_slim_classification_report,
                       label_dictionary_to_label_mat)

In [56]:
def create_flair_classification_sentence(text, label_object, label_type="class"):
    sentence = Sentence(text, use_tokenizer=SegtokTokenizer())
    for label in [k for k, v in label_object.items() if v > 0]:
        sentence.add_label(label_type, label, 1.0)
    return sentence

def predict_flair_tars(text, flair_tars_model):
    sentence = Sentence(text)
    labels = flair_tars_model.get_current_label_dictionary().get_items()
    flair_tars_model.predict(sentence)
    pred_dict = {
        label: 0.0 for label in labels
    }
    for e in sentence.labels:
        label = e.to_dict()["value"]
        confidence = round(float(e.to_dict()["confidence"]), 2)
        pred_dict[label] = confidence
    return pred_dict

In [39]:
from model_util import fit_and_log_flair_tars

model_config = CONFIG['models'][-1]

with wandb.init(
    project=CONFIG["wandb_project"],
    name=model_config["type"],
    group=CONFIG["wandb_group"],
    entity=CONFIG["wandb_entity"],
) as run:
    wandb.config.type = model_config["type"]
    wandb.config.group = CONFIG["wandb_group"]
    label_type = model_config.get("label_type", 'multi_label_class')

    train_dev = pd.concat([train_split, dev_split], sort=True)
    train_sents = (
        pd.concat([train_split, dev_split], sort=True)
        .apply(
            lambda x: create_flair_classification_sentence(
                x[CONFIG["text_col"]], x[CONFIG["label_col"]], label_type
            ),
            axis=1,
        )
        .tolist()
    )
    test_sents = test_split.apply(
        lambda x: create_flair_classification_sentence(
            x[CONFIG["text_col"]], x[CONFIG["label_col"]], label_type
        ),
        axis=1,
    ).tolist()

    # make a corpus with train and test split
    corpus = Corpus(train=train_sents, test=test_sents)

    # train a tiny model, with tiny parameters
    tars = TARSClassifier.load("tars-base")

    # 2. make the model aware of the desired set of labels from the new corpus
    tars.add_and_switch_to_new_task(
        task_name=label_type,
        label_dictionary=corpus.make_label_dictionary(label_type),
        label_type=label_type,
        multi_label=True,
    )

    # 3. initialize the text classifier trainer with your corpus
    trainer = ModelTrainer(tars, corpus)

    # 4. train model
    with tempfile.TemporaryDirectory() as artefact_dir:
        trainer.train(
            base_path=artefact_dir,  # path to store the model artifacts
            learning_rate=model_config.get(
                "learning_rate", 0.02
            ),  # use very small learning rate
            mini_batch_size=model_config.get(
                "mini_batch_size", 1
            ),  # small mini-batch size since corpus is tiny
            max_epochs=model_config.get("max_epochs", 10),
            save_final_model=model_config.get("max_epochs", False),
        )
        trainer.model.save(
            Path(artefact_dir) / "final-model.pt", checkpoint=False)

2022-09-21 13:43:08,108 loading file /Users/samhardyhey/.flair/models/tars-base-v8.pt
2022-09-21 13:43:20,656 Computing label dictionary. Progress:


100%|██████████| 38/38 [00:00<00:00, 2603.07it/s]

2022-09-21 13:43:20,678 Corpus contains the labels: multi_label_class (#38)
2022-09-21 13:43:20,679 Created (for label 'multi_label_class') Dictionary with 6 tags: <unk>, covid, retirement, regulation, fund, contribution
2022-09-21 13:43:20,683 ----------------------------------------------------------------------------------------------------
2022-09-21 13:43:20,688 Model: "TARSClassifier(
  (tars_model): TextClassifier(
    (loss_function): CrossEntropyLoss()
    (document_embeddings): TransformerDocumentEmbeddings(
      (model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
           




2022-09-21 13:43:24,780 epoch 1 - iter 3/38 - loss 0.50980861 - samples/sec: 0.77 - lr: 0.020000
2022-09-21 13:43:34,915 epoch 1 - iter 6/38 - loss 0.44571636 - samples/sec: 0.30 - lr: 0.020000
2022-09-21 13:43:40,714 epoch 1 - iter 9/38 - loss 0.39601129 - samples/sec: 0.52 - lr: 0.020000
2022-09-21 13:43:46,622 epoch 1 - iter 12/38 - loss 0.31869319 - samples/sec: 0.51 - lr: 0.020000
2022-09-21 13:43:52,265 epoch 1 - iter 15/38 - loss 0.31104973 - samples/sec: 0.53 - lr: 0.020000
2022-09-21 13:43:56,940 epoch 1 - iter 18/38 - loss 0.28198198 - samples/sec: 0.64 - lr: 0.020000
2022-09-21 13:44:01,651 epoch 1 - iter 21/38 - loss 0.26863811 - samples/sec: 0.64 - lr: 0.020000
2022-09-21 13:44:05,657 epoch 1 - iter 24/38 - loss 0.25754949 - samples/sec: 0.75 - lr: 0.020000
2022-09-21 13:44:09,523 epoch 1 - iter 27/38 - loss 0.26086130 - samples/sec: 0.78 - lr: 0.020000
2022-09-21 13:44:13,664 epoch 1 - iter 30/38 - loss 0.24829121 - samples/sec: 0.72 - lr: 0.020000
2022-09-21 13:44:17,284

{'test_score': 0.33333333333333326,
 'dev_score_history': [0.42857142857142855],
 'train_loss_history': [0.24733116775751113],
 'dev_loss_history': [0.46979967495426533]}

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

In [57]:
test_preds = test_split.assign(pred=test_split[CONFIG["text_col"]].apply(
    lambda y: predict_flair_tars(y, tars)))


In [60]:
create_classification_report(test_split, test_preds, CONFIG)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,label,precision,recall,f1-score,support
0,covid,0.0,0.0,0.0,2.0
1,retirement,1.0,0.5,0.666667,2.0
2,regulation,0.0,0.0,0.0,2.0
3,fund,0.0,0.0,0.0,3.0
4,contribution,0.0,0.0,0.0,1.0
5,micro avg,0.5,0.1,0.166667,10.0
6,macro avg,0.2,0.1,0.133333,10.0
7,weighted avg,0.2,0.1,0.133333,10.0
8,samples avg,0.125,0.0625,0.083333,10.0


In [48]:
sentence = corpus.test[1]
tars.predict(sentence)

In [55]:
predict_flair_tars('hello world', tars)

{'covid': 0.0,
 'retirement': 0.0,
 'regulation': 0.0,
 'fund': 0.0,
 'contribution': 0.67}

In [49]:
sentence.labels

[retirement (0.8649),
 regulation (0.7273),
 fund (0.6314),
 retirement (0.8649),
 regulation (0.7273),
 fund (0.6314)]

## WANDB misc

In [None]:
# clear out for dev purposes
import wandb

api = wandb.Api()

for run in api.runs(path="cool_stonebreaker/tyre_kick"):
    run = api.run(f"cool_stonebreaker/tyre_kick/{run.id}")
    run.delete()


In [None]:
!pip install plotly

In [None]:
import matplotlib.pyplot as plt

import wandb

fibonacci = [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
plt.plot(fibonacci)
plt.ylabel("some interesting numbers")

# Initialize run
with wandb.init(
    project=CONFIG["wandb_project"],
    name="flair_tars",
    group=CONFIG["wandb_group"],
    entity="cool_stonebreaker",
) as run:

    # Log plot object
    wandb.log({"plot": plt})
