# Use BERT Representations with LogisticRegression Softmax Classifier

In [1]:
from collections import Counter
import os
import numpy as np
import pandas as pd
import torch
from torch import nn, optim
import torch.nn as nn
from torch.utils.data import TensorDataset, Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from transformers import BertTokenizer, BertModel, BertForSequenceClassification


import dataset
import vsm
import sst

In [2]:
TWITTER = 2
TWITTER_AIRLINES = 3
TWITTER_APPLE = 4

In [3]:
twitter_train, twitter_validate, twitter_test =  dataset.dataset_reader(TWITTER_AIRLINES)
[twitter_train, twitter_validate, twitter_test] = list(map(lambda ds : dataset.prune_columns(TWITTER_AIRLINES, ds), [twitter_train, twitter_validate, twitter_test]))

In [4]:
bert_weights_name = 'bert-base-cased'
bert_tokenizer = BertTokenizer.from_pretrained(bert_weights_name)
bert_model = BertModel.from_pretrained(bert_weights_name)
# model = BertForSequenceClassification.from_pretrained(bert_weights_name)
# Unique values of sentiment
twitter_sentiment_labels = twitter_train['sentiment'].unique()

In [5]:
def fit_softmax_classifier(X, y):
    mod = LogisticRegression(
        fit_intercept=True,
        solver='liblinear',
        multi_class='ovr')
    mod.fit(X, y)
    return mod

In [6]:
def hf_cls_phi(text):
    # Get the ids. `vsm.hf_encode` will help; be sure to
    # set `add_special_tokens=True`.
    ##### YOUR CODE HERE
    subtok_ids = vsm.hf_encode(text, bert_tokenizer, add_special_tokens=True)

    # Get the BERT representations. `vsm.hf_represent` will help:
    ##### YOUR CODE HERE
    subtok_reps = vsm.hf_represent(subtok_ids, bert_model, layer=-1)

    # Index into `reps` to get the representation above [CLS].
    # The shape of `reps` should be (1, n, 768), where n is the
    # number of tokens. You need the 0th element of the 2nd dim:
    ##### YOUR CODE HERE
    cls_rep = subtok_reps[0][:][0]

    # These conversions should ensure that you can work with the
    # representations flexibly. Feel free to change the variable
    # name:
    return cls_rep.cpu().numpy()

In [7]:
twitter_train.size, twitter_validate.size, twitter_test.size

(46848, 5856, 5856)

In [8]:
%%time
bert_experiment1500 = sst.experiment(
    twitter_train[:1500], # 
    hf_cls_phi,
    fit_softmax_classifier,
    assess_dataframes=[twitter_validate[:1000]],
    vectorize=False)

              precision    recall  f1-score   support

    negative      0.823     0.901     0.860       614
     neutral      0.595     0.498     0.542       207
    positive      0.748     0.648     0.695       179

    accuracy                          0.772      1000
   macro avg      0.722     0.682     0.699      1000
weighted avg      0.762     0.772     0.765      1000

CPU times: user 28min 25s, sys: 16.2 s, total: 28min 42s
Wall time: 4min 48s


In [9]:
%%time
bert_experiment3000 = sst.experiment(
    twitter_train[:3000], # 
    hf_cls_phi,
    fit_softmax_classifier,
    assess_dataframes=[twitter_validate[:1000]],
    vectorize=False)

              precision    recall  f1-score   support

    negative      0.847     0.919     0.881       614
     neutral      0.638     0.536     0.583       207
    positive      0.744     0.665     0.702       179

    accuracy                          0.794      1000
   macro avg      0.743     0.707     0.722      1000
weighted avg      0.785     0.794     0.787      1000

CPU times: user 41min 16s, sys: 22 s, total: 41min 38s
Wall time: 6min 58s


