In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os

os.getcwd()

'/content'

In [None]:
os.chdir("/content/drive/MyDrive/Colab Notebooks/qbe/candidate_test")

In [None]:
os.listdir()

['restaurants_train.csv',
 'restaurants_val.csv',
 'requirements.txt',
 'restaurants_holdout.csv',
 'README.pdf',
 '.~lock.restaurants_train.csv#',
 'entity_extraction.ipynb',
 '.ipynb_checkpoints',
 'rest_hold_preds_batchsize10.xlsx',
 'b_10_ss5',
 'b_10_cosine_sched',
 'bert_bs1',
 'b_1_ss50_e50',
 'b_1_ss50_e50_postag',
 'bert_bs2',
 'rest_hold_preds_batchsize10_20epochs.xlsx',
 'b_10_ss10_e50_swap']

In [None]:
!pip install -r requirements.txt



In [None]:
import pandas as pd

train = pd.read_csv('restaurants_train.csv')

val = pd.read_csv('restaurants_val.csv')

holdout = pd.read_csv('restaurants_holdout.csv')

train = train.fillna('')

val = val.fillna('')

holdout = holdout.fillna('')

In [None]:
import random
def get_replacement_label(label, label_list):
    
    cands = [x for x in label_list if x != label]
    
    return cands[random.randint(0, len(cands)-1)]

def swap_labels(txt, label, replacement):
    
    return txt.replace(label, replacement)
    
train_swap = train.loc[train.restaurant_name != "",:].copy()

label_list = train_swap["restaurant_name"].unique().tolist()

train_swap["rep_label"] = train_swap["restaurant_name"].apply(lambda x: get_replacement_label(x, label_list))

train_swap["sentence"] = train_swap.apply(lambda x: swap_labels(x["sentence"], x["restaurant_name"], 
                                                                x["rep_label"]) , axis = 1)

train_swap["restaurant_name"] = train_swap["rep_label"]

train_swap

Unnamed: 0,sentence,restaurant_name,rep_label
4,can you find bellinis dedham pizzeria that hav...,bellinis,bellinis
6,can you tell me where the nearest bar la grass...,bar la grassa,bar la grassa
14,do you think thursday has fabulous service,thursday,thursday
17,can i get hambers at maid cafe,maid cafe,maid cafe
19,do you know if there are any reviews on bar la...,bar la grassa,bar la grassa
24,are there reservations still available for bel...,bellinis,bellinis
26,are there any olive garden in the city open on...,olive garden,olive garden
27,do i need a reservation for name,name,name
35,are there any name within 5 minutes drive that...,name,name
37,are there any dominos in town,dominos,dominos


In [None]:
train = pd.concat([train, train_swap], axis = 0, ignore_index=True)

In [None]:
train.shape

(160, 3)

In [None]:
train = train.sample(frac = 1)

In [None]:
import sys
import torch
import torch.nn as nn
import random

from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from transformers import BartForConditionalGeneration, BartTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

_ = model.to(device)

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/558M [00:00<?, ?B/s]

In [None]:
def generate_outputs(model, sentence):
    inputs = tokenizer([sentence], max_length=1024, return_tensors='pt')
    output_ids = model.generate(inputs['input_ids'].to(device), num_beams=4)
    #print(output_ids, sentence)
    output = tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False).strip()
    return output

In [None]:
train

Unnamed: 0,sentence,restaurant_name,rep_label
117,are there any chicken wing places nearby,,
100,any good place to get a pie at an affordable p...,,
29,anything on the avenue,,
148,direct me to the nearest bar la grassa,bar la grassa,bar la grassa
63,can i dine at the barat a nossa casa,barat a nossa casa,
...,...,...,...
31,are there any restaurants around with a smokin...,,
111,are there any restaurants that are open 24 hours,,
126,are there any olive garden in the city open on...,olive garden,olive garden
147,can you help me find a starting gate restauran...,starting gate restaurant,starting gate restaurant


In [None]:
val

Unnamed: 0,sentence,restaurant_name
0,are there any ice cream shops in my neighborho...,
1,are there any restaurants within 5 miles that ...,
2,are there any locally owned franchises that gi...,
3,are there any restaurants that will let me tak...,
4,are there any five star restaurants around here,
5,do you think the noodle bar is open,noodle bar
6,are there any vegetarian restaurants in this town,
7,any places around here that has a nice view,
8,are there any jazz clubs that serve food,
9,do any famous people frequent the jimmys pizza...,jimmys pizza


