In [1]:
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import MultiLabelBinarizer

from sklearn.linear_model import SGDClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.multiclass import OneVsRestClassifier

In [2]:
test = []
train = []
all_labels = []

with open("./reuters/reuters/reuters/cats.txt") as f:
    content = f.readlines()
    for line in tqdm(content):
        line = line.replace("\n","")
        line = line.split(" ")
        file = line[0]
        labels = line[1:]
        all_labels.extend(labels)

        with open("./reuters/reuters/reuters/" + file, encoding="utf8", errors='ignore') as f:
            data = f.read().replace('\n', '')
            if 'test' in file:
                test.append({"text" : data, "labels" : labels})
            elif 'train' in file:
                train.append({"text" : data, "labels" : labels})
            else:
                print("invalid file ", file)

all_labels = list(set(all_labels))
len(all_labels), len(train), len(test)

  0%|          | 0/10788 [00:00<?, ?it/s]

(90, 7769, 3019)

In [3]:
data = pd.DataFrame(train)
data.dropna(inplace = True)
data

Unnamed: 0,text,labels
0,BAHIA COCOA REVIEW Showers continued througho...,[cocoa]
1,NATIONAL AVERAGE PRICES FOR FARMER-OWNED RESER...,"[sorghum, oat, barley, corn, wheat, grain]"
2,ARGENTINE 1986/87 GRAIN/OILSEED REGISTRATIONS ...,"[wheat, sorghum, grain, sunseed, corn, oilseed..."
3,CHAMPION PRODUCTS &lt;CH> APPROVES STOCK SPLIT...,[earn]
4,COMPUTER TERMINAL SYSTEMS &lt;CPML> COMPLETES ...,[acq]
...,...,...
7764,BANK OF JAPAN INTERVENES SOON AFTER TOKYO OPEN...,"[money-fx, dlr]"
7765,JAPAN RUBBER STOCKS FALL IN MARCH Japan's rub...,[rubber]
7766,SOUTH KOREAN WON FIXED AT 25-MONTH HIGH THE B...,[money-fx]
7767,NIPPON MINING LOWERS COPPER PRICE Nippon Mini...,[copper]


In [4]:
mlb = MultiLabelBinarizer()
mlb.fit_transform([all_labels])

array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1]])

In [5]:
test_set = pd.DataFrame(test)
test_set

Unnamed: 0,text,labels
0,ASIAN EXPORTERS FEAR DAMAGE FROM U.S.-JAPAN RI...,[trade]
1,CHINA DAILY SAYS VERMIN EAT 7-12 PCT GRAIN STO...,[grain]
2,JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWA...,"[nat-gas, crude]"
3,THAI TRADE DEFICIT WIDENS IN FIRST QUARTER Th...,"[rubber, tin, sugar, corn, rice, grain, trade]"
4,INDONESIA SEES CPO PRICE RISING SHARPLY Indon...,"[palm-oil, veg-oil]"
...,...,...
3014,N.Z.'S CHASE CORP MAKES OFFER FOR ENTREGROWTH ...,[acq]
3015,TOKYO DEALERS SEE DOLLAR POISED TO BREACH 140 ...,"[yen, dlr, money-fx]"
3016,JAPAN/INDIA CONFERENCE CUTS GULF WAR RISK CHAR...,[ship]
3017,SOVIET INDUSTRIAL GROWTH/TRADE SLOWER IN 1987 ...,[ipi]


In [6]:
stopwords = []
with open("./reuters/reuters/reuters/stopwords") as f:
    content = f.readlines()
    for line in content:
        line = line.replace("\n","")
        line = line.split(" ")
        stopwords.extend(line)
len(stopwords)

571

In [7]:
import nltk
stoplist = nltk.corpus.stopwords.words('english')
stopwords.extend(stoplist)
stopwords = list(set(stopwords))

In [8]:

DROP_STOPWORDS = False
STEMMING = False

MAX_NGRAM_LENGTH = 1
VECTOR_LENGTH = 1000
SET_RANDOM = 9999

codelist = ['\r', '\n', '\t']    