In [10]:
%%time
bert_experiment6000 = sst.experiment(
    twitter_train[:6000], # 
    hf_cls_phi,
    fit_softmax_classifier,
    assess_dataframes=[twitter_validate[:1500]],
    vectorize=False)

              precision    recall  f1-score   support

    negative      0.837     0.922     0.878       903
     neutral      0.676     0.543     0.603       300
    positive      0.798     0.697     0.744       261

    accuracy                          0.805      1464
   macro avg      0.771     0.721     0.742      1464
weighted avg      0.797     0.805     0.798      1464

CPU times: user 41min 47s, sys: 14.1 s, total: 42min 1s
Wall time: 7min 5s


In [11]:
%%time
bert_experiment_full = sst.experiment(
    twitter_train, # 
    hf_cls_phi,
    fit_softmax_classifier,
    assess_dataframes=[twitter_validate],
    vectorize=False)

              precision    recall  f1-score   support

    negative      0.847     0.931     0.887       903
     neutral      0.700     0.577     0.633       300
    positive      0.826     0.709     0.763       261

    accuracy                          0.819      1464
   macro avg      0.791     0.739     0.761      1464
weighted avg      0.813     0.819     0.813      1464

CPU times: user 1h 13min 48s, sys: 24.7 s, total: 1h 14min 13s
Wall time: 12min 33s


In [12]:
bert_experiment_full.keys()

dict_keys(['model', 'phi', 'train_dataset', 'assess_datasets', 'predictions', 'metric', 'scores'])

In [13]:
bert_experiment_full['scores']

[0.7608528443603727]

In [14]:
bert_experiment_full['metric']

'safe_macro_f1'

In [15]:
bert_experiment_full['model']

LogisticRegression(multi_class='ovr', solver='liblinear')

# Test BERT trained on Tweets on test set

In [16]:
def predict_one_bert(text):
    # List of tokenized examples:
    X = [bert_experiment_full['phi'](text)]
    # Standard `predict` step on a list of lists of str:
    preds = bert_experiment_full['model'].predict(X)
    # Be sure to return the only member of the predictions,
    # rather than the singleton list:
    return preds[0]

In [17]:
# %% time
# twitter_test['prediction'] = twitter_test['text'].apply(predict_one_bert)

In [18]:
import importlib
importlib.reload(sst)

<module 'sst' from '/mnt/c/Users/echya/Documents/XCS224U - 007 Natural Language Understanding/CS224-final-project/sst.py'>

In [19]:
%%time
bert_test = sst.evaluate(
    bert_experiment_full['model'],
    bert_experiment_full['phi'],
    assess_dataframes=[twitter_test],
    vectorizer=bert_experiment_full['assess_datasets'][0]['vectorizer'],
    vectorize=False
)

              precision    recall  f1-score   support

    negative      0.828     0.918     0.871       898
     neutral      0.720     0.570     0.636       316
    positive      0.808     0.708     0.755       250

    accuracy                          0.807      1464
   macro avg      0.785     0.732     0.754      1464
weighted avg      0.801     0.807     0.800      1464

CPU times: user 8min 14s, sys: 2.78 s, total: 8min 16s
Wall time: 1min 22s


In [20]:
type(bert_test['predictions'][0])

numpy.ndarray

In [22]:
predictions_fname ='results/BERT_predictions_on_twitter_test_airline.csv'
df = bert_test['predictions'][0]
pd.DataFrame(df).to_csv(predictions_fname)

In [23]:
encoding_fname ='results/BERT_encodings_on_twitter_test_airline.csv'
encoded_test = bert_test['assess_datasets'][0]
pd.DataFrame(df).to_csv(encoding_fname)

In [24]:
predictions_df = pd.DataFrame(df)
predictions_df = predictions_df.set_index(twitter_test.index)
predictions_df

Unnamed: 0,0
11,positive
13,negative
32,negative
35,neutral
46,neutral
...,...
14605,negative
14609,negative
14621,negative
14622,negative