In [None]:
model.eval()
generate_outputs(model, train.sentence[6])

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


'can you tell me where the nearest wendys is'

In [None]:
model.train()
1

1

In [None]:
################################
# DO NOT CHANGE THIS FUNCTION! #
################################

def get_f1_score_on_test_data(model, data):
    model.eval()
    true_positives = 0
    false_positives = 0
    false_negatives = 0
    for index, row in data.iterrows():
        sentence = row.sentence
        expected = row.restaurant_name
        inputs = tokenizer([sentence], max_length=1024, return_tensors='pt')
        predicted = generate_outputs(model, sentence)
        if expected != '' and expected == predicted:
            true_positives += 1
        if expected != '' and expected != predicted:
            false_positives += 1
        if expected == '' and predicted != '':
            false_positives += 1
        if expected != '' and predicted == '':
            false_negatives += 1

    precision = 0
    recall = 0
    f1_score = 0
    if true_positives + false_positives:
        precision = true_positives / (true_positives + false_positives)
    if true_positives + false_negatives:
        recall = true_positives /(true_positives + false_negatives)
    if precision + recall:
        f1_score = 2 * precision * recall / (precision + recall)

    print(f'precision: {precision} | recall {recall} | f1_score {f1_score}')
    return f1_score

In [None]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

_ = model.to(device)

num_epochs = 20
learning_rate = 2e-5
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=10, gamma=0.8)

In [None]:
# from torch.nn.utils.rnn import pad_sequence

In [None]:
# tokenizer.batch_encode_plus?

In [None]:
# train.sentence[0]

In [None]:
# tokenizer.batch_encode_plus([train.sentence[0]])["input_ids"]

In [None]:
# tokenizer.encode?
# #(train.sentence[0], return_tensors='pt')

In [None]:
# sents = [x.sentence for row, x in train.iterrows()]

# labels = [x.restaurant_name for row, x in train.iterrows()]

# start = 0

# tokenizer.batch_encode_plus(sents[start:start+10], padding=True)["input_ids"]

In [None]:
def train_one_batch_and_get_loss(train_model, data, optimizer, criterion, batch_size=10):
    total_loss = 0.

    sents = [x.sentence for row, x in data.iterrows()]
    labels = [x.restaurant_name for row, x in data.iterrows()]

    #sents_tok = [tokenizer.encode(x, padding=True) for x in sents]

    #labels_tok = [tokenizer.encode(x, padding=True) for x in labels]

    total_steps = int(len(sents) / batch_size)

    for i in tqdm(range(total_steps), total=total_steps): 
        start = i * batch_size       
        #print(start, start+batch_size)
        batch_sents = torch.tensor(tokenizer.batch_encode_plus(sents[start:start+batch_size], padding=True)["input_ids"]).to(device)
        
        batch_labels = torch.tensor(tokenizer.batch_encode_plus(labels[start:start+batch_size], padding=True)["input_ids"]).to(device)
        model.train()
        optimizer.zero_grad()
        
        loss = train_model(batch_sents, labels=batch_labels)[0]
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / data.shape[0]

In [None]:
def train_one_step_and_get_loss(train_model, data, optimizer, criterion):
    total_loss = 0.
    for i, row in tqdm(data.iterrows(), total=data.shape[0]):
        model.train()
        optimizer.zero_grad()
        input_ids = tokenizer.encode(row.sentence, return_tensors='pt').to(device)
        output_ids = tokenizer.encode(row.restaurant_name, return_tensors='pt').to(device)
        loss = train_model(input_ids, labels=output_ids)[0]
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / data.shape[0]

In [None]:
best_model = None
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

_ = model.to(device)

num_epochs = 50
learning_rate = 2e-5
batch_size = 10
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=10, gamma=0.8)
min_f1 = 0
holdout_f1 = 0
MODEL_PATH = "b_10_ss10_e50_swap"
for epoch in range(num_epochs):
    loss = train_one_batch_and_get_loss(model, train, optimizer, criterion, batch_size=batch_size)
    print('Epoch:', epoch , '-' * 35)
    print('Training loss:', loss)
    f1_val = get_f1_score_on_test_data(model, val)
    f1_train = get_f1_score_on_test_data(model, train)
    f1 = get_f1_score_on_test_data(model, holdout)
    if f1_val > min_f1:
        
        print("saving model:", f1, f1_val,min_f1, epoch)
        min_f1 = f1_val
        holdout_f1 = f1
        torch.save(model.state_dict(), MODEL_PATH)
    best_model = model
    print('-' * 44)
    sys.stdout.flush()
    scheduler.step()

