# 0. Outline

This notebook is a short demonstration of applying the Bidirectional Encoder Representations from Transformers (BERT) from tensorflow hub for learning purpose. 
We will build an natural language processing (NLP) model to predict if a given tweet is talking about a real emergency situation or not. The data comes from a kaggle competition at "getting started" level, see https://www.kaggle.com/c/nlp-getting-started.

The following sections started by very limited data exploration and light preprocessing, so we can quickly dive into the application of BERT. The first solution is to finetune BERT from pretrained weights in a few epochs of training. The second solution is to attach more dense layers after BERT while freezing BERT's weights. This way of transfer learning complexifies the model's structure for broader missions without heavy retraining. Finally, the third solution explores more models attached after BERT to further improve the predictions.


Download datasets/models and install modules

In [1]:
# pretrained BERT is downloaded from tensorflow hub by Google
!wget https://storage.googleapis.com/tfhub-modules/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1.tar.gz
!mkdir bert_model
!tar -xvf '/content/1.tar.gz'  -C '/content/bert_model'
module_path = "/content/bert_model"
# direct url if one doesn't want to save the model
#module_path = "https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1"

# download dataset provided by kaggle
!wget --quiet https://raw.githubusercontent.com/whitejetyeh/NLP-with-Disaster-Tweets/master/train.csv
!wget --quiet https://raw.githubusercontent.com/whitejetyeh/NLP-with-Disaster-Tweets/master/test.csv
!wget --quiet https://raw.githubusercontent.com/whitejetyeh/NLP-with-Disaster-Tweets/master/sample_submission.csv

# Download a text cleaning function for tweets
#ref https://www.kaggle.com/gunesevitan/nlp-with-disaster-tweets-eda-full-cleaning
!wget --quiet https://raw.githubusercontent.com/whitejetyeh/NLP-with-Disaster-Tweets/master/CleanTweets.py

# the official tokenization script created by Google
!wget --quiet https://raw.githubusercontent.com/tensorflow/models/master/official/nlp/bert/tokenization.py

# for importing sentencepiece in tokenization.py
!pip install sentencepiece

--2020-02-10 02:18:51--  https://storage.googleapis.com/tfhub-modules/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1.tar.gz
Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.219.128, 2607:f8b0:4001:c07::80
Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.219.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1244531387 (1.2G) [application/x-tar]
Saving to: ‘1.tar.gz’


2020-02-10 02:19:00 (133 MB/s) - ‘1.tar.gz’ saved [1244531387/1244531387]

./
./variables/
./variables/variables.data-00000-of-00001
./variables/variables.index
./assets/
./assets/vocab.txt
./saved_model.pb
Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 3.3MB/s 
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencep

In [8]:
%tensorflow_version 2.x
import re
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint
import tensorflow_hub as hub
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import tokenization

print(tf.__version__)

# check GPU connection with Google
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

2.1.0
Found GPU at: /device:GPU:0


# 1. minimum data exploration and preprocessing

Here, we slightly explore the dataset. More fine analysis can be found on kaggle, for example https://www.kaggle.com/gunesevitan/nlp-with-disaster-tweets-eda-cleaning-and-bert.

## data exploration

note: instead of reading test.csv for predicting unknown targets, we split a portion of train.csv to be df_test for validation.

In [5]:
# read data prepared by kaggle into pandas' data frames
df_train = pd.read_csv("/content/train.csv")
#df_test = pd.read_csv("/content/test.csv")
#submission = pd.read_csv("/content/sample_submission.csv") #no need
df_train, df_test = train_test_split(df_train,
                                     test_size=0.1,
                                     random_state=39)

display(df_train.head())

print('an example of a tweet')
display(df_train.text.iloc[0])

Unnamed: 0,id,keyword,location,text,target
4933,7027,mayhem,"PG County, MD",Tonight It's Going To Be Mayhem @ #4PlayThursd...,0
1673,2416,collide,"Kansas, The Free State! ~ KC",That sounds about right. Our building will hav...,0
5654,8066,rescue,Big NorthEast Litter Box,I'm on 2 blood pressure meds and it's still pr...,0
3532,5049,eyewitness,Pennsylvania,A true #TBT Eyewitness News WBRE WYOU http://...,0
5212,7444,obliterated,,I think I'll get obliterated tonight,0


Index(['id', 'keyword', 'location', 'text', 'target'], dtype='object')

an example of a tweet


"Tonight It's Going To Be Mayhem @ #4PlayThursdays. Everybody Free w/ Text. 1716 I ST NW (18+) http://t.co/cQ7jJ6Yjfz"




BERT makes predictions based on the text content in the 'text' column, and we will consider columns of 'keyword' and 'location' in the booster solution to improve BERT's predictions.

In [4]:
# basic infomation
print('basic infomation of training data')
display(df_train.info())
print('basic infomation of test data')
display(df_test.info())

# stats of data
print('stats of training data')
display(df_train.describe(include=['object']))
print('stats of test data')
display(df_test.describe(include=['object']))

# missing values
print('missing values in the training data')
display(df_train.isnull().sum())
print('missing values in the test data')
display(df_test.isnull().sum())

basic infomation of training data
<class 'pandas.core.frame.DataFrame'>
Int64Index: 6851 entries, 4933 to 3465
Data columns (total 5 columns):
id          6851 non-null int64
keyword     6797 non-null object
location    4577 non-null object
text        6851 non-null object
target      6851 non-null int64
dtypes: int64(2), object(3)
memory usage: 321.1+ KB


None

basic infomation of test data
<class 'pandas.core.frame.DataFrame'>
Int64Index: 762 entries, 7313 to 3305
Data columns (total 5 columns):
id          762 non-null int64
keyword     755 non-null object
location    503 non-null object
text        762 non-null object
target      762 non-null int64
dtypes: int64(2), object(3)
memory usage: 35.7+ KB


None

stats of training data


Unnamed: 0,keyword,location,text
count,6797,4577,6851
unique,221,3066,6764
top,fatalities,USA,11-Year-Old Boy Charged With Manslaughter of T...
freq,43,94,9


stats of test data


Unnamed: 0,keyword,location,text
count,755,503,762
unique,212,440,760
top,dust%20storm,USA,To fight bioterrorism sir.
freq,12,10,3


missing values in the training data


id             0
keyword       54
location    2274
text           0
target         0
dtype: int64

missing values in the test data


id            0
keyword       7
location    259
text          0
target        0
dtype: int64

## Text cleaning
* The most common type of words in oov have punctuations at the start or end. Those words doesn't have embeddings because of the trailing punctuations. Punctuations #, @, !, ?, (, ),[, ], *, %, ..., ', ., :, ; are separated from words
* Special characters that are attached to words are removed completely
* Contractions are expanded
* Urls are removed
* Character entity references are replaced with their actual symbols
* Typos and slang are corrected, and informal abbreviations are written in their long forms
* Hashtags and usernames are expanded
* Some words are replaced with their acronyms

