<a href="https://colab.research.google.com/github/sadidhasan/text-classification/blob/main/reutersdata_keras_processing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!mkdir reuters21578

In [2]:
%cd reuters21578/

/content/reuters21578


In [3]:
!pwd

/content/reuters21578


In [4]:
!wget http://kdd.ics.uci.edu/databases/reuters21578/reuters21578.tar.gz

--2021-09-11 18:16:23--  http://kdd.ics.uci.edu/databases/reuters21578/reuters21578.tar.gz
Resolving kdd.ics.uci.edu (kdd.ics.uci.edu)... 128.195.1.86
Connecting to kdd.ics.uci.edu (kdd.ics.uci.edu)|128.195.1.86|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 8150596 (7.8M) [application/x-gzip]
Saving to: ‘reuters21578.tar.gz’


2021-09-11 18:16:25 (4.00 MB/s) - ‘reuters21578.tar.gz’ saved [8150596/8150596]



In [5]:
!ls -l

total 7960
-rw-r--r-- 1 root root 8150596 Mar 12  1999 reuters21578.tar.gz


In [6]:
!tar -xvzf reuters21578.tar.gz

README.txt
all-exchanges-strings.lc.txt
all-orgs-strings.lc.txt
all-people-strings.lc.txt
all-places-strings.lc.txt
all-topics-strings.lc.txt
cat-descriptions_120396.txt
feldman-cia-worldfactbook-data.txt
lewis.dtd
reut2-000.sgm
reut2-001.sgm
reut2-002.sgm
reut2-003.sgm
reut2-004.sgm
reut2-005.sgm
reut2-006.sgm
reut2-007.sgm
reut2-008.sgm
reut2-009.sgm
reut2-010.sgm
reut2-011.sgm
reut2-012.sgm
reut2-013.sgm
reut2-014.sgm
reut2-015.sgm
reut2-016.sgm
reut2-017.sgm
reut2-018.sgm
reut2-019.sgm
reut2-020.sgm
reut2-021.sgm


In [7]:
!ls

all-exchanges-strings.lc.txt	    reut2-002.sgm  reut2-013.sgm
all-orgs-strings.lc.txt		    reut2-003.sgm  reut2-014.sgm
all-people-strings.lc.txt	    reut2-004.sgm  reut2-015.sgm
all-places-strings.lc.txt	    reut2-005.sgm  reut2-016.sgm
all-topics-strings.lc.txt	    reut2-006.sgm  reut2-017.sgm
cat-descriptions_120396.txt	    reut2-007.sgm  reut2-018.sgm
feldman-cia-worldfactbook-data.txt  reut2-008.sgm  reut2-019.sgm
lewis.dtd			    reut2-009.sgm  reut2-020.sgm
README.txt			    reut2-010.sgm  reut2-021.sgm
reut2-000.sgm			    reut2-011.sgm  reuters21578.tar.gz
reut2-001.sgm			    reut2-012.sgm


In [8]:
import os
import re
from keras.preprocessing.text import Tokenizer
import collections