100%|██████████| 16/16 [00:03<00:00,  4.52it/s]
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Epoch: 0 -----------------------------------
Training loss: 0.7854907304048538
precision: 0.0 | recall 0.0 | f1_score 0
precision: 0.012048192771084338 | recall 0.013888888888888888 | f1_score 0.012903225806451613
precision: 0.0 | recall 0.0 | f1_score 0
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.70it/s]


Epoch: 1 -----------------------------------
Training loss: 0.5994654268026351
precision: 0.25 | recall 0.375 | f1_score 0.3
precision: 0.4 | recall 0.6181818181818182 | f1_score 0.48571428571428577
precision: 0.14285714285714285 | recall 0.25 | f1_score 0.18181818181818182
saving model: 0.18181818181818182 0.3 0 1
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.53it/s]


Epoch: 2 -----------------------------------
Training loss: 0.4281005397439003
precision: 0.75 | recall 0.9 | f1_score 0.8181818181818182
precision: 0.6161616161616161 | recall 0.9384615384615385 | f1_score 0.7439024390243902
precision: 0.30612244897959184 | recall 0.6976744186046512 | f1_score 0.425531914893617
saving model: 0.425531914893617 0.8181818181818182 0.3 2
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.52it/s]


Epoch: 3 -----------------------------------
Training loss: 0.3195929333567619
precision: 0.5 | recall 0.8571428571428571 | f1_score 0.631578947368421
precision: 0.5742574257425742 | recall 0.9830508474576272 | f1_score 0.725
precision: 0.26851851851851855 | recall 0.8285714285714286 | f1_score 0.4055944055944056
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.71it/s]


Epoch: 4 -----------------------------------
Training loss: 0.23693885356187822
precision: 0.6666666666666666 | recall 0.8888888888888888 | f1_score 0.761904761904762
precision: 0.6288659793814433 | recall 0.9682539682539683 | f1_score 0.7624999999999998
precision: 0.30392156862745096 | recall 0.775 | f1_score 0.4366197183098592
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.70it/s]


Epoch: 5 -----------------------------------
Training loss: 0.16760413944721222
precision: 0.5833333333333334 | recall 0.7777777777777778 | f1_score 0.6666666666666666
precision: 0.7647058823529411 | recall 0.9420289855072463 | f1_score 0.8441558441558441
precision: 0.3125 | recall 0.5681818181818182 | f1_score 0.40322580645161293
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.70it/s]


Epoch: 6 -----------------------------------
Training loss: 0.11888889819383622
precision: 0.5833333333333334 | recall 0.875 | f1_score 0.7000000000000001
precision: 0.6923076923076923 | recall 1.0 | f1_score 0.8181818181818181
precision: 0.25 | recall 0.75 | f1_score 0.375
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.71it/s]


Epoch: 7 -----------------------------------
Training loss: 0.08380216602236032
precision: 0.5833333333333334 | recall 0.875 | f1_score 0.7000000000000001
precision: 0.8045977011494253 | recall 1.0 | f1_score 0.8917197452229298
precision: 0.3118279569892473 | recall 0.7073170731707317 | f1_score 0.4328358208955224
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.71it/s]


Epoch: 8 -----------------------------------
Training loss: 0.07235456705093384
precision: 0.5833333333333334 | recall 0.7777777777777778 | f1_score 0.6666666666666666
precision: 0.7976190476190477 | recall 1.0 | f1_score 0.8874172185430463
precision: 0.37037037037037035 | recall 0.6666666666666666 | f1_score 0.47619047619047616
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.71it/s]


Epoch: 9 -----------------------------------
Training loss: 0.038657298870384695
precision: 0.5833333333333334 | recall 0.875 | f1_score 0.7000000000000001
precision: 0.7294117647058823 | recall 1.0 | f1_score 0.8435374149659863
precision: 0.26851851851851855 | recall 0.8529411764705882 | f1_score 0.4084507042253522
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.68it/s]


