# Next sentence prediction (NSP) task

In [1]:
import torch
from transformers import BertForNextSentencePrediction, BertTokenizer

In [2]:
model_id = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_id)
bert_nsp = BertForNextSentencePrediction.from_pretrained(model_id)

In [3]:
bert_nsp

BertForNextSentencePrediction(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [5]:
text1 = [
    "A preface is an introductory section of a book, written by the author, that provides context for the reader.",
    "It typically explains the book's purpose, the author's motivation for writing it, and the inspiration behind the subject matter."
]

text2 = [
    'A basic math problem could be: "If you have 5 apples and you buy 3 more, how many apples do you have?"',
    'A great marriage is one where each person, without agenda, celebrates the unique and distinctive characteristics of the other, and lovingly helps them be the best possible version of themselves.'
]

In [7]:
inputs1 = tokenizer(text1[0], text1[1], return_tensors='pt')
inputs2 = tokenizer(text2[0], text2[1], return_tensors='pt')

In [8]:
inputs1

{'input_ids': tensor([[  101,  1037, 18443,  2003,  2019, 23889,  2930,  1997,  1037,  2338,
          1010,  2517,  2011,  1996,  3166,  1010,  2008,  3640,  6123,  2005,
          1996,  8068,  1012,   102,  2009,  4050,  7607,  1996,  2338,  1005,
          1055,  3800,  1010,  1996,  3166,  1005,  1055, 14354,  2005,  3015,
          2009,  1010,  1998,  1996,  7780,  2369,  1996,  3395,  3043,  1012,
           102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1]])}

In [9]:
inputs2

{'input_ids': tensor([[  101,  1037,  3937,  8785,  3291,  2071,  2022,  1024,  1000,  2065,
          2017,  2031,  1019, 18108,  1998,  2017,  4965,  1017,  2062,  1010,
          2129,  2116, 18108,  2079,  2017,  2031,  1029,  1000,   102,  1037,
          2307,  3510,  2003,  2028,  2073,  2169,  2711,  1010,  2302, 11376,
          1010, 21566,  1996,  4310,  1998,  8200,  6459,  1997,  1996,  2060,
          1010,  1998,  8295,  2135,  7126,  2068,  2022,  1996,  2190,  2825,
          2544,  1997,  3209,  1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [28]:
outputs1 = bert_nsp(**inputs1)
outputs2 = bert_nsp(**inputs2)

In [29]:
outputs1

NextSentencePredictorOutput(loss=None, logits=tensor([[ 6.3702, -6.1489]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [30]:
outputs2

NextSentencePredictorOutput(loss=None, logits=tensor([[-3.2097,  6.2860]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)