In [62]:
import re
import os
import nltk
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.utils import shuffle
from nltk.corpus import reuters
from transformers import pipeline

%load_ext nb_black

The nb_black extension is already loaded. To reload it, use:
  %reload_ext nb_black


<IPython.core.display.Javascript object>

In [14]:
documents = reuters.fileids()
train_docs = list(filter(lambda doc: doc.startswith("train"), documents))
test_docs = list(filter(lambda doc: doc.startswith("test"), documents))
categories = reuters.categories()

<IPython.core.display.Javascript object>

In [15]:
title_re = "([A-Z].*[a-z]$)"


def title_split(doc):
    prev_word = ""
    for word in doc:
        if bool(re.match(title_re, word)):
            title_pos = doc.index(word)
            if prev_word == "A" or word == "USA":
                title_pos = title_pos - 1
            return " ".join(doc[:title_pos])
        prev_word = word


def words_split(doc):
    prev_word = ""
    for word in doc:
        if bool(re.match(title_re, word)):
            title_pos = doc.index(word)
            if prev_word == "A" or word == "USA":
                title_pos = title_pos - 1
            return " ".join(doc[title_pos:])
        prev_word = word

<IPython.core.display.Javascript object>

In [16]:
train_list = [int(x.split("/")[1]) for x in train_docs]
cat_list = [reuters.categories(x) for x in train_docs]
title_list = [title_split(reuters.words(x)) for x in train_docs]
words_list = [words_split(reuters.words(x)) for x in train_docs]

<IPython.core.display.Javascript object>

In [17]:
news_df = pd.DataFrame(
    {
        "doc_id": train_list,
        "categories": cat_list,
        "title": title_list,
        "content": words_list,
    }
)

<IPython.core.display.Javascript object>

In [18]:
def title_split(row):
    split_word = row["content"].split()[0]
    title_new = split_word + row["content"].split(split_word)[1]
    row["title"] = title_new
    row["content"] = title_new
    return row

<IPython.core.display.Javascript object>

In [19]:
title_split_df = news_df[news_df["title"] == ""].apply(title_split, axis=1)
cols = list(news_df.columns)
news_df.loc[news_df.doc_id.isin(title_split_df.doc_id), cols] = title_split_df[cols]

<IPython.core.display.Javascript object>

In [20]:
cat_dict = {
    "acq": "acorn capital investment fund limited",
    "alum": "aluminum",
    "bop": "bottom of the pyramid",
    "castor-oil": "castor oil",
    "carcass": "livestock",
    "cpu": "computershare",
    "coconut-oil": "coconut oil",
    "copra-cake": "copra cake",
    "cpi": "consumer price index",
    "dfl": "deutsche fussball liga",
    "dlr": "digital realty trust",
    "dmk": "dravida munnetra kazhagam",
    "earn": "ellington residential mortgage reit stock",
    "gnp": "grupo nacional provincial",
    "groundnut-oil": "ground nut oil",
    "instal-debt": "instalco intressenter debt",
    "ipi": "intrepid potash",
    "iron-steel": "iron steel",
    "l-cattle": "cattle",
    "lei": "leading economic index",
    "lin-oil": "linseed oil",
    "meal-feed": "meal feed",
    "money-fx": "money foreign exchange",
    "money-supply": "money supply",
    "nat-gas": "natural gas",
    "nkr": "nokia",
    "palm-oil": "palm oil",
    "palmkernel": "palm kernel",
    "pet-chem": "petronas chemicals",
    "rape-oil": "rapeseed oil",
    "soy-meal": "soy meal",
    "soy-oil": "soy oil",
    "strategic-metal": "strategic metal",
    "sun-meal": "sunflower meal",
    "sun-oil": "sunflower oil",
    "sunseed": "sunflower seed",
    "veg-oil": "vagetable oil",
    "wpi": "waterfront philippines",
}

<IPython.core.display.Javascript object>

In [21]:
news_df

Unnamed: 0,doc_id,categories,title,content
0,1,[cocoa],BAHIA COCOA REVIEW,Showers continued throughout the week in the B...
1,10,[acq],COMPUTER TERMINAL SYSTEMS & lt ; CPML > COMPLE...,Computer Terminal Systems Inc said it has comp...
2,100,[money-supply],N . Z . TRADING BANK DEPOSIT GROWTH RISES SLIG...,New Zealand ' s trading bank seasonally adjust...
3,1000,[acq],NATIONAL AMUSEMENTS AGAIN UPS VIACOM & lt ; VI...,Viacom International Inc said & lt ; National ...
4,10000,[earn],ROGERS & lt ; ROG > SEES 1ST QTR NET UP SIGNIF...,Rogers Corp said first quarter earnings will b...
...,...,...,...,...
7764,999,"[interest, money-fx]",U . K . MONEY MARKET SHORTAGE FORECAST REVISED...,The Bank of England said it had revised its fo...
7765,9992,[earn],KNIGHT - RIDDER INC & lt ; KRN > SETS QUARTERLY,Qtly div 25 cts vs 25 cts prior Pay April 13 R...
7766,9993,[earn],TECHNITROL INC & lt ; TNL > SETS QUARTERLY,Qtly div 12 cts vs 12 cts prior Pay April 21 R...
7767,9994,[earn],NATIONWIDE CELLULAR SERVICE INC & lt ; NCEL > ...,"Shr loss six cts vs loss 18 cts Net loss 89 , ..."