In [9]:
def parse_doc(text):
    text = text.lower()
    text = re.sub(r'&(.)+', "", text)  # no & references  
    text = re.sub(r'pct', 'percent', text)  # replace pct abreviation  
    text = re.sub(r"[^\w\d'\s]+", '', text)  # no punct except single quote 
    text = re.sub(r'[^\x00-\x7f]',r'', text)  # no non-ASCII strings    
    if text.isdigit(): text = ""  # omit words that are all digits    
    for code in codelist:
        text = re.sub(code, ' ', text)  # get rid of escape codes  
    # replace multiple spacess with one space
    text = re.sub('\s+', ' ', text)        
    return text


In [10]:

def parse_words(text): 
    # split document into individual words
    tokens=text.split()
    re_punc = re.compile('[%s]' % re.escape(string.punctuation))
    # remove punctuation from each word
    tokens = [re_punc.sub('', w) for w in tokens]
    # remove remaining tokens that are not alphabetic
    tokens = [word for word in tokens if word.isalpha()]
    # filter out tokens that are one or two characters long
    tokens = [word for word in tokens if len(word) > 2]
    # filter out tokens that are more than twenty characters long
    tokens = [word for word in tokens if len(word) < 21]
    # filter out stop words if requested
    if DROP_STOPWORDS:
        tokens = [w for w in tokens if not w in stopwords]         
    # perform word stemming if requested
    if STEMMING:
        ps = PorterStemmer()
        tokens = [ps.stem(word) for word in tokens]
    # recreate the document string from parsed words
    text = ''
    for token in tokens:
        text = text + ' ' + token
    return tokens, text

In [11]:
# import re
# text = data.text[1]
# aa = parse_doc(text)
# aa
    

In [12]:
# import string
# parse_words(aa)

In [13]:
multilabel = MultiLabelBinarizer()
y = multilabel.fit_transform(data['labels'])
y

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 1, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [14]:
tfidf = TfidfVectorizer(analyzer='word', ngram_range=(1,3), stop_words='english')
X = tfidf.fit_transform(data['text'])
X.shape, y.shape

((7769, 831643), (7769, 90))

In [15]:
X_test = tfidf.transform(test_set['text'])
y_test = multilabel.transform(test_set['labels'])


In [16]:
sgd = SGDClassifier()
lr = LogisticRegression(solver='lbfgs')
svc = LinearSVC()

In [25]:
from sklearn.metrics import jaccard_score

def print_score(y_pred, clf):
    print("Classifier: ", clf.__class__.__name__)
    print('Jaccard score: {}'.format(jaccard_score(y_test, y_pred, average = 'samples')))

In [28]:
# for classifier in [sgd, lr, svc]: 
#     clf = OneVsRestClassifier(classifier)
#     clf.fit(X, y)
#     y_pred = clf.predict(X_test)
#     print(y_pred)
#     print_score(y_pred, classifier)

## SGD

In [29]:
clf_sgd = OneVsRestClassifier(sgd)
clf_sgd.fit(X, y)
y_pred_sgd = clf_sgd.predict(X_test)
print_score(y_pred_sgd, sgd)
print(classification_report(y_test, y_pred_sgd))
print('-------------------------------------------------------------------------')

Classifier:  SGDClassifier
Jaccard score: 0.854579737181592
              precision    recall  f1-score   support

           0       0.98      0.96      0.97       719
           1       1.00      0.22      0.36        23
           2       1.00      0.57      0.73        14
           3       0.91      0.67      0.77        30
           4       0.89      0.44      0.59        18
           5       0.00      0.00      0.00         1
           6       1.00      0.94      0.97        18
           7       0.00      0.00      0.00         2
           8       0.00      0.00      0.00         3
           9       0.96      0.96      0.96        28
          10       1.00      0.83      0.91        18
          11       0.00      0.00      0.00         1
          12       0.96      0.77      0.85        56
          13       1.00      0.45      0.62        20
          14       0.00      0.00      0.00         2
          15       0.85      0.39      0.54        28
          16       0.

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [30]:
test_set['org'] = mlb.inverse_transform(y_pred_sgd)
test_set

