# Sentiment Analysis: Bag-of-Words (JAX and Haiku)

---



The notebook is roughly divided into two sections: pre-processing and the actual model.

## Setup and downloading data

We use the Kaggle API to download the data. 

**N.B.** If you want to run the notebook you need to upload a `kaggle.json` file which contains your API credentials. Instructions for downloading the file from Kaggle can be found [here](https://github.com/Kaggle/kaggle-api).


In [0]:
# install the kaggle API and haiku
!pip install kaggle -q
!pip install dm-haiku -q

In [0]:
# data wrangling
import numpy as onp
import pandas as pd
from sklearn.model_selection import train_test_split

# language model 
import spacy
spacy_en = spacy.load('en')

# jax and haiku
import jax
import jax.numpy as jnp
import haiku as hk
from jax.experimental import optix

# data pipeline
from torch.utils import data

from collections import Counter
import copy
import re
import matplotlib.pyplot as plt
import matplotlib 

In [26]:
# upload kaggle.json file
from google.colab import files
uploaded = files.upload()

Saving kaggle.json to kaggle (1).json


In [0]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [5]:
# download the dataset and unzip
!kaggle competitions download sentiment-analysis-on-movie-reviews
!unzip train.tsv.zip

Downloading test.tsv.zip to /content
  0% 0.00/494k [00:00<?, ?B/s]
100% 494k/494k [00:00<00:00, 75.7MB/s]
Downloading sampleSubmission.csv to /content
  0% 0.00/583k [00:00<?, ?B/s]
100% 583k/583k [00:00<00:00, 81.7MB/s]
Downloading train.tsv.zip to /content
  0% 0.00/1.28M [00:00<?, ?B/s]
100% 1.28M/1.28M [00:00<00:00, 87.5MB/s]
Archive:  train.tsv.zip
  inflating: train.tsv               
Archive:  test.tsv.zip
  inflating: test.tsv                


## Pre-processing 

The model cannot accept strings as input, thus in this section we convert each element of the dataset into an integer-encoded vector.

###  Text processing

Descriptions:
- `tokenizer`: function that takes as input a string of text and returns a list of tokens 
- `Vocabularly`: class that stores the vocab found in the dataset and creates a mapping from token to integer

In [0]:
def tokenizer(text): 
    text = text.lower()
    text = re.sub("-rrb-","", text)
    text = re.sub("-lrb-","", text)
    tokens = spacy_en.tokenizer(text)
    #tokens = [tok for tok in tokens if tok.is_stop == False]
    tokens = [tok.lemma_ for tok in tokens]
    return tokens

In [0]:
class Vocabulary:
      
    def __init__(self, vocabCount, min_freq):
        
        # UNK tokens 
        self.PAD_token = 0
        self.UNK_token = 1
        self.vocabCount = vocabCount
        self.min_freq = min_freq
        # initialize list of words and vocab dictionary
        self.wordlist = ["<pad>", "<unk>"]
        self.word2index = {}
        # build vocab
        self.build_vocab(vocabCount)

    def __len__(self):
        return len(self.word2index)

    def __getitem__(self, word):
        return self.word2index.get(word, 1)

    def __iter__(self):
        return iter(vocab.word2index)

    def build_vocab(self, vocabCount):
        # sort vocab s.t. words that occur most frequently added first
        svocabCount = {k: v for k, v in reversed(sorted(vocabCount.items(), 
                                                      key=lambda item: item[1]))}
        
        for word in svocabCount:
            if svocabCount[word] >= self.min_freq:
                self.wordlist.append(word)
        self.word2index.update({tok: i for i, tok in enumerate(self.wordlist)})

### Loading and processing data

In [0]:
train_data = pd.read_csv('/content/train.tsv', sep="\t", 
                         encoding="utf_8_sig")

phrases = onp.array(train_data.iloc[:, 2])
target = onp.array(train_data.iloc[:, 3])

# create train and validation sets
X_train , X_val, y_train , y_val = train_test_split(phrases, target, 
                                                    test_size = 0.2, random_state=42)

# create validation and test sets
X_val , X_test, y_val , y_test = train_test_split(X_val, y_val, 
                                                    test_size = 0.4, random_state=42)

In [0]:
X_train = [tokenizer(phrase) for phrase in X_train]
X_val = [tokenizer(phrase) for phrase in X_val]
X_test = [tokenizer(phrase) for phrase in X_test]

In [449]:
print("Length of train dataset: {} \nLength of validation dataset: {} \nLength of test dataset: {}".format(len(X_train), len(X_val), len(X_test)))

Length of train dataset: 124848 
Length of validation dataset: 18727 
Length of test dataset: 12485


### Creating vocabularly

In [0]:
vocabCount = Counter([item for sublist in X_train for item in sublist])

In [0]:
vocab = Vocabulary(vocabCount, 1)

### Converting tokens to integers

In [0]:
X_trainNum = [onp.array([vocab[word] for word in phrase]) for phrase in X_train]
X_valNum = [onp.array([vocab[word] for word in phrase]) for phrase in X_val]
X_testNum = [onp.array([vocab[word] for word in phrase]) for phrase in X_test]

Our pre-processing might have resulted in some now empty lists. These can cause problems later so we just fill them with padding (another possibility would be to remove these examples as they contain no useful information).

In [0]:
# make sure each tensor actually has values
for i, el in enumerate(X_trainNum):
    if len(el) == 0:
        X_trainNum[i] = onp.array([0])

# make sure each tensor actually has values
for i, el in enumerate(X_valNum):
    if len(el) == 0:
        X_valNum[i] = onp.array([0])

# make sure each tensor actually has values
for i, el in enumerate(X_testNum):
    if len(el) == 0:
        X_testNum[i] = onp.array([0])

##  Bag of vectors 

In [0]:
class WordDataset(data.Dataset):
    
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):  
        X = self.X[idx]
        Y = self.y[idx]
        return X, Y

