In [10]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import accuracy_score,classification_report

In [11]:
newsgroups=fetch_20newsgroups(subset='all')

In [12]:
newsgroups.data[1]

'From: mblawson@midway.ecn.uoknor.edu (Matthew B Lawson)\nSubject: Which high-performance VLB video card?\nSummary: Seek recommendations for VLB video card\nNntp-Posting-Host: midway.ecn.uoknor.edu\nOrganization: Engineering Computer Network, University of Oklahoma, Norman, OK, USA\nKeywords: orchid, stealth, vlb\nLines: 21\n\n  My brother is in the market for a high-performance video card that supports\nVESA local bus with 1-2MB RAM.  Does anyone have suggestions/ideas on:\n\n  - Diamond Stealth Pro Local Bus\n\n  - Orchid Farenheit 1280\n\n  - ATI Graphics Ultra Pro\n\n  - Any other high-performance VLB card\n\n\nPlease post or email.  Thank you!\n\n  - Matt\n\n-- \n    |  Matthew B. Lawson <------------> (mblawson@essex.ecn.uoknor.edu)  |   \n  --+-- "Now I, Nebuchadnezzar, praise and exalt and glorify the King  --+-- \n    |   of heaven, because everything he does is right and all his ways  |   \n    |   are just." - Nebuchadnezzar, king of Babylon, 562 B.C.           |   \n'

In [13]:
vectorizer=TfidfVectorizer(stop_words='english')
X=vectorizer.fit_transform(newsgroups.data)
y=newsgroups.target

In [14]:
X.shape

(18846, 173451)

In [15]:
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3,random_state=1)

In [16]:
model=MultinomialNB()
model.fit(X_train,y_train)
y_pred=model.predict(X_test)

In [17]:
y_pred

array([16, 19, 18, ...,  8,  2,  3])

In [18]:
y_test

array([16, 19, 18, ...,  8,  2,  3])

In [19]:
y_test

array([16, 19, 18, ...,  8,  2,  3])

In [20]:
accuracy_score(y_test,y_pred)

0.8793774319066148

In [21]:
accuracy_score

In [22]:
print(classification_report(y_test,y_pred,target_names=newsgroups.target_names))

                          precision    recall  f1-score   support

             alt.atheism       0.91      0.84      0.87       251
           comp.graphics       0.84      0.84      0.84       289
 comp.os.ms-windows.misc       0.89      0.84      0.87       318
comp.sys.ibm.pc.hardware       0.79      0.84      0.82       304
   comp.sys.mac.hardware       0.87      0.90      0.88       278
          comp.windows.x       0.90      0.88      0.89       290
            misc.forsale       0.92      0.79      0.85       300
               rec.autos       0.92      0.89      0.90       311
         rec.motorcycles       0.89      0.97      0.93       297
      rec.sport.baseball       0.91      0.98      0.95       283
        rec.sport.hockey       0.94      0.98      0.96       314
               sci.crypt       0.88      0.97      0.92       312
         sci.electronics       0.89      0.82      0.86       287
                 sci.med       0.97      0.90      0.93       297
         

In [23]:
from gensim.models import Word2Vec
import re

In [37]:
import numpy as np

In [24]:
def preprocess_text(text):
  # use regex to extract the list of tokens (only words from a string)
  tokens=re.findall(r'\b[a-zA-Z]+\b',text.lower())
  return(tokens)

In [25]:
preprocessed_data=[preprocess_text(text) for text in newsgroups.data]

In [28]:
preprocessed_data

'from'

In [29]:
model=Word2Vec(sentences=preprocessed_data,vector_size=100,window=5,min_count=5,sg=1,workers=4)

In [32]:
model.wv['computer']

array([-0.06506277,  0.3493992 , -0.13675642, -0.5615899 , -0.09025293,
       -0.20775332,  0.44752806,  0.50011367, -0.27830014, -0.3287856 ,
        0.15212156, -0.5202427 ,  0.33188426,  0.57080245, -0.3237164 ,
       -0.33437228, -0.5770538 ,  0.24276102,  0.15060677, -0.15626763,
        0.12359783,  0.8363527 , -0.040628  , -0.50147635, -0.0115218 ,
        0.3769376 ,  0.37988892,  0.08129547, -0.2898867 ,  0.77681255,
       -0.14792529, -0.10054581,  0.45180997, -0.15928599, -0.07883754,
        0.12883921,  0.09934603,  0.07930949,  0.5309025 , -0.50919   ,
       -0.3657965 , -0.1464366 ,  0.18016645, -0.35796562, -0.11324129,
       -0.32701358,  0.59696335, -0.17565508,  0.37851396,  0.05381031,
        0.08693907, -0.29159117,  0.13613757,  0.00711787, -0.17435715,
       -0.00184189,  0.05187647, -0.362711  ,  0.0314949 , -0.24650747,
        0.3393591 , -0.05058189,  0.25475407, -0.18724525, -0.40233245,
        0.03243318, -0.0231081 ,  0.9593594 , -1.0683002 , -0.03

In [30]:
def get_average_word2vec(tokens,model,vector_size=100):
  word_vectors=[model.wv[token] for token in tokens if token in model.wv]
  if len(word_vectors)==0:
    return([0]*vector_size)
  return(sum(word_vectors)/len(word_vectors))

In [33]:
X=[get_average_word2vec(text,model) for text in preprocessed_data]

In [39]:
y=newsgroups.target

In [40]:
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3,random_state=1)

In [42]:
from sklearn.tree import DecisionTreeClassifier
#from sklearn.ensemble import RandomForest

In [43]:
model_tree=DecisionTreeClassifier(random_state=42)

In [44]:
model_tree.fit(X_train,y_train)

In [45]:
y_pred=model_tree.predict(X_test)

In [46]:
accuracy=accuracy_score(y_test,y_pred)

In [47]:
accuracy

0.44623275557127695

In [48]:
print(classification_report(y_test,y_pred,target_names=newsgroups.target_names))

                          precision    recall  f1-score   support

             alt.atheism       0.42      0.38      0.40       251
           comp.graphics       0.35      0.33      0.34       289
 comp.os.ms-windows.misc       0.35      0.34      0.35       318
comp.sys.ibm.pc.hardware       0.32      0.34      0.33       304
   comp.sys.mac.hardware       0.27      0.29      0.28       278
          comp.windows.x       0.52      0.53      0.53       290
            misc.forsale       0.55      0.50      0.52       300
               rec.autos       0.34      0.32      0.33       311
         rec.motorcycles       0.39      0.43      0.41       297
      rec.sport.baseball       0.56      0.60      0.58       283
        rec.sport.hockey       0.70      0.69      0.69       314
               sci.crypt       0.61      0.55      0.58       312
         sci.electronics       0.27      0.28      0.28       287
                 sci.med       0.50      0.47      0.48       297
         