Epoch: 10 -----------------------------------
Training loss: 0.030647348798811435
precision: 0.45454545454545453 | recall 0.5555555555555556 | f1_score 0.5
precision: 0.8809523809523809 | recall 0.9736842105263158 | f1_score 0.925
precision: 0.3013698630136986 | recall 0.5116279069767442 | f1_score 0.3793103448275862
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]


Epoch: 11 -----------------------------------
Training loss: 0.03294290397316217
precision: 0.5833333333333334 | recall 0.875 | f1_score 0.7000000000000001
precision: 0.8333333333333334 | recall 1.0 | f1_score 0.9090909090909091
precision: 0.2672413793103448 | recall 0.8857142857142857 | f1_score 0.4105960264900662
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]


Epoch: 12 -----------------------------------
Training loss: 0.02126146461814642
precision: 0.5833333333333334 | recall 0.875 | f1_score 0.7000000000000001
precision: 0.8674698795180723 | recall 1.0 | f1_score 0.9290322580645161
precision: 0.30392156862745096 | recall 0.8378378378378378 | f1_score 0.4460431654676259
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]


Epoch: 13 -----------------------------------
Training loss: 0.01771645280532539
precision: 0.5833333333333334 | recall 0.875 | f1_score 0.7000000000000001
precision: 0.9036144578313253 | recall 1.0 | f1_score 0.9493670886075949
precision: 0.26956521739130435 | recall 0.96875 | f1_score 0.42176870748299317
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.68it/s]


Epoch: 14 -----------------------------------
Training loss: 0.012642196612432599
precision: 0.5833333333333334 | recall 0.7777777777777778 | f1_score 0.6666666666666666
precision: 0.9397590361445783 | recall 1.0 | f1_score 0.9689440993788819
precision: 0.3253012048192771 | recall 0.6923076923076923 | f1_score 0.4426229508196721
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.71it/s]


Epoch: 15 -----------------------------------
Training loss: 0.012358259409666061
precision: 0.6666666666666666 | recall 0.8888888888888888 | f1_score 0.761904761904762
precision: 0.9036144578313253 | recall 1.0 | f1_score 0.9493670886075949
precision: 0.35714285714285715 | recall 0.875 | f1_score 0.5072463768115941
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.71it/s]


Epoch: 16 -----------------------------------
Training loss: 0.009548271959647536
precision: 0.5 | recall 0.75 | f1_score 0.6
precision: 0.9036144578313253 | recall 0.9868421052631579 | f1_score 0.9433962264150944
precision: 0.34177215189873417 | recall 0.675 | f1_score 0.453781512605042
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.70it/s]


Epoch: 17 -----------------------------------
Training loss: 0.013199388608336448
precision: 0.75 | recall 0.9 | f1_score 0.8181818181818182
precision: 0.8674698795180723 | recall 1.0 | f1_score 0.9290322580645161
precision: 0.35789473684210527 | recall 0.85 | f1_score 0.5037037037037038
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.70it/s]


Epoch: 18 -----------------------------------
Training loss: 0.008804555283859372
precision: 0.6666666666666666 | recall 0.8888888888888888 | f1_score 0.761904761904762
precision: 0.9397590361445783 | recall 1.0 | f1_score 0.9689440993788819
precision: 0.29347826086956524 | recall 0.7297297297297297 | f1_score 0.41860465116279066
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.72it/s]


Epoch: 19 -----------------------------------
Training loss: 0.006480376329272985
precision: 0.45454545454545453 | recall 0.7142857142857143 | f1_score 0.5555555555555556
precision: 0.9397590361445783 | recall 1.0 | f1_score 0.9689440993788819
precision: 0.36 | recall 0.675 | f1_score 0.46956521739130425
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.70it/s]


Epoch: 20 -----------------------------------
Training loss: 0.006066537543665618
precision: 0.6666666666666666 | recall 0.8888888888888888 | f1_score 0.761904761904762
precision: 0.9156626506024096 | recall 1.0 | f1_score 0.9559748427672956
precision: 0.30927835051546393 | recall 0.7692307692307693 | f1_score 0.4411764705882353
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]


Epoch: 21 -----------------------------------
Training loss: 0.0065623917384073135
precision: 0.6666666666666666 | recall 0.8888888888888888 | f1_score 0.761904761904762
precision: 0.9156626506024096 | recall 1.0 | f1_score 0.9559748427672956
precision: 0.3118279569892473 | recall 0.7631578947368421 | f1_score 0.44274809160305345
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.70it/s]


