## StarSpace

StarSpace [[1]](#fn1) is an entity embedding approach which uses a similarity function between entities to construct a prediction task for a neural network. It maps objects of different types into a common vector space where they can be compared to each other. StarSpace can learn word, sentence and document level embeddings, ranking, text classification, embedding graphs, image classification, etc.

This notebook requires a working SparSpace program which can be built on any modern Linux or Windows machine as described in the building instructions in the [GitHub repository](https://github.com/facebookresearch/StarSpace). Please note that in order to run this notebook on Windows you will need either [MinGW with MSYS](http://www.mingw.org/) or [Cygwin](https://www.cygwin.com/) to compile StarSpace and run the first four cells while the remainder of the notebook is portable. In addition, the following packages are required:

- gensim==3.8.3
- matplotlib==3.3.2
- scikit-learn==0.23.2

-----
<span id="fn1"> [1] Ledell Yu Wu, Adam Fisch, Sumit Chopra, Keith Adams, Antoine Bordes, and Jason Weston. Starspace: Embed all the things! In Proceedings of the 32nd AAAI Conference on Artificial Intelligence, pages 5569–5577, 2018. </span>

----

We follow the official documentation of StarSpace and present the text classification example. First, we compile the starspace binary.

In [1]:
!git clone git@github.com:facebookresearch/StarSpace.git

Cloning into 'StarSpace'...
remote: Enumerating objects: 5, done.[K
remote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (5/5), done.[K
remote: Total 873 (delta 0), reused 0 (delta 0), pack-reused 868[K
Receiving objects: 100% (873/873), 3.05 MiB | 3.75 MiB/s, done.
Resolving deltas: 100% (567/567), done.


In [2]:
!cd StarSpace && make

g++ -pthread -std=gnu++11 -O3 -funroll-loops -g -c src/utils/normalize.cpp
g++ -pthread -std=gnu++11 -O3 -funroll-loops -I/usr/local/bin/boost_1_63_0/ -g -c src/dict.cpp
g++ -pthread -std=gnu++11 -O3 -funroll-loops -g -c src/utils/args.cpp
g++ -pthread -std=gnu++11 -O3 -funroll-loops -I/usr/local/bin/boost_1_63_0/ -g -c src/proj.cpp
g++ -pthread -std=gnu++11 -O3 -funroll-loops -I/usr/local/bin/boost_1_63_0/ -g -c src/parser.cpp -o parser.o
g++ -pthread -std=gnu++11 -O3 -funroll-loops -I/usr/local/bin/boost_1_63_0/ -g -c src/data.cpp -o data.o
g++ -pthread -std=gnu++11 -O3 -funroll-loops -I/usr/local/bin/boost_1_63_0/ -g -c src/model.cpp
g++ -pthread -std=gnu++11 -O3 -funroll-loops -I/usr/local/bin/boost_1_63_0/ -g -c src/starspace.cpp
g++ -pthread -std=gnu++11 -O3 -funroll-loops -I/usr/local/bin/boost_1_63_0/ -g -c src/doc_parser.cpp -o doc_parser.o
g++ -pthread -std=gnu++11 -O3 -funroll-loops -I/usr/local/bin/boost_1_63_0/ -g -c src/doc_data.cpp -o doc_data.o
g++ -pthread -std=gnu++11

The executable is now available as `data/StarSpace/starspace`. The original bash script for the text classification example is available in the [Starspace GitHub repository](https://github.com/facebookresearch/Starspace/blob/master/examples/classification_ag_news.sh). We reimplement it as a Jupyter notebook.

The data is based on [Antonio Gulli's corpus (AG)](http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html) which is a collection of more than 1 million news articles. From this collection, Zhang et al. [[2]](#fn2) constructed a smaller corpus, containing only the four largest news categoriess from the original corpus. Each category (i.e. class value) contains 30,000 training instances and 1,900 testing instances. The total number of training samples is 120,000 and 7,600 samples are resrved for testing. We download, unpack and inspect the corpus.

----
<span id="fn2"> [2] Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015).</span>

----

In [3]:
!wget -c https://dl.fbaipublicfiles.com/starspace/ag_news_csv.tar.gz -P data
!cd data && tar -xzvf ag_news_csv.tar.gz
!ls data

--2020-11-16 17:32:02--  https://dl.fbaipublicfiles.com/starspace/ag_news_csv.tar.gz
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 2606:4700:10::6816:4a8e, 2606:4700:10::ac43:904, 2606:4700:10::6816:4b8e, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|2606:4700:10::6816:4a8e|:443... connected.
HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable

    The file is already fully retrieved; nothing to do.

ag_news_csv/
ag_news_csv/train.csv
ag_news_csv/test.csv
ag_news_csv/classes.txt
ag_news_csv/readme.txt
ag_news_csv  ag_news_csv.tar.gz


There are four classes and each news from the train and test set is classified using the line number of the actual class value.

In [4]:
!cat data/ag_news_csv/classes.txt
!head -n 5 data/ag_news_csv/train.csv

World
Sports
Business
Sci/Tech
"3","Wall St. Bears Claw Back Into the Black (Reuters)","Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again."
"3","Carlyle Looks Toward Commercial Aerospace (Reuters)","Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market."
"3","Oil and Economy Cloud Stocks' Outlook (Reuters)","Reuters - Soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums."
"3","Iraq Halts Oil Exports from Main Southern Pipeline (Reuters)","Reuters - Authorities have halted oil export\flows from the main pipeline in southern Iraq after\intelligence showed a rebel militia could strike\infrastructure, an oil official said on Saturday."
"3","Oil prices soar to all-time re

We read the data into a Pandas DataFrame object and preprocess the text by converting it to lowercase and replacing a number of characters. The category is prefixed with `__label__` as required for the fastText word embedding file format. The transformed data is randomly shuffled and written into a fastText compatible text file. The train and test data are balanced concerning the four categories.

In [21]:
import pandas as pd
import os
from pprint import pprint

idx2category = {1: '__label__world',2: '__label__sports', 3:'__label__business', 4:'__label__scitech'}

def preprocess(df):
    df = df.replace({'category': idx2category})
    df['text'] = df['title'] + ' ' + df['body']
    df = df.drop(labels=['title', 'body'], axis=1)
    df['text'] = df['text'].str.lower()
    for s, rep in [("'"," ' "),
                   ('"',''),
                   ('.',' . '),
                   ('<br />',' '),
                   (',',' , '),
                   ('(',' ( '),
                   (')',' ) '),
                   ('!',' ! '),
                   ('?',' ? '),
                   (';',' '),
                   (':',' '),
                   ('\\',''),
                   ('  ',' ')
                  ]:
        df['text'] = df['text'].str.replace(s, rep)   
    df = df.sample(frac=1, random_state=42)
    return df

for filename in ['data/ag_news_csv/train.csv','data/ag_news_csv/test.csv']:
    df = pd.read_csv(filename, names=['category', 'title', 'body'])
    df = preprocess(df)
    print('File {}'.format(os.path.split(filename)[1]))
    pprint(df['category'].value_counts().to_dict())
    with open('{}.pp'.format(os.path.splitext(filename)[0]), 'w') as fp:
        for row in df.itertuples():
            fp.write('{} {}\n'.format(row.category, row.text))

File train.csv
{'__label__business': 30000,
 '__label__scitech': 30000,
 '__label__sports': 30000,
 '__label__world': 30000}
File test.csv
{'__label__business': 1900,
 '__label__scitech': 1900,
 '__label__sports': 1900,
 '__label__world': 1900}


We can now run StarSpace on the preprocessed files. The set of parameters is the same as in the example from the StarSpace repository. The `trainMode=0` and `fileFormat='FastText'` combinations defines the mode where the labels are individual words, i.e. the classification task. 

In [22]:
!./StarSpace/starspace train \
  -trainFile "data/ag_news_csv/train.pp" \
  -model "data/ag_news_csv/model" \
  -initRandSd 0.01 \
  -adagrad false \
  -ngrams 1 \
  -lr 0.01 \
  -epoch 5 \
  -thread 20 \
  -dim 10 \
  -negSearchLimit 5 \
  -trainMode 0 \
  -label "__label__" \
  -similarity "dot" \
  -verbose false

Arguments: 
lr: 0.01
dim: 10
epoch: 5
maxTrainTime: 8640000
validationPatience: 10
saveEveryEpoch: 0
loss: hinge
margin: 0.05
similarity: dot
maxNegSamples: 10
negSearchLimit: 5
batchSize: 5
thread: 20
minCount: 1
minCountLabel: 1
label: __label__
label: __label__
ngrams: 1
bucket: 2000000
adagrad: 0
trainMode: 0
fileFormat: fastText
normalizeText: 0
dropoutLHS: 0
dropoutRHS: 0
useWeight: 0
weightSep: :
Start to initialize starspace model.
Build dict from input file : data/ag_news_csv/train.pp
Read 5M words
Number of words in dictionary:  94698
Number of labels in dictionary: 4
Loading data from file : data/ag_news_csv/train.pp
Total number of examples loaded : 120000
Training epoch 0: 0.01 0.002
Epoch: 100.0%  lr: 0.008000  loss: 0.029552  eta: <1min   tot: 0h0m0s  (20.0%).7%  lr: 0.009550  loss: 0.070614  eta: <1min   tot: 0h0m0s  (2.5%)50.7%  lr: 0.008583  loss: 0.036626  eta: <1min   tot: 0h0m0s  (10.1%)60.2%  lr: 0.008433  loss: 0.034962  eta: <1min   tot: 0h0m0s  (12.0%)
 ---+++ 

The resulting Starspace model embeddsthe input into a common 10-dimensional space (set by the `-dim 10` setting). We load it into a dataframe and inspect it. As shown in the table below, the model embedds everything into a common space: words that are present in documents but also the categories (the last four rows). In this way, we can now compare entities of different kinds.

In [23]:
pd.read_csv('data/ag_news_csv/model.tsv', sep='\t', header=None, keep_default_na=False)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
0,.,-0.033355,-0.080821,-0.007533,0.011347,-0.039631,0.035159,-0.044535,0.012915,-0.037053,-0.075862
1,the,0.033314,0.026314,-0.001037,-0.011768,-0.020240,0.003223,0.009371,-0.016185,0.005037,0.020802
2,",",-0.009922,0.037519,0.001665,-0.011969,-0.016249,0.027321,0.030746,0.030399,-0.003744,0.043704
3,,-0.044209,-0.013623,0.019413,0.010136,-0.017818,0.013624,-0.041977,0.039315,-0.086028,0.035785
4,to,-0.011142,-0.018124,-0.004958,0.001349,0.019493,0.005109,-0.008347,-0.004185,0.010806,-0.005729
...,...,...,...,...,...,...,...,...,...,...,...
94697,maleafter,0.011226,-0.013288,-0.004855,0.019057,-0.013254,0.010894,-0.023351,-0.003122,-0.022541,0.006589
94698,__label__business,-0.158502,-0.062817,-0.050236,-0.183091,0.043215,-0.446341,0.148448,0.143121,0.157458,0.114998
94699,__label__world,-0.030737,-0.115177,0.039558,0.070553,-0.354145,0.125507,-0.010681,0.139103,-0.152312,-0.192986
94700,__label__sports,0.263442,0.304824,0.009041,0.088723,0.080759,0.308616,0.063859,-0.340450,0.340611,-0.195544


Wen compute predictions and measure the peformance. In the test mode, StarSpace reports the hit@k evaluation metric which tells us how many correct answers are among the top k predictions. We are interested in the most probable category, therefore we use the hit@1 metric (in general, assignment of categories to text can be viewed as a multi-label classification problem). StarSpace achieves the score $hit@1=0.46$ which means that in 46% of test cases the model's first prediction is the correct answer.

In [24]:
!./StarSpace/starspace test \
  -model "data/ag_news_csv/model" \
  -testFile "data/ag_news_csv/test.pp" \
  -ngrams 1 \
  -dim 10 \
  -label "__label__" \
  -thread 10 \
  -similarity "dot" \
  -trainMode 0 \
  -verbose false \
  -predictionFile "data/ag_news_csv/test.y"

Arguments: 
lr: 0.01
dim: 10
epoch: 5
maxTrainTime: 8640000
validationPatience: 10
saveEveryEpoch: 0
loss: hinge
margin: 0.05
similarity: dot
maxNegSamples: 10
negSearchLimit: 50
batchSize: 5
thread: 10
minCount: 1
minCountLabel: 1
label: __label__
label: __label__
ngrams: 1
bucket: 2000000
adagrad: 1
trainMode: 0
fileFormat: fastText
normalizeText: 0
dropoutLHS: 0
dropoutRHS: 0
useWeight: 0
weightSep: :
Start to load a trained starspace model.
STARSPACE-2018-2
Model loaded.
Loading data from file : data/ag_news_csv/test.pp
Total number of examples loaded : 7600
Predictions use 4 known labels.
------Loaded model args:
Arguments: 
lr: 0.01
dim: 10
epoch: 5
maxTrainTime: 8640000
validationPatience: 10
saveEveryEpoch: 0
loss: hinge
margin: 0.05
similarity: dot
maxNegSamples: 10
negSearchLimit: 5
batchSize: 5
thread: 10
minCount: 1
minCountLabel: 1
label: __label__
label: __label__
ngrams: 1
bucket: 2000000
adagrad: 1
trainMode: 0
fileFormat: fastText
normalizeText: 0
dropoutLHS: 0
dropout

The performance in this example is not good and differs significantly from the published results [[1]](#fn1) where the authors report 91.6% accuracy on the test set for this task. Is is unclear, what is the reason for this discrepancy. Demonstrating the performance of a baseline classifier based on TF-IDF + SVM shows similar performance of the BOW + multinomial logistic regression reported in the paper.

In [25]:
import gensim
def to_tfidf(documents, dic=None, tfidf_model=None):
    documents = [gensim.parsing.preprocessing.preprocess_string(doc) for doc in documents]
    if dic is None:
        dic = gensim.corpora.Dictionary(documents)
        dic.filter_extremes()
    bows = [dic.doc2bow(doc) for doc in documents]
    if tfidf_model is None:
        tfidf_model = gensim.models.tfidfmodel.TfidfModel(dictionary=dic)
    tfidf_vectors = tfidf_model[bows]
    return tfidf_vectors, dic, tfidf_model


train = pd.read_csv('data/ag_news_csv/train.csv', names=['category', 'title', 'body'])
X_train = [x.title + ' ' + x.body for x in train.itertuples()]
y_train = [x.category for x in train.itertuples()]

test = pd.read_csv('data/ag_news_csv/test.csv', names=['category', 'title', 'body'])
X_test = [x.title + ' ' + x.body for x in test.itertuples()]
y_test = [x.category for x in test.itertuples()]

X_train_tfidf, dic, tfidf_model = to_tfidf(X_train)
X_test_tfidf, _, __ = to_tfidf(X_test, dic, tfidf_model)

The TF-IDF weighting used with the linear SVM achieves the accuracy of 91%. Because this is a multiclass classification problem, this metric is the same as hit@1, reported by StarSpace.

In [26]:
from sklearn.svm import LinearSVC
from sklearn import metrics
from sklearn import preprocessing

le = preprocessing.LabelEncoder()
le.fit(y_train)

svc = LinearSVC()
svc.fit(gensim.matutils.corpus2csc(X_train_tfidf, num_terms=len(dic)).T, le.transform(y_train))
y_predicted = svc.predict(gensim.matutils.corpus2csc(X_test_tfidf, num_terms=len(dic)).T)
print('Accuracy: {:.3f}'.format(metrics.accuracy_score(le.transform(y_test), y_predicted)))

Accuracy: 0.910


We have embeddings for a large number of words, so we can run clustering to see if the embeddings vectors can be used to partition words into four categories.

In [27]:
import numpy as np
from sklearn.cluster import KMeans

model = pd.read_csv('data/ag_news_csv/model.tsv', sep='\t', header=None, keep_default_na=False)
embeddings = model[model.columns[1:]]
kmeans = KMeans(n_clusters=4, random_state=12345).fit(embeddings)

The first three produced clusters closely match the topics World, Business, and Sci/Tech while the last and by far the largest cluster is less specific and contains words from all topics.

In [28]:
words_array = model[0].to_numpy()
for ci in range(kmeans.n_clusters):
    cluster_words = np.compress(kmeans.labels_==ci, words_array)
    print('Cluster {} ({} instances)'.format(ci, len(cluster_words)))
    print(cluster_words[:100])
    print('')

Cluster 0 (1703 instances)
['.' '-' "'" 'iraq' 'president' 'sunday' 'would' 'security' 'government'
 'people' 'afp' 'win' 'night' 'china' 'minister' 'bush' 'international'
 'killed' 'city' 'stocks' 'european' 'talks' 'league' 'country' 'british'
 'japan' 'india' 'police' 'prime' 'iraqi' 'leader' 'during' 'hit' 'say'
 'baghdad' 'expected' 'election' 'her' 'north' 'war' 'australia'
 'military' 'cut' 'nuclear' 'higher' 'un' 'official' 'palestinian' 'sox'
 'attack' 'troops' 'russia' 'israeli' 'gaza' 'press' 'west' 'even'
 'including' 'general' 'man' 'iran' 'football' 'forces' 'athens' 'past'
 'europe' 'investors' 'peace' 'release' 'canadian' 'six' 'russian' 'beat'
 'pakistan' 'public' 'eu' 'where' 'foreign' 'bomb' 'attacks' 'israel'
 'nations' 'championship' 'korea' 'australian' 'kerry' 'leaders' 'french'
 'men' 'death' 'killing' 'darfur' 'arafat' 'capital' 'army' 'japanese'
 'campaign' 'race' 'france' 'vote']

Cluster 1 (1713 instances)
['us' 'company' 'oil' 'inc' 'york' 'yesterday' 'no' 