In [1]:
from argparse import Namespace
from collections import Counter
import json
import os
import re
import string
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm_notebook

## Preprocessing ##

The BBC news data is a csv with category label and the news item. Here I load this dataset using pandas, and do a 70, 15, 15 split across the categories into training, validation, and test items. This information is added as a new column, and the data is written out to a new csv with this additional column in it.

The data has been exported from BigQuery and is stored on GCS, and can be access directly via

!gsutil cp gs://dataset-uploader/bbc/bbc-text.csv .

In [2]:
data = pd.read_csv("bbc-text.csv")

In [3]:
news_categories = set(data['category'])

In [4]:
class NewsCategory(object):
    def __init__(self, name, num_train, num_val, num_test, num_total):
        self._name = name
        self._num_train = num_train
        self._num_val = num_val
        self._num_test = num_test
        self._num_total = num_total
        
    @classmethod
    def create_class(cls, name, total_count):
        train = int(0.7 * total_count)
        dev = int(0.15 * total_count)
        test = total_count - train - dev
        return cls(name, train, dev, test, total_count)

In [5]:
new_data = pd.DataFrame()
cat_dict = {}
for category in news_categories:
    dss = data[data['category'] == category].copy(deep=True)
    count = len(dss)
    print("{} has {} entries".format(category, count))
    dss['split'] = ''
    cat_obj = NewsCategory.create_class(category, count)
    cat_dict[category] = cat_obj
    t = cat_obj._num_train
    v = cat_obj._num_train + cat_obj._num_val
    print("{}:{} train, {}:{} val and {}:{} test".format(0, t, t, v, v, count))
    
    dss.split.iloc[:t] = 'train'
    dss.split.iloc[t:v] = 'val'
    dss.split.iloc[v:count] = 'test'
    
    #print("{} with split train".format(len(dss[dss['split'] == 'train'])))
    #print("{} with split val".format(len(dss[dss['split'] == 'val'])))
    #print("{} with split test".format(len(dss[dss['split'] == 'test'])))
    
    new_data = new_data.append(dss)

sport has 511 entries
0:357 train, 357:433 val and 433:511 test
business has 510 entries
0:357 train, 357:433 val and 433:510 test
entertainment has 386 entries
0:270 train, 270:327 val and 327:386 test
politics has 417 entries
0:291 train, 291:353 val and 353:417 test
tech has 401 entries
0:280 train, 280:340 val and 340:401 test


In [6]:
len(new_data[new_data['split'] == 'train'])

1555

In [7]:
new_data.category.value_counts()

sport            511
business         510
politics         417
tech             401
entertainment    386
Name: category, dtype: int64

In [8]:
# Preprocess the news [remove non-aplha, add spaces]
def preprocess_text(text):
    text = text.lower()
    text = re.sub(r"([.,!?])", r" \1 ", text)
    text = re.sub(r"[^a-zA-Z.,!?]+", r" ", text)
    return text

In [9]:
new_data.text = new_data.text.apply(preprocess_text)

In [10]:
new_data = new_data.sample(frac=1)

In [11]:
new_data.head()

Unnamed: 0,category,text,split
1493,entertainment,tautou film tops cesar prize nods french film ...,train
679,sport,davenport puts retirement on hold lindsay dave...,train
1267,business,watchdog probes vivendi bond sale french stock...,train
2109,entertainment,court halts mark morrison album premiership fo...,test
2001,business,china continues rapid growth china s economy h...,test


In [12]:
new_data.to_csv("bbc_news_split.csv", index=False)

## The Vocabulary ##

This will create a numeric representation of our tokens (words).

