In [1]:
import model

In [2]:
predictor = model.MaskedTokenBert()

# We haven't specified any data in partcular, so it assumes the default file
predictor.file

'KumonTaskData.csv'

All of the following functions will load data before proceeding:
``` python
predictor.tokenize_data()
predictor.model()
predictor.predict()
```

In [3]:
# If instead, you want to predict for a specific example use: .tokenize_data(example)
example = "[CLS] Who was Jim Henson ? [SEP] Jim [MASK] was a puppeteer [SEP]"
predictor.tokenize_data(example)

In [4]:
%%time
# .predict(top_n) returns a list of top_n predictions for each [example[MASK]]; default=5
predictor.predict(2)

CPU times: user 4.17 s, sys: 678 ms, total: 4.84 s
Wall time: 4.9 s


[[['henson', 'who']]]

In [5]:
%%time
# Let's retry with our dataset
predictor = model.MaskedTokenBert()
predictor.predict()[0]

CPU times: user 1min 23s, sys: 2.11 s, total: 1min 25s
Wall time: 25.6 s


[['cowboy', 'combat', 'tennis', 'hiking', 'football']]

In [6]:
# .data stores the text being used to model and predict
predictor.data[0]

'[CLS] I remember how much she wanted a pair of football boots just like her big brother’s, and she wore them to bed every night for a week. [SEP] She wanted a pair of [MASK] boots. [SEP]'

In [7]:
# The model also saves the hidden state of the last Bert model layer
predictor.encoded_layers[0]

tensor([[[-0.2618,  0.0729, -0.6069,  ..., -0.5260,  0.5045,  0.2083],
         [ 0.0640,  0.1475, -0.1458,  ...,  0.1452,  0.4170, -0.4609],
         [ 0.0848,  0.4831,  0.0085,  ...,  0.2944, -0.1135, -0.4246],
         ...,
         [ 1.2831, -0.0580,  0.0136,  ..., -0.4035, -0.6010, -0.7674],
         [ 0.6358, -0.0315, -0.4482,  ...,  0.1102, -0.4142, -0.5990],
         [ 0.6308, -0.0439, -0.3949,  ...,  0.1086, -0.4175, -0.5768]]])

In [8]:
# .df stores a pandas df with the .csv se we can reference other columns
predictor.df.head(1)

Unnamed: 0,Kumon Level,Workbook Page,Average Score,Task,Masked Words
0,AII,138.0,93.0,[CLS] I remember how much she wanted a pair of...,[football]


In [9]:
# .score(show_wrong=True) scores Bert's predictions (and prints the wrong answers)
scores = predictor.score(show_wrong=True)

(0,0: actual=football bert=cowboy
football was bert's #5 choice
(2,0: actual=piggy bert=dog
(2,1: actual=bank bert=rabbit
(6,1: actual=spacebus bert=day
(7,0: actual=Roboto bert=tanaka
(8,1: actual=understand bert=smile
understand was bert's #2 choice
(9,0: actual=rocked bert=caused
rocked was bert's #2 choice
(12,1: actual=started bert=was
started was bert's #2 choice
(13,1: actual=daisies bert=flowers
(17,0: actual=singing bert=still
(17,1: actual=cheerfully bert=singing
(18,0: actual=knew bert=trouble
18,1 broke... moving on
(22,0: actual=Lisa bert=she
(22,2: actual=says bert=thinks
says was bert's #2 choice
(22,3: actual=aunt bert=cousin
(23,0: actual=sliced bert=fresh
(23,1: actual=steaming bert=delicious
steaming was bert's #4 choice
(25,1: actual=snakes bert=eels
snakes was bert's #2 choice
(25,2: actual=eels bert=.
(26,0: actual=miserable bert=whole
(26,1: actual=of bert=to
of was bert's #2 choice
(26,2: actual=scraps bert=now
(27,0: actual=lumpy bert=cold
(29,0: actual=hoping 

In [10]:
import pandas as pd

pd.set_option('display.max_rows', predictor.size)
predictor.df.head(69)

Unnamed: 0,Kumon Level,Workbook Page,Average Score,Bert Score,Task,Masked Words
0,AII,138.0,93.0,70.0,[CLS] I remember how much she wanted a pair of...,[football]
1,AII,138.0,93.0,70.0,[CLS] I remember how much she wanted a pair of...,[night]
2,AII,138.0,93.0,70.0,[CLS] And I remember the day she had a fight w...,"[piggy, bank, rabbit, nose]"
3,AII,148.0,90.5,100.0,[CLS] So he took a pillow from the bed and pul...,[apron]
4,AII,148.0,90.5,100.0,[CLS] So he took a pillow from the bed and pul...,[rubber]
5,AII,148.0,90.5,100.0,[CLS] In this disguise he set out to milk Beli...,"[looked, blinked, chewing]"
6,BI,78.0,96.9,69.0,[CLS] That morning the Wake-Up machine had rol...,"[bought, spacebus]"
7,BI,78.0,96.9,69.0,[CLS] That morning the Wake-Up machine had rol...,"[Roboto, teacher]"
8,BI,78.0,96.9,69.0,[CLS] That morning the Wake-Up machine had rol...,"[type, understand]"
9,BI,78.0,96.9,69.0,[CLS] That morning the Wake-Up machine had rol...,"[rocked, sleep]"