See https://raw.githubusercontent.com/whitejetyeh/NLP-with-Disaster-Tweets/master/CleanTweets.py for details.

In [21]:
def clean(tweet): 
            
    # Special characters
    tweet = re.sub(r"\x89Û_", "", tweet)
    tweet = re.sub(r"\x89ÛÒ", "", tweet)
    tweet = re.sub(r"\x89ÛÓ", "", tweet)
    tweet = re.sub(r"\x89ÛÏWhen", "When", tweet)
    tweet = re.sub(r"\x89ÛÏ", "", tweet)
    tweet = re.sub(r"China\x89Ûªs", "China's", tweet)
    tweet = re.sub(r"let\x89Ûªs", "let's", tweet)
    tweet = re.sub(r"\x89Û÷", "", tweet)
    tweet = re.sub(r"\x89Ûª", "", tweet)
    tweet = re.sub(r"\x89Û\x9d", "", tweet)
    tweet = re.sub(r"å_", "", tweet)
    tweet = re.sub(r"\x89Û¢", "", tweet)
    tweet = re.sub(r"\x89Û¢åÊ", "", tweet)
    tweet = re.sub(r"fromåÊwounds", "from wounds", tweet)
    tweet = re.sub(r"åÊ", "", tweet)
    tweet = re.sub(r"åÈ", "", tweet)
    tweet = re.sub(r"JapÌ_n", "Japan", tweet)    
    tweet = re.sub(r"Ì©", "e", tweet)
    tweet = re.sub(r"å¨", "", tweet)
    tweet = re.sub(r"SuruÌ¤", "Suruc", tweet)
    tweet = re.sub(r"åÇ", "", tweet)
    
    # Contractions
    tweet = re.sub(r"he's", "he is", tweet)
    tweet = re.sub(r"there's", "there is", tweet)
    tweet = re.sub(r"We're", "We are", tweet)
    tweet = re.sub(r"That's", "That is", tweet)
    tweet = re.sub(r"won't", "will not", tweet)
    tweet = re.sub(r"they're", "they are", tweet)
    tweet = re.sub(r"Can't", "Cannot", tweet)
    tweet = re.sub(r"wasn't", "was not", tweet)
    tweet = re.sub(r"don\x89Ûªt", "do not", tweet)
    tweet = re.sub(r"aren't", "are not", tweet)
    tweet = re.sub(r"isn't", "is not", tweet)
    tweet = re.sub(r"What's", "What is", tweet)
    tweet = re.sub(r"haven't", "have not", tweet)
    tweet = re.sub(r"hasn't", "has not", tweet)
    tweet = re.sub(r"There's", "There is", tweet)
    tweet = re.sub(r"He's", "He is", tweet)
    tweet = re.sub(r"It's", "It is", tweet)
    tweet = re.sub(r"You're", "You are", tweet)
    tweet = re.sub(r"I'M", "I am", tweet)
    tweet = re.sub(r"shouldn't", "should not", tweet)
    tweet = re.sub(r"wouldn't", "would not", tweet)
    tweet = re.sub(r"i'm", "I am", tweet)
    tweet = re.sub(r"I\x89Ûªm", "I am", tweet)
    tweet = re.sub(r"I'm", "I am", tweet)
    tweet = re.sub(r"Isn't", "is not", tweet)
    tweet = re.sub(r"Here's", "Here is", tweet)
    tweet = re.sub(r"you've", "you have", tweet)
    tweet = re.sub(r"you\x89Ûªve", "you have", tweet)
    tweet = re.sub(r"we're", "we are", tweet)
    tweet = re.sub(r"what's", "what is", tweet)
    tweet = re.sub(r"couldn't", "could not", tweet)
    tweet = re.sub(r"we've", "we have", tweet)
    tweet = re.sub(r"it\x89Ûªs", "it is", tweet)
    tweet = re.sub(r"doesn\x89Ûªt", "does not", tweet)
    tweet = re.sub(r"It\x89Ûªs", "It is", tweet)
    tweet = re.sub(r"Here\x89Ûªs", "Here is", tweet)
    tweet = re.sub(r"who's", "who is", tweet)
    tweet = re.sub(r"I\x89Ûªve", "I have", tweet)
    tweet = re.sub(r"y'all", "you all", tweet)
    tweet = re.sub(r"can\x89Ûªt", "cannot", tweet)
    tweet = re.sub(r"would've", "would have", tweet)
    tweet = re.sub(r"it'll", "it will", tweet)
    tweet = re.sub(r"we'll", "we will", tweet)
    tweet = re.sub(r"wouldn\x89Ûªt", "would not", tweet)
    tweet = re.sub(r"We've", "We have", tweet)
    tweet = re.sub(r"he'll", "he will", tweet)
    tweet = re.sub(r"Y'all", "You all", tweet)
    tweet = re.sub(r"Weren't", "Were not", tweet)
    tweet = re.sub(r"Didn't", "Did not", tweet)
    tweet = re.sub(r"they'll", "they will", tweet)
    tweet = re.sub(r"they'd", "they would", tweet)
    tweet = re.sub(r"DON'T", "DO NOT", tweet)
    tweet = re.sub(r"That\x89Ûªs", "That is", tweet)
    tweet = re.sub(r"they've", "they have", tweet)
    tweet = re.sub(r"i'd", "I would", tweet)
    tweet = re.sub(r"should've", "should have", tweet)
    tweet = re.sub(r"You\x89Ûªre", "You are", tweet)
    tweet = re.sub(r"where's", "where is", tweet)
    tweet = re.sub(r"Don\x89Ûªt", "Do not", tweet)
    tweet = re.sub(r"we'd", "we would", tweet)
    tweet = re.sub(r"i'll", "I will", tweet)
    tweet = re.sub(r"weren't", "were not", tweet)
    tweet = re.sub(r"They're", "They are", tweet)
    tweet = re.sub(r"Can\x89Ûªt", "Cannot", tweet)
    tweet = re.sub(r"you\x89Ûªll", "you will", tweet)
    tweet = re.sub(r"I\x89Ûªd", "I would", tweet)
    tweet = re.sub(r"let's", "let us", tweet)
    tweet = re.sub(r"it's", "it is", tweet)
    tweet = re.sub(r"can't", "cannot", tweet)
    tweet = re.sub(r"don't", "do not", tweet)
    tweet = re.sub(r"you're", "you are", tweet)
    tweet = re.sub(r"i've", "I have", tweet)
    tweet = re.sub(r"that's", "that is", tweet)
    tweet = re.sub(r"i'll", "I will", tweet)
    tweet = re.sub(r"doesn't", "does not", tweet)
    tweet = re.sub(r"i'd", "I would", tweet)
    tweet = re.sub(r"didn't", "did not", tweet)
    tweet = re.sub(r"ain't", "am not", tweet)
    tweet = re.sub(r"you'll", "you will", tweet)
    tweet = re.sub(r"I've", "I have", tweet)
    tweet = re.sub(r"Don't", "do not", tweet)
    tweet = re.sub(r"I'll", "I will", tweet)
    tweet = re.sub(r"I'd", "I would", tweet)
    tweet = re.sub(r"Let's", "Let us", tweet)
    tweet = re.sub(r"you'd", "You would", tweet)
    tweet = re.sub(r"It's", "It is", tweet)
    tweet = re.sub(r"Ain't", "am not", tweet)
    tweet = re.sub(r"Haven't", "Have not", tweet)
    tweet = re.sub(r"Could've", "Could have", tweet)
    tweet = re.sub(r"youve", "you have", tweet)  
    tweet = re.sub(r"donå«t", "do not", tweet)   
            
    # Character entity references
    tweet = re.sub(r"&gt;", ">", tweet)
    tweet = re.sub(r"&lt;", "<", tweet)
    tweet = re.sub(r"&amp;", "&", tweet)
    
    # Typos, slang and informal abbreviations
    tweet = re.sub(r"w/e", "whatever", tweet)
    tweet = re.sub(r"w/", "with", tweet)
    tweet = re.sub(r"USAgov", "USA government", tweet)
    tweet = re.sub(r"recentlu", "recently", tweet)
    tweet = re.sub(r"Ph0tos", "Photos", tweet)
    tweet = re.sub(r"amirite", "am I right", tweet)
    tweet = re.sub(r"exp0sed", "exposed", tweet)
    tweet = re.sub(r"<3", "love", tweet)
    tweet = re.sub(r"amageddon", "armageddon", tweet)
    tweet = re.sub(r"Trfc", "Traffic", tweet)
    tweet = re.sub(r"8/5/2015", "2015-08-05", tweet)
    tweet = re.sub(r"WindStorm", "Wind Storm", tweet)
    tweet = re.sub(r"8/6/2015", "2015-08-06", tweet)
    tweet = re.sub(r"10:38PM", "10:38 PM", tweet)
    tweet = re.sub(r"10:30pm", "10:30 PM", tweet)
    tweet = re.sub(r"16yr", "16 year", tweet)
    tweet = re.sub(r"lmao", "laughing my ass off", tweet)   
    tweet = re.sub(r"TRAUMATISED", "traumatized", tweet)
    
    # Hashtags and usernames
    tweet = re.sub(r"IranDeal", "Iran Deal", tweet)
    tweet = re.sub(r"ArianaGrande", "Ariana Grande", tweet)
    tweet = re.sub(r"camilacabello97", "camila cabello", tweet) 
    tweet = re.sub(r"RondaRousey", "Ronda Rousey", tweet)     
    tweet = re.sub(r"MTVHottest", "MTV Hottest", tweet)
    tweet = re.sub(r"TrapMusic", "Trap Music", tweet)
    tweet = re.sub(r"ProphetMuhammad", "Prophet Muhammad", tweet)
    tweet = re.sub(r"PantherAttack", "Panther Attack", tweet)
    tweet = re.sub(r"StrategicPatience", "Strategic Patience", tweet)
    tweet = re.sub(r"socialnews", "social news", tweet)
    tweet = re.sub(r"NASAHurricane", "NASA Hurricane", tweet)
    tweet = re.sub(r"onlinecommunities", "online communities", tweet)
    tweet = re.sub(r"humanconsumption", "human consumption", tweet)
    tweet = re.sub(r"Typhoon-Devastated", "Typhoon Devastated", tweet)
    tweet = re.sub(r"Meat-Loving", "Meat Loving", tweet)
    tweet = re.sub(r"facialabuse", "facial abuse", tweet)
    tweet = re.sub(r"LakeCounty", "Lake County", tweet)
    tweet = re.sub(r"BeingAuthor", "Being Author", tweet)
    tweet = re.sub(r"withheavenly", "with heavenly", tweet)
    tweet = re.sub(r"thankU", "thank you", tweet)
    tweet = re.sub(r"iTunesMusic", "iTunes Music", tweet)
    tweet = re.sub(r"OffensiveContent", "Offensive Content", tweet)
    tweet = re.sub(r"WorstSummerJob", "Worst Summer Job", tweet)
    tweet = re.sub(r"HarryBeCareful", "Harry Be Careful", tweet)
    tweet = re.sub(r"NASASolarSystem", "NASA Solar System", tweet)
    tweet = re.sub(r"animalrescue", "animal rescue", tweet)
    tweet = re.sub(r"KurtSchlichter", "Kurt Schlichter", tweet)
    tweet = re.sub(r"aRmageddon", "armageddon", tweet)
    tweet = re.sub(r"Throwingknifes", "Throwing knives", tweet)
    tweet = re.sub(r"GodsLove", "God's Love", tweet)
    tweet = re.sub(r"bookboost", "book boost", tweet)
    tweet = re.sub(r"ibooklove", "I book love", tweet)
    tweet = re.sub(r"NestleIndia", "Nestle India", tweet)
    tweet = re.sub(r"realDonaldTrump", "Donald Trump", tweet)
    tweet = re.sub(r"DavidVonderhaar", "David Vonderhaar", tweet)
    tweet = re.sub(r"CecilTheLion", "Cecil The Lion", tweet)
    tweet = re.sub(r"weathernetwork", "weather network", tweet)
    tweet = re.sub(r"withBioterrorism&use", "with Bioterrorism & use", tweet)
    tweet = re.sub(r"Hostage&2", "Hostage & 2", tweet)
    tweet = re.sub(r"GOPDebate", "GOP Debate", tweet)
    tweet = re.sub(r"RickPerry", "Rick Perry", tweet)
    tweet = re.sub(r"frontpage", "front page", tweet)
    tweet = re.sub(r"NewsInTweets", "News In Tweets", tweet)
    tweet = re.sub(r"ViralSpell", "Viral Spell", tweet)
    tweet = re.sub(r"til_now", "until now", tweet)
    tweet = re.sub(r"volcanoinRussia", "volcano in Russia", tweet)
    tweet = re.sub(r"ZippedNews", "Zipped News", tweet)
    tweet = re.sub(r"MicheleBachman", "Michele Bachman", tweet)
    tweet = re.sub(r"53inch", "53 inch", tweet)
    tweet = re.sub(r"KerrickTrial", "Kerrick Trial", tweet)
    tweet = re.sub(r"abstorm", "Alberta Storm", tweet)
    tweet = re.sub(r"Beyhive", "Beyonce hive", tweet)
    tweet = re.sub(r"IDFire", "Idaho Fire", tweet)
    tweet = re.sub(r"DETECTADO", "Detected", tweet)
    tweet = re.sub(r"RockyFire", "Rocky Fire", tweet)
    tweet = re.sub(r"Listen/Buy", "Listen / Buy", tweet)
    tweet = re.sub(r"NickCannon", "Nick Cannon", tweet)
    tweet = re.sub(r"FaroeIslands", "Faroe Islands", tweet)
    tweet = re.sub(r"yycstorm", "Calgary Storm", tweet)
    tweet = re.sub(r"IDPs:", "Internally Displaced People :", tweet)
    tweet = re.sub(r"ArtistsUnited", "Artists United", tweet)
    tweet = re.sub(r"ClaytonBryant", "Clayton Bryant", tweet)
    tweet = re.sub(r"jimmyfallon", "jimmy fallon", tweet)
    tweet = re.sub(r"justinbieber", "justin bieber", tweet)  
    tweet = re.sub(r"UTC2015", "UTC 2015", tweet)
    tweet = re.sub(r"Time2015", "Time 2015", tweet)
    tweet = re.sub(r"djicemoon", "dj icemoon", tweet)
    tweet = re.sub(r"LivingSafely", "Living Safely", tweet)
    tweet = re.sub(r"FIFA16", "Fifa 2016", tweet)
    tweet = re.sub(r"thisiswhywecanthavenicethings", "this is why we cannot have nice things", tweet)
    tweet = re.sub(r"bbcnews", "bbc news", tweet)
    tweet = re.sub(r"UndergroundRailraod", "Underground Railraod", tweet)
    tweet = re.sub(r"c4news", "c4 news", tweet)
    tweet = re.sub(r"OBLITERATION", "obliteration", tweet)
    tweet = re.sub(r"MUDSLIDE", "mudslide", tweet)
    tweet = re.sub(r"NoSurrender", "No Surrender", tweet)
    tweet = re.sub(r"NotExplained", "Not Explained", tweet)
    tweet = re.sub(r"greatbritishbakeoff", "great british bake off", tweet)
    tweet = re.sub(r"LondonFire", "London Fire", tweet)
    tweet = re.sub(r"KOTAWeather", "KOTA Weather", tweet)
    tweet = re.sub(r"LuchaUnderground", "Lucha Underground", tweet)
    tweet = re.sub(r"KOIN6News", "KOIN 6 News", tweet)
    tweet = re.sub(r"LiveOnK2", "Live On K2", tweet)
    tweet = re.sub(r"9NewsGoldCoast", "9 News Gold Coast", tweet)
    tweet = re.sub(r"nikeplus", "nike plus", tweet)
    tweet = re.sub(r"david_cameron", "David Cameron", tweet)
    tweet = re.sub(r"peterjukes", "Peter Jukes", tweet)
    tweet = re.sub(r"JamesMelville", "James Melville", tweet)
    tweet = re.sub(r"megynkelly", "Megyn Kelly", tweet)
    tweet = re.sub(r"cnewslive", "C News Live", tweet)
    tweet = re.sub(r"JamaicaObserver", "Jamaica Observer", tweet)
    tweet = re.sub(r"TweetLikeItsSeptember11th2001", "Tweet like it is september 11th 2001", tweet)
    tweet = re.sub(r"cbplawyers", "cbp lawyers", tweet)
    tweet = re.sub(r"fewmoretweets", "few more tweets", tweet)
    tweet = re.sub(r"BlackLivesMatter", "Black Lives Matter", tweet)
    tweet = re.sub(r"cjoyner", "Chris Joyner", tweet)
    tweet = re.sub(r"ENGvAUS", "England vs Australia", tweet)
    tweet = re.sub(r"ScottWalker", "Scott Walker", tweet)
    tweet = re.sub(r"MikeParrActor", "Michael Parr", tweet)
    tweet = re.sub(r"4PlayThursdays", "Foreplay Thursdays", tweet)
    tweet = re.sub(r"TGF2015", "Tontitown Grape Festival", tweet)
    tweet = re.sub(r"realmandyrain", "Mandy Rain", tweet)
    tweet = re.sub(r"GraysonDolan", "Grayson Dolan", tweet)
    tweet = re.sub(r"ApolloBrown", "Apollo Brown", tweet)
    tweet = re.sub(r"saddlebrooke", "Saddlebrooke", tweet)
    tweet = re.sub(r"TontitownGrape", "Tontitown Grape", tweet)
    tweet = re.sub(r"AbbsWinston", "Abbs Winston", tweet)
    tweet = re.sub(r"ShaunKing", "Shaun King", tweet)
    tweet = re.sub(r"MeekMill", "Meek Mill", tweet)
    tweet = re.sub(r"TornadoGiveaway", "Tornado Giveaway", tweet)
    tweet = re.sub(r"GRupdates", "GR updates", tweet)
    tweet = re.sub(r"SouthDowns", "South Downs", tweet)
    tweet = re.sub(r"braininjury", "brain injury", tweet)
    tweet = re.sub(r"auspol", "Australian politics", tweet)
    tweet = re.sub(r"PlannedParenthood", "Planned Parenthood", tweet)
    tweet = re.sub(r"calgaryweather", "Calgary Weather", tweet)
    tweet = re.sub(r"weallheartonedirection", "we all heart one direction", tweet)
    tweet = re.sub(r"edsheeran", "Ed Sheeran", tweet)
    tweet = re.sub(r"TrueHeroes", "True Heroes", tweet)
    tweet = re.sub(r"S3XLEAK", "sex leak", tweet)
    tweet = re.sub(r"ComplexMag", "Complex Magazine", tweet)
    tweet = re.sub(r"TheAdvocateMag", "The Advocate Magazine", tweet)
    tweet = re.sub(r"CityofCalgary", "City of Calgary", tweet)
    tweet = re.sub(r"EbolaOutbreak", "Ebola Outbreak", tweet)
    tweet = re.sub(r"SummerFate", "Summer Fate", tweet)
    tweet = re.sub(r"RAmag", "Royal Academy Magazine", tweet)
    tweet = re.sub(r"offers2go", "offers to go", tweet)
    tweet = re.sub(r"foodscare", "food scare", tweet)
    tweet = re.sub(r"MNPDNashville", "Metropolitan Nashville Police Department", tweet)
    tweet = re.sub(r"TfLBusAlerts", "TfL Bus Alerts", tweet)
    tweet = re.sub(r"GamerGate", "Gamer Gate", tweet)
    tweet = re.sub(r"IHHen", "Humanitarian Relief", tweet)
    tweet = re.sub(r"spinningbot", "spinning bot", tweet)
    tweet = re.sub(r"ModiMinistry", "Modi Ministry", tweet)
    tweet = re.sub(r"TAXIWAYS", "taxi ways", tweet)
    tweet = re.sub(r"Calum5SOS", "Calum Hood", tweet)
    tweet = re.sub(r"po_st", "po.st", tweet)
    tweet = re.sub(r"scoopit", "scoop.it", tweet)
    tweet = re.sub(r"UltimaLucha", "Ultima Lucha", tweet)
    tweet = re.sub(r"JonathanFerrell", "Jonathan Ferrell", tweet)
    tweet = re.sub(r"aria_ahrary", "Aria Ahrary", tweet)
    tweet = re.sub(r"rapidcity", "Rapid City", tweet)
    tweet = re.sub(r"OutBid", "outbid", tweet)
    tweet = re.sub(r"lavenderpoetrycafe", "lavender poetry cafe", tweet)
    tweet = re.sub(r"EudryLantiqua", "Eudry Lantiqua", tweet)
    tweet = re.sub(r"15PM", "15 PM", tweet)
    tweet = re.sub(r"OriginalFunko", "Funko", tweet)
    tweet = re.sub(r"rightwaystan", "Richard Tan", tweet)
    tweet = re.sub(r"CindyNoonan", "Cindy Noonan", tweet)
    tweet = re.sub(r"RT_America", "RT America", tweet)
    tweet = re.sub(r"narendramodi", "Narendra Modi", tweet)
    tweet = re.sub(r"BakeOffFriends", "Bake Off Friends", tweet)
    tweet = re.sub(r"TeamHendrick", "Hendrick Motorsports", tweet)
    tweet = re.sub(r"alexbelloli", "Alex Belloli", tweet)
    tweet = re.sub(r"itsjustinstuart", "Justin Stuart", tweet)
    tweet = re.sub(r"gunsense", "gun sense", tweet)
    tweet = re.sub(r"DebateQuestionsWeWantToHear", "debate questions we want to hear", tweet)
    tweet = re.sub(r"RoyalCarribean", "Royal Carribean", tweet)
    tweet = re.sub(r"samanthaturne19", "Samantha Turner", tweet)
    tweet = re.sub(r"JonVoyage", "Jon Stewart", tweet)
    tweet = re.sub(r"renew911health", "renew 911 health", tweet)
    tweet = re.sub(r"SuryaRay", "Surya Ray", tweet)
    tweet = re.sub(r"pattonoswalt", "Patton Oswalt", tweet)
    tweet = re.sub(r"minhazmerchant", "Minhaz Merchant", tweet)
    tweet = re.sub(r"TLVFaces", "Israel Diaspora Coalition", tweet)
    tweet = re.sub(r"pmarca", "Marc Andreessen", tweet)
    tweet = re.sub(r"pdx911", "Portland Police", tweet)
    tweet = re.sub(r"jamaicaplain", "Jamaica Plain", tweet)
    tweet = re.sub(r"Japton", "Arkansas", tweet)
    tweet = re.sub(r"RouteComplex", "Route Complex", tweet)
    tweet = re.sub(r"INSubcontinent", "Indian Subcontinent", tweet)
    tweet = re.sub(r"NJTurnpike", "New Jersey Turnpike", tweet)
    tweet = re.sub(r"Politifiact", "PolitiFact", tweet)
    tweet = re.sub(r"Hiroshima70", "Hiroshima", tweet)
    tweet = re.sub(r"GMMBC", "Greater Mt Moriah Baptist Church", tweet)
    tweet = re.sub(r"versethe", "verse the", tweet)
    tweet = re.sub(r"TubeStrike", "Tube Strike", tweet)
    tweet = re.sub(r"MissionHills", "Mission Hills", tweet)
    tweet = re.sub(r"ProtectDenaliWolves", "Protect Denali Wolves", tweet)
    tweet = re.sub(r"NANKANA", "Nankana", tweet)
    tweet = re.sub(r"SAHIB", "Sahib", tweet)
    tweet = re.sub(r"PAKPATTAN", "Pakpattan", tweet)
    tweet = re.sub(r"Newz_Sacramento", "News Sacramento", tweet)
    tweet = re.sub(r"gofundme", "go fund me", tweet)
    tweet = re.sub(r"pmharper", "Stephen Harper", tweet)
    tweet = re.sub(r"IvanBerroa", "Ivan Berroa", tweet)
    tweet = re.sub(r"LosDelSonido", "Los Del Sonido", tweet)
    tweet = re.sub(r"bancodeseries", "banco de series", tweet)
    tweet = re.sub(r"timkaine", "Tim Kaine", tweet)
    tweet = re.sub(r"IdentityTheft", "Identity Theft", tweet)
    tweet = re.sub(r"AllLivesMatter", "All Lives Matter", tweet)
    tweet = re.sub(r"mishacollins", "Misha Collins", tweet)
    tweet = re.sub(r"BillNeelyNBC", "Bill Neely", tweet)
    tweet = re.sub(r"BeClearOnCancer", "be clear on cancer", tweet)
    tweet = re.sub(r"Kowing", "Knowing", tweet)
    tweet = re.sub(r"ScreamQueens", "Scream Queens", tweet)
    tweet = re.sub(r"AskCharley", "Ask Charley", tweet)
    tweet = re.sub(r"BlizzHeroes", "Heroes of the Storm", tweet)
    tweet = re.sub(r"BradleyBrad47", "Bradley Brad", tweet)
    tweet = re.sub(r"HannaPH", "Typhoon Hanna", tweet)
    tweet = re.sub(r"meinlcymbals", "MEINL Cymbals", tweet)
    tweet = re.sub(r"Ptbo", "Peterborough", tweet)
    tweet = re.sub(r"cnnbrk", "CNN Breaking News", tweet)
    tweet = re.sub(r"IndianNews", "Indian News", tweet)
    tweet = re.sub(r"savebees", "save bees", tweet)
    tweet = re.sub(r"GreenHarvard", "Green Harvard", tweet)
    tweet = re.sub(r"StandwithPP", "Stand with planned parenthood", tweet)
    tweet = re.sub(r"hermancranston", "Herman Cranston", tweet)
    tweet = re.sub(r"WMUR9", "WMUR-TV", tweet)
    tweet = re.sub(r"RockBottomRadFM", "Rock Bottom Radio", tweet)
    tweet = re.sub(r"ameenshaikh3", "Ameen Shaikh", tweet)
    tweet = re.sub(r"ProSyn", "Project Syndicate", tweet)
    tweet = re.sub(r"Daesh", "ISIS", tweet)
    tweet = re.sub(r"s2g", "swear to god", tweet)
    tweet = re.sub(r"listenlive", "listen live", tweet)
    tweet = re.sub(r"CDCgov", "Centers for Disease Control and Prevention", tweet)
    tweet = re.sub(r"FoxNew", "Fox News", tweet)
    tweet = re.sub(r"CBSBigBrother", "Big Brother", tweet)
    tweet = re.sub(r"JulieDiCaro", "Julie DiCaro", tweet)
    tweet = re.sub(r"theadvocatemag", "The Advocate Magazine", tweet)
    tweet = re.sub(r"RohnertParkDPS", "Rohnert Park Police Department", tweet)
    tweet = re.sub(r"THISIZBWRIGHT", "Bonnie Wright", tweet)
    tweet = re.sub(r"Popularmmos", "Popular MMOs", tweet)
    tweet = re.sub(r"WildHorses", "Wild Horses", tweet)
    tweet = re.sub(r"FantasticFour", "Fantastic Four", tweet)
    tweet = re.sub(r"HORNDALE", "Horndale", tweet)
    tweet = re.sub(r"PINER", "Piner", tweet)
    tweet = re.sub(r"BathAndNorthEastSomerset", "Bath and North East Somerset", tweet)
    tweet = re.sub(r"thatswhatfriendsarefor", "that is what friends are for", tweet)
    tweet = re.sub(r"residualincome", "residual income", tweet)
    tweet = re.sub(r"YahooNewsDigest", "Yahoo News Digest", tweet)
    tweet = re.sub(r"MalaysiaAirlines", "Malaysia Airlines", tweet)
    tweet = re.sub(r"AmazonDeals", "Amazon Deals", tweet)
    tweet = re.sub(r"MissCharleyWebb", "Charley Webb", tweet)
    tweet = re.sub(r"shoalstraffic", "shoals traffic", tweet)
    tweet = re.sub(r"GeorgeFoster72", "George Foster", tweet)
    tweet = re.sub(r"pop2015", "pop 2015", tweet)
    tweet = re.sub(r"_PokemonCards_", "Pokemon Cards", tweet)
    tweet = re.sub(r"DianneG", "Dianne Gallagher", tweet)
    tweet = re.sub(r"KashmirConflict", "Kashmir Conflict", tweet)
    tweet = re.sub(r"BritishBakeOff", "British Bake Off", tweet)
    tweet = re.sub(r"FreeKashmir", "Free Kashmir", tweet)
    tweet = re.sub(r"mattmosley", "Matt Mosley", tweet)
    tweet = re.sub(r"BishopFred", "Bishop Fred", tweet)
    tweet = re.sub(r"EndConflict", "End Conflict", tweet)
    tweet = re.sub(r"EndOccupation", "End Occupation", tweet)
    tweet = re.sub(r"UNHEALED", "unhealed", tweet)
    tweet = re.sub(r"CharlesDagnall", "Charles Dagnall", tweet)
    tweet = re.sub(r"Latestnews", "Latest news", tweet)
    tweet = re.sub(r"KindleCountdown", "Kindle Countdown", tweet)
    tweet = re.sub(r"NoMoreHandouts", "No More Handouts", tweet)
    tweet = re.sub(r"datingtips", "dating tips", tweet)
    tweet = re.sub(r"charlesadler", "Charles Adler", tweet)
    tweet = re.sub(r"twia", "Texas Windstorm Insurance Association", tweet)
    tweet = re.sub(r"txlege", "Texas Legislature", tweet)
    tweet = re.sub(r"WindstormInsurer", "Windstorm Insurer", tweet)
           
    # Urls
    tweet = re.sub(r"https?:\/\/t.co\/[A-Za-z0-9]+", "", tweet)
        
    # Words with punctuations and special characters
    punctuations = '@#!?+&*[]-%.:/();$=><|{}^' + "'`"
    for p in punctuations:
        tweet = tweet.replace(p, f' {p} ')
        
    # ... and ..
    tweet = tweet.replace('...', ' ... ')
    if '...' not in tweet:
        tweet = tweet.replace('..', ' ... ')      
        
    # Acronyms
    tweet = re.sub(r"MH370", "Malaysia Airlines Flight 370", tweet)
    tweet = re.sub(r"mÌ¼sica", "music", tweet)
    tweet = re.sub(r"okwx", "Oklahoma City Weather", tweet)
    tweet = re.sub(r"arwx", "Arkansas Weather", tweet)    
    tweet = re.sub(r"gawx", "Georgia Weather", tweet)  
    tweet = re.sub(r"scwx", "South Carolina Weather", tweet)  
    tweet = re.sub(r"cawx", "California Weather", tweet)
    tweet = re.sub(r"tnwx", "Tennessee Weather", tweet)
    tweet = re.sub(r"azwx", "Arizona Weather", tweet)    
    tweet = re.sub(r"wordpressdotcom", "wordpress", tweet)    
    tweet = re.sub(r"usNWSgov", "United States National Weather Service", tweet)
    tweet = re.sub(r"Suruc", "Sanliurfa", tweet)   
    
    # Grouping same words without embeddings
    tweet = re.sub(r"Bestnaijamade", "bestnaijamade", tweet)
    tweet = re.sub(r"SOUDELOR", "Soudelor", tweet)
    
    return tweet