<IPython.core.display.Javascript object>

In [23]:
news_df["actual"] = news_df["categories"].apply(lambda x: x[0])
news_df["title_length"] = news_df["title"].apply(lambda x: len(x) if x != None else 0)
news_df["con_length"] = news_df["content"].apply(lambda x: len(x) if x != None else 0)
news_df["actual"] = news_df["actual"].apply(
    lambda x: cat_dict[x] if x in cat_dict.keys() else x
)

<IPython.core.display.Javascript object>

In [41]:
news_df.to_pickle("C:/Users/rparg/Documents/Data/Reuters/news_df.pkl")

<IPython.core.display.Javascript object>

In [24]:
new_cats = [
    "Acorn Capital Investment Fund Limited",
    "aluminum",
    "bottom of the pyramid",
    "barley",
    "castor oil",
    "carcass",
    "cattle",
    "cocoa",
    "coconut",
    "coconut oil",
    "coffee",
    "consumer price index",
    "copper",
    "copra cake",
    "corn",
    "cotton",
    "computershare",
    "crude",
    "deutsche fussball liga",
    "digital realty trust",
    "dravida munnetra kazhagam",
    "ellington residential mortgage reit stock",
    "fuel",
    "gas",
    "gold",
    "grain",
    "groundnut",
    "ground nut oil",
    "grupo nacional provincial",
    "heat",
    "hog",
    "housing",
    "income",
    "instalco intressenter debt",
    "interest",
    "intrepid potash",
    "iron steel",
    "jet",
    "jobs",
    "lead",
    "leading economic index",
    "linseed oil",
    "livestock",
    "lumber",
    "meal feed",
    "money foreign exchange",
    "money supply",
    "natural gas",
    "naphtha",
    "nickel",
    "nokia",
    "new zealand dollar",
    "oat",
    "orange",
    "palladium",
    "palm oil",
    "palm kernel",
    "petronas chemicals",
    "platinum",
    "potato",
    "propane",
    "rand",
    "rapeseed",
    "rapeseed oil",
    "reserves",
    "retail",
    "rice",
    "rubber",
    "rye",
    "ship",
    "silver",
    "sorghum",
    "soybean",
    "soy meal",
    "soy oil",
    "strategic metal",
    "sugar",
    "sunflower meal",
    "sunflower oil",
    "sunflower seed",
    "tea",
    "tin",
    "trade",
    "vagetable oil",
    "waterfront philippines",
    "wheat",
    "yen",
    "zinc",
]

<IPython.core.display.Javascript object>

In [3]:
news_df = pd.read_pickle("C:/Users/rparg/Documents/Data/Reuters/news_df.pkl")

<IPython.core.display.Javascript object>

## load zero-shot classifier

In [25]:
classifier = pipeline("zero-shot-classification", model="roberta-large-mnli")

All model checkpoint layers were used when initializing TFRobertaForSequenceClassification.

All the layers of TFRobertaForSequenceClassification were initialized from the model checkpoint at roberta-large-mnli.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFRobertaForSequenceClassification for predictions without further training.


<IPython.core.display.Javascript object>

## selecting documents from the 16 top-categories with 30+ docs

In [26]:
cat_count_df = news_df.groupby(by="actual").count()
cat_count_df.sort_values(by="doc_id", ascending=False).iloc[:35]

Unnamed: 0_level_0,doc_id,categories,title,content,title_length,con_length
actual,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
ellington residential mortgage reit stock,2843,2843,2693,2693,2843,2843
acorn capital investment fund limited,1650,1650,1494,1494,1650,1650
crude,370,370,331,331,370,370
interest,329,329,285,285,329,329
money foreign exchange,266,266,238,238,266,266
trade,253,253,230,230,253,253
grain,218,218,204,204,218,218
corn,157,157,137,137,157,157
digital realty trust,126,126,109,109,126,126
money supply,125,125,76,76,125,125


<IPython.core.display.Javascript object>

In [51]:
sub_cat_list = [
    "aluminum",
    "barley",
    "cocoa",
    "coffee",
    "copper",
    "corn",
    "cotton",
    "crude",
    "gold",
    "grain",
    "iron steel",
    "livestock",
    "natural gas",
    "palm oil",
    "rubber",
    "sugar",
]