In [0]:
def pad_sequences(sequences):
  # adapted and simplified from https://github.com/keras-team/keras-preprocessing/blob
  # /master/keras_preprocessing/sequence.py

  num_samples = len(sequences)
  lengths = []

  for sequence in sequences:
    lengths.append(len(sequence))
  
  max_len = onp.max(lengths)

  x = onp.full((num_samples, max_len), 0)

  for idx, sequence in enumerate(sequences):
    x[idx, :len(sequence)] = sequence
  return x

In [0]:
def numpy_collate(batch):
  data = [item[0] for item in batch]
  targets = onp.array([item[1] for item in batch])


  data = pad_sequences(data)

  return jnp.asarray(data), jnp.array(targets)

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

In [0]:
trainingset = WordDataset(X_trainNum, y_train)
valset = WordDataset(X_valNum, y_val)
testset = WordDataset(X_testNum, y_test)

In [0]:
training_generator = NumpyLoader(trainingset, batch_size=64)
val_generator = NumpyLoader(valset, batch_size=len(y_val))
test_generator = NumpyLoader(testset, batch_size=len(y_test))

In [0]:
training_eval = NumpyLoader(trainingset, batch_size=len(y_train))

### Model

In [0]:
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)

In [0]:
def _forward(inputs, rng, is_training=True):

  # create mask to adjust for padding
  mask = jnp.where(inputs==0,0,1)

  # define embedding layer
  embed = hk.Embed(len(vocab), embed_dim=300)

  # pass through network
  x = embed(inputs)
  x = sum_pooling(x, mask)
  x = hk.Linear(128)(x)
  x = jax.nn.relu(x)
  x = dropout(rng, 0.5, x, is_training)
  x = hk.Linear(64)(x)
  x = jax.nn.relu(x)
  x = dropout(rng, 0.5, x, is_training)
  x = hk.Linear(5)(x)

  return x

In [0]:
@jax.jit
def sum_pooling(inputs, mask):

  return jnp.einsum("ijk, ij -> ik", inputs, mask)

In [0]:
def dropout(rng, rate, x, is_training):

  if is_training:
    return hk.dropout(rng, rate, x)
  else:
    return x

In [0]:
# convert into pure functions
net = hk.transform(_forward)

# initialise parameters (need sample inputs for this)
sample_input, _ = next(iter(training_generator))
params = net.init(jax.random.PRNGKey(42), sample_input, subkey)

### Loss and evalulation functions

In [0]:
@jax.jit
def loss(params, batch, rng, is_training=True):
  inputs, targets = batch
  preds = net.apply(params, inputs, rng, is_training)
  targets = hk.one_hot(targets, 5)
  return -np.mean(np.sum(jax.nn.log_softmax(preds) * targets, axis=1))

In [0]:
@jax.jit
def eval_batch(params, batch, rng):
  inputs, targets = batch
  logits = net.apply(params, inputs, rng, False)

  # for accuracy
  predicted_label = jnp.argmax(logits, axis=-1)
  correct = np.sum(jnp.equal(predicted_label, targets))

  # for loss
  one_hot_targets = hk.one_hot(targets, 5)
  loss = -np.sum((np.sum(jax.nn.log_softmax(logits) * one_hot_targets, axis=1)))

  return correct.astype(jnp.float32), loss.astype(jnp.float32)