df_train['text_cleaned'] = df_train['text'].apply(lambda s : clean(s))
df_test['text_cleaned'] = df_test['text'].apply(lambda s : clean(s))

print('a tweet before cleaning')
display(df_train.text.iloc[0])
print('a tweet after cleaning')
display(df_train.text_cleaned.iloc[0])

a tweet before cleaning


"Tonight It's Going To Be Mayhem @ #4PlayThursdays. Everybody Free w/ Text. 1716 I ST NW (18+) http://t.co/cQ7jJ6Yjfz"

a tweet after cleaning


'Tonight It is Going To Be Mayhem  @   # Foreplay Thursdays .  Everybody Free with Text .  1716 I ST NW  ( 18 +  )  '

In [19]:
# clean is a handmade text cleaning function for tweets
#Reference: https://www.kaggle.com/gunesevitan/nlp-with-disaster-tweets-eda-full-cleaning
from CleanTweets import clean
df_train['text_cleaned'] = df_train['text'].apply(lambda s : clean(s))
df_test['text_cleaned'] = df_test['text'].apply(lambda s : clean(s))

print('a tweet before cleaning')
display(df_train.text.iloc[0])
print('a tweet after cleaning')
display(df_train.text_cleaned.iloc[0])

NameError: ignored

# 2. finetune a pretrained BERT model
Here, we apply the plain BERT model. The first step is to process text data into the BERT format [token, mask, segment] with `bert_encode` and `tokenizer`. Then, the model defined in `build_model` is the original base BERT with the one node output layer determining predictions to be 1 for real disaster or 0 for non-disaster.