Epoch: 22 -----------------------------------
Training loss: 0.0057531386730261145
precision: 0.5 | recall 0.75 | f1_score 0.6
precision: 0.9146341463414634 | recall 1.0 | f1_score 0.9554140127388536
precision: 0.3076923076923077 | recall 0.717948717948718 | f1_score 0.43076923076923085
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]


Epoch: 23 -----------------------------------
Training loss: 0.0053708544233813885
precision: 0.45454545454545453 | recall 0.7142857142857143 | f1_score 0.5555555555555556
precision: 0.9512195121951219 | recall 1.0 | f1_score 0.975
precision: 0.34210526315789475 | recall 0.65 | f1_score 0.4482758620689655
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.66it/s]


Epoch: 24 -----------------------------------
Training loss: 0.004148738284129649
precision: 0.5833333333333334 | recall 0.875 | f1_score 0.7000000000000001
precision: 0.9512195121951219 | recall 1.0 | f1_score 0.975
precision: 0.3125 | recall 0.8108108108108109 | f1_score 0.45112781954887216
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.73it/s]


Epoch: 25 -----------------------------------
Training loss: 0.0047230265219695864
precision: 0.5 | recall 0.8571428571428571 | f1_score 0.631578947368421
precision: 0.9634146341463414 | recall 1.0 | f1_score 0.9813664596273292
precision: 0.32142857142857145 | recall 0.7297297297297297 | f1_score 0.44628099173553715
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.68it/s]


Epoch: 26 -----------------------------------
Training loss: 0.005191712849773466
precision: 0.45454545454545453 | recall 0.7142857142857143 | f1_score 0.5555555555555556
precision: 0.9512195121951219 | recall 1.0 | f1_score 0.975
precision: 0.3333333333333333 | recall 0.6842105263157895 | f1_score 0.44827586206896547
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.70it/s]


Epoch: 27 -----------------------------------
Training loss: 0.004068744368851185
precision: 0.6666666666666666 | recall 0.8888888888888888 | f1_score 0.761904761904762
precision: 0.9506172839506173 | recall 1.0 | f1_score 0.9746835443037974
precision: 0.3157894736842105 | recall 0.8571428571428571 | f1_score 0.46153846153846156
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.67it/s]


Epoch: 28 -----------------------------------
Training loss: 0.0037547939573414624
precision: 0.6666666666666666 | recall 0.8888888888888888 | f1_score 0.761904761904762
precision: 0.927710843373494 | recall 1.0 | f1_score 0.9625
precision: 0.32323232323232326 | recall 0.9411764705882353 | f1_score 0.481203007518797
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]


Epoch: 29 -----------------------------------
Training loss: 0.003657051152549684
precision: 0.6666666666666666 | recall 0.8888888888888888 | f1_score 0.761904761904762
precision: 0.9512195121951219 | recall 1.0 | f1_score 0.975
precision: 0.34408602150537637 | recall 0.8648648648648649 | f1_score 0.49230769230769234
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]


Epoch: 30 -----------------------------------
Training loss: 0.0029559612681623547
precision: 0.6666666666666666 | recall 0.8888888888888888 | f1_score 0.761904761904762
precision: 0.926829268292683 | recall 1.0 | f1_score 0.9620253164556963
precision: 0.32941176470588235 | recall 0.7368421052631579 | f1_score 0.45528455284552843
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.67it/s]


Epoch: 31 -----------------------------------
Training loss: 0.0023613531433511526
precision: 0.6666666666666666 | recall 0.8888888888888888 | f1_score 0.761904761904762
precision: 0.9512195121951219 | recall 1.0 | f1_score 0.975
precision: 0.32222222222222224 | recall 0.7631578947368421 | f1_score 0.453125
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.67it/s]


Epoch: 32 -----------------------------------
Training loss: 0.004588562628487125
precision: 0.6666666666666666 | recall 0.8888888888888888 | f1_score 0.761904761904762
precision: 0.9512195121951219 | recall 1.0 | f1_score 0.975
precision: 0.32967032967032966 | recall 0.7894736842105263 | f1_score 0.46511627906976744
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.67it/s]


Epoch: 33 -----------------------------------
Training loss: 0.0046714985102880744
precision: 0.5 | recall 0.75 | f1_score 0.6
precision: 0.9512195121951219 | recall 1.0 | f1_score 0.975
precision: 0.3 | recall 0.7297297297297297 | f1_score 0.4251968503937008
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]


