# Using BERT for next sentence prediction

The hardest part of this is making sure you've got the python packages you need installed. You'll need to install ```torch``` and ```transformers,``` and as usual with python, you may run into compatibility issues.

All I can say to help there is "google the error message"?

But once you've got the packages installed it's easy.

First we load everything and get it ready to run.

In [1]:
import torch
from transformers import BertTokenizer, BertForNextSentencePrediction

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
print('built tokenizer')
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
model.eval()
print('built model')

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

built tokenizer


Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForNextSentencePrediction: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


built model


Then here's a function to do next sentence prediction:

In [127]:
# OLD FUNCTION

def old_get_logits(firstsentence, secondsentence):
    global tokenizer, model

    encoding = tokenizer.encode_plus(firstsentence, secondsentence, return_tensors = 'pt', max_seq_length = 255)
    loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1]))

    return loss, logits

# NEW FUNCTIONS

def get_raw_output(firstsentence, secondsentence, tokenizer, model):
    
    encoding = tokenizer.encode_plus(firstsentence, secondsentence, return_tensors = 'pt', padding = False)
    result = model(**encoding)

    return result

def get_logits(firstsentence, secondsentence, tokenizer, model):

    encoding = tokenizer.encode_plus(firstsentence, secondsentence, return_tensors = 'pt', padding = False)
    result_object = model(**encoding)
    
    logits = result_object['logits'].tolist()[0]
    
    return logits

The changes to this function are necessitated because [HuggingFace has implemented a new way to wrap "model outputs"](https://huggingface.co/transformers/main_classes/output.html) since I originally wrote this.

You used to get two numeric results. Now you get a NextSentencePredictorOutput, which in turn wraps the results as PyTorch Tensors. Let's look at it.

In [116]:
firstsentence = "I was walking to the store one day to buy groceries."
secondsentence = "At the store I bought bananas and milk."

In [117]:
result = get_raw_output(firstsentence, secondsentence, tokenizer, model)
result

NextSentencePredictorOutput(loss=None, logits=tensor([[ 6.2713, -6.1164]], grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)

After reading [the HuggingFace documentation](https://huggingface.co/transformers/main_classes/output.html) and [Googling Stack Overflow](https://stackoverflow.com/questions/53903373/convert-pytorch-tensor-to-python-list) I was able to write the new get_logits function, which unpacks those objects to get numbers we can deal with.

In [118]:
get_logits(firstsentence, secondsentence, tokenizer, model)

[6.271258354187012, -6.116359233856201]

The relation between logits and probability makes my head hurt to explain, so I'm just going to [point at Wikipedia.](https://en.wikipedia.org/wiki/Logit)

But for a quick and dirty approach I wrote this function which *loosely* translates BERT's logits output into a probability for the sequence. Also checked [this blog post](https://towardsdatascience.com/bert-for-next-sentence-prediction-466b67f8226f) to confirm that the probability of "yes, this is the next sentence" is associated with the first logit.

In [138]:
import math

def get_probability(firstsent, secondsent):
    '''
    
    :param logits: a tensor produced by BERT
    :return: probability of the first category after softmax
    '''
    global tokenizer, model
    
    logits = get_logits(firstsent, secondsent, tokenizer, model)
    
    poslogit = logits[0]
    neglogit = logits[1]

    pospart = math.pow(2.72, poslogit)
    negpart = math.pow(2.72, neglogit)

    posprob = pospart / (pospart + negpart)

    return round(posprob, 6)

In [139]:
firstsentence = "I was walking to the store one day to buy groceries."
secondsentence = "At the store I bought bananas and milk."
get_probability(firstsentence, secondsentence)

0.999996

Ah, now we can see that BERT considers that a pretty probable sequence. Let's try a less probable sequence.

We'll use the same first sentence about walking to the store, and for our second sentence

    Psychedelics are a hallucinogenic class of psychoactive drug whose primary effect is to trigger non-ordinary states of consciousness and psychedelic experiences via serotonin 2A receptor agonism.
    
Which is from Wikipedia on "psychedelic drug."


In [143]:
firstsentence = "I was walking to the store one day to buy groceries."
secondsentence = "Psychedelics are a hallucinogenic class of psychoactive drug whose primary effect is to trigger non-ordinary states of consciousness and psychedelic experiences via serotonin 2A receptor agonism."
get_probability(firstsentence, secondsentence)

5e-05

That's a much less probable sequence! Let's try a slightly weaker non-sequitur.

In [131]:
result

NextSentencePredictorOutput(loss=None, logits=tensor([[-3.4202,  6.4734]], grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)

In [132]:
firstsentence = "I was walking to the store one day to buy groceries."
secondsentence = "Everything is closed due to the pandemic."
get_probability(firstsentence, secondsentence)

0.05549650796189104

Okay, that probability is slightly higher. Still unlikely. But not *totally* improbable.