In [1]:
import pandas as pd

https://meta.wikimedia.org/wiki/List_of_Wikipedias

# Download samples of Wikipedia in multiple languages

In [2]:
tables = pd.read_html('https://meta.wikimedia.org/wiki/List_of_Wikipedias')

In [3]:
len(tables)

9

In [4]:
all_langs = pd.concat(tables[0:8], ignore_index=True)
all_langs = all_langs[all_langs.Articles > 0].copy()

In [5]:
all_langs

Unnamed: 0,№,Language,Language (local),Wiki,Articles,Total,Edits,Admins,Users,Active Users,Files,Depth
0,1,English,English,en,6522949,56115029,1091916972,1034,43848293,118305,894011,1125
1,2,Cebuano,Sinugboanong Binisaya,ceb,6125836,11231578,34857733,6,91542,173,0,2
2,3,German,Deutsch,de,2700621,7458794,222426812,193,3947411,17446,128572,93
3,4,Swedish,Svenska,sv,2551055,6110382,50520284,66,825828,2051,0,16
4,5,French,Français,fr,2432573,12080528,194298530,159,4406976,17874,68725,253
...,...,...,...,...,...,...,...,...,...,...,...,...
320,321,Hiri Motu,Hiri Motu,ho,3,129,3786,1,1575,0,0,--
321,322,Sichuan Yi,ꆇꉙ,ii,3,189,11653,1,2031,0,0,--
322,323,Afar,Afar,aa,1,510,4685,1,4061,2,0,--
323,324,Northern Luri,لۊری شومالی,lrc,1,237,140064,1,5036,1,0,--


In [6]:
import requests

In [7]:
from tqdm.auto import tqdm, trange

In [8]:
import numpy as np

In [9]:
import random

In [10]:
w = all_langs.Articles.values ** (1/5)
w /= sum(w)
sum(w)

0.999999999999999

## Scrape the texts

Sample the languages with temperature sampling. 

In [202]:
import threading
from bs4 import BeautifulSoup
import time

In [203]:
def get_page(lang, timeout=0.1):
    res = requests.get(f'https://{lang}.wikipedia.org/wiki/Special:Random')
    if '<title>Wikimedia Error</title>' in res.text:
        time.sleep(timeout)
        return get_page(lang, timeout=timeout*2)
    return [res.url, res.text]

In [31]:
def url2lang(url):
    return url.split('//')[1].split('.')[0]

url2lang('https://en.wikipedia.org/wiki/Dunbar_Duncan')

'en'

In [32]:
from collections import defaultdict

In [29]:
pages[0][0]

'https://en.wikipedia.org/wiki/Dunbar_Duncan'

In [26]:
pages[0][0].split('//')[1].split('.')[0]

'en'

In [20]:
get_paragraphs(pages[1][1])

