In [45]:
import os
import numpy as np
from tqdm import tqdm

DATA_DIR = '../../datasets/reuters'
train_folder = os.path.join(DATA_DIR, 'training')
test_folder = os.path.join(DATA_DIR, 'test')

In [68]:
def get_docs(folder):
    files = np.array(os.listdir(folder))
    files = files[np.argsort(np.array(files, dtype=int))]
    docs = []
    ids = []
    for filename in tqdm(files):
        filepath = os.path.join(folder, filename)
        try:
            doc = open(filepath, 'r', encoding='utf-8')
            docs.append(doc.read())
            ids.append(filename)
        except UnicodeDecodeError:
            doc.close()
    return docs, ids

In [69]:
train_docs, train_ids = get_docs(train_folder)
test_docs, test_ids = get_docs(test_folder)

100%|██████████| 7769/7769 [00:00<00:00, 20624.80it/s]
100%|██████████| 3019/3019 [00:00<00:00, 25722.26it/s]


In [43]:
len(test_docs)

3018

In [44]:
len(train_docs)

7769

In [74]:
classes = open(os.path.join(DATA_DIR, 'cats.txt'), 'r', encoding='utf-8')

In [75]:
class_lines = classes.readlines()

In [87]:
labels_set = set()

In [90]:
train_labels_dict = {}
test_labels_dict = {}

for line in class_lines:
    labels_split = line.strip().split()
    split_type, doc_id = labels_split[0].split('/')
    labels = labels_split[1:]
    labels_set |= set(labels)
    #print(split_type, doc_id, labels)
    if split_type == 'test':
        test_labels_dict[doc_id] = labels
    else:
        train_labels_dict[doc_id] = labels

In [92]:
train_labels = []
for doc_id in train_ids:
    train_labels.append(train_labels_dict[doc_id])
test_labels = []
for doc_id in test_ids:
    test_labels.append(test_labels_dict[doc_id])

In [101]:
np.mean([len(train_labels[i]) for i in range(len(train_labels))])

1.2337495173123953

In [126]:
label_names = np.array(list(labels_set))
label_to_index = dict(zip(label_names, np.arange(len(label_names))))

In [128]:
frequences = [0] * len(label_to_index)
for ind in range(len(train_labels)):
    for label_index in train_labels[ind]:
        frequences[label_to_index[label_index]] += 1
        
index_to_frequency_index = dict(zip(np.argsort(frequences)[::-1],
                                    np.arange(len(frequences))))
frequency_indexes = [index_to_frequency_index[label_to_index[label]]
                     for label in label_names]
label_to_index = dict(zip(label_names, frequency_indexes))
index_to_label = dict(zip(frequency_indexes, label_names))

labels_train = []
for labels in train_labels:
    labels_train.append(sorted(
        [label_to_index[label] for label in labels]))
    
labels_test = []
for labels in test_labels:
    labels_test.append(sorted(
        [label_to_index[label] for label in labels]))

In [161]:
import pandas as pd

In [171]:
df_train_valid = pd.DataFrame({
    'text': train_docs,
    'id': train_ids,
    'labels': labels_train
})

df_test = pd.DataFrame({
    'text': test_docs,
    'id': test_ids,
    'labels': labels_test
})

In [174]:
from sklearn.model_selection import train_test_split

In [175]:
df_train, df_valid = train_test_split(df_train_valid, test_size=0.1,
                                      random_state=3773)

In [172]:
df_train_valid.head()

Unnamed: 0,text,id,labels
0,BAHIA COCOA REVIEW\n Showers continued throug...,1,[23]
1,NATIONAL AVERAGE PRICES FOR FARMER-OWNED RESER...,5,"[3, 7, 9, 32, 39, 62]"
2,ARGENTINE 1986/87 GRAIN/OILSEED REGISTRATIONS\...,6,"[3, 7, 9, 13, 17, 18, 39, 53, 57, 67, 84]"
3,CHAMPION PRODUCTS &lt;CH> APPROVES STOCK SPLIT...,9,[0]
4,COMPUTER TERMINAL SYSTEMS &lt;CPML> COMPLETES ...,10,[1]