This part is forked from another fine kernel on kaggle, 
https://www.kaggle.com/xhlulu/disaster-nlp-keras-bert-using-tfhub.

## define model

In [0]:
def bert_encode(texts, tokenizer, max_len=512):
    all_tokens = []
    all_masks = []
    all_segments = []
    
    for text in texts:
        text = tokenizer.tokenize(text)
            
        text = text[:max_len-2]
        input_sequence = ["[CLS]"] + text + ["[SEP]"]
        pad_len = max_len - len(input_sequence)
        
        tokens = tokenizer.convert_tokens_to_ids(input_sequence)
        tokens += [0] * pad_len
        pad_masks = [1] * len(input_sequence) + [0] * pad_len
        segment_ids = [0] * max_len
        
        all_tokens.append(tokens)
        all_masks.append(pad_masks)
        all_segments.append(segment_ids)
    
    return [np.array(all_tokens), np.array(all_masks), np.array(all_segments)]

def build_model(bert_layer, max_len=512):
    input_word_ids = Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
    input_mask = Input(shape=(max_len,), dtype=tf.int32, name="input_mask")
    segment_ids = Input(shape=(max_len,), dtype=tf.int32, name="segment_ids")

    _, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])
    clf_output = sequence_output[:, 0, :]
    out = Dense(1, activation='sigmoid')(clf_output)
    
    model = Model(inputs=[input_word_ids, input_mask, segment_ids], outputs=out)
    model.compile(Adam(lr=2e-6), loss='binary_crossentropy', metrics=['accuracy'])
    
    return model

