# Purpose

- Inspired by https://towardsdatascience.com/text-classification-challenge-with-extra-small-datasets-fine-tuning-versus-chatgpt-6348fecea357
- Want to explore approaches using transformers that work well on multi-label text classification problems with limited data

# Packages

In [1]:
import pandas as pd
import time
import warnings

warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2
from constants import DATASOURCES
from data import DataCollector
from src.classifier.basic import BasicClassifier
from src.classifier.retrieval import RetrievalClassifier
from src.utils import evaluate_clf

# Data

- Refer to README.md for details about the datasets being used
- For each dataset, sample a small subset to mimic situation with limited labelled data
- Preserve distribution of labels between train & test using stratified random sampling

In [2]:
data = DataCollector(n_train=500, n_test=250, n_labels=5)
data.load()

Loading goemotions - goemotions_1.csv ...
Done in 0.3s with 501 records in train & 249 records in test
Loading just_dance - jd-multi-label-dataset.csv ...
Done in 1.3s with 498 records in train & 252 records in test
Loading pubmed - PubMed Multi Label Text Classification Dataset Processed.csv ...
Done in 1.1s with 500 records in train & 250 records in test
Loading research_papers - train.csv ...
Done in 0.2s with 501 records in train & 249 records in test


In [3]:
data.print_label_distribution()


*** Dataset: goemotions (train=501, test=249) ***

            train  test
admiration    124    62
amusement      75    37
anger          68    34
annoyance     115    58
approval      159    79

*** Dataset: just_dance (train=498, test=252) ***

              train  test
Usability       198    99
UX              457   229
H-QOL           377   189
Memorability      6     1
Learnability     29    13

*** Dataset: pubmed (train=500, test=250) ***

   train  test
A    223   130
B    477   239
C    264   132
D    322   161
E    395   198

*** Dataset: research_papers (train=501, test=249) ***

                      train  test
Computer Science        209   105
Physics                 149    74
Mathematics             121    60
Statistics              133    66
Quantitative Biology     15     8


# Fit Classifiers

- Want to compare 2 different classifiers
    - BasicClassifier: traditional tfidf + SVM classifer approach
    - RetrievalClassifier: creates vectorstore using training data & predicts using labels from relevant training data retrieved from vectorstore

In [4]:
clfs = []
for data_name, (X_train, y_train, X_test, y_test) in data.datasets:
    for clf_cls in [BasicClassifier, RetrievalClassifier]:
        t = time.time()
        clf = clf_cls()
        clf.fit(X_train, y_train)
        clf_name = clf.__class__.__name__
        print(f'{data_name} - {clf_name} fit completed in {time.time()-t:.1f}s')
        clfs.append({'data_name': data_name, 'clf_name': clf_name, 'clf': clf})

goemotions - BasicClassifier fit completed in 0.0s
goemotions - RetrievalClassifier fit completed in 3.3s
just_dance - BasicClassifier fit completed in 0.0s
just_dance - RetrievalClassifier fit completed in 1.5s
pubmed - BasicClassifier fit completed in 0.2s
pubmed - RetrievalClassifier fit completed in 12.0s
research_papers - BasicClassifier fit completed in 0.1s
research_papers - RetrievalClassifier fit completed in 11.2s


In [5]:
clfs

[{'data_name': 'goemotions',
  'clf_name': 'BasicClassifier',
  'clf': BasicClassifier()},
 {'data_name': 'goemotions',
  'clf_name': 'RetrievalClassifier',
  'clf': RetrievalClassifier(k=4, model_name='sentence-transformers/all-MiniLM-L6-v2')},
 {'data_name': 'just_dance',
  'clf_name': 'BasicClassifier',
  'clf': BasicClassifier()},
 {'data_name': 'just_dance',
  'clf_name': 'RetrievalClassifier',
  'clf': RetrievalClassifier(k=4, model_name='sentence-transformers/all-MiniLM-L6-v2')},
 {'data_name': 'pubmed',
  'clf_name': 'BasicClassifier',
  'clf': BasicClassifier()},
 {'data_name': 'pubmed',
  'clf_name': 'RetrievalClassifier',
  'clf': RetrievalClassifier(k=4, model_name='sentence-transformers/all-MiniLM-L6-v2')},
 {'data_name': 'research_papers',
  'clf_name': 'BasicClassifier',
  'clf': BasicClassifier()},
 {'data_name': 'research_papers',
  'clf_name': 'RetrievalClassifier',
  'clf': RetrievalClassifier(k=4, model_name='sentence-transformers/all-MiniLM-L6-v2')}]