In [173]:
df_test.head()

Unnamed: 0,text,id,labels
0,ASIAN EXPORTERS FEAR DAMAGE FROM U.S.-JAPAN RI...,14826,[5]
1,CHINA DAILY SAYS VERMIN EAT 7-12 PCT GRAIN STO...,14828,[3]
2,JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWA...,14829,"[4, 21]"
3,THAI TRADE DEFICIT WIDENS IN FIRST QUARTER\n ...,14832,"[3, 5, 9, 12, 34, 36, 46]"
4,INDONESIA SEES CPO PRICE RISING SHARPLY\n Ind...,14833,"[17, 37]"


In [176]:
len(df_train_valid), len(df_train),len(df_valid), len(df_test)

(7769, 6992, 777, 3018)

In [177]:
df_train.to_csv(os.path.join(DATA_DIR, 'train.csv'),
                encoding='utf-8', index=0)
df_valid.to_csv(os.path.join(DATA_DIR, 'valid.csv'),
                encoding='utf-8', index=0)
df_test.to_csv(os.path.join(DATA_DIR, 'test.csv'),
               encoding='utf-8', index=0)

In [181]:
df_train

Unnamed: 0,text,id,labels
7124,&lt;OCELOT INDUSTRIES LTD> YEAR LOSS\n Shr lo...,12391,[0]
4953,JAPAN ECONOMY MAY START BOTTOMING OUT SOON -AG...,8671,"[5, 29]"
2105,PLACER &lt;PLC> TO INCREASE STAKE IN EQUITY SI...,3868,"[1, 42]"
5225,Sumita says major nations cooperated to stabil...,9132,[2]
900,UNICORP AMERICAN CORP &lt;UAC> 4TH QTR NET\n ...,1722,[0]
3396,&lt;CABRE EXPLORATION LTD> SIX MTHS JAN 31 NET...,6055,[0]
3985,JOHN FAIRFAX LTD &lt;FFXA.S> FIRST HALF\n 26 ...,6999,[0]
7484,CULLEN/FROST &lt;CFBI> TO OMIT DIVIDEND\n Cul...,13180,[0]
2270,LEIGH-PEMBERTON OPPOSES TAKEOVER PROTECTION RU...,4127,[1]
6808,U.S. CREDIT MARKET OUTLOOK - PRIME RATE\n The...,11861,[6]


In [179]:
doc_lengths = [len(train_docs[i].split()) for i in range(len(train_docs))]
np.mean(doc_lengths), np.std(doc_lengths)

(130.11185480756853, 137.5386571049968)

In [180]:
130 + 137

267

In [158]:
labels_train[:10]

[[23],
 [3, 7, 9, 32, 39, 62],
 [3, 7, 9, 13, 17, 18, 39, 53, 57, 67, 84],
 [0],
 [1],
 [0],
 [0, 1],
 [0],
 [0],
 [0]]

In [123]:
label_to_index

{'ship': 10}

In [113]:
label_names