## initialize BERT and encode text data

In [0]:
# initialize BERT with trainable weights
bert_layer = hub.KerasLayer(module_path, trainable=True)

# establish tokenizer with bert_layer
vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
tokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case)

# encode cleaned text data for BERT to read
# the maximum text length (max_len) with BERT-base is 512
train_input = bert_encode(df_train.text_cleaned.values, tokenizer, max_len=160)
test_input = bert_encode(df_test.text_cleaned.values, tokenizer, max_len=160)
train_labels = df_train.target.values
test_labels = df_test.target.values

## train model and predict

In [24]:
%%time

train_model = False
if train_model:
    # initialize model
    n_epoch = 3
    model = build_model(bert_layer, max_len=160)
    checkpoint_path = "/content/bert_model.ckpt"
    display(model.summary())

    # start training (about 30 mins)
    # Create a callback that saves the model's weights
    cp_callback = ModelCheckpoint(filepath=checkpoint_path,
                                  save_weights_only=True,
                                  save_best_only=True,
                                  verbose=1)
    model.fit(train_input, train_labels,
              validation_split=0.1,
              epochs=n_epoch,
              batch_size=16,
              callbacks=[cp_callback])

    # predict df_test (validation data from train.csv)
    predictions = model.predict(test_input).round().astype(int)
    print(classification_report(test_labels, predictions, labels=[0, 1]))

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_word_ids (InputLayer)     [(None, 160)]        0                                            
__________________________________________________________________________________________________
input_mask (InputLayer)         [(None, 160)]        0                                            
__________________________________________________________________________________________________
segment_ids (InputLayer)        [(None, 160)]        0                                            
__________________________________________________________________________________________________
keras_layer (KerasLayer)        [(None, 1024), (None 335141889   input_word_ids[0][0]             
                                                                 input_mask[0][0]             

None

Train on 6165 samples, validate on 686 samples
Epoch 1/3
Epoch 00001: val_loss improved from inf to 0.43643, saving model to /content/bert_model.ckpt
Epoch 2/3
Epoch 00002: val_loss improved from 0.43643 to 0.42187, saving model to /content/bert_model.ckpt
Epoch 3/3
Epoch 00003: val_loss did not improve from 0.42187
              precision    recall  f1-score   support

           0       0.85      0.87      0.86       451
           1       0.81      0.77      0.79       311

    accuracy                           0.83       762
   macro avg       0.83      0.82      0.83       762
weighted avg       0.83      0.83      0.83       762

CPU times: user 16min 53s, sys: 11min 42s, total: 28min 35s
Wall time: 35min 16s


## load trained model for inference

In [0]:
%%time
load_model = False
if load_model:
    # load trained model for prediction
    model_path = "/content/bert_model.ckpt"

    # initialize model with the same structure as the loaded one
    model = build_model(bert_layer, max_len=160)
    display(model.summary())
    # Load the previously saved weights
    model.load_weights(model_path)

    # predict df_test (validation data from train.csv)
    predictions = model.predict(test_input)
    print(classification_report(test_labels, predictions, labels=[0, 1]))

# 3. transfer learning with BERT

## define model

In [0]:
'''Transfer learning of bert'''
def build_ext_model(module_path, max_len=512):
    input_word_ids = Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
    input_mask = Input(shape=(max_len,), dtype=tf.int32, name="input_mask")
    segment_ids = Input(shape=(max_len,), dtype=tf.int32, name="segment_ids")

    # load BERT with frozen weights
    bert_layer = hub.KerasLayer(module_path, trainable=False)

    _, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])

    x = tf.keras.layers.GlobalAveragePooling1D()(sequence_output)
    x = tf.keras.layers.Dropout(0.2)(x)
    # dense layers stacked after bert
    x = tf.keras.layers.Dense(400, activation="relu")(x)
    x = tf.keras.layers.Dropout(0.1)(x)
    x = tf.keras.layers.Dense(200, activation="relu")(x)
    x = tf.keras.layers.Dropout(0.1)(x)
    x = tf.keras.layers.Dense(100, activation="relu")(x)
    x = tf.keras.layers.Dropout(0.1)(x)
    out = Dense(1, activation='sigmoid', name="dense_output")(x)
    
    model = Model(inputs=[input_word_ids, input_mask, segment_ids], outputs=out)
    model.compile(Adam(lr=3e-5), loss='binary_crossentropy', metrics=['accuracy'])
    
    return model