In [13]:
class Vocabulary(object):
    
    def __init__(self, token_to_idx=None, add_unk=True, unk_token="<UNK>"):
        """
        Args:
            token_to_idx (dict): a pre-existing map of tokens to indices
            add_unk (bool): a flag that indicates whether to add the UNK token
            unk_token (str): the UNK token to add into the Vocabulary
        """

        self._add_unk = add_unk
        self._unk_token = unk_token
            
        if token_to_idx is None:
            token_to_idx = {}
        
        self._token_to_idx = token_to_idx
        
        self._index_to_token = {idx: token
                                for token, idx in self._token_to_idx.items()}
        
        self._unk_index = -1
        if add_unk:
            self._unk_index = self.add_token(unk_token)

    def to_serializable(self):
        """ returns a dictionary that can be serialized """
        return {'token_to_idx': self._token_to_idx, 
                'add_unk': self._add_unk, 
                'unk_token': self._unk_token}

    @classmethod
    def from_serializable(cls, contents):
        """ instantiates the Vocabulary from a serialized dictionary """
        return cls(**contents)


    def add_token(self, token):
        """Update mapping dicts based on the token.

        Args:
            token (str): the item to add into the Vocabulary
        Returns:
            index (int): the integer corresponding to the token
        """
        if token in self._token_to_idx:
            index = self._token_to_idx[token]
        else:
            index = len(self._index_to_token)
            self._index_to_token[index] = token
            self._token_to_idx[token] = index
        return index
    
    def add_many(self, tokens):
        """Add a list of tokens into the Vocabulary
        
        Args:
            tokens (list): a list of string tokens
        Returns:
            indices (list): a list of indices corresponding to the tokens
        """
        return [self.add_token(token) for token in tokens]
    
    def lookup_token(self, token):
        """Retrieve the index associated with the token 
          or the UNK index if token isn't present.
        
        Args:
            token (str): the token to look up 
        Returns:
            index (int): the index corresponding to the token
        Notes:
            `unk_index` needs to be >=0 (having been added into the Vocabulary) 
              for the UNK functionality 
        """
        if self._unk_index >= 0:
            index = self._token_to_idx.get(token, self._unk_index)
        else:
            index = self._token_to_idx[token]
        return index
    
    def lookup_index(self, index):
        """Return the token associated with the index
        
        Args: 
            index (int): the index to look up
        Returns:
            token (str): the token corresponding to the index
        Raises:
            KeyError: if the index is not in the Vocabulary
        """
        if index not in self._index_to_token:
            raise KeyError("the index (%d) is not in the Vocabulary" % index)
        return self._index_to_token[index]
    
    def __str__(self):
        return "<Vocabulary(size=%d)>" % len(self)
    
    def __len__(self):
        return len(self._token_to_idx)

## The Vectorizer ##

This will use the vocabulary to return a 1-hot vector with length equal to the vocabulary size.

In [14]:
class BBCNewsVectorizer(object):
    """ The Vectorizer which coordinates the Vocabularies and puts them to use"""
    def __init__(self, news_vocab, category_vocab):
        """
        Args:
            news_vocab (Vocabulary): maps words to integers
            category_vocab (Vocabulary): maps class labels to integers
        """
        self.news_vocab = news_vocab
        self.category_vocab = category_vocab
    
    def vectorize(self, news_item):
        """Create a collapsed one-hit vector for the news item
        
        Args:
            review (str): the news item 
        Returns:
            one_hot (np.ndarray): the collapsed one-hot encoding 
        """
        one_hot = np.zeros(len(self.news_vocab))
        for token in news_item.split(" "):
            if token not in string.punctuation:
                one_hot[self.news_vocab.lookup_token(token)] = 1
                
        return one_hot
    
    @classmethod
    def from_dataframe(cls, news_df, cutoff=10):
        """Instantiate the vectorizer from the dataset dataframe
        
        Args:
            news_df (pandas.DataFrame): the news dataset
            cutoff (int): the parameter for frequency-based filtering
        Returns:
            an instance of the BBCNewsVectorizer
        """
        
        news_vocab = Vocabulary(add_unk=True)
        category_vocab = Vocabulary(add_unk=False)
        
        # add the categories
        for category in sorted(set(news_df.category)):
            category_vocab.add_token(category)
            
        # add words that occur more than a certain times
        word_counts = Counter()
        for news_item in news_df.text:
            for word in news_item.split(" "):
                if word not in string.punctuation:
                    word_counts[word] += 1
                    
        for word, count in word_counts.items():            
            if count >= cutoff:
                news_vocab.add_token(word)
                
        return cls(news_vocab, category_vocab)
    
    @classmethod
    def from_serializable(cls, contents):
        """Instantiate a BBCNewsVectorizer from a serializable dictionary
        
        Args:
            contents (dict): the serializable dictionary
        Returns:
            an instance of the BBCNewsVectorizer class
        """
        news_vocab = Vocabulary.from_serializable(contents['news_vocab'])
        category_vocab = Vocabulary.from_serializable(contents['category_vocab'])
        
        return cls(news_vocab, category_vocab)
    
    def to_serializable(self):
        """Create the serializable dictionary for caching
        
        Returns:
            contents (dict): the serializable dictionary
        """
        return {'news_vocab': self.news_vocab.to_serializable(),
                'category_vocab': self.category_vocab.to_serializable()}