['money-supply',
 'lei',
 'veg-oil',
 'dlr',
 'cotton',
 'oilseed',
 'bop',
 'iron-steel',
 'cpu',
 'strategic-metal',
 'lin-oil',
 'cpi',
 'ipi',
 'oat',
 'gnp',
 'castor-oil',
 'income',
 'earn',
 'l-cattle',
 'jet',
 'livestock',
 'tea',
 'coconut',
 'palladium',
 'palm-oil',
 'cotton-oil',
 'barley',
 'silver',
 'propane',
 'housing',
 'nkr',
 'palmkernel',
 'acq',
 'cocoa',
 'soy-oil',
 'platinum',
 'wpi',
 'pet-chem',
 'rand',
 'gold',
 'reserves',
 'corn',
 'nzdlr',
 'rice',
 'carcass',
 'groundnut-oil',
 'nat-gas',
 'jobs',
 'heat',
 'naphtha',
 'grain',
 'trade',
 'dfl',
 'lumber',
 'sun-oil',
 'coffee',
 'hog',
 'soy-meal',
 'orange',
 'yen',
 'instal-debt',
 'nickel',
 'crude',
 'interest',
 'gas',
 'sorghum',
 'copra-cake',
 'potato',
 'sugar',
 'rape-oil',
 'zinc',
 'sunseed',
 'dmk',
 'fuel',
 'wheat',
 'soybean',
 'retail',
 'rubber',
 'copper',
 'tin',
 'ship',
 'meal-feed',
 'sun-meal',
 'alum',
 'groundnut',
 'lead',
 'rye',
 'coconut-oil',
 'rapeseed',
 'money-fx']

In [91]:
train_labels_dict

{'1': ['cocoa'],
 '5': ['sorghum', 'oat', 'barley', 'corn', 'wheat', 'grain'],
 '6': ['wheat',
  'sorghum',
  'grain',
  'sunseed',
  'corn',
  'oilseed',
  'soybean',
  'sun-oil',
  'soy-oil',
  'lin-oil',
  'veg-oil'],
 '9': ['earn'],
 '10': ['acq'],
 '11': ['earn'],
 '12': ['acq', 'earn'],
 '13': ['earn'],
 '14': ['earn'],
 '18': ['earn'],
 '19': ['grain', 'wheat'],
 '22': ['copper'],
 '23': ['earn'],
 '24': ['earn'],
 '27': ['earn'],
 '29': ['housing'],
 '30': ['money-supply'],
 '36': ['earn'],
 '37': ['earn'],
 '38': ['earn'],
 '40': ['earn'],
 '41': ['earn'],
 '42': ['coffee'],
 '44': ['ship', 'acq'],
 '45': ['acq'],
 '46': ['sugar'],
 '47': ['trade'],
 '48': ['reserves'],
 '49': ['ship'],
 '50': ['earn'],
 '53': ['earn'],
 '56': ['earn'],
 '57': ['corn', 'grain'],
 '58': ['money-supply'],
 '59': ['ship'],
 '64': ['earn'],
 '65': ['earn'],
 '66': ['earn'],
 '68': ['acq'],
 '69': ['soy-meal', 'meal-feed', 'oilseed', 'soybean', 'veg-oil'],
 '71': ['earn'],
 '74': ['earn'],
 '75': [

In [21]:
train_docs[2]

'ARGENTINE 1986/87 GRAIN/OILSEED REGISTRATIONS\n  Argentine grain board figures show\n  crop registrations of grains, oilseeds and their products to\n  February 11, in thousands of tonnes, showing those for futurE\n  shipments month, 1986/87 total and 1985/86 total to February\n  12, 1986, in brackets:\n      Bread wheat prev 1,655.8, Feb 872.0, March 164.6, total\n  2,692.4 (4,161.0).\n      Maize Mar 48.0, total 48.0 (nil).\n      Sorghum nil (nil)\n      Oilseed export registrations were:\n      Sunflowerseed total 15.0 (7.9)\n      Soybean May 20.0, total 20.0 (nil)\n      The board also detailed export registrations for\n  subproducts, as follows,\n      SUBPRODUCTS\n      Wheat prev 39.9, Feb 48.7, March 13.2, Apr 10.0, total\n  111.8 (82.7) .\n      Linseed prev 34.8, Feb 32.9, Mar 6.8, Apr 6.3, total 80.8\n  (87.4).\n      Soybean prev 100.9, Feb 45.1, MAr nil, Apr nil, May 20.0,\n  total 166.1 (218.5).\n      Sunflowerseed prev 48.6, Feb 61.5, Mar 25.1, Apr 14.5,\n  total 149.