def make_reuters_dataset(path=os.path.join('/content/reuters21578'), min_samples_per_topic=15):

    wire_topics = []
    topic_counts = {}
    wire_bodies = []

    for fname in sorted(os.listdir(path)):
        if 'sgm' in fname:
            s = open(os.path.join(path, fname),'rb').read().decode('latin-1')
            #print(s)
            tag = '<TOPICS>'
            while tag in s:
                s = s[s.find(tag)+len(tag):]
                topics = s[:s.find('</')]
                if topics and '</D><D>' not in topics:
                    topic = topics.replace('<D>', '').replace('</D>', '')
                    wire_topics.append(topic)
                    topic_counts[topic] = topic_counts.get(topic, 0) + 1
                else:
                    continue

                bodytag = '<BODY>'
                body = s[s.find(bodytag)+len(bodytag):]
                body = body[:body.find('</')]
                wire_bodies.append(body)

    # only keep most common topics
    items = list(topic_counts.items())
    items.sort(key=lambda x: x[1], reverse=True)
    kept_topics = set()
    for x in items:
        print(x[0] + ': ' + str(x[1])),
        if x[1] >= min_samples_per_topic:
            kept_topics.add(x[0])
    print
    print('-')
    print('Kept topics:', len(kept_topics))
    
    # filter wires with rare topics
    kept_wires = []
    labels = []
    topic_indexes = {}
    for t, b in zip(wire_topics, wire_bodies):
        if t in kept_topics:
            if t not in topic_indexes:
                topic_index = len(topic_indexes)
                topic_indexes[t] = topic_index
            else:
                topic_index = topic_indexes[t]

            labels.append(topic_index)
            kept_wires.append(b)

    print('Kept wires:', len(kept_wires))
    print('-')
    print('Topic mapping:', sorted(topic_indexes.items(), key=lambda x:x[1]))
    print('-')

    # vectorize wires
    tokenizer = Tokenizer()
    tokenizer.fit_on_texts(kept_wires)
    X = tokenizer.texts_to_sequences(kept_wires)

    print('Sanity check:')
    for w in ["banana", "oil", "chocolate", "the", "dsft"]:
        print('...index of', w, ':', tokenizer.word_index.get(w))
    print('text reconstruction:')
    reverse_word_index = dict([(v, k) for k, v in tokenizer.word_index.items()])
    print(' '.join(reverse_word_index[i] for i in X[10]))

    print(collections.Counter(labels))
    print("number of labels:", len(labels), "\n")
    print("number of data points:", len(X), "\n")
    for i in range(len(X)):
      print("data point ", i, ":", X[i], labels[i])
      if i==3:
        break
    dataset = (X, labels)
    print('-')
    
    for a in dataset:
      print(a[0], a[1])
      
if __name__ == "__main__":
    make_reuters_dataset()

