In [1]:
import typing

import numpy as np
import pandas as pd

from datasets import load_dataset

import tqdm 
import requests

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import classification_report

from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.tree import DecisionTreeClassifier

import cltrier_nlp as nlp

In [2]:
SAMPLE_SIZE: int = 1_000

In [3]:
DATASET = load_dataset("stanfordnlp/imdb").shuffle()

In [4]:
def remote_encoding(batch: typing.List[str]) -> typing.List:
    embeds: typing.List[np.ndarray] = []

    for value in tqdm.tqdm(batch):
    
        try: 
            embed = np.array(requests.post(
                'https://inf.cl.uni-trier.de/embed/',
                json={'prompt': value}
            ).json()["response"])
            
        except Exception as _e:
            display(_e)
            embed = None
        
        embeds.append(embed)

    return embeds

In [5]:
ENCODERS: typing.Dict[str, typing.Dict[str, typing.Callable]] = {
    "tfidf": {
        "engine": (tfidf := TfidfVectorizer()),
        "embed_train": lambda x: tfidf.fit_transform(x),
        "embed_test": lambda x: tfidf.transform(x),
    },
    "tiny transformer (local)": {
        "engine": (transformer := nlp.encoder.Encoder()),
        "embed_train": lambda x: np.stack(
            [embed.detach().numpy() for embed in nlp.encoder.EncoderPooler()(transformer(x), form="sent_cls")]),
        "embed_test": lambda x: np.stack(
            [embed.detach().numpy() for embed in nlp.encoder.EncoderPooler()(transformer(x), form="sent_cls")]),
    },
    "sota transformer (remote)": {
        "embed_train": lambda x: remote_encoding(x),
        "embed_test": lambda x: remote_encoding(x),
    }
}



In [6]:
CLASSIFIERS: typing.Dict[str, typing.Callable] = {
    "random_forest": RandomForestClassifier,
    "ada_boost": AdaBoostClassifier,
    "decision_tree": DecisionTreeClassifier,
    "k_neighbors": KNeighborsClassifier,
    "mlp": MLPClassifier,
}

In [7]:
results: typing.List[pd.DataFrame] = []

for encoder_label, encoder in ENCODERS.items():

    embed_train = encoder["embed_train"](DATASET["train"][:SAMPLE_SIZE]["text"])
    embed_test = encoder["embed_test"](DATASET["test"][:SAMPLE_SIZE]["text"])

    results.append(
        pd.json_normalize(
            data=[
                classification_report(
                    DATASET["test"]["label"][:SAMPLE_SIZE],
                    (
                        classifier()
                        .fit(
                                embed_train, 
                                DATASET["train"]["label"][:SAMPLE_SIZE]
                            )
                        .predict(embed_test)
                    ),
                    zero_division=1.,
                    output_dict=True
                ) | {"classifier": classifier_label, "encoder": encoder_label}
                for classifier_label, classifier in CLASSIFIERS.items()
            ]
        )
        .set_index(["encoder", "classifier"], drop=True)
        .filter(
            items=[
                "accuracy",
                "macro avg.f1-score",
                "weighted avg.f1-score"
            ]
        )
        .sort_values(by="accuracy", ascending=False)
    )

100%|██████████| 1000/1000 [02:46<00:00,  6.02it/s]
100%|██████████| 1000/1000 [03:08<00:00,  5.31it/s]


In [8]:
pd.concat(results).sort_values(by="weighted avg.f1-score", ascending=False)

Unnamed: 0_level_0,Unnamed: 1_level_0,accuracy,macro avg.f1-score,weighted avg.f1-score
encoder,classifier,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
sota transformer (remote),random_forest,0.885,0.884163,0.884557
sota transformer (remote),mlp,0.878,0.877875,0.878031
sota transformer (remote),k_neighbors,0.865,0.864277,0.864673
sota transformer (remote),ada_boost,0.849,0.848575,0.848896
tfidf,mlp,0.805,0.804223,0.804716
tfidf,random_forest,0.801,0.800727,0.801022
sota transformer (remote),decision_tree,0.765,0.764338,0.764838
tfidf,ada_boost,0.741,0.737769,0.738933
tiny transformer (local),mlp,0.721,0.720575,0.721011
tiny transformer (local),random_forest,0.693,0.692483,0.692987