In [25]:
twitter_test['BERT_sentiment'] = predictions_df

In [26]:
twitter_test

Unnamed: 0,tweet_id,text,sentiment,airline,BERT_sentiment
11,570289724453216256,@VirginAmerica I &lt;3 pretty graphics. so muc...,positive,Virgin America,positive
13,570287408438120448,@VirginAmerica @virginmedia I'm flying your #f...,positive,Virgin America,negative
32,570088404156698625,"@VirginAmerica help, left expensive headphones...",negative,Virgin America,negative
35,570051991277342720,Nice RT @VirginAmerica: Vibe with the moodligh...,neutral,Virgin America,neutral
46,570011341483843584,@VirginAmerica DREAM http://t.co/oA2dRfAoQ2 h...,neutral,Virgin America,neutral
...,...,...,...,...,...
14605,569592447455465472,@AmericanAir trying to book a flight on hold- ...,negative,American,negative
14609,569592148338876416,"@AmericanAir &amp; if that wasn't enough, your...",negative,American,negative
14621,569590892085915649,@AmericanAir I've been on hold for 55 mins abo...,negative,American,negative
14622,569590191758962688,I just need a place to sleep when I land witho...,negative,American,negative


In [27]:
test_predictions_fname ='results/BERT_predictions_added_to_twitter_test_airline.csv'
twitter_test.to_csv(test_predictions_fname)

In [28]:
correct = twitter_test[twitter_test['sentiment'] == twitter_test['BERT_sentiment']]

In [29]:
correct

Unnamed: 0,tweet_id,text,sentiment,airline,BERT_sentiment
11,570289724453216256,@VirginAmerica I &lt;3 pretty graphics. so muc...,positive,Virgin America,positive
32,570088404156698625,"@VirginAmerica help, left expensive headphones...",negative,Virgin America,negative
35,570051991277342720,Nice RT @VirginAmerica: Vibe with the moodligh...,neutral,Virgin America,neutral
46,570011341483843584,@VirginAmerica DREAM http://t.co/oA2dRfAoQ2 h...,neutral,Virgin America,neutral
66,569976620158578688,@VirginAmerica heyyyy guyyyys.. been trying to...,negative,Virgin America,negative
...,...,...,...,...,...
14598,569593050235736064,@AmericanAir can you guys help me please?,neutral,American,neutral
14605,569592447455465472,@AmericanAir trying to book a flight on hold- ...,negative,American,negative
14609,569592148338876416,"@AmericanAir &amp; if that wasn't enough, your...",negative,American,negative
14621,569590892085915649,@AmericanAir I've been on hold for 55 mins abo...,negative,American,negative


In [30]:
incorrect = twitter_test[twitter_test['sentiment'] != twitter_test['BERT_sentiment']]

In [31]:
incorrect

Unnamed: 0,tweet_id,text,sentiment,airline,BERT_sentiment
13,570287408438120448,@VirginAmerica @virginmedia I'm flying your #f...,positive,Virgin America,negative
65,569982307634794497,@VirginAmerica Flight 0736 DAL to DCA 2/24 2:1...,neutral,Virgin America,negative
212,569240250192670720,@VirginAmerica twitter team. you guys killed i...,positive,Virgin America,negative
215,569236471690829825,"@VirginAmerica on iPad and iPhone, clicking th...",neutral,Virgin America,negative
244,569180623165915137,"@VirginAmerica classiq, luv Virgin America. Gr...",positive,Virgin America,neutral
...,...,...,...,...,...
14405,569620512139145216,@AmericanAir please call us back to rebook!!!...,negative,American,neutral
14501,569606135960858624,@AmericanAir I've been trying to change frm AA...,neutral,American,negative
14509,569605452230754305,@AmericanAir @BDinDallas The personal touch yo...,negative,American,positive
14521,569604400479649792,@AmericanAir how does one book a ticket online...,neutral,American,negative