## train model and predict

In [0]:
%%time

train_model = False
if train_model:
    # initialize model
    n_epoch = 5
    model = build_ext_model(module_path, max_len=160)    
    checkpoint_path = "/content/bert_ext_model.ckpt"
    display(model.summary())

    # start training (about 30 mins)
    # Create a callback that saves the model's weights
    cp_callback = ModelCheckpoint(filepath=checkpoint_path,
                                  save_weights_only=True,
                                  save_best_only=True,
                                  verbose=1)
    model.fit(train_input, train_labels,
              validation_split=0.1,
              epochs=n_epoch,
              batch_size=16,
              callbacks=[cp_callback])

    # predict df_test (validation data from train.csv)
    predictions = model.predict(test_input).round().astype(int)
    print(classification_report(test_labels, predictions, labels=[0, 1]))

              precision    recall  f1-score   support

           0       0.87      0.87      0.87       294
           1       0.82      0.82      0.82       206

    accuracy                           0.85       500
   macro avg       0.85      0.85      0.85       500
weighted avg       0.85      0.85      0.85       500

CPU times: user 14.3 s, sys: 9.03 s, total: 23.3 s
Wall time: 50.8 s


## load trained model for inference

In [0]:
%%time
load_model = False
if load_model:
    # load trained model for prediction
    model_path = "/content/bert_ext_model.ckpt"

    # initialize model with the same structure as the loaded one
    model = build_model(bert_layer, max_len=160)
    display(model.summary())
    # Load the previously saved weights
    model.load_weights(model_path)

    # predict df_test (validation data from train.csv)
    predictions = model.predict(test_input)
    print(classification_report(test_labels, predictions, labels=[0, 1]))