## The Dataset ##

I use the Vocabulary and the Vectorizer to build the dataset.

In [15]:
class BBCNewsDataset(Dataset):
    def __init__(self, news_df, vectorizer):
        """
        Args:
            news_df (pandas.DataFrame): the dataset
            vectorizer (BBCNewsVectorizer): vectorizer instantiated from dataset
        """
        self._news_df = news_df
        self._vectorizer = vectorizer
        
        self.train_df = self._news_df[self._news_df.split == 'train']
        self.train_size = len(self.train_df)
        
        self.val_df = self._news_df[self._news_df.split == 'val']
        self.val_size = len(self.val_df)
        
        self.test_df = self._news_df[self._news_df.split == 'test']
        self.test_len = len(self.test_df)
        
        self._lookup_dict = {'train': (self.train_df, self.train_size),
                             'val': (self.val_df, self.val_size),
                             'test': (self.test_df, self.test_len)}
        
        self.set_split('train')
        
    @classmethod
    def load_dataset_and_make_vectorizer(cls, news_csv):
        """Load dataset and make a new vectorizer from scratch
        
        Args:
            news_csv (str): location of the dataset
        Returns:
            an instance of BBCNewsDataset
        """
        news_df = pd.read_csv(news_csv)
        # only build the vectorizer using the training data (???)
        train_news_df = news_df[news_df.split == 'train']
        vectorizer = BBCNewsVectorizer.from_dataframe(train_news_df, cutoff=10)
        
        return cls(news_df, vectorizer)
    
    @classmethod
    def load_dataset_and_load_vectorizer(cls, news_csv, vectorizer_filepath):
        """Load dataset and the corresponding vectorizer. 
        Used in the case in the vectorizer has been cached for re-use
        
        Args:
            review_csv (str): location of the dataset
            vectorizer_filepath (str): location of the saved vectorizer
        Returns:
            an instance of BBCNewsDataset
        """
        news_df = pd.read_csv(news_csv)
        vectorizer = cls.load_vectorizer_only(vectorizer_filepath)
        
        return cls(news_df, vectorizer)

    @staticmethod
    def load_vectorizer_only(self, vectorizer_filepath):
        """a static method for loading the vectorizer from file
        
        Args:
            vectorizer_filepath (str): the location of the serialized vectorizer
        Returns:
            an instance of BBCNewsVectorizer
        """
        with open(vectorizer_filepath) as fp:
            return BBCNewsVectorizer.from_serializable(json.load(fp))
        
    def save_vectorizer(self, vectorizer_filepath):
        """saves the vectorizer to disk using json
        
        Args:
            vectorizer_filepath (str): the location to save the vectorizer
        """
        with open(vectorizer_filepath, 'w') as fp:
            json.dump(self._vectorizer.to_serializable(), fp)
            
    def get_vectorizer(self):
        """ returns the vectorizer """
        return self._vectorizer

    def set_split(self, split="train"):
        """ selects the splits in the dataset using a column in the dataframe 
        
        Args:
            split (str): one of "train", "val", or "test"
        """
        self._target_split = split
        self._target_df, self._target_size = self._lookup_dict[split]
        
    def __len__(self):
        return self._target_size

    def __getitem__(self, index):
        """the primary entry point method for PyTorch datasets
        
        Args:
            index (int): the index to the data point 
        Returns:
            a dictionary holding the data point's features (x_data) and label (y_target)
        """
        row = self._target_df.iloc[index]
        
        vector = self._vectorizer.vectorize(row.text)
        
        category = self._vectorizer.category_vocab.lookup_token(row.category)
        
        return {'x_data': vector,
                'y_target': category}
    
    def get_num_batches(self, batch_size):
        """Given a batch size, return the number of batches in the dataset
        
        Args:
            batch_size (int)
        Returns:
            number of batches in the dataset
        """
        return len(self) // batch_size  

def generate_batches(dataset, batch_size, shuffle=True,
                     drop_last=True, device="cpu"):
    """
    A generator function which wraps the PyTorch DataLoader. It will 
      ensure each tensor is on the right device location.
    """
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
                            shuffle=shuffle, drop_last=drop_last)

    for data_dict in dataloader:
        out_data_dict = {}
        for name, _ in data_dict.items():
            out_data_dict[name] = data_dict[name].to(device)
        yield out_data_dict

## The Model: BBCNewsClassifier ##



