In [1]:
cd ..

/Users/yasas/Documents/Projects/textkit-learn


In [2]:
# import datasets
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from tklearn import datasets
from tklearn.metrics import TextClassificationMetric
from tklearn.nn.trainer import Trainer
from tklearn.nn.evaluator import Evaluator
from tklearn.nn.callbacks import ProgbarLogger
from tklearn.config import config, config_scope

In [3]:
config['emotion/Trainer'] = {
    'epochs': 2,
}

In [4]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")


def tokenize(texts):
    return tokenizer(texts["text"].tolist(), padding="max_length", truncation=True)

In [5]:
train_dset = (
    datasets.load_dataset('hf', 'dair-ai/emotion', split="train").take(range(100))
    .map(tokenize, batched=True, keep_columns=True)
    .rename_column('label', 'labels')
    .remove_columns(["text"])
)

# train_dset = train_dset[:500].to_pylist()

Map:   0%|          | 0/100 [00:00<?, ?it/s]

Rename:   0%|          | 0/100 [00:00<?, ?it/s]

Remove:   0%|          | 0/100 [00:00<?, ?it/s]

In [6]:
train_dset

Path,Format,#Rows,#Columns
/Users/yasas/.tklearn/cache/dataset-mapped-4712c86c82069e1b9a689596132b2037,arrow,100,4

Unnamed: 0,labels,input_ids,token_type_ids,attention_mask
0,0,"[101, 1045, 2134, 2102, 2514, 26608, 102, 0, 0...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,0,"[101, 1045, 2064, 2175, 2013, 3110, 2061, 2062...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,3,"[101, 10047, 9775, 1037, 3371, 2000, 2695, 104...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, ..."
3,2,"[101, 1045, 2572, 2412, 3110, 16839, 9080, 128...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
4,3,"[101, 1045, 2572, 3110, 24665, 7140, 11714, 10...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, ..."


In [7]:
valid_dset = (
    datasets.load_dataset('hf', "dair-ai/emotion", split="validation").take(range(100))
    .map(tokenize, batched=True, keep_columns=True)
    .rename_column('label', 'labels')
    .remove_columns(["text"])
)

Map:   0%|          | 0/100 [00:00<?, ?it/s]

Rename:   0%|          | 0/100 [00:00<?, ?it/s]

Remove:   0%|          | 0/100 [00:00<?, ?it/s]

In [8]:
len(train_dset), len(valid_dset)

(100, 100)

In [9]:
model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=8,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
callbacks = [
    ProgbarLogger(),
]

with config_scope('emotion'):
    trainer = Trainer(model, callbacks=callbacks)

In [11]:
metric = TextClassificationMetric(num_labels=8)

evaluator = Evaluator(valid_dset, metric=metric, postprocessor='argmax')

In [12]:
trainer.fit(train_dset, evaluator=evaluator)

Train:   0%|          | 0/26 [00:00<?, ?it/s]

Predict:   0%|          | 0/13 [00:00<?, ?it/s]

Predict:   0%|          | 0/13 [00:00<?, ?it/s]

<tklearn.nn.trainer.Trainer at 0x2d2087190>

In [13]:
test_dset = (
    datasets.load_dataset('hf', "dair-ai/emotion", split="test")
    .map(tokenize, batched=True, keep_columns=True)
    .rename_column('label', 'labels')
    .remove_columns(["text"])
)

evaluator = Evaluator(test_dset, metric=metric, postprocessor='argmax')

In [14]:
evaluator.evaluate(trainer)

Predict:   0%|          | 0/250 [00:00<?, ?it/s]

{'micro_f1': 0.378,
 'micro_precision': 0.378,
 'micro_recall': 0.378,
 'macro_f1': 0.13633159655519014,
 'macro_precision': 0.19161561310684116,
 'macro_recall': 0.18868489748961273,
 'weighted_f1': 0.258912318063247,
 'weighted_precision': 0.3059185533277638,
 'weighted_recall': 0.378,
 'accuracy': 0.378}