# MY GOOGLE DRIVE

In [29]:
!ls -l /content/

total 5144792
-rw-r--r-- 1 root root 1244531387 Oct 22 10:28 1.tar.gz
drwxr-x--T 4 root root       4096 Oct 22 08:20 bert_model
-rw-r--r-- 1 root root     410812 Feb 10 03:01 bert_model.ckpt.data-00000-of-00002
-rw-r--r-- 1 root root 4021714981 Feb 10 03:02 bert_model.ckpt.data-00001-of-00002
-rw-r--r-- 1 root root      90631 Feb 10 03:02 bert_model.ckpt.index
-rw-r--r-- 1 root root         87 Feb 10 03:02 checkpoint
-rw-r--r-- 1 root root      23231 Feb 10 02:34 CleanTweets.py
drwxr-xr-x 2 root root       4096 Feb 10 02:29 __pycache__
drwxr-xr-x 1 root root       4096 Feb  5 18:37 sample_data
-rw-r--r-- 1 root root      22746 Feb 10 02:19 sample_submission.csv
-rw-r--r-- 1 root root     420783 Feb 10 02:19 test.csv
-rw-r--r-- 1 root root      16775 Feb 10 02:19 tokenization.py
-rw-r--r-- 1 root root     987712 Feb 10 02:19 train.csv