# Evaluate Classifiers

In [6]:
summaries, details = [], []
for clf in clfs:
    data_name = clf['data_name']
    _, _, X_test, y_test = data.get_datasets(data_name)
    summary, detail = evaluate_clf(clf['clf'], X_test, y_test)
    summary['data_name'], detail['data_name'] = data_name, data_name
    summaries.append(summary)
    details.append(detail)

BasicClassifier predict completed in 0.0s


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 249/249 [00:10<00:00, 22.84it/s]


RetrievalClassifier predict completed in 10.9s
BasicClassifier predict completed in 0.0s


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 252/252 [00:08<00:00, 29.16it/s]


RetrievalClassifier predict completed in 8.7s
BasicClassifier predict completed in 0.1s


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:28<00:00,  8.89it/s]


RetrievalClassifier predict completed in 28.1s
BasicClassifier predict completed in 0.1s


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 249/249 [00:23<00:00, 10.72it/s]

RetrievalClassifier predict completed in 23.2s





In [7]:
pd.DataFrame(summaries).set_index(['data_name', 'clf'])

Unnamed: 0_level_0,Unnamed: 1_level_0,accuracy,f1_score,precision,recall
data_name,clf,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
goemotions,BasicClassifier,0.228916,0.315086,0.587828,0.233426
goemotions,RetrievalClassifier,0.269076,0.424673,0.433897,0.423973
just_dance,BasicClassifier,0.353175,0.478793,0.442012,0.523706
just_dance,RetrievalClassifier,0.373016,0.513497,0.489912,0.556591
pubmed,BasicClassifier,0.304,0.836427,0.825244,0.855761
pubmed,RetrievalClassifier,0.256,0.837754,0.777192,0.910791
research_papers,BasicClassifier,0.51004,0.557957,0.660592,0.489712
research_papers,RetrievalClassifier,0.534137,0.683246,0.674833,0.722843


In [8]:
detail = pd.concat(details)
for data_name in detail['data_name'].unique():
    _, _, _, y_test = data.get_datasets(data_name)
    print(f'\n*** Dataset: {data_name} ***\n')
    display(
        detail.query('data_name == @data_name')
        .pivot_table(index='clf', columns=['label'], values='f1-score')
        [y_test.columns]
    )


*** Dataset: goemotions ***



label,admiration,amusement,anger,annoyance,approval
clf,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
BasicClassifier,0.382022,0.416667,0.190476,0.219178,0.367089
RetrievalClassifier,0.44898,0.492754,0.387097,0.377358,0.417178



*** Dataset: just_dance ***



label,Usability,UX,H-QOL,Memorability,Learnability
clf,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
BasicClassifier,0.60262,0.95,0.841346,0.0,0.0
RetrievalClassifier,0.572614,0.95,0.844869,0.0,0.2



*** Dataset: pubmed ***



label,A,B,C,D,E
clf,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
BasicClassifier,0.666667,0.977505,0.821293,0.83432,0.882353
RetrievalClassifier,0.717241,0.977505,0.769231,0.84,0.884793



*** Dataset: research_papers ***



label,Computer Science,Physics,Mathematics,Statistics,Quantitative Biology
clf,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
BasicClassifier,0.738916,0.80597,0.673469,0.571429,0.0
RetrievalClassifier,0.781513,0.847682,0.75,0.703704,0.333333


# Summary

- Looking at the F1 score of the 4 different datasets, it seems that RetrievalClassifier outperforms BasicClassifier on datasets with limited labels (ie. goemotions, just_dance, research_papers).
- Performance seems to be similar on datasets with sufficient labels (ie. pubmed)