earn: 3972
acq: 2423
money-fx: 682
crude: 543
grain: 537
trade: 473
interest: 339
ship: 209
money-supply: 177
sugar: 154
gnp: 127
coffee: 126
gold: 123
veg-oil: 94
cpi: 86
oilseed: 81
cocoa: 67
copper: 62
reserves: 62
bop: 60
livestock: 58
ipi: 57
jobs: 57
alum: 53
iron-steel: 52
nat-gas: 51
dlr: 46
rubber: 42
gas: 38
tin: 32
carcass: 29
pet-chem: 29
cotton: 28
wpi: 27
retail: 23
wheat: 22
meal-feed: 22
orange: 22
zinc: 21
housing: 19
strategic-metal: 19
lead: 19
hog: 17
heat: 16
lei: 16
silver: 16
lumber: 13
fuel: 13
income: 13
tea: 9
corn: 9
instal-debt: 7
yen: 6
soybean: 5
potato: 5
nickel: 5
stg: 5
cpu: 4
jet: 4
platinum: 4
propane: 3
rice: 3
inventories: 3
palm-oil: 3
groundnut: 3
l-cattle: 2
coconut: 2
rapeseed: 2
plywood: 2
f-cattle: 2
fishmeal: 1
tapioca: 1
rand: 1
saudriyal: 1
nzdlr: 1
wool: 1
austdlr: 1
soy-meal: 1
barley: 1
cruzado: 1
hk: 1
naphtha: 1
-
Kept topics: 46
Kept wires: 11228
-
Topic mapping: [('cocoa', 0), ('grain', 1), ('veg-oil', 2), ('earn', 3), ('acq', 4), ('

In [9]:
import os, sys
import keras
import numpy as np
import collections
import statistics

import tensorflow as tf
import tensorflow.keras as keras

In [10]:
(X_train, y_train), (X_test, y_test) = keras.datasets.reuters.load_data()

print(X_train.shape)
print(y_test.shape)
print(max(len(e) for e in X_train))
print(max(len(e) for e in X_test))

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/reuters.npz
(8982,)
(2246,)
2376
1032


In [11]:
print("Raw training entry No 0: {}".format(X_train[0]))
print
print("Raw training label No 0: '{}'".format(y_train[0]))

Raw training entry No 0: [1, 27595, 28842, 8, 43, 10, 447, 5, 25, 207, 270, 5, 3095, 111, 16, 369, 186, 90, 67, 7, 89, 5, 19, 102, 6, 19, 124, 15, 90, 67, 84, 22, 482, 26, 7, 48, 4, 49, 8, 864, 39, 209, 154, 6, 151, 6, 83, 11, 15, 22, 155, 11, 15, 7, 48, 9, 4579, 1005, 504, 6, 258, 6, 272, 11, 15, 22, 134, 44, 11, 15, 16, 8, 197, 1245, 90, 67, 52, 29, 209, 30, 32, 132, 6, 109, 15, 17, 12]
Raw training label No 0: '3'


In [12]:
raw_word_index = keras.datasets.reuters.get_word_index()
word_index = {v+3:k for k,v in raw_word_index.items()}
word_index[0] = '-PAD-'
word_index[1] = '-START-'
word_index[2] = '-UNK-'

# Reconstruct train data entry as string
entry = 202
print("Newswire category: {}".format(y_train[entry]))
print(" ".join(word_index.get(w, 2) for w in X_train[entry]))

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/reuters_word_index.json
Newswire category: 21
-START- high labour costs and slower corporate investment could hinder sweden's economic growth after 1987 the organisation for economic cooperation and development oecd said the swedish economy grew at a slower rate in 1986 than in previous years gdp rose about 1 7 pct in 1986 compared with 2 2 pct in 1985 but this growth depended largely on external factors particularly lower oil prices the oecd secretariat said in its latest annual report on sweden it warned that labour costs had risen more rapidly in sweden than in other oecd countries because of high labour costs swedish industry which largely relies upon export markets was losing market share wages in the manufacturing sector grew by seven pct in 1986 in line with 1985 increases while public sector wages rose an estimated 9 2 pct in 1986 up from six pct in 1985 this was significantly higher than average 

In [13]:
category = 5

train_elabels = [(c, i) for i,c in enumerate(y_train)]
cat = [e[1] for e in filter(lambda x: x[0]==category, train_elabels)]
print("Number of entries for category {}: {}".format(category, len(cat)))
print()
for c in cat:
    print(" ".join(word_index.get(w, 2) for w in X_train[c]))
    print()

Number of entries for category 5: 17

-START- the european commission's decision to release an additional 300 000 tonnes of british intervention feed wheat for the home market will provide only moderate relief in an increasingly tight market traders said some operators had been anticipating a larger tonnage pointing out that at this week's u k intervention tender the market sought to buy 340 000 tonnes but only 126 000 tonnes were granted the new tranche of intervention grain is unlikely to satisfy demand they said and keen buying competition for supplies in stores is expected to keep prices firm the release of the feed wheat followed recent strong representations by the u k grain trade to the commission there has been growing concern that rising internal prices triggered by heavy exports were creating areas of shortage in interior markets the latest ec authorisation will add 70 000 tonnes at the april 14 tender and a further 30 000 tonnes later in the month the remaining 200 000 tonne

In [14]:
mapping = ['cocoa','grain','veg-oil','earn','acq','wheat','copper','housing','money-supply',
           'coffee','sugar','trade','reserves','ship','cotton','carcass','crude','nat-gas',
           'cpi','money-fx','interest','gnp','meal-feed','alum','oilseed','gold','tin',
           'strategic-metal','livestock','retail','ipi','iron-steel','rubber','heat','jobs',
           'lei','bop','zinc','orange','pet-chem','dlr','gas','silver','wpi','hog','lead']

train_count = collections.Counter(y_train)
test_count = collections.Counter(y_test)
total_words = [statistics.mean([len(e) for e in X_train[y_train.flatten() == i]]) for i in range(46)]

print("{:5s} {:20s} {:5s} {:5s}  {:7s}".format(" "    ," "         , "Nr of", "docs", "Mean nr of words"))
print("{:5s} {:20s} {:5s}  {:5s} {:7s}".format("Index","Class name", "train", "test", "in train set"))
for i in range(46):
    print("{:5d} {:20s} {:5d} {:5d}   {:6.2f}".format(i,mapping[i], train_count[i], test_count[i], total_words[i]))

                           Nr of docs   Mean nr of words
Index Class name           train  test  in train set
    0 cocoa                   55    12   225.78
    1 grain                  432   105   188.67
    2 veg-oil                 74    20   184.86
    3 earn                  3159   813    87.67
    4 acq                   1949   474   135.83
    5 wheat                   17     5   213.35
    6 copper                  48    14   154.46
    7 housing                 16     3   180.38
    8 money-supply           139    38   191.48
    9 coffee                 101    25   225.87
   10 sugar                  124    30   184.73
   11 trade                  390    83   253.80
   12 reserves                49    13   186.92
   13 ship                   172    37   164.66
   14 cotton                  26     2   142.69
   15 carcass                 20     9   170.45
   16 crude                  444    99   219.79
   17 nat-gas                 39    12   149.82
   18 cpi                 