In [1]:
import collections
import pathlib
import sys

from sklearn.preprocessing import MultiLabelBinarizer
from tqdm import tqdm

sys.path.append(str(pathlib.Path("../..").resolve()))


from source.data import (
    create_splits,
    explode_locc,
    explode_multiple_locc,
    get_label_to_index_mapping,
)
from source.files import get_book_text
from source.metrics import calculate_flat_binary_metrics
from source.models.xgb import (
    create_all_minilm_xgboost_model,
    create_paraphrase_multilingual_minilm_xgboost_model,
    create_tfidf_xgboost_model,
    create_bge_m3_xgboost_model
)

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/oleksandr/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /Users/oleksandr/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/oleksandr/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/oleksandr/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /Users/oleksandr/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
splits = create_splits(verbose=False)
X_train, X_test, y_train, y_test = splits

In [3]:
NUM_CHUNKS = 10
TOKENS_PER_CHUNK = 500

In [4]:
X_train_texts = [
    get_book_text(str(num), NUM_CHUNKS, TOKENS_PER_CHUNK)
    for num in tqdm(X_train["Etext Number"])
]
X_test_texts = [
    get_book_text(str(num), NUM_CHUNKS, TOKENS_PER_CHUNK)
    for num in tqdm(X_test["Etext Number"])
]

100%|██████████| 35542/35542 [00:35<00:00, 1002.90it/s]
100%|██████████| 13898/13898 [00:14<00:00, 992.34it/s] 


In [5]:
labels, lti, itl = get_label_to_index_mapping(splits)

def get_parent_locc(labels, locc) -> str | None:
    for cls in reversed(explode_locc(locc)):
        if cls == locc:
            continue

        if cls in labels:
            return cls

mlb = MultiLabelBinarizer(classes=labels)
y_train_binarized = mlb.fit_transform([explode_multiple_locc(locc) for locc in y_train])
y_test_binarized = mlb.transform([explode_multiple_locc(locc) for locc in y_test])



In [6]:
models = {
    "tfidf": lambda: create_tfidf_xgboost_model(),
    "all-minilm": lambda: create_all_minilm_xgboost_model(),
    "multilingual": lambda: create_paraphrase_multilingual_minilm_xgboost_model(),
    "bge_m3": lambda: create_bge_m3_xgboost_model()
}

In [7]:
model_name = "bge_m3"
model = models[model_name]()

In [None]:
model.fit(X_train_texts, y_train_binarized)

In [13]:
y_pred_binarized = model.predict(X_test_texts)

In [14]:
y_pred_binarized = mlb.transform(y_pred_binarized)

In [16]:
y_pred_binarized[0]

array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0])

In [None]:
calculate_flat_binary_metrics(y_test_binarized, y_pred_binarized, labels, "xgboost", model_name, save=True)