# News Classification

In this notebook we're training a multiclass text classifier. 

## Toy Example

In [1]:
import pandas as pd

from simpletransformers.classification import ClassificationModel

In [3]:
train_data = [
    ["Pizza and pasta are Italian food", 0],
    ["Before start cooking find a good recipe", 0],
    ["Cooking is one of my hobbies", 0],
    ["I like football", 1],
    ["I hate tennis", 1],
    ["This year the Olympic Games are held in Tokyo", 1],
    ["Natural Language Processing deals with talking machines", 2],
    ["Textual entailment and semantic similarity are NLP tasks", 2],
    ["NLU stands for natural language understanding", 2],
]

train_df = pd.DataFrame(train_data, columns=["text", "labels"])

eval_data = [
    ["NLU stands for natural language understanding", 2],
    ["I hate tennis", 1],
    ["Cooking is one of my hobbies", 0],
]

eval_df = pd.DataFrame(eval_data, columns=["text", "labels"])

In [4]:
# configuration
args = {
    "output_dir": "outputs/",
    "cache_dir": "cache_dir/",
    "fp16": False,
    "fp16_opt_level": "O1",
    "max_seq_length": 128,
    "train_batch_size": 32,
    "gradient_accumulation_steps": 1,
    "eval_batch_size": 8,
    "num_train_epochs": 10,
    "weight_decay": 0,
    "learning_rate": 4e-5,
    "adam_epsilon": 1e-8,
    "warmup_ratio": 0.06,
    "warmup_steps": 0,
    "max_grad_norm": 1.0,
    "logging_steps": 50,
    "save_steps": 2000,
    "overwrite_output_dir": True,
    "reprocess_input_data": False,
    "evaluate_during_training": False,
    # "process_count": cpu_count() - 2 if cpu_count() > 2 else 1,
    "n_gpu": 1,
    "wandb_project": "test-master",
}

In [5]:
# Create a ClassificationModel
model = ClassificationModel(
    "bert", "bert-base-cased", num_labels=3, args=args
)

# Train the model
model.train_model(train_df)