Epoch: 34 -----------------------------------
Training loss: 0.002790589694632217
precision: 0.5833333333333334 | recall 0.875 | f1_score 0.7000000000000001
precision: 0.9146341463414634 | recall 1.0 | f1_score 0.9554140127388536
precision: 0.32653061224489793 | recall 0.8648648648648649 | f1_score 0.47407407407407404
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.70it/s]


Epoch: 35 -----------------------------------
Training loss: 0.002912899450166151
precision: 0.5833333333333334 | recall 0.875 | f1_score 0.7000000000000001
precision: 0.926829268292683 | recall 1.0 | f1_score 0.9620253164556963
precision: 0.30303030303030304 | recall 0.8571428571428571 | f1_score 0.4477611940298507
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.70it/s]


Epoch: 36 -----------------------------------
Training loss: 0.0038120577926747503
precision: 0.5833333333333334 | recall 0.875 | f1_score 0.7000000000000001
precision: 0.9146341463414634 | recall 1.0 | f1_score 0.9554140127388536
precision: 0.31313131313131315 | recall 0.8611111111111112 | f1_score 0.45925925925925926
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]


Epoch: 37 -----------------------------------
Training loss: 0.003064108669059351
precision: 0.5833333333333334 | recall 0.875 | f1_score 0.7000000000000001
precision: 0.9390243902439024 | recall 1.0 | f1_score 0.9685534591194969
precision: 0.29 | recall 0.8285714285714286 | f1_score 0.42962962962962964
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]


Epoch: 38 -----------------------------------
Training loss: 0.0024978032743092626
precision: 0.5833333333333334 | recall 0.875 | f1_score 0.7000000000000001
precision: 0.9512195121951219 | recall 1.0 | f1_score 0.975
precision: 0.27835051546391754 | recall 0.75 | f1_score 0.406015037593985
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]


Epoch: 39 -----------------------------------
Training loss: 0.0021055003686342388
precision: 0.5 | recall 0.8571428571428571 | f1_score 0.631578947368421
precision: 0.9512195121951219 | recall 1.0 | f1_score 0.975
precision: 0.2857142857142857 | recall 0.8 | f1_score 0.4210526315789473
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]


Epoch: 40 -----------------------------------
Training loss: 0.001990148361073807
precision: 0.5833333333333334 | recall 0.875 | f1_score 0.7000000000000001
precision: 0.9512195121951219 | recall 1.0 | f1_score 0.975
precision: 0.2857142857142857 | recall 0.7027027027027027 | f1_score 0.40624999999999994
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.67it/s]


Epoch: 41 -----------------------------------
Training loss: 0.006532393640372902
precision: 0.6666666666666666 | recall 0.8888888888888888 | f1_score 0.761904761904762
precision: 0.9146341463414634 | recall 1.0 | f1_score 0.9554140127388536
precision: 0.28431372549019607 | recall 0.90625 | f1_score 0.43283582089552236
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.68it/s]


Epoch: 42 -----------------------------------
Training loss: 0.005961420072708279
precision: 0.6666666666666666 | recall 0.8888888888888888 | f1_score 0.761904761904762
precision: 0.926829268292683 | recall 1.0 | f1_score 0.9620253164556963
precision: 0.28846153846153844 | recall 0.9375 | f1_score 0.44117647058823534
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.71it/s]


Epoch: 43 -----------------------------------
Training loss: 0.0018139202788006515
precision: 0.6666666666666666 | recall 0.8888888888888888 | f1_score 0.761904761904762
precision: 0.9634146341463414 | recall 1.0 | f1_score 0.9813664596273292
precision: 0.2828282828282828 | recall 0.9032258064516129 | f1_score 0.43076923076923074
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.72it/s]


Epoch: 44 -----------------------------------
Training loss: 0.0021531648701056836
precision: 0.5454545454545454 | recall 0.75 | f1_score 0.631578947368421
precision: 0.9634146341463414 | recall 1.0 | f1_score 0.9813664596273292
precision: 0.3037974683544304 | recall 0.6486486486486487 | f1_score 0.41379310344827586
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.71it/s]


Epoch: 45 -----------------------------------
Training loss: 0.004408880940172821
precision: 0.5454545454545454 | recall 0.75 | f1_score 0.631578947368421
precision: 0.975609756097561 | recall 1.0 | f1_score 0.9876543209876543
precision: 0.325 | recall 0.6666666666666666 | f1_score 0.4369747899159664
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]