In [16]:
class BBCNewsClassifier(nn.Module):
    """ a simple perceptron based classifier """
    def __init__(self, num_features):
        """
        Args:
            num_features (int): the size of the input feature vector
        """
        super(BBCNewsClassifier, self).__init__()
        self.fc1 = nn.Linear(in_features=num_features, out_features=5)
        
    def forward(self, x_in, apply_sigmoid=False):
        """The forward pass of the classifier
        
        Args:
            x_in (torch.Tensor): an input data tensor. 
                x_in.shape should be (batch, num_features)
            apply_sigmoid (bool): a flag for the sigmoid activation
                should be false if used with the Cross Entropy losses
        Returns:
            the resulting tensor. tensor.shape should be (batch,)
        """
        y_out = self.fc1(x_in).squeeze()
        if apply_sigmoid:
            y_out = torch.sigmoid(y_out)
        
        return y_out

## Training Routine ##

### Helper Functions ###

In [17]:
def make_train_state(args):
    return {'stop_early': False,
            'early_stopping_step': 0,
            'early_stopping_best_val': 1e8,
            'learning_rate': args.learning_rate,
            'epoch_index': 0,
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': [],
            'test_loss': -1,
            'test_acc': -1,
            'model_filename': args.model_state_file}

def update_train_state(args, model, train_state):
    """Handle the training state updates.

    Components:
     - Early Stopping: Prevent overfitting.
     - Model Checkpoint: Model is saved if the model is better

    :param args: main arguments
    :param model: model to train
    :param train_state: a dictionary representing the training state values
    :returns:
        a new train_state
    """

    # Save one model at least
    if train_state['epoch_index'] == 0:
        torch.save(model.state_dict(), train_state['model_filename'])
        train_state['stop_early'] = False

    # Save model if performance improved
    elif train_state['epoch_index'] >= 1:
        loss_tm1, loss_t = train_state['val_loss'][-2:]

        # If loss worsened
        if loss_t >= train_state['early_stopping_best_val']:
            # Update step
            train_state['early_stopping_step'] += 1
        # Loss decreased
        else:
            # Save the best model
            if loss_t < train_state['early_stopping_best_val']:
                torch.save(model.state_dict(), train_state['model_filename'])

            # Reset early stopping step
            train_state['early_stopping_step'] = 0

        # Stop early ?
        train_state['stop_early'] = \
            train_state['early_stopping_step'] >= args.early_stopping_criteria

    return train_state

def compute_accuracy(y_pred, y_target):
    #print(y_pred)
    #print(y_target)
    #print(torch.argmax(y_pred).cpu())
    y_target = y_target.cpu()
    y_pred_indices = (torch.argmax(y_pred, dim=1)).cpu().long()#.max(dim=1)[1]
    n_correct = torch.eq(y_pred_indices, y_target).sum().item()
    return n_correct / len(y_pred_indices) * 100

### General Utilities ###

In [18]:
def set_seed_everywhere(seed, cuda):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed_all(seed)

def handle_dirs(dirpath):
    if not os.path.exists(dirpath):
        os.makedirs(dirpath)

### Settings and some prep work ###

In [19]:
args = Namespace(
    # Data and Path information
    frequency_cutoff=10,
    model_state_file='model.pth',
    news_csv='bbc_news_split.csv',
    save_dir='.',
    vectorizer_file='vectorizer.json',
    # No Model hyper parameters
    # Training hyper parameters
    batch_size=32,
    early_stopping_criteria=5,
    learning_rate=0.001,
    num_epochs=100,
    seed=1337,
    # Runtime options
    catch_keyboard_interrupt=True,
    cuda=False,
    expand_filepaths_to_save_dir=True,
    reload_from_files=False,
)

if args.expand_filepaths_to_save_dir:
    args.vectorizer_file = os.path.join(args.save_dir,
                                        args.vectorizer_file)

    args.model_state_file = os.path.join(args.save_dir,
                                         args.model_state_file)
    
    print("Expanded filepaths: ")
    print("\t{}".format(args.vectorizer_file))
    print("\t{}".format(args.model_state_file))
    
# Check CUDA
if not torch.cuda.is_available():
    args.cuda = False

print("Using CUDA: {}".format(args.cuda))

args.device = torch.device("cuda" if args.cuda else "cpu")

# Set seed for reproducibility
set_seed_everywhere(args.seed, args.cuda)

# handle dirs
handle_dirs(args.save_dir)

Expanded filepaths: 
	./vectorizer.json
	./model.pth
Using CUDA: False