<IPython.core.display.Javascript object>

In [63]:
sub_cat_df = news_df[news_df["actual"].isin(sub_cat_list)]
sub_cat_df = sub_cat_df[
    (sub_cat_df["con_length"] > 200) & (sub_cat_df["title_length"] > 20)
]
sc_sample_df = (
    sub_cat_df.groupby("actual")
    .apply(lambda x: x.sample(20, replace=True))
    .reset_index(drop=True)
)
sc_sample_df = shuffle(
    sc_sample_df.drop(
        columns=["categories", "title_length", "con_length"]
    ).drop_duplicates()
)
sc_sample_df.to_pickle("C:/Users/rparg/Documents/Data/Reuters/sc_sample.pkl")

<IPython.core.display.Javascript object>

# classified based on document content

In [81]:
## classified based on document content

sc_con_list = list(sc_sample_df["content"])
sc_class_list = [
    classifier(content, sub_cat_list, multi_label=True) for content in sc_con_list
]
categorized_df = pd.DataFrame.from_records(sc_class_list)
categorized_df["labels"] = categorized_df["labels"].apply(lambda x: x[0])
categorized_df["scores"] = categorized_df["scores"].apply(lambda x: x[0])
categorized_df = categorized_df.rename(
    {
        "sequence": "content",
        "labels": "con_predicted",
        "scores": "con_score",
    },
    axis=1,
)
con_clfd_df = sc_sample_df.merge(categorized_df, how="inner", on="content")
con_clfd_df.to_pickle("C:/Users/rparg/Documents/Data/Reuters/con_clfd.pkl")


## classified based on document titles

sc_title_list = list(sc_sample_df["title"])
tc_clfd_list = [
    classifier(title, sub_cat_list, multi_label=True) for title in sc_title_list
]
tc_clfd_df = pd.DataFrame.from_records(tc_clfd_list)
tc_clfd_df["labels"] = tc_clfd_df["labels"].apply(lambda x: x[0])
tc_clfd_df["scores"] = tc_clfd_df["scores"].apply(lambda x: x[0])
tc_clfd_df = tc_clfd_df.rename(
    {
        "sequence": "title",
        "labels": "title_predicted",
        "scores": "title_score",
    },
    axis=1,
)
title_clfd_df = sc_sample_df.merge(tc_clfd_df, how="inner", on="title")
title_clfd_df.to_pickle("C:/Users/rparg/Documents/Data/Reuters/title_clfd.pkl")
sc_clfd_df = con_clfd_df.merge(tc_clfd_df, how="inner", on="title")
sc_clfd_df.to_pickle("C:/Users/rparg/Documents/Data/Reuters/sc_clfd.pkl")

<IPython.core.display.Javascript object>

# classified based on document titles

In [82]:
sc_clfd_df

Unnamed: 0,doc_id,title,content,actual,con_predicted,con_score,title_predicted,title_score
0,3862,PERU GUERRILLAS INTERRUPT TRAIN ROUTE TO MINES,Maoist guerrillas using dynamite derailed two ...,copper,copper,0.777180,crude,0.859559
1,12728,COOPER BASIN NATURAL GAS RESERVES UPGRADED,Remaining recoverable gas reserves in the area...,natural gas,natural gas,0.980061,natural gas,0.997715
2,1556,JANUARY CRUDE OIL MOVEMENTS FALL SEVEN MLN TONS,Worldwide spot crude oil movements fell to 30 ...,crude,crude,0.980273,crude,0.970165
3,9521,ARGENTINE SOYBEAN YIELD ESTIMATES DOWN FURTHER,Argentine grain producers again reduced their ...,corn,grain,0.703691,livestock,0.882180
4,11124,VENEZUELA PLANS METALS INVESTMENT FOR 1987 - 89,"The Venezuela Guayana Corporation , CVG , whic...",aluminum,iron steel,0.773195,copper,0.334685
...,...,...,...,...,...,...,...,...
272,4679,"CHINA TRYING TO INCREASE COTTON OUTPUT , PAPER...",China ' s 1987 cotton output must rise above t...,cotton,cotton,0.996337,cotton,0.990517
273,6344,"INDONESIA TO IMPORT PALM OIL , FEARS MAY SHORTAGE",Indonesia has issued licences to traders to im...,palm oil,palm oil,0.917967,palm oil,0.989234
274,12489,40 MINERS TRAPPED BY FIRE IN GASPE COPPER MINE,Some 40 miners were trapped underground today ...,copper,copper,0.996184,crude,0.782214
275,12355,COFFEE FUTURES UNDER DLR A POUND AT SIX - YEAR...,Coffee futures dipped further and closed below...,coffee,livestock,0.823553,coffee,0.991903


<IPython.core.display.Javascript object>