Features loaded from cache at cache_dir/cached_train_bert_128_3_9


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=10.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=1.0, style=ProgressStyle(descript…

Running loss: 1.193155


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=1.0, style=ProgressStyle(descript…

Running loss: 1.197329


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=1.0, style=ProgressStyle(descript…

Running loss: 1.075137


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=1.0, style=ProgressStyle(descript…

Running loss: 0.974361


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=1.0, style=ProgressStyle(descript…

Running loss: 0.965430


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=1.0, style=ProgressStyle(descript…

Running loss: 0.973906


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=1.0, style=ProgressStyle(descript…

Running loss: 0.960047


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=1.0, style=ProgressStyle(descript…

Running loss: 0.891930


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=1.0, style=ProgressStyle(descript…

Running loss: 0.888590


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=1.0, style=ProgressStyle(descript…

Running loss: 0.856310

Training of bert model complete. Saved to outputs/.


In [None]:
# Evaluate the model
result, model_outputs, wrong_predictions = model.eval_model(eval_df)
print(result)

In [8]:
predictions, raw_outputs = model.predict(["This class is about natural language"])
print(predictions)

Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


[2]


## Training a Text Classifier for News


In [9]:
train = pd.read_csv("../datasets/agnews/train.csv", header=None)
train.columns = "labels text paragraph".split()
train.head(10)

Unnamed: 0,labels,text,paragraph
0,3,Wall St. Bears Claw Back Into the Black (Reuters),"Reuters - Short-sellers, Wall Street's dwindli..."
1,3,Carlyle Looks Toward Commercial Aerospace (Reu...,Reuters - Private investment firm Carlyle Grou...
2,3,Oil and Economy Cloud Stocks' Outlook (Reuters),Reuters - Soaring crude prices plus worries\ab...
3,3,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...
4,3,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco..."
5,3,"Stocks End Up, But Near Year Lows (Reuters)",Reuters - Stocks ended slightly higher on Frid...
6,3,Money Funds Fell in Latest Week (AP),AP - Assets of the nation's retail money marke...
7,3,Fed minutes show dissent over inflation (USATO...,USATODAY.com - Retail sales bounced back a bit...
8,3,Safety Net (Forbes.com),Forbes.com - After earning a PH.D. in Sociolog...
9,3,Wall St. Bears Claw Back Into the Black,"NEW YORK (Reuters) - Short-sellers, Wall Stre..."


AGNews is a collection of news categorized under 4 distinc categories:

- 1: World
- 2: Sports
- 3: Business
- 4: Sci/Tech

In [12]:
train[train["labels"] == 1].head()

Unnamed: 0,labels,text,paragraph
492,1,Venezuelans Vote Early in Referendum on Chavez...,Reuters - Venezuelans turned out early\and in ...
493,1,S.Koreans Clash with Police on Iraq Troop Disp...,Reuters - South Korean police used water canno...
494,1,Palestinians in Israeli Jails Start Hunger Str...,Reuters - Thousands of Palestinian\prisoners i...
495,1,Seven Georgian soldiers wounded as South Osset...,AFP - Sporadic gunfire and shelling took place...
496,1,Rwandan Troops Arrive in Darfur (AP),AP - Dozens of Rwandan soldiers flew into Suda...


In [13]:
train[train["labels"] == 2].head()

Unnamed: 0,labels,text,paragraph
448,2,"Phelps, Thorpe Advance in 200 Freestyle (AP)",AP - Michael Phelps took care of qualifying fo...
449,2,Reds Knock Padres Out of Wild-Card Lead (AP),AP - Wily Mo Pena homered twice and drove in f...
450,2,"Dreaming done, NBA stars awaken to harsh Olymp...",AFP - National Basketball Association players ...
451,2,"Indians Beat Twins 7-1, Nearing AL Lead (AP)",AP - The Cleveland Indians pulled within one g...
452,2,"Galaxy, Crew Play to 0-0 Tie (AP)",AP - Kevin Hartman made seven saves for Los An...


In [14]:
train[train["labels"] == 3].head()

Unnamed: 0,labels,text,paragraph
0,3,Wall St. Bears Claw Back Into the Black (Reuters),"Reuters - Short-sellers, Wall Street's dwindli..."
1,3,Carlyle Looks Toward Commercial Aerospace (Reu...,Reuters - Private investment firm Carlyle Grou...
2,3,Oil and Economy Cloud Stocks' Outlook (Reuters),Reuters - Soaring crude prices plus worries\ab...
3,3,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...
4,3,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco..."


In [15]:
train[train["labels"] == 4].head()

Unnamed: 0,labels,text,paragraph
78,4,"'Madden,' 'ESPN' Football Score in Different W...",Reuters - Was absenteeism a little high\on Tue...
79,4,Group to Propose New High-Speed Wireless Forma...,Reuters - A group of technology companies\incl...
80,4,AOL to Sell Cheap PCs to Minorities and Senior...,Reuters - America Online on Thursday said it\p...
81,4,Companies Approve New High-Capacity Disc Forma...,Reuters - A group of consumer electronics\make...
82,4,Missing June Deals Slow to Return for Software...,Reuters - The mystery of what went wrong for t...


In [None]:
train.hist(column="labels")

In [10]:
# configuration
args = {
    "output_dir": "outputs/",
    "cache_dir": "cache_dir/",
    "fp16": False,
    "fp16_opt_level": "O1",
    "max_seq_length": 128,
    "train_batch_size": 32,
    "gradient_accumulation_steps": 1,
    "eval_batch_size": 8,
    "num_train_epochs": 10,
    "weight_decay": 0,
    "learning_rate": 4e-5,
    "adam_epsilon": 1e-8,
    "warmup_ratio": 0.06,
    "warmup_steps": 0,
    "max_grad_norm": 1.0,
    "logging_steps": 50,
    "save_steps": 2000,
    "overwrite_output_dir": True,
    "reprocess_input_data": False,
    "evaluate_during_training": False,
    # "process_count": cpu_count() - 2 if cpu_count() > 2 else 1,
    "n_gpu": 1,
    "wandb_project": "nlp-exercises",
}

In [11]:
# Create a ClassificationModel
model = ClassificationModel(
    "bert", "bert-base-cased", num_labels=4, args=args
)

# Train the model
model.train_model(train)

RuntimeError: CUDA error: device-side assert triggered

In [None]:
# load the test set
test = pd.read_csv("../datasets/agnews/test.csv", header=None)
test.columns = "labels headline text".split()

In [None]:
# Evaluate the model
result, model_outputs, wrong_predictions = model.eval_model(test)

In [None]:
model.predict(["Brazil recalls diplomats, officials from Argentina."])

## A Quick Demo

In [None]:
def load_model(
    model_architecture: str,
    directory: str = "outputs/",
    use_cuda: bool = False,
    **kwargs
):
    """Loads a pre-trained model"""
    model = ClassificationModel(
        model_architecture, directory, use_cuda=use_cuda, args=kwargs
    )
    return model

In [None]:
model = load_model("bert")

In [None]:
from IPython.core.magic import register_cell_magic

@register_cell_magic
def classify_news(line, text):
    """Prints predictions of a Text Classifier"""
    predictions, raw_outputs = model.predict([text])
    return predictions[0]

In [None]:
%%classify_news
Venezuelan President has urged families to have six children for the good of the country.