# Debugging 20Newsgroups
- **Task**: Classify "Christianity" vs "Atheism" documents from the 20 Newsgroups dataset.
- **Problem**: The 20Newsgroups dataset is special because it contains a lot of artifacts – tokens (e.g., person names, punctuation marks) which are not relevant, but strongly cooccur with one of the classes. For evaluation, we therefore used the Religion dataset by [Ribeiro et al. (2016)](https://arxiv.org/pdf/1602.04938.pdf), containing "Christianity" and "Atheism" web pages, as a target dataset.
- **Solution**: We use our framework to identify the features detecting irrelevant words (that do not capture the meaning of Christianity/Atheism and cannot generalize to the Religion dataset) and disable such features.

In [1]:
# Notebook setup
import pickle
import os
import datetime
import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [14, 7]
os.environ['PYTHONHASHSEED'] = '0'

# Set random seed to create reproducable results
the_seed = 1234
np.random.seed(the_seed)
random.seed(the_seed)
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
from keras import backend as K
tf.set_random_seed(the_seed)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
import find

## Settings

- GloVe word embeddings: Please replace the string in the second line with a path to your GloVe embeddings file which can be download [here](http://nlp.stanford.edu/data/glove.6B.zip)

In [3]:
EMBEDDING_DIM = 300
EMBEDDING_PATH = f"../../CNNAnalysis/data/glove/glove.6B.{EMBEDDING_DIM}d.txt" # Path to your glove embeddings

- Dataset

In [4]:
DATA_PATH = 'preprocessed_data/'
MAIN_DATASET = '20Newsgroups'
SECOND_DATASET = 'Religion'
THIRD_DATASET = None
GENDER_BIAS = False

- Model

In [5]:
MODEL_PATH = 'trained_models/' # Path to save your trained models
MODEL_ARCH = 'CNN'
MAXLEN = 150
FILTERS = [(10, 2), (10, 3), (10, 4)] # Ten filters of each window size [2,3,4]
BATCH_SIZE = 128

## Model creation and training

In [6]:
# 0. Load GloVe embeddings
embedding_matrix, vocab_size, index2word, word2index = find.get_embedding_matrix(EMBEDDING_PATH, EMBEDDING_DIM, pad_initialisation = "zeros")

Loading Glove Model


400000it [00:52, 7652.66it/s]


Done. 400000  words loaded!


In [7]:
# 1. Load datasets and prepare inputs
# 1.1 Main dataset
data_1 = pickle.load(open(DATA_PATH + f'all_data_{MAIN_DATASET}.pickle', 'rb'))
class_names = data_1['class_names']
X_train_1, X_validate_1, X_test_1 = find.get_data_matrix(data_1['text_train'], word2index, MAXLEN), \
                                    find.get_data_matrix(data_1['text_validate'], word2index, MAXLEN), \
                                    find.get_data_matrix(data_1['text_test'], word2index, MAXLEN)
y_test_1 = data_1['y_test']
gender_test_1 = data_1['gender_test'] if GENDER_BIAS else None

# 1.2 Second dataset
if SECOND_DATASET is not None:
    data_2 = pickle.load(open(DATA_PATH + f'all_data_{SECOND_DATASET}.pickle', 'rb'))
    X_test_2, y_test_2 = find.get_data_matrix(data_2['text_test'], word2index, MAXLEN), data_2['y_test']
    gender_test_2 = data_2['gender_test'] if GENDER_BIAS else None
else:
    X_test_2, y_test_2, gender_test_2 = None, None, None

# 1.3 Third dataset
if THIRD_DATASET is not None:
    data_3 = pickle.load(open(DATA_PATH + f'all_data_{THIRD_DATASET}.pickle', 'rb'))
    X_test_3, y_test_3 = find.get_data_matrix(data_3['text_test'], word2index, MAXLEN), data_3['y_test']
    gender_test_3 = data_3['gender_test'] if GENDER_BIAS else None
else:
    X_test_3, y_test_3, gender_test_2  = None, None, None

100%|██████████| 863/863 [00:10<00:00, 86.19it/s] 
100%|██████████| 216/216 [00:02<00:00, 106.57it/s]
100%|██████████| 717/717 [00:07<00:00, 95.33it/s] 
100%|██████████| 1819/1819 [01:21<00:00, 22.23it/s]


In [8]:
# 2. Create the result directory
if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH)
result_folder = MAIN_DATASET + '_' + MODEL_ARCH + '_' + datetime.datetime.now().strftime("%Y%m%d%H%M%S") + '/'
result_path = MODEL_PATH + result_folder
os.mkdir(result_path)

In [9]:
# 3. Create a model
if MODEL_ARCH == 'CNN':
    model = find.get_CNN_model(vocab_size, EMBEDDING_DIM, embedding_matrix, MAXLEN, class_names, FILTERS)
else:
    assert False, f"Unsupported model architecture: {MODEL_ARCH}"

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 150, 300)     120000600   input_1[0][0]                    
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, 149, 10)      6010        embedding_1[0][0]                
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               (None, 148, 10)      9010        embedding_1[0][0]                
__________________________________________________________________________________________________
conv1d_3 (

In [10]:
# 4. Train the model
history = find.model_train(model, result_path + f'trained_{MODEL_ARCH}.h5', X_train_1, data_1['y_train'], X_validate_1, data_1['y_validate'], BATCH_SIZE, epochs = 300)

Train on 863 samples, validate on 216 samples
Epoch 1/300
 - 3s - loss: 0.7094 - acc: 0.5272 - val_loss: 0.5997 - val_acc: 0.6759

Epoch 00001: val_loss improved from inf to 0.59968, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN.h5
Epoch 2/300
 - 3s - loss: 0.5400 - acc: 0.7474 - val_loss: 0.4885 - val_acc: 0.7917

Epoch 00002: val_loss improved from 0.59968 to 0.48853, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN.h5
Epoch 3/300
 - 3s - loss: 0.4448 - acc: 0.8320 - val_loss: 0.4590 - val_acc: 0.8102

Epoch 00003: val_loss improved from 0.48853 to 0.45903, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN.h5
Epoch 4/300
 - 3s - loss: 0.3883 - acc: 0.8598 - val_loss: 0.4081 - val_acc: 0.8194

Epoch 00004: val_loss improved from 0.45903 to 0.40808, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN.h5
Epoch 5/300
 - 3s - loss: 0.3401 - acc: 0.8876 - val_loss: 0.3951 - val_acc: 0.851

In [11]:
# 5. Evaluate the model
if not GENDER_BIAS:
    find.evaluate_all(model, class_names, BATCH_SIZE, X_test_1, y_test_1, X_test_2, y_test_2, X_test_3, y_test_3, result_path = result_path, model_name = 'original')
else:
    find.evaluate_all_gender(model, class_names, BATCH_SIZE, X_test_1, y_test_1, gender_test_1, X_test_2, y_test_2, gender_test_2, result_path = result_path, model_name = 'original')

Evaluate with the original test set:
{'per_class': {0: {'all_positive': 298,
                   'all_true': 319,
                   'class_f1': 0.826580226904376,
                   'class_name': 'alt.atheism',
                   'class_precision': 0.8557046979865772,
                   'class_recall': 0.799373040752351,
                   'true_positive': 255},
               1: {'all_positive': 419,
                   'all_true': 398,
                   'class_f1': 0.8690330477356182,
                   'class_name': 'soc.religion.christian',
                   'class_precision': 0.847255369928401,
                   'class_recall': 0.8919597989949749,
                   'true_positive': 355}},
 'total': {'accuracy': 0.8507670850767085,
           'macro_f1': 0.8485632695814168,
           'macro_precision': 0.8514800339574891,
           'macro_recall': 0.845666419873663,
           'micro_f1': 0.8507670850767085,
           'micro_precision': 0.8507670850767085,
           'micro_r

## Model understanding and debugging

In [12]:
# 6. Generate wordclouds
settings = {
    'model_arch': MODEL_ARCH,
    'filters': FILTERS,
    'maxlen': MAXLEN,
    'result_path': result_path,
    'index2word': index2word,
    'embedding_dim': EMBEDDING_DIM,
    'batch_size': BATCH_SIZE
}
all_wordclouds = find.generate_wordclouds(model, X_train_1, settings, max_examples = 2000)

 43%|████▎     | 3/7 [00:00<00:00, 19.66it/s]

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
embedded_text_input (InputLayer (None, 150, 300)     0                                            
__________________________________________________________________________________________________
conv1d_4 (Conv1D)               (None, 149, 10)      6010        embedded_text_input[0][0]        
__________________________________________________________________________________________________
conv1d_5 (Conv1D)               (None, 148, 10)      9010        embedded_text_input[0][0]        
__________________________________________________________________________________________________
conv1d_6 (Conv1D)               (None, 147, 10)      12010       embedded_text_input[0][0]        
__________________________________________________________________________________________________
global_max

100%|██████████| 7/7 [00:00<00:00, 13.34it/s]
 23%|██▎       | 7/30 [00:15<00:53,  2.31s/it]



100%|██████████| 30/30 [01:07<00:00,  2.27s/it]
100%|██████████| 30/30 [00:24<00:00,  1.22it/s]


- Get input from a human

In [13]:
is_feature_enabled = [True for i in range(find.num_all_filters(FILTERS))]

In [14]:
# UI components from ipywidgets
import ipywidgets as wgt

def update_screen(feature_idx):
    show_action_panel(feature_idx)
    wordcloud = all_wordclouds[feature_idx]
    f, ax = plt.subplots()
    plt.rcParams['figure.figsize'] = [14, 7]
    ax.imshow(wordcloud, interpolation='bilinear')
    ax.axis("off")
    
    W = model.layers[-1].get_weights()[0] # For the final layer
    weight_plot = find.visualize_weights(W, feature_idx, class_names, show = False)
    plt.show()

def update_action(action):
    global feature_radio_button, is_feature_enabled
    feature_idx = feature_radio_button.value
    if action == 'enabled':
        print('enable')
        is_feature_enabled[feature_idx] = True
    elif action == 'disabled':
        print('disable')
        is_feature_enabled[feature_idx] = False
    else:
        assert False
    
def show_action_panel(feature_idx):
    global action_radio_button
    action_radio_button.description = f'Current status of feature {feature_idx}:'
    action_radio_button.value = 'enabled' if is_feature_enabled[feature_idx] else 'disabled'
    
feature_radio_button = wgt.RadioButtons(options=list(range(30)), value=0, description='Feature:', disabled=False)
action_radio_button = wgt.RadioButtons(options=['enabled', 'disabled'],
    value = 'enabled' if is_feature_enabled[feature_radio_button.value] else 'disabled',
    description = f'Current status of feature {feature_radio_button.value}:',
    style = {'description_width': 'initial'},
    disabled = False
)

wgt.interactive_output(update_action, {'action':action_radio_button})
out = wgt.interactive_output(update_screen, {'feature_idx':feature_radio_button})

In [15]:
# 7. Get input from a human 
# Please investigate word clouds of these features and disable some irrelevant features using the radio-buttons under the bar plot. 
# Once you are happy, please then proceed to the next cell.
display(wgt.HBox([feature_radio_button, wgt.VBox([out, action_radio_button])]))#

HBox(children=(RadioButtons(description='Feature:', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,…

In [16]:
print(f"Total: {len(is_feature_enabled)} features \nEnabled: {sum(is_feature_enabled)} features \nDisabled: {len(is_feature_enabled)-sum(is_feature_enabled)} features")
print(f"Disabled features: {[i for i,s in enumerate(is_feature_enabled) if not s]}")

Total: 30 features 
Enabled: 19 features 
Disabled: 11 features
Disabled features: [1, 2, 3, 9, 12, 16, 20, 24, 25, 26, 28]


## Creating and fine-tuning an improved classifier

In [17]:
# 8. Create an improved model
# 8.1 Copy the existing CNN features
model_improved = find.get_CNN_model(vocab_size, EMBEDDING_DIM, embedding_matrix, MAXLEN, class_names, 
                                    FILTERS, trainable_filters = False)
model_improved.set_weights(model.get_weights()) 

# 8.2 Apply human decisions to disable irrelevant features
for idx, enable in enumerate(is_feature_enabled):
    if not enable:
        model_improved.layers[-1].disable_mask(idx)

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, 150, 300)     120000600   input_2[0][0]                    
__________________________________________________________________________________________________
conv1d_7 (Conv1D)               (None, 149, 10)      6010        embedding_2[0][0]                
__________________________________________________________________________________________________
conv1d_8 (Conv1D)               (None, 148, 10)      9010        embedding_2[0][0]                
__________________________________________________________________________________________________
conv1d_9 (

In [18]:
# 9. Fine-tuning the improved model
history = find.model_train(model_improved, result_path + f'trained_{MODEL_ARCH}_improved.h5', X_train_1, data_1['y_train'], X_validate_1, data_1['y_validate'], BATCH_SIZE, epochs = 300)

Train on 863 samples, validate on 216 samples
Epoch 1/300
 - 2s - loss: 0.7720 - acc: 0.5829 - val_loss: 1.0870 - val_acc: 0.5694

Epoch 00001: val_loss improved from inf to 1.08697, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN_improved.h5
Epoch 2/300
 - 2s - loss: 0.6408 - acc: 0.6188 - val_loss: 0.9389 - val_acc: 0.6019

Epoch 00002: val_loss improved from 1.08697 to 0.93887, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN_improved.h5
Epoch 3/300
 - 2s - loss: 0.5213 - acc: 0.6628 - val_loss: 0.8066 - val_acc: 0.6296

Epoch 00003: val_loss improved from 0.93887 to 0.80661, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN_improved.h5
Epoch 4/300
 - 2s - loss: 0.4220 - acc: 0.7323 - val_loss: 0.6912 - val_acc: 0.6620

Epoch 00004: val_loss improved from 0.80661 to 0.69116, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN_improved.h5
Epoch 5/300
 - 2s - loss: 0.3405 - acc: 0.8042

Epoch 37/300
 - 2s - loss: 0.1280 - acc: 0.9884 - val_loss: 0.3151 - val_acc: 0.8935

Epoch 00037: val_loss improved from 0.31544 to 0.31510, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN_improved.h5
Epoch 38/300
 - 2s - loss: 0.1271 - acc: 0.9884 - val_loss: 0.3141 - val_acc: 0.8935

Epoch 00038: val_loss improved from 0.31510 to 0.31409, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN_improved.h5
Epoch 39/300
 - 2s - loss: 0.1261 - acc: 0.9896 - val_loss: 0.3140 - val_acc: 0.8935

Epoch 00039: val_loss improved from 0.31409 to 0.31402, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN_improved.h5
Epoch 40/300
 - 2s - loss: 0.1252 - acc: 0.9884 - val_loss: 0.3131 - val_acc: 0.8935

Epoch 00040: val_loss improved from 0.31402 to 0.31307, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN_improved.h5
Epoch 41/300
 - 2s - loss: 0.1242 - acc: 0.9896 - val_loss: 0.3125 - val_acc: 0.8935

Epoch 73/300
 - 2s - loss: 0.0987 - acc: 0.9919 - val_loss: 0.2968 - val_acc: 0.8935

Epoch 00073: val_loss improved from 0.29704 to 0.29682, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN_improved.h5
Epoch 74/300
 - 2s - loss: 0.0980 - acc: 0.9919 - val_loss: 0.2962 - val_acc: 0.8935

Epoch 00074: val_loss improved from 0.29682 to 0.29621, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN_improved.h5
Epoch 75/300
 - 2s - loss: 0.0973 - acc: 0.9919 - val_loss: 0.2959 - val_acc: 0.8935

Epoch 00075: val_loss improved from 0.29621 to 0.29595, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN_improved.h5
Epoch 76/300
 - 2s - loss: 0.0967 - acc: 0.9919 - val_loss: 0.2959 - val_acc: 0.8935

Epoch 00076: val_loss improved from 0.29595 to 0.29589, saving model to trained_models/20Newsgroups_CNN_20201010014944/trained_CNN_improved.h5
Epoch 77/300
 - 2s - loss: 0.0960 - acc: 0.9919 - val_loss: 0.2952 - val_acc: 0.8935

In [19]:
# 10. Evaluate the improved model
if not GENDER_BIAS:
    find.evaluate_all(model_improved, class_names, BATCH_SIZE, X_test_1, y_test_1, X_test_2, y_test_2, X_test_3, y_test_3, result_path = result_path, model_name = 'debugged')
else:
    find.evaluate_all_gender(model_improved, class_names, BATCH_SIZE, X_test_1, y_test_1, gender_test_1, X_test_2, y_test_2, gender_test_2, result_path = result_path, model_name = 'debugged')

Evaluate with the original test set:
{'per_class': {0: {'all_positive': 286,
                   'all_true': 319,
                   'class_f1': 0.7867768595041322,
                   'class_name': 'alt.atheism',
                   'class_precision': 0.8321678321678322,
                   'class_recall': 0.7460815047021944,
                   'true_positive': 238},
               1: {'all_positive': 431,
                   'all_true': 398,
                   'class_f1': 0.8443908323281063,
                   'class_name': 'soc.religion.christian',
                   'class_precision': 0.8120649651972158,
                   'class_recall': 0.8793969849246231,
                   'true_positive': 350}},
 'total': {'accuracy': 0.8200836820083682,
           'macro_f1': 0.8174009291550225,
           'macro_precision': 0.822116398682524,
           'macro_recall': 0.8127392448134088,
           'micro_f1': 0.8200836820083681,
           'micro_precision': 0.8200836820083682,
           'micr