In [1]:
import fasttext
from datasets import load_dataset
import pandas as pd
import csv
from sklearn.metrics import accuracy_score, f1_score

In [2]:
dataset = load_dataset('ag_news')

In [3]:
# Let's convert them to dataframes
df_train = pd.DataFrame(dataset['train'])
df_test = pd.DataFrame(dataset['test'])

In [4]:
def get_fasttext_label(label):
    if label == 0:
        return '__label__world'
    if label == 1:
        return '__label__sports'
    if label == 2:
        return '__label__business'
    if label == 3:
        return '__label__technology'

In [5]:
df_train['fasttext_label'] = df_train['label'].apply(get_fasttext_label)
df_test['fasttext_label'] = df_test['label'].apply(get_fasttext_label)

In [6]:
df_train.head(10)

Unnamed: 0,text,label,fasttext_label
0,Wall St. Bears Claw Back Into the Black (Reute...,2,__label__business
1,Carlyle Looks Toward Commercial Aerospace (Reu...,2,__label__business
2,Oil and Economy Cloud Stocks' Outlook (Reuters...,2,__label__business
3,Iraq Halts Oil Exports from Main Southern Pipe...,2,__label__business
4,"Oil prices soar to all-time record, posing new...",2,__label__business
5,"Stocks End Up, But Near Year Lows (Reuters) Re...",2,__label__business
6,Money Funds Fell in Latest Week (AP) AP - Asse...,2,__label__business
7,Fed minutes show dissent over inflation (USATO...,2,__label__business
8,Safety Net (Forbes.com) Forbes.com - After ear...,2,__label__business
9,Wall St. Bears Claw Back Into the Black NEW Y...,2,__label__business


In [7]:
df_train[['fasttext_label', 'text']].to_csv('../../data/train_fasttext_agnews.txt', index = False, sep = ' ', header = None, quoting = csv.QUOTE_NONE, quotechar = "", escapechar = " ")
df_test[['fasttext_label', 'text']].to_csv('../../data/test_fasttext_agnews.txt', index = False, sep = ' ', header = None, quoting = csv.QUOTE_NONE, quotechar = "", escapechar = " ")

In [8]:
model = fasttext.train_supervised('../../data/train_fasttext_agnews.txt', epoch=25, wordNgrams=2)

Read 4M words
Number of words:  188111
Number of labels: 4
Progress: 100.0% words/sec/thread: 1608438 lr:  0.000000 avg.loss:  0.049638 ETA:   0h 0m 0s


In [9]:
model.test('../../data/train_fasttext_agnews.txt') # train accuracy

(120000, 0.9998333333333334, 0.9998333333333334)

In [10]:
model.test('../../data/test_fasttext_agnews.txt') # test accuracy

(7600, 0.9189473684210526, 0.9189473684210526)

In [11]:
def predict_fasttext(text):
    label = model.predict(text)[0][0]
    if label == '__label__world':
        return 0
    if label == '__label__sports':
        return 1
    if label == '__label__business':
        return 2
    if label == '__label__technology':
        return 3

In [12]:
df_train['predicted_fasttext'] = df_train['text'].apply(predict_fasttext)
df_test['predicted_fasttext'] = df_test['text'].apply(predict_fasttext)

In [13]:
print('Train Accuracy: {}, Test Accuracy: {}'.format(
    accuracy_score(y_true=df_train['label'], y_pred=df_train['predicted_fasttext']), 
    accuracy_score(y_true=df_test['label'], y_pred=df_test['predicted_fasttext'])
))

Train Accuracy: 0.9998333333333334, Test Accuracy: 0.9189473684210526


In [14]:
print('Train f-score: {}, Test f-score: {}'.format(
    f1_score(y_true=df_train['label'], y_pred=df_train['predicted_fasttext'], average='weighted'), 
    f1_score(y_true=df_test['label'], y_pred=df_test['predicted_fasttext'], average='weighted'))
)

Train f-score: 0.9998333256876728, Test f-score: 0.918798996049317