['Bukid ang Kūh-e Bandāv-e Zard (Pinulongang Persiyano: کوه بنداو زرد) sa Iran.[1] Nahimutang ni sa lalawigan sa Yazd, sa sentro nga bahin sa nasod, 500 km sa habagatan-sidlakan sa Tehrān ang ulohan sa nasod. 2,519 metros ibabaw sa dagat kahaboga ang nahimutangan sa Kūh-e Bandāv-e Zard,[1] o 76 ka metros sa ibabaw sa naglibot nga tereyn[saysay 1]. Mga 2.0 ka kilometro ang gilapdon sa tiilan niini.[saysay 2]\n',
 'Ang yuta palibot sa Kūh-e Bandāv-e Zard kasagaran kabungtoran, apan sa sidlakan nga kini mao ang kabukiran.[saysay 3] Ang kinahabogang dapit sa palibot dunay gihabogon nga 2,631 ka metro ug 2.5 km sa amihanan-kasadpan sa Kūh-e Bandāv-e Zard.[saysay 4] Dunay mga 3 ka tawo kada kilometro kwadrado Hapit nalukop sa desiyerto ug kamingawan ang palibot sa Kūh-e Bandāv-e Zard may kaayo gamay nga populasyon.[3] Sa palibot sa Kūh-e Bandāv-e Zard.[4] Sa rehiyon palibot sa Kūh-e Bandāv-e Zard, kabukiran, ug mga tubud talagsaon komon.[saysay 5]\n',
 'Ang klima bugnaw nga ugahon.[5] Ang ka

In [15]:
pages = []

In [13]:
from multiprocessing.pool import ThreadPool
pool = ThreadPool(processes=100)

%%time
async_result = pool.map(get_page, all_langs.Wiki.tolist())

In [204]:
for i in trange(10):
    async_result = pool.map(get_page, all_langs.Wiki.tolist())
    pages.extend(async_result)

  0%|          | 0/10 [00:00<?, ?it/s]

In [325]:
for i in trange(100):
    if i % 5 == 0:
        # uniform
        langs = all_langs.Wiki.tolist()
    else:
        # heated distribution
        langs = all_langs.Wiki.loc[random.choices(range(len(all_langs)), weights=w, k=len(all_langs))].tolist()
    async_result = pool.map(get_page, langs)
    pages.extend(async_result)

  0%|          | 0/100 [00:00<?, ?it/s]

ConnectionError: ('Connection aborted.', ConnectionResetError(10054, 'Удаленный хост принудительно разорвал существующее подключение', None, 10054, None))

for i in trange(1_000_000):

    if random.random() < 0.2:
        # uniform
        idx = random.choice(range(len(all_langs)))
    else:
        # heated uniform
        idx = random.choices(range(len(all_langs)), weights=w)[0]
    lang = all_langs.Wiki.loc[idx]
    res = requests.get(f'https://{lang}.wikipedia.org/wiki/Special:Random')
    pages.append([res.url, res.text])

In [326]:
len(pages)

93600

## put texts by language together

In [1]:
import re

def get_paragraphs(html):
    soup = BeautifulSoup(html)
    body = soup.find('div', {'id': 'bodyContent'})
    if not body:
        return []
    # todo: remove <sup> elements
    # 
    result = []
    for p in body.findAll('p'):
        for bad_tag in ['sup', 'style', 'script']:
            for unwanted in p.find_all(bad_tag):
                unwanted.extract()
        text = p.text.replace('\xa0', ' ')
        if text.strip():
            result.append(text)
    return result

In [443]:
url, html = random.choice(pages)
print(url)

print(get_paragraphs(html))

https://ce.wikipedia.org/wiki/%D0%9A%D0%B0%D0%BD%D0%B0%D0%B9_(%D0%90%D0%BA%D1%82%D0%BE%D0%B1%D0%B5%D0%BD_%D0%BE%D0%B1%D0%BB%D0%B0%D1%81%D1%82%D1%8C)
['Канай — Кхазакхстанан Актобен областан Хобдинан кӀоштара эвла.\n', 'Кхузахь климат барамехь континентан ю, аьхка йовха хуьлу, ткъа Ӏа барамехь-шийла хуьлу. Шаран уггаре а бовха бутт бу — июль (мангалан), уггаре а шийла — январь (кхолламан).\n', 'Кеп:Хобдинан кӀошт\n']


In [444]:
lang2texts = defaultdict(list)

In [445]:
for url, html in tqdm(pages):
    lang2texts[url2lang(url)].extend(get_paragraphs(html))

  0%|          | 0/93600 [00:00<?, ?it/s]

In [446]:
for k in sorted(lang2texts.keys()):
    v = [x for x in lang2texts[k] if x.strip()]
    if not v: 
        continue
    print(k, len(v))
    for i in range(2):
        print(random.choice(v))

aa 1
Wikipedia bödero kœáth amliethog e rhyδ gē gell tō un gelœga. Denn grëað zero œáthws ar ðechre di zera gentannwm hon.
Template:Link FA

Wikipedia bödero kœáth amliethog e rhyδ gē gell tō un gelœga. Denn grëað zero œáthws ar ðechre di zera gentannwm hon.
Template:Link FA

ab 241
617 (фышәи жәибжь) шықәса — иулиантәи амзар 617-тәи ашықәс ауп.

1439 (зқьы ԥшьышәи ҩажәи зеижә) шықәса — иулиантәи амзар 1439-тәи ашықәс ауп.

ace 361
Panté Gajah nakeuh gampông di Peusangan, Kabupatèn Bireuen, Acèh. Lumbôi gampông nyoe nakeuh 11.11.05.2006.

Robert Scott Adsit (1965 – ) nakeuh sidroe aktor asay Amirika Syarikat.

ady 524
Къэрал тхьаматэр – Кая Каллас.

тӀы

af 1527
Die bekendste seerowers in die Middeleeuse Europa was die Wikings, seevegters van Skandinawië wat hoofsaaklik tussen die 8ste en 12de eeu geroof en geplunder het; dit was die Wikingtydperk van die vroeë Middeleeue. Hulle het rooftogte uitgevoer op die kus, riviere en binnelandse stede van die hele Wes-Europa tot so ver as Sevil

ff 413
Eɗen ciftora wonnde Jallo Teli heblaa ko he gannde jojjannde walla mbiyen ganndal sariyaaji (duruwaa), ko ɗuum waɗi nde o arti Gine tan Seku toɗɗii mo ko Jaagorɗo Kalfinaaɗo Ñaawooje. O wonii heen tuggude nde o art he 1972 haa 1976. He ngoon posto kadi o addi heen karallaagal timmungal e pellital gollaade ngal o anndiraa ngal. O wayli kuule leydi (code civil) ɗee, o itti ɗum en e ɗe ɓe ndonnoo e laamu koloñaal Farayse o wiɗti ɗum en kanum en fof kuule Ginenaaje. O naati kadi he lannda mbajja ka (Parti Démocratique Guinéen), wonnoka he jappeere ko aldaa e pooɗondiral. Hono kaadareeɓe lannda kaa fof ni, saha kala omo yaha he nder dowriiji he ngam daronoyaade lannda ka, kono kadi no kala kaadareeɓe leydi ndii kala ɗo yahi ko deenaaɗo o alaa wellitaare, o waawaa yaltude leydi ngam ɗannaade.

Sir Ahmadu Bello Hambe teetii Lesdi Niseeriya gadda njuungo

fi 2523
Waltteri Torikka voitti MTV3:n Tähdet, tähdet -televisiosarjan toisen tuotantokauden finaalin vuonna 2015. Esikoisalbuminsa S

sh 2143
Вуковић је насеље у Србији у општини Кучево у Браничевском округу. Према попису из 2002. био је 301 становник (према попису из 1991. било је 486 становника).

Američka Samoa •
Američki Djevičanski otoci •
otok Baker •
Guam •
otok Howland •
otok Jarvis  •
atol Johnston •
greben Kingman •
atol Midway •
Navassa •
atol Palmyra •
Portoriko •
Sjeverni Marijanski otoci •
otok Wake

shi 890
Amllay ɣid ar ittli stusna ntmazirt d ufgan nes, isussmt a y ẓerra manik sa adrn midn, smaqqeln daɣ ɣu umzruy d maddifl ufgan amzwaru.  Ar n tin itmallaytad tafransist ( tourisme social et culturel). 

Asnus n wurɣ (s tlatinit: Asinus aureus)، tga ungal akumidi yara-t ufulay، tga ungal amzwaru ɣ umzruy n ufgan. Adlis n Asnus n wurɣ illa gis 11 n tnfust n ufgan iẓlin s umiyn.

shn 899
(1) ဝၼ်း။

(1) ဝၼ်း။

si 1335
ජෝන් ෆිට්ස්‍ජෙරාල්ඩ් “ජැක්” කෙනඩි (1917 මැයි 29 – 1963 නොවැම්බර් 22) මොහුව සාමාන්‍යයෙන් හඳුන්වනු ලබන්නේ ඔහුගේ නමේ මුලකුරු (JFK) වලිනි.

සටනින් වැදි රජු පෙරළා බිම හෙළන කුමරු, ඔහු මරා දැමීමට 

In [447]:
stats = []

for k, v in lang2texts.items():
    all_texts = ' '.join(v)
    stats.append({
        'lang': k,
        'pars': len(v),
        'chars': len(all_texts),
        'unique_chars': len(set(all_texts))
    })
stats = pd.DataFrame(stats)

In [448]:
stats.sort_values('chars', ascending=False)

Unnamed: 0,lang,pars,chars,unique_chars
2,de,5234,2052325,555
0,en,5481,1834006,720
7,es,4416,1744099,501
29,sh,2143,1685462,312
4,fr,6135,1631283,531
...,...,...,...,...
318,kj,121,1956,42
323,lrc,94,1503,14
321,ii,112,1495,82
320,ho,43,1032,28


In [449]:
stats.sort_values('unique_chars', ascending=False)

Unnamed: 0,lang,pars,chars,unique_chars
12,zh,2894,327836,5278
152,zh-classical,985,77716,3501
11,ja,4546,585020,3437
100,wuu,486,47433,2771
64,zh-yue,779,64044,2753
...,...,...,...,...
318,kj,121,1956,42
322,aa,1,135,39
320,ho,43,1032,28
324,mus,660,2507,22


In [450]:
stats[stats.lang.apply(lambda x: x in ['ru', 'mdf', 'myv'])]

Unnamed: 0,lang,pars,chars,unique_chars
6,ru,4626,1308708,1068
173,myv,848,159507,240
251,mdf,353,79764,244


In [455]:
import json

In [456]:
with open('lang2texts_wiki.json', 'w') as f:
    json.dump(lang2texts, f, ensure_ascii=False, indent='  ')

# Train the detection model

In [572]:
from sklearn.model_selection import train_test_split
train_texts = []
train_labels = []
test_texts = []
test_labels = []

for k, v in lang2texts.items():
    v2 = sorted(set(v))
    if len(v2) < 2:
        continue
    tr, ts = train_test_split(v2, test_size=0.2, random_state=1)
    train_texts.extend(tr)
    test_texts.extend(ts)
    train_labels.extend([k] * len(tr))
    test_labels.extend([k] * len(ts))
print(len(train_texts), len(test_texts))

267754 67096


In [573]:
print(pd.Series(train_labels).value_counts()['myv'])
pd.Series(train_labels).value_counts()

548


en     4352
fr     4240
de     4139
ru     3680
pl     3656
       ... 
ii       12
cho      10
mh        7
kj        4
ho        2
Length: 323, dtype: int64

In [474]:
bible = pd.read_csv('erzya_bible.tsv', sep='\t')
print(bible.shape)
bible = bible.dropna()
print(bible.shape)
bible = bible[~bible.myv.str.startswith('Глава ')]
print(bible.shape)
print(bible.columns)

(12926, 4)
(12899, 4)
(12483, 4)
Index(['Unnamed: 0', 'myv', 'ru', 'source'], dtype='object')


In [755]:
mbible = pd.read_csv('moksha_bible.tsv', sep='\t')
print(mbible.shape)
mbible = mbible.dropna()
print(mbible.shape)
mbible = mbible[~mbible.mdf.str.startswith('Глава ')]
print(mbible.shape)
print(mbible.columns)

(12517, 4)
(12500, 4)
(12344, 4)
Index(['Unnamed: 0', 'mdf', 'ru', 'source'], dtype='object')


In [566]:
from nltk import sent_tokenize

def split_by_newline(text):
    return text.split('\n')

def split_overlapping_chunks(text, step=100, chunk_size=300):
    results = []
    for i in range(0, len(text), step):
        results.append(text[i:i+chunk_size])
    return results

def try_split(text, max_length=500):
    results = [text]
    parts = [text]
    
    for splitter in [split_by_newline, sent_tokenize, split_overlapping_chunks]:
        new_parts = []
        for part in parts:
            small_parts = [p.strip() for p in splitter(part)]
            for small_part in small_parts:
                if len(small_part) < 3: # the text is too short 
                    continue
                if len(small_part) <= min(len(part) * 0.5, max_length):
                    results.append(small_part)
                new_parts.append(small_part)
        parts = new_parts
    results.extend(parts)
    return sorted(set(results), key=lambda x: len(x))

In [575]:
train_texts_aug = []
train_labels_aug = []
for t, l in zip(tqdm(train_texts), train_labels):
    a = try_split(t)
    train_texts_aug.extend(a)
    train_labels_aug.extend([l]*len(a))

  0%|          | 0/267754 [00:00<?, ?it/s]

In [576]:
print(len(train_texts), len(train_texts_aug))

267754 1393009


In [577]:
print(pd.Series(train_labels_aug).value_counts()['myv'])
pd.Series(train_labels_aug).value_counts()

3127


de     38694
es     25238
en     24849
fr     21234
ru     18800
       ...  
cho       27
ii        18
mus       16
kj         9
ho         4
Length: 323, dtype: int64

In [579]:
train_texts_aug.extend(bible.myv.tolist())
train_texts_aug.extend(bible.ru.tolist())

train_labels_aug.extend(['myv']*len(bible))
train_labels_aug.extend(['ru']*len(bible))

In [756]:
train_texts_aug.extend(mbible.mdf.tolist())
train_labels_aug.extend(['mdf']*len(mbible))

In [757]:
print(pd.Series(train_labels_aug).value_counts()['myv'])
print(pd.Series(train_labels_aug).value_counts()['mdf'])
pd.Series(train_labels_aug).value_counts()

15610
13603


de     38694
ru     31283
es     25238
en     24849
fr     21234
       ...  
cho       27
ii        18
mus       16
kj         9
ho         4
Length: 323, dtype: int64

#### Sklearn pipeline

In [607]:
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import CountVectorizer, HashingVectorizer
from sklearn.model_selection import cross_val_score, KFold

In [611]:
pipe = make_pipeline(
    HashingVectorizer(analyzer='char_wb', ngram_range=(1,4), n_features=100_000), 
    LogisticRegression(C=1e-4, max_iter=1_000, solver='saga')
)

In [613]:
%%time
pipe.fit(train_texts_aug, train_labels_augg)

Wall time: 6h 16min 31s


Pipeline(steps=[('hashingvectorizer',
                 HashingVectorizer(analyzer='char_wb', n_features=100000,
                                   ngram_range=(1, 4))),
                ('logisticregression',
                 LogisticRegression(C=0.0001, max_iter=1000, solver='saga'))])

In [622]:
%%time
preds = pipe.predict_proba(test_texts)

Wall time: 50 s


In [624]:
preds.argmax(1)

array([63, 63, 63, ..., 63, 63, 63], dtype=int64)

In [629]:
(pipe.classes_[preds.argmax(1)] == test_labels).mean()

0.028913795159174912

In [631]:
pd.Series(pipe.classes_[preds.argmax(1)]).value_counts()

de    56051
ru    11045
dtype: int64

After 6 hours of training a model, it is still underfit as fuck...

#### Try FastText

In [633]:
import fasttext

In [758]:
with open('ft_train.txt', 'w') as f:
    for label, text in zip(train_labels_aug, train_texts_aug):
        f.write(f'__label__{label} ')
        f.write(text.replace('\n', ' ') + '\n')

model = fasttext.train_supervised(
    input="ft_train.txt", 
    lr=0.1, 
    epoch=25, 
    wordNgrams=0, 
    bucket=100_000,
    dim=50, # 16 seems to be OK
    loss='softmax',
    minn=1,
    maxn=4,
    minCount=1, # a larger number is required
)

model = fasttext.train_supervised(
    input="ft_train.txt", 
    lr=0.1, # 0.1 gives 38.8% / 0.5 gives 34.8% (overfitting?)
    epoch=25, # 5 gives 38.8%  / 10 gives 49.5% / 25 gives 69%
    wordNgrams=0, 
    bucket=100_000,  # default is 200K; 100K gives 38% acc, and 200K as well.
    dim=32, # FB uses 16, but 32 is much better with me, and 64 seems +- the same
    loss='softmax',
    minn=1,
    maxn=3, # if I decrease this to 3, the quality is no worse.
    minCount=100, # a larger number is required  # 5 gets 39% accuracy with 650K words; 50 gets 38% acc with 50K words; 
    # 300 gets acc 38% with 6K words
)

model = fasttext.train_supervised(
    input="ft_train.txt", 
    lr=0.1, # 0.1 gives 38.8% / 0.5 gives 34.8% (overfitting?)
    epoch=25, # 5 gives 38.8%  / 10 gives 49.5% / 
    wordNgrams=0, 
    bucket=200_000,  # default is 200K; 100K gives 38% acc, and 200K as well.
    dim=64, # FB uses 16, but 32 is much better with me, and 64 seems +- the same
    loss='softmax',
    minn=1,
    maxn=4, # if I decrease this to 3, the quality is no worse.
    minCount=100, # a larger number is required  # 5 gets 39% accuracy with 650K words; 50 gets 38% acc with 50K words; 
    # 300 gets acc 38% with 6K words
) # 0.7683170382735185


In [759]:
model = fasttext.train_supervised(
    input="ft_train.txt", 
    lr=0.1, # 0.1 gives 38.8% / 0.5 gives 34.8% (overfitting?)
    epoch=25, # 5 gives 38.8%  / 10 gives 49.5% / 
    wordNgrams=0, 
    bucket=200_000,  # default is 200K; 100K gives 38% acc, and 200K as well.
    dim=64, # FB uses 16, but 32 is much better with me, and 64 seems +- the same
    loss='softmax',
    minn=1,
    maxn=4, # if I decrease this to 3, the quality is no worse.
    minCount=100, # a larger number is required  # 5 gets 39% accuracy with 650K words; 50 gets 38% acc with 50K words; 
    # 300 gets acc 38% with 6K words
)

Try a super long model

In [765]:
model = fasttext.train_supervised(
    input="ft_train.txt", 
    lr=0.05, # 0.1 gives 38.8% / 0.5 gives 34.8% (overfitting?)
    epoch=100, # 5 gives 38.8%  / 10 gives 49.5% / 
    wordNgrams=0, 
    bucket=200_000,  # default is 200K; 100K gives 38% acc, and 200K as well.
    dim=64, # FB uses 16, but 32 is much better with me, and 64 seems +- the same
    loss='softmax',
    minn=1,
    maxn=4, # if I decrease this to 3, the quality is no worse.
    minCount=100, # a larger number is required  # 5 gets 39% accuracy with 650K words; 50 gets 38% acc with 50K words; 
    # 300 gets acc 38% with 6K words
)

In [766]:
model.predict('пек вадря', k=5)

(('__label__myv', '__label__ru', '__label__lez', '__label__tg', '__label__be'),
 array([0.74181491, 0.20231327, 0.0135121 , 0.00760135, 0.00718429]))

In [767]:
model.predict(['пек вадря', 'привет'], k=1)

([['__label__myv'], ['__label__ru']],
 [array([0.7418149], dtype=float32), array([0.9943341], dtype=float32)])

In [768]:
%%time
ft_preds = model.predict([t.replace('\n', ' ') for t in test_texts])

Wall time: 9.3 s


In [769]:
ft_labels = [l[0][9:] for l in ft_preds[0]]

77% accuracy, which is not bad! / 0.69618 in a smaller version / 0.76831 a middle one / 87% for longer training

In [770]:
(np.array(ft_labels) == test_labels).mean()

0.8764009777035889

In [656]:
ft_sota = fasttext.load_model("../langid/lid.176.bin")



In [657]:
ft_sota_preds = ft_sota.predict([t.replace('\n', ' ') for t in test_texts])

In [658]:
(np.array([l[0][9:] for l in ft_sota_preds[0]]) == test_labels).mean()

0.5884553475617026

In [772]:
model.save_model('../langid/lid.323.bin')

In [785]:
model = fasttext.load_model('../langid/lid.323.bin')



In [787]:
model.quantize(retrain=True, input="ft_train.txt", qnorm=True, cutoff=50_000)

In [788]:
len(model.words)

2282

In [789]:
model.save_model('../langid/lid.323.ftz')

In [790]:
%%time
ft_preds = model.predict([t.replace('\n', ' ') for t in test_texts])

Wall time: 25.1 s


After compression, the model retains 89% accuracy - even more than the large one. Probably, due to longer training.

In [792]:
ft_labels = [l[0][9:] for l in ft_preds[0]]
(np.array(ft_labels) == test_labels).mean()

0.8900828663407655

# More quality estimation

In [1]:
import fasttext

In [2]:
model = fasttext.load_model('../langid/lid.323.ftz')



In [4]:
import json

In [5]:
with open('lang2texts_wiki.json', 'r') as f:
    lang2texts= json.load(f)

In [6]:
from sklearn.model_selection import train_test_split
train_texts = []
train_labels = []
test_texts = []
test_labels = []

for k, v in lang2texts.items():
    v2 = sorted(set(v))
    if len(v2) < 2:
        continue
    tr, ts = train_test_split(v2, test_size=0.2, random_state=1)
    train_texts.extend(tr)
    test_texts.extend(ts)
    train_labels.extend([k] * len(tr))
    test_labels.extend([k] * len(ts))
print(len(train_texts), len(test_texts))

267754 67096


In [17]:
len(train_texts)

267754

In [8]:
%%time
ft_preds = model.predict([t.replace('\n', ' ') for t in test_texts])

Wall time: 15.9 s


In [10]:
import numpy as np

In [11]:
ft_labels = [l[0][9:] for l in ft_preds[0]]
(np.array(ft_labels) == test_labels).mean()

0.8900828663407655

In [12]:
from sklearn.metrics import classification_report

In [16]:
len(set(test_labels))

323

In [15]:
print(classification_report(test_labels, ft_labels))

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

          ab       0.93      0.95      0.94        44
         ace       0.98      0.91      0.95        68
         ady       0.77      0.68      0.72        79
          af       0.95      0.90      0.93       304
          ak       0.70      0.30      0.42       102
         als       0.73      0.89      0.80       417
         alt       0.93      0.95      0.94       319
          am       0.98      0.98      0.98       151
         ami       0.86      0.90      0.88       159
          an       0.79      0.95      0.87       186
         ang       0.97      0.94      0.95       108
          ar       0.93      0.99      0.96       408
         arc       1.00      1.00      1.00        40
         ary       0.93      0.90      0.91       117
         arz       0.97      0.91      0.94       300
          as       0.98      0.96      0.97       302
         ast       0.87      0.96      0.91       628
         atj       0.95    

  _warn_prf(average, modifier, msg_start, len(result))