In [0]:
def evaluate(params, dataLoader, rng):
  correct = 0
  loss = 0
  total = 0
  for batch in dataLoader:
    c, l = eval_batch(params, batch, rng)
    correct += c
    loss += l
    total += batch[1].shape[0]
  acc = correct/total
  loss = loss/total
  return {"acc": acc, "loss": loss}

### Optimizer

In [0]:
def lr_schedule(step):
  
  steps_per_epoch = jnp.ceil(len(training_generator.dataset)/ 64)
  current_epoch = step / steps_per_epoch 
  factor = current_epoch // 5
 
  return 0.85**factor

In [0]:
def make_optimizer():
  
  return optix.chain(optix.adam(1e-4),
                     optix.scale_by_schedule(lr_schedule))

In [0]:
 opt_state = make_optimizer().init(params)

### Update function

In [0]:
@jax.jit
def update(params, opt_state, batch, rng):
  grads = jax.grad(loss)(params, batch, rng)
  updates, opt_state = make_optimizer().update(grads, opt_state)
  new_params = optix.apply_updates(params, updates)
  return new_params, opt_state

### Training

In [472]:
key = jax.random.PRNGKey(42)

for step in range(50):
  
  for b_idx, batch in enumerate(training_generator):

    key, subkey = jax.random.split(key)
    params, opt_state = update(params, opt_state, batch, subkey)

  # dont need a new random key for evalulation
  train_metrics = evaluate(params, training_eval, subkey)
  train_accuracy, train_loss = train_metrics["acc"], train_metrics["loss"]
  train_accuracy, train_loss = jax.device_get((train_accuracy,
                                               train_loss))
  
  val_metrics = evaluate(params, val_generator, subkey)
  val_accuracy, val_loss = val_metrics["acc"], val_metrics["loss"]
  val_accuracy, val_loss = jax.device_get((val_accuracy,
                                               val_loss))
  
  print(("Epoch: {} \t Loss (train): {:.3f} (val): {:.3f} \t" +
              "Acc (train) {:.3f} (val): {:.3f}").format(step + 1,
                            train_loss, val_loss, train_accuracy, 
                            val_accuracy))

Epoch: 1 	 Loss (train): 1.277 (val): 1.287 	Acc (train) 0.514 (val): 0.505
Epoch: 2 	 Loss (train): 1.235 (val): 1.248 	Acc (train) 0.514 (val): 0.505
Epoch: 3 	 Loss (train): 1.202 (val): 1.220 	Acc (train) 0.516 (val): 0.508
Epoch: 4 	 Loss (train): 1.170 (val): 1.192 	Acc (train) 0.523 (val): 0.516
Epoch: 5 	 Loss (train): 1.130 (val): 1.157 	Acc (train) 0.536 (val): 0.528
Epoch: 6 	 Loss (train): 1.095 (val): 1.127 	Acc (train) 0.549 (val): 0.541
Epoch: 7 	 Loss (train): 1.060 (val): 1.098 	Acc (train) 0.561 (val): 0.550
Epoch: 8 	 Loss (train): 1.025 (val): 1.069 	Acc (train) 0.574 (val): 0.561
Epoch: 9 	 Loss (train): 0.995 (val): 1.045 	Acc (train) 0.584 (val): 0.570
Epoch: 10 	 Loss (train): 0.966 (val): 1.022 	Acc (train) 0.596 (val): 0.579
Epoch: 11 	 Loss (train): 0.942 (val): 1.004 	Acc (train) 0.606 (val): 0.588
Epoch: 12 	 Loss (train): 0.922 (val): 0.989 	Acc (train) 0.613 (val): 0.592
Epoch: 13 	 Loss (train): 0.904 (val): 0.976 	Acc (train) 0.620 (val): 0.596
Epoch: 1

### Results

In [0]:
train_metrics = evaluate(params, training_eval, subkey)
val_metrics = evaluate(params, val_generator, subkey)
test_metrics = evaluate(params, test_generator, subkey)

In [475]:
print("Train metrics: \n{}".format(train_metrics))
print("Val metrics: \n{}".format(val_metrics))
print("Test metrics: \n{}".format(test_metrics))

Train metrics: 
{'acc': DeviceArray(0.7088059, dtype=float32), 'loss': DeviceArray(0.68640625, dtype=float32)}
Val metrics: 
{'acc': DeviceArray(0.6484221, dtype=float32), 'loss': DeviceArray(0.8717084, dtype=float32)}
Test metrics: 
{'acc': DeviceArray(0.64989984, dtype=float32), 'loss': DeviceArray(0.86235493, dtype=float32)}
