This notebook provides instructions for training ProtoCNN on your own data.
Let's assume that our data is in `data/amazon/data.csv`. Let's visualize them:

In [1]:
import string
import warnings

import pandas as pd
import torch
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from sklearn.model_selection import train_test_split
from torchtext import data
from torchtext.data import BucketIterator
from torchtext.vocab import GloVe

from dataframe_dataset import DataFrameDataset
from models.protoconv.data_visualizer import DataVisualizer
from models.protoconv.lit_module import ProtoConvLitModule
from utils import plot_html

warnings.simplefilter("ignore")
seed_everything(0)

Global seed set to 0


0

In [2]:
!head ../data/amazon/data.csv

text,label
"MUCH ADO ABOUT NOTHING: There's really nothing to watch there. There's nothing to like. Don't get me wrong - I hate violence in all its manifestations in life although I like it on screen, but here we can't even empathize this woman 'cos the movie is just sooo bad. The acting is terrible, characters do very strange and unexplainable things. You can't feel with a character when you see everything in the film is false and naive. Don't believe the taglines - there's nothing shocking in there. It's just one big waste of time. You'll feel cheated and robbed. If you hate violence in cinema, I guess you are not reading this right now and if you like shocking movies, don't bother with this one, there are plenty of others that are much more shocking and I assume you already know them all by names. Go and watch ""Irreversible"" for god's sake.",0
"A Well Written and Enjoyable Travelogue: This book tells a very enjoyable story about Nashville's birth and culture as Music City. This 

We will start by loading the data:

In [3]:
df_dataset = pd.read_csv(f'../data/amazon/data.csv')
df_dataset.head()

Unnamed: 0,text,label
0,MUCH ADO ABOUT NOTHING: There's really nothing...,0
1,A Well Written and Enjoyable Travelogue: This ...,1
2,Really Works: This I didn't noticed how well i...,1
3,too predictable: this book couldn't hold my at...,0
4,Not my style: The book covers several generati...,0


We will divide the collection into training and testing

In [4]:
train_df, valid_df = train_test_split(df_dataset, test_size=0.2, stratify=df_dataset['label'])
train_df.shape, valid_df.shape

((24000, 2), (6000, 2))

Now we will create a `torchtext` dataset, you can use any input format.
We will use a dataset created from a table in pandas.

In [5]:
TEXT = data.Field(init_token='<START>', eos_token='<END>', tokenize='spacy', tokenizer_language='en',
                  batch_first=True, lower=True, stop_words=set(string.punctuation))
LABEL = data.Field(dtype=torch.float, is_target=True, unk_token=None, sequential=False, use_vocab=False)

train_dataset = DataFrameDataset(train_df, {
    'text': TEXT,
    'label': LABEL
})

val_dataset = DataFrameDataset(valid_df, {
    'text': TEXT,
    'label': LABEL
})

train_loader, val_loader = BucketIterator.splits(
    (train_dataset, val_dataset),
    batch_size=32,
    sort_key=lambda x: len(x.text),
    device='cuda'
)

TEXT.build_vocab(train_dataset.text, vectors=GloVe('42B', cache='../.vector_cache/'))

We will add saving the best model, stopping training early if there is no improvement in loss,
and decreasing the learning rate. We will load the model with the parameters used in the publication.

In [6]:
model_checkpoint = ModelCheckpoint(filepath='../checkpoints/{epoch_0:02d}-{val_loss_0:.4f}-{val_acc_0:.4f}',
                                   save_weights_only=True, save_top_k=1, monitor='val_acc_0', period=1)

callbacks = [
    LearningRateMonitor(logging_interval='epoch'),
    EarlyStopping(monitor=f'val_loss_0', patience=10, verbose=True, mode='min', min_delta=0.005),
    model_checkpoint
]

model = ProtoConvLitModule(vocab_size=len(TEXT.vocab), embedding_dim=TEXT.vocab.vectors.shape[1], fold_id=0, lr=1e-3,
                           itos=TEXT.vocab.itos, verbose_proto=False)

Start training

In [7]:
trainer = Trainer(max_epochs=30, callbacks=callbacks, gpus=1, deterministic=True, num_sanity_val_steps=0)
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)


GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.

  | Name       | Type               | Params