Unnamed: 0,text,labels,org
0,ASIAN EXPORTERS FEAR DAMAGE FROM U.S.-JAPAN RI...,[trade],"(trade,)"
1,CHINA DAILY SAYS VERMIN EAT 7-12 PCT GRAIN STO...,[grain],"(grain,)"
2,JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWA...,"[nat-gas, crude]",()
3,THAI TRADE DEFICIT WIDENS IN FIRST QUARTER Th...,"[rubber, tin, sugar, corn, rice, grain, trade]","(trade,)"
4,INDONESIA SEES CPO PRICE RISING SHARPLY Indon...,"[palm-oil, veg-oil]","(palm-oil, veg-oil)"
...,...,...,...
3014,N.Z.'S CHASE CORP MAKES OFFER FOR ENTREGROWTH ...,[acq],"(acq,)"
3015,TOKYO DEALERS SEE DOLLAR POISED TO BREACH 140 ...,"[yen, dlr, money-fx]","(dlr, money-fx)"
3016,JAPAN/INDIA CONFERENCE CUTS GULF WAR RISK CHAR...,[ship],"(ship,)"
3017,SOVIET INDUSTRIAL GROWTH/TRADE SLOWER IN 1987 ...,[ipi],"(ipi, trade)"


## Logistic Regression

In [31]:
clf_lr = OneVsRestClassifier(lr)
clf_lr.fit(X, y)
y_pred_lr = clf_lr.predict(X_test)
print_score(y_pred_lr, lr)
print(classification_report(y_test, y_pred_lr))
print('-------------------------------------------------------------------------')

KeyboardInterrupt: 

In [None]:
test_set['org'] = mlb.inverse_transform(y_pred_lr)
test_set

## SVM


In [35]:
clf_svm = OneVsRestClassifier(svc)
clf_svm.fit(X, y)
y_pred_svm = clf_svm.predict(X_test)
print_score(y_pred_svm, svc)
print(classification_report(y_test, y_pred_svm))
print("-----------------------------------------------------")

Classifier:  LinearSVC
Jaccard score: 0.8459325187712069
              precision    recall  f1-score   support

           0       0.98      0.95      0.97       719
           1       1.00      0.22      0.36        23
           2       1.00      0.57      0.73        14
           3       0.90      0.63      0.75        30
           4       0.88      0.39      0.54        18
           5       0.00      0.00      0.00         1
           6       1.00      0.83      0.91        18
           7       0.00      0.00      0.00         2
           8       0.00      0.00      0.00         3
           9       0.96      0.96      0.96        28
          10       1.00      0.83      0.91        18
          11       0.00      0.00      0.00         1
          12       0.93      0.75      0.83        56
          13       1.00      0.45      0.62        20
          14       0.00      0.00      0.00         2
          15       0.91      0.36      0.51        28
          16       0.00 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [37]:
test_set['org'] = mlb.inverse_transform(y_pred_svm)
test_set

Unnamed: 0,text,labels,org
0,ASIAN EXPORTERS FEAR DAMAGE FROM U.S.-JAPAN RI...,[trade],"(trade,)"
1,CHINA DAILY SAYS VERMIN EAT 7-12 PCT GRAIN STO...,[grain],"(grain,)"
2,JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWA...,"[nat-gas, crude]",()
3,THAI TRADE DEFICIT WIDENS IN FIRST QUARTER Th...,"[rubber, tin, sugar, corn, rice, grain, trade]","(trade,)"
4,INDONESIA SEES CPO PRICE RISING SHARPLY Indon...,"[palm-oil, veg-oil]","(palm-oil, veg-oil)"
...,...,...,...
3014,N.Z.'S CHASE CORP MAKES OFFER FOR ENTREGROWTH ...,[acq],"(acq,)"
3015,TOKYO DEALERS SEE DOLLAR POISED TO BREACH 140 ...,"[yen, dlr, money-fx]","(dlr, money-fx)"
3016,JAPAN/INDIA CONFERENCE CUTS GULF WAR RISK CHAR...,[ship],"(ship,)"
3017,SOVIET INDUSTRIAL GROWTH/TRADE SLOWER IN 1987 ...,[ipi],"(ipi, trade)"


## USING BERT

In [None]:
from transformers import BertTokenizer, BertModel

model = BertModel.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
encoded = data['labels'].apply(lambda x: tokenizer.encode_plus(x, add_special_tokens=True,
                                     pad_to_max_length=True,
                                     return_attention_mask=True,
                                     truncation = True,
                                     max_length=max_length,
                                     return_tensors='pt'))
encoded

In [None]:
import torch
input_ids = torch.cat(tuple(encoded.apply(lambda x:x['input_ids'])))
input_ids

In [None]:
attention_mask = torch.cat(tuple(encoded.apply(lambda x:x['attention_mask'])))
attention_mask

In [None]:
output[0][:, 0, :]

In [None]:
len(output)