## Initializations ##

In [20]:
if args.reload_from_files:
    # training from a checkpoint
    print("Loading dataset and vectorizer")
    dataset = BBCNewsDataset.load_dataset_and_load_vectorizer(args.news_csv,
                                                              args.vectorizer_file)
else:
    print("Loading dataset and creating vectorizer")
    # create dataset and vectorizer
    dataset = BBCNewsDataset.load_dataset_and_make_vectorizer(args.news_csv)
    dataset.save_vectorizer(args.vectorizer_file)    

vectorizer = dataset.get_vectorizer()

classifier = BBCNewsClassifier(len(vectorizer.news_vocab))

Loading dataset and creating vectorizer


In [21]:
classifier

BBCNewsClassifier(
  (fc1): Linear(in_features=5504, out_features=5, bias=True)
)

In [22]:
vectorizer

<__main__.BBCNewsVectorizer at 0x11d85e0b8>

## Training Loop ##

In [23]:
# switch to cpu or cuda device
classifier = classifier.to(args.device)

# instantiate loss function
loss_fn = nn.CrossEntropyLoss()

# instantiatiate optimizer
optimizer = optim.Adam(classifier.parameters(), lr=args.learning_rate)

# instantiatiate scheduler for the Adam optimizer
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                 mode='min', factor=0.5,
                                                 patience=1)

train_state = make_train_state(args)

epoch_bar = tqdm_notebook(desc='training routine', 
                          total=args.num_epochs,
                          position=0)

dataset.set_split('train')
train_bar = tqdm_notebook(desc='split=train',
                          total=dataset.get_num_batches(args.batch_size), 
                          position=1, 
                          leave=True)
dataset.set_split('val')
val_bar = tqdm_notebook(desc='split=val',
                        total=dataset.get_num_batches(args.batch_size), 
                        position=1, 
                        leave=True)

try:
    for epoch_index in range(args.num_epochs):
        train_state['epoch_index'] = epoch_index
        
        # first train on the training split of the data set
        dataset.set_split('train')
        
        batch_generator = generate_batches(dataset=dataset,
                                           batch_size=args.batch_size,
                                           device=args.device)
        running_loss = 0.0
        running_acc  = 0.0
        
        # set classifier to training mode
        classifier.train()
        
        for batch_index, batch_dict in enumerate(batch_generator):
            #
            # there are 5 steps in training
            #
            
            # step 1: initialize the gradiants to 0
            optimizer.zero_grad()
            
            # step 2: compute the output
            y_pred = classifier(x_in=batch_dict['x_data'].float())
            
            # step 3: compute the loss
            loss = loss_fn(y_pred, batch_dict['y_target'].long())
            loss_t = loss.item()
            running_loss += (loss_t - running_loss) / (batch_index + 1)
            
            # step 4: backprop the gradient
            loss.backward()
            
            # step 5: use the propagated gradient to adjust parameters
            optimizer.step()
            
            # -----------------------------------------
            # compute the accuracy
            acc_t = compute_accuracy(y_pred, batch_dict['y_target'])
            running_acc += (acc_t - running_acc) / (batch_index + 1)

            # update bar
            train_bar.set_postfix(loss=running_loss, 
                                  acc=running_acc, 
                                  epoch=epoch_index)
            train_bar.update()

        train_state['train_loss'].append(running_loss)
        train_state['train_acc'].append(running_acc)

        # now evaluate on the validation dataset
        dataset.set_split('val')
        
        batch_generator = generate_batches(dataset=dataset,
                                           batch_size=args.batch_size,
                                           device=args.device)
        running_loss = 0.0
        running_acc  = 0.0
        
        # set classifier to training mode
        classifier.eval()

        for batch_index, batch_dict in enumerate(batch_generator):
            
            # step 1: compute the output
            y_pred = classifier(x_in=batch_dict['x_data'].float())
            
            # step 2: compute the loss
            loss = loss_fn(y_pred, batch_dict['y_target'].long())
            loss_t = loss.item()
            running_loss += (loss_t - running_loss) / (batch_index + 1)

            # compute the accuracy
            acc_t = compute_accuracy(y_pred, batch_dict['y_target'])
            running_acc += (acc_t - running_acc) / (batch_index + 1)

            # update bar
            val_bar.set_postfix(loss=running_loss, 
                                acc=running_acc, 
                                epoch=epoch_index)
            val_bar.update()
            
        train_state['val_loss'].append(running_loss)
        train_state['val_acc'].append(running_acc)

        train_state = update_train_state(args=args, model=classifier,
                                         train_state=train_state)

        scheduler.step(train_state['val_loss'][-1])

        train_bar.n = 0
        val_bar.n = 0
        epoch_bar.update()

        if train_state['stop_early']:
            break

        train_bar.n = 0
        val_bar.n = 0
        epoch_bar.update()