--------------------------------------------------
0 | embedding  | Embedding          | 17.1 M
1 | conv1      | ConvolutionalBlock | 96.2 K
2 | prototypes | PrototypeLayer     | 12.8 K
3 | fc1        | Linear             | 100   
4 | train_acc  | Accuracy           | 0     
5 | valid_acc  | Accuracy           | 0     
6 | loss       | BCEWithLogitsLoss  | 0     
--------------------------------------------------
102 K     Trainable params
17.1 M    Non-trainable params
17.2 M    Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch    14: reducing learning rate of group 0 to 1.0000e-04.


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch    19: reducing learning rate of group 0 to 1.0000e-05.



1

The result of the best model is stored in the model checkpointer

In [8]:
'Best accuracy: ', model_checkpoint.best_model_score.tolist()


('Best accuracy: ', 0.8163333535194397)

We will now load the weights of the best model and visualize the prototypes along with the random prediction explanations

In [9]:
best_model = ProtoConvLitModule.load_from_checkpoint(model_checkpoint.best_model_path)
data_visualizer = DataVisualizer(best_model)

In [10]:
plot_html(data_visualizer.visualize_prototypes())

In [11]:
plot_html(data_visualizer.visualize_random_predictions(val_loader, n=5))

0,1,2
Most similar phrase,Prototype,Similarity * Weight
can not beat the deal,the worst,2.46 * 1.41 = 3.47
beat the deal you get,worked the opposite where it,1.00 * 0.20 = 0.20
a large cup of milk,have purchased two other remanufactured,0.85 * 0.12 = 0.10

0,1,2
Most similar phrase,Prototype,Similarity * Weight
great deal you can not,good product i,3.72 * 1.94 = 7.21
great deal you can not,have a lot more power,1.00 * 0.09 = 0.09
the deal you get on,novel the author gives us,1.20 * 0.07 = 0.08

0,1,2
Most similar phrase,Prototype,Similarity * Weight
is one of the worst,the worst,6.70 * 1.41 = 9.46
it is one of the,worked the opposite where it,1.11 * 0.20 = 0.23
it is one of the,have purchased two other remanufactured,1.17 * 0.12 = 0.13

0,1,2
Most similar phrase,Prototype,Similarity * Weight
reviewer big zach not a,good product i,1.34 * 1.94 = 2.59
reviewer big zach not a,have a lot more power,1.38 * 0.09 = 0.12
reviewer big zach not a,novel the author gives us,1.08 * 0.07 = 0.08

0,1,2
Most similar phrase,Prototype,Similarity * Weight
would not reccomend this movie,the worst,5.83 * 1.41 = 8.23
movie after hearing that it,worked the opposite where it,1.01 * 0.20 = 0.20
and the plot was too,have purchased two other remanufactured,1.20 * 0.12 = 0.14

0,1,2
Most similar phrase,Prototype,Similarity * Weight
it 's not worth buying,good product i,1.94 * 1.94 = 3.76
was laughing almost the whole,have a lot more power,0.79 * 0.09 = 0.07
laughing almost the whole time,novel the author gives us,0.97 * 0.07 = 0.07

0,1,2
Most similar phrase,Prototype,Similarity * Weight
not her best ... to,the worst,5.06 * 1.41 = 7.15
is very similar in theme,worked the opposite where it,0.97 * 0.20 = 0.20
theme and rythm to one,have purchased two other remanufactured,1.04 * 0.12 = 0.12

0,1,2
Most similar phrase,Prototype,Similarity * Weight
is very similar in theme,good product i,3.13 * 1.94 = 6.07
years is too long to,have a lot more power,1.26 * 0.09 = 0.11
you 're new to her,novel the author gives us,1.13 * 0.07 = 0.08

0,1,2
Most similar phrase,Prototype,Similarity * Weight
but not the last because,the worst,4.04 * 1.41 = 5.70
at the apple music store,worked the opposite where it,1.01 * 0.20 = 0.21
a recent fan of the,have purchased two other remanufactured,1.35 * 0.12 = 0.16

0,1,2
Most similar phrase,Prototype,Similarity * Weight
you 'll love this album,good product i,3.61 * 1.94 = 7.00
you 'll love this album,have a lot more power,0.91 * 0.09 = 0.08
a minute of music on,novel the author gives us,0.97 * 0.07 = 0.07