In [27]:
from google.colab import drive
drive.mount('/gdrive')

download_weight = True
if download_weight:
    !cp -r /gdrive/'My Drive'/kaggletest/. /content/

upload_bert = False
if upload_bert:
    !cp -r /content/bert_model.ckpt* /gdrive/'My Drive'/kaggletest/
    #!cp -r /content/bert_ext_model.ckpt* /gdrive/'My Drive'/kaggletest/

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


# 3. subsequent booster after BERT

## encode categorical features

In [0]:
'''encode categorical features'''
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from collections import defaultdict

cat_df_train = df_train[['keyword','location']].copy()
cat_df_test = df_test[['keyword','location']].copy()
# handle missing values
for cat in ['keyword','location']:
    cat_df_train[cat].loc[cat_df_train[cat].isnull()] = 'NaN'
    cat_df_test[cat].loc[cat_df_test[cat].isnull()] = 'NaN'

# initialize the encoder
le = defaultdict(LabelEncoder)
# fit the encoder and transform the training set
fit_cat_df_train = cat_df_train.apply(lambda x: le[x.name].fit_transform(x))
# normalize the encoding
scaler = MinMaxScaler(feature_range=(0, 1))
fit_cat_df_train = scaler.fit_transform(fit_cat_df_train)
fit_cat_df_train = pd.DataFrame(fit_cat_df_train,
                                index=cat_df_train.index,
                                columns=['keyword','location'])


# Replace test set labels unseen in the training set
for cat in ['keyword','location']:
    labels_train = cat_df_train[cat].unique().tolist()
    replacement_label = cat_df_train[cat].mode()[0]
    cat_df_test[cat].loc[~cat_df_test[cat].isin(labels_train)] = replacement_label

# Using the dictionary to label future data
fit_cat_df_test = cat_df_test.apply(lambda x: le[x.name].transform(x))
# normalize the encoding
fit_cat_df_test = scaler.transform(fit_cat_df_test)
fit_cat_df_test = pd.DataFrame(fit_cat_df_test,
                               index=cat_df_test.index,
                               columns=['keyword','location'])


## aggregate bert predictions with extra features in subsequent models

In [32]:
%%time
'''aggregate previous predictions with extra features(processed)'''

# preliminary prediction from BERT
bert_train_pred = model.predict(train_input)
bert_predict_df = pd.DataFrame(bert_train_pred, 
                               index=df_train.index,
                               columns=['target'])
# more features to be considered
boosting_input = pd.concat([bert_predict_df,fit_cat_df_train],axis=1)
    
bert_test_pred = model.predict(test_input)
bert_test_predict_df = pd.DataFrame(bert_test_pred, 
                                    index=df_test.index,
                                    columns=['target'])
# more features to be considered
test_boosting_input = pd.concat([bert_test_predict_df,fit_cat_df_test],axis=1)

NameError: ignored

In [39]:
%%time
'''HistGradientBoosting Classifier (lightGBM inspired)'''
# To use this experimental feature, we need to explicitly ask for it:
from sklearn.experimental import enable_hist_gradient_boosting  # noqa
from sklearn.ensemble import HistGradientBoostingClassifier
HGB_classifier = HistGradientBoostingClassifier()
HGB_classifier.fit(boosting_input, train_labels)

predictions = HGB_classifier.predict(test_boosting_input).round().astype(int)
print(classification_report(test_labels, predictions, labels=[0, 1]))
print(HGB_classifier)

              precision    recall  f1-score   support

           0       0.85      0.87      0.86       451
           1       0.80      0.78      0.79       311

    accuracy                           0.83       762
   macro avg       0.83      0.82      0.82       762
weighted avg       0.83      0.83      0.83       762

HistGradientBoostingClassifier(l2_regularization=0.0, learning_rate=0.1,
                               loss='auto', max_bins=255, max_depth=None,
                               max_iter=100, max_leaf_nodes=31,
                               min_samples_leaf=20, n_iter_no_change=None,
                               random_state=None, scoring=None, tol=1e-07,
                               validation_fraction=0.1, verbose=0,
                               warm_start=False)
CPU times: user 652 ms, sys: 22.4 ms, total: 675 ms
Wall time: 356 ms


In [40]:
%%time
'''XGBoost'''
import xgboost as xgb
XGB_classifier = xgb.XGBClassifier()
XGB_classifier.fit(boosting_input, train_labels)

predictions = XGB_classifier.predict(test_boosting_input).round().astype(int)
print(classification_report(test_labels, predictions, labels=[0, 1]))
print(XGB_classifier)

              precision    recall  f1-score   support

           0       0.85      0.87      0.86       451
           1       0.80      0.78      0.79       311

    accuracy                           0.83       762
   macro avg       0.83      0.82      0.82       762
weighted avg       0.83      0.83      0.83       762

XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
              colsample_bynode=1, colsample_bytree=1, gamma=0,
              learning_rate=0.1, max_delta_step=0, max_depth=3,
              min_child_weight=1, missing=None, n_estimators=100, n_jobs=1,
              nthread=None, objective='binary:logistic', random_state=0,
              reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
              silent=None, subsample=1, verbosity=1)
CPU times: user 276 ms, sys: 55.1 ms, total: 331 ms
Wall time: 1.01 s