except KeyboardInterrupt:
    print("Exiting loop")

HBox(children=(IntProgress(value=0, description='training routine', style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='split=train', max=48, style=ProgressStyle(description_width='…

HBox(children=(IntProgress(value=0, description='split=val', max=10, style=ProgressStyle(description_width='in…

## Inference ##

In [24]:
# compute the loss & accuracy on the test set using the best available model

classifier.load_state_dict(torch.load(train_state['model_filename']))
classifier = classifier.to(args.device)

dataset.set_split('test')
batch_generator = generate_batches(dataset, 
                                   batch_size=args.batch_size, 
                                   device=args.device)
running_loss = 0.0
running_acc = 0.0

classifier.eval()

for batch_index, batch_dict in enumerate(batch_generator):
    
        # step 1: compute the output
        y_pred = classifier(x_in=batch_dict['x_data'].float())
            
        # step 2: compute the loss
        loss = loss_fn(y_pred, batch_dict['y_target'].long())
        loss_t = loss.item()
        running_loss += (loss_t - running_loss) / (batch_index + 1)

        # compute the accuracy
        acc_t = compute_accuracy(y_pred, batch_dict['y_target'])
        running_acc += (acc_t - running_acc) / (batch_index + 1)

train_state['test_loss'] = running_loss
train_state['test_acc'] = running_acc

In [25]:
print("Test loss: {:.3f}".format(train_state['test_loss']))
print("Test Accuracy: {:.2f}".format(train_state['test_acc']))

Test loss: 0.109
Test Accuracy: 96.56


In [39]:
def predict_category(news, classifier, vectorizer):
    """Predict the rating of a review
    
    Args:
        news (str): the text of the review
        classifier (BBCNewsClassifier): the trained model
        vectorizer (BBCNewsVectorizer): the corresponding vectorizer
    """
    news_pp = preprocess_text(news)
    news_vector = torch.tensor(vectorizer.vectorize(news))
    
    cat_npred = classifier(x_in=news_vector.float())
    cat_pred = torch.argmax(cat_npred)
    
    return vectorizer.category_vocab.lookup_index(int(cat_pred))

In [40]:
test_news = "today the political party a national emergency as a stunt"

classifier = classifier.cpu()
prediction = predict_category(test_news, classifier, vectorizer)
print("{} -> {}".format(test_news, prediction))

today the political party a national emergency as a stunt -> politics


In [32]:
classifier.fc1.weight.shape

torch.Size([5, 5504])

In [37]:
# Sort weights
for i in range(5):
    fc1_weights = classifier.fc1.weight.detach()[i]
    _, indices = torch.sort(fc1_weights, dim=0, descending=True)
    indices = indices.numpy().tolist()
    topic = vectorizer.category_vocab.lookup_index(i)
    
    # Top 20 words
    print("Most influential words in {} News:".format(topic))
    print("--------------------------------------")
    for i in range(20):
        print(vectorizer.news_vocab.lookup_index(indices[i]))
    
    print("====\n\n\n")

    # Top 20 negative words
    print("Least influential words in {} News:".format(topic))
    print("--------------------------------------")
    indices.reverse()
    for i in range(20):
        print(vectorizer.news_vocab.lookup_index(indices[i]))
        
    print("====\n\n\n")

Most influential words in business News:
--------------------------------------
firm
bank
company
state
financial
investment
economy
bn
banks
reuters
stock
firms
fall
shares
securities
giant
however
profits
higher
chain
====



Least influential words in business News:
--------------------------------------
britain
digital
straw
film
jack
commons
event
committee
form
party
music
scientists
movies
she
spam
carried
athlete
chart
users
entertainment
====



Most influential words in entertainment News:
--------------------------------------
singer
show
tv
film
films
band
star
among
chart
festival
television
movie
stars
series
music
producers
including
hollywood
pop
actor
====



Least influential words in entertainment News:
--------------------------------------
games
stores
computer
round
brand
firm
operating
offer
sport
olympic
election
legislation
game
gamers
union
championships
xbox
players
mobile
midfielder
====



Most influential words in politics News:
---------------------------