In [1]:
from sklearn.datasets import fetch_20newsgroups
from sklearn import metrics
newsgroups_train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))
newsgroups_test = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))

In [2]:
from nltk import tokenize
#nltk.download('punkt')
import numpy as np
from bert_serving.client import BertClient
import random
from tqdm import tqdm_notebook as tqdm
import pickle

In [3]:
print(newsgroups_train.target_names)

['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc']


In [45]:
print(newsgroups_train.data[0])
print()
print("Label: " + newsgroups_train.target_names[newsgroups_train.target[0]])

I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is 
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail.

Label: rec.autos


In [4]:
NUM_TAGS = len(newsgroups_train.target_names)
# group tags
tags_data = [[] for i in range(NUM_TAGS)]
for i in range(len(newsgroups_train.data)):
    label = newsgroups_train.target[i]
    tags_data[label].append(newsgroups_train.data[i])
# Combine docs
comb_tags_data = []
for cat in range(NUM_TAGS):
    concat_doc = ""
    for doc in tags_data[cat]:
        concat_doc += doc
    comb_tags_data.append(concat_doc)

In [5]:
# Separates into sentences then gets the sentence embedding
def poolEncode(doc, n=-1):
    sentences = tokenize.sent_tokenize(doc)
    # Maybe remove sentences that are too short
    sentences = list(filter(lambda x : len(x) > 5, sentences))
    # Sometimes sentences comes out empty
    if len(sentences) == 0:
        sentences.append("a")
        
    sample = sentences
    if n != -1:
        # Only sample some sentences
        sample = random.sample(sentences, min(n, len(sentences)))
    doc_vecs = bc.encode(sample)
    
    # Mean-Pool the embeddings
    vec = np.mean(doc_vecs, axis=0)
    return vec

In [6]:
def rankScore(query, tag_vecs, tag_names, topk=1, sample=-1, v=False):
    if sample == -1:
        query_vec = poolEncode(query)
    else:
        query_vec = poolEncode(query, sample)
    # compute normalized dot product as score
    score = np.sum(query_vec * tag_vecs, axis=1) / np.linalg.norm(tag_vecs, axis=1)
    topk_idx = np.argsort(score)[::-1][:topk]
    if v:
        for idx in topk_idx:
            print('> %s\t%s' % (score[idx], tag_names[idx]))
    
    return topk_idx

In [7]:
def train():
    tag_vecs = []
    for i in range(NUM_TAGS):
        tag_vecs.append(poolEncode(comb_tags_data[i], 200))
        print("Loaded " + tag_names[i] + ".")
    return tag_vecs

In [None]:
print("Loading model...")
bc = BertClient(ip="10.0.0.11")
tag_names = newsgroups_test.target_names
#tag_vecs = train()

topk = 10
while True:
    query = str(input('Search: '))
    if query == "":
        continue
    rankScore(query, tag_vecs, tag_names)
    
    
    print()
    print()

In [126]:
def evaluate(tag_vecs):
    preds = []
    for query in tqdm(newsgroups_test.data[:500]):
        preds.append(rankScore(query, tag_vecs, tag_names, sample=50, topk=10))
            
    return preds

In [42]:
#tag_vecs = train()
with open("tag_vecs.pkl", "wb") as pickle_out:
    pickle.dump(tag_vecs, pickle_out)

In [None]:
preds3000 = evaluate(tag_vecs)
with open("preds3000.pkl", "wb") as pickle_out:
    pickle.dump(preds3000, pickle_out)

## Top 1 metrics

In [39]:
print("f1_score: " + str(metrics.f1_score(newsgroups_test.target[:500], preds3000[:,0], average='macro')))
print("accuracy: " + str(metrics.accuracy_score(newsgroups_test.target[:500], preds3000[:,0])))

f1_score: 0.42671604576131295
accuracy: 0.43


## Top k accuracy

In [26]:
k = 3
preds = preds3000
eval_size = len(preds)
count = 0
done = 8
for i, y in enumerate(newsgroups_test.target[:eval_size]):
    if y in preds[i][:k]:
        count += 1
    else:
        if done < 10:
            print('Index: ' + str(i))
            print('Document:') 
            print('  '+newsgroups_test.data[i])
            print(f'Guesses: {[newsgroups_test.target_names[i] for i in list(preds[i][:k])]}')
            print(f'Correct: {newsgroups_test.target_names[y]}')
            print()
            done += 1
        
print(f'Top {k} accuracy: {count/eval_size}')

Index: 2
Document:
  
In a word, yes.

Guesses: ['talk.politics.guns', 'sci.med', 'rec.motorcycles']
Correct: alt.atheism

Index: 7
Document:
  A friend of mine managed to get a copy of a computerised Greek and Hebrew 
Lexicon called "The Word Perfect" (That is not the word processing 
package WordPerfect). However, some one wiped out the EXE file, and she 
has not been able to restore it. There are no distributors of the package in 
South Africa. I would appreciate it, if some one could email me the file, or 
at least tell me where I could get it from. 

My email address is
	fortmann@superbowl.und.ac.za     or
	fortmann@shrike.und.ac.za
 
Many thanks.
Guesses: ['comp.graphics', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware']
Correct: soc.religion.christian

Top 3 accuracy: 0.702


In [13]:
# Load file
with open("preds3000.pkl", "rb") as pickle_out:
    preds3000 = pickle.load(pickle_out)