Epoch: 46 -----------------------------------
Training loss: 0.0021838757034856825
precision: 0.5454545454545454 | recall 0.75 | f1_score 0.631578947368421
precision: 0.9634146341463414 | recall 1.0 | f1_score 0.9813664596273292
precision: 0.313953488372093 | recall 0.6923076923076923 | f1_score 0.43200000000000005
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]


Epoch: 47 -----------------------------------
Training loss: 0.0016032239422202111
precision: 0.5 | recall 0.75 | f1_score 0.6
precision: 0.9512195121951219 | recall 1.0 | f1_score 0.975
precision: 0.28421052631578947 | recall 0.8181818181818182 | f1_score 0.42187499999999994
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.68it/s]


Epoch: 48 -----------------------------------
Training loss: 0.0019000217493157835
precision: 0.5833333333333334 | recall 0.875 | f1_score 0.7000000000000001
precision: 0.9512195121951219 | recall 1.0 | f1_score 0.975
precision: 0.3 | recall 0.7297297297297297 | f1_score 0.4251968503937008
--------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]


Epoch: 49 -----------------------------------
Training loss: 0.003466495400061831
precision: 0.5833333333333334 | recall 0.875 | f1_score 0.7000000000000001
precision: 0.9634146341463414 | recall 1.0 | f1_score 0.9813664596273292
precision: 0.29347826086956524 | recall 0.7941176470588235 | f1_score 0.42857142857142855
--------------------------------------------


In [None]:
holdout_f1

0.425531914893617

In [None]:
best_model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

best_model.load_state_dict(torch.load(MODEL_PATH))

best_model.to(device)

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0): BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        

In [None]:
get_f1_score_on_test_data(best_model, holdout)

precision: 0.30612244897959184 | recall 0.6976744186046512 | f1_score 0.425531914893617


0.425531914893617

In [None]:
get_f1_score_on_test_data(model, holdout)

precision: 0.29347826086956524 | recall 0.7941176470588235 | f1_score 0.42857142857142855


0.42857142857142855

In [None]:
model.eval()
holdout["preds"] = holdout["sentence"].apply(lambda x: generate_outputs(model, x))

holdout

Unnamed: 0,sentence,restaurant_name,preds
0,find pizza places,,
1,find me the best rated chinese restaurant in t...,,chinese restaurant
2,what kind of food does abc cafe serve,abc cafe,abc cafe
3,how far away is the nearest steak house,,
4,i am looking for a mexican restuarant that has...,,
...,...,...,...
145,find me brazilian food with on location parking,,
146,get me to a mexican place,,
147,how far am i from the nearest bagel shop,,bagel shop
148,what time does sonic open,sonic,dominos


In [None]:
def generate_outputs(model, sentence):
    inputs = tokenizer([sentence], max_length=1024, return_tensors='pt')
    output_ids = model.generate(inputs['input_ids'].to(device), num_beams=4)
    #print(output_ids, sentence)
    output = tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False).strip()

    if output in sentence:
        return output
    else:
        return ""

In [None]:
get_f1_score_on_test_data(model, holdout)

precision: 0.32142857142857145 | recall 0.5510204081632653 | f1_score 0.40601503759398494


0.40601503759398494

In [None]:
holdout["preds2"] = holdout["sentence"].apply(lambda x: generate_outputs(model, x))

holdout

Unnamed: 0,sentence,restaurant_name,preds,preds2
0,find pizza places,,,
1,find me the best rated chinese restaurant in t...,,chinese restaurant,chinese restaurant
2,what kind of food does abc cafe serve,abc cafe,abc cafe,abc cafe
3,how far away is the nearest steak house,,,
4,i am looking for a mexican restuarant that has...,,,
...,...,...,...,...
145,find me brazilian food with on location parking,,,
146,get me to a mexican place,,,
147,how far am i from the nearest bagel shop,,bagel shop,bagel shop
148,what time does sonic open,sonic,dominos,


In [None]:
holdout.to_excel("rest_hold_preds_batchsize10_20epochs.xlsx", index=False)

In [None]:
num = 0
print('input:', holdout.sentence[num])
print('expected:', holdout.restaurant_name[num])
print('predicted:', generate_outputs(best_model, holdout.sentence[num]))

input: find pizza places
expected: 
predicted: pizza places
