# **Paragraph Embeddings - Petros Christodoulou**

To create a paragraph embedding we will take a different approach that takes advantage of more recent developments in NLP while dealing with the weaknesses of PV-DBOW and PV-DM. The steps in our approach will be:

1.   Split a given paragraph up into its **sentences**
2.   **Embed each sentence** using a pre-trained **BERT** model (or some other recent embedding model). This is a form of transfer learning and doing this lets us take advantage of the masses amount of data and compute time used to train BERT.
3.   **Feed each embedded sentence into the encoder GRU** (or LSTM) that will be used to **produce the paragraph embedding e**. We use a GRU or LSTM because: 
a) The sentences are in sequential order and so we wish to take advantage of this information and b) Using a GRU/LSTM lets our approach be independent of the length of the paragraph / document.
4.   Take the **output of the encoder GRU (e)** and **feed it into the decoder GRU** that will try to recreate the sequence of BERT embeddings
5.   Train the combined model using the **mean squared error loss on the difference between the input to the encoder GRU and the output of the decoder GRU**. Doing this means we are training the model to produce a paragraph embedding e that maintains as much information as possible about the paragraph so that the decoder GRU is able to recreate the sentences in the input paragraph.
6.   After the model is trained it **can be used to generate paragraph embeddings given by the output of the encoder GRU (e)**. This will work for any size paragraph or document. There also will not be any training required at test time, therefore dealing with the biggest problem with the PV-DBOW and PV-DM methods


Below we provide a diagram demonstrating the model in more detail as well as a PyTorch implementation:


![alt text](https://drive.google.com/uc?id=1SSuGAnClC6OvblnTIB5-q292ncnvMp41)


## **PyTorch Implementation** 

In [1]:
# We download this package so we can use a pre-trained BERT model
!pip install transformers -q

[K     |████████████████████████████████| 317kB 6.5MB/s 
[K     |████████████████████████████████| 645kB 48.0MB/s 
[K     |████████████████████████████████| 1.0MB 42.8MB/s 
[K     |████████████████████████████████| 860kB 50.8MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [3]:
import torch
from transformers import *

paragraph = "My first sentence. My second sentence. My third sentence"
sentences = paragraph.split(".")

# Load the BERT tokenizer which prepares a sentence for usage in a BERT model
tokenizer = tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

# Now we encode each sentence so it is ready to be embedded by the BERT model 
tokenized_sentences = []
for ix, sentence in enumerate(sentences):
  tokenized_sentence = tokenizer.encode(sentence, add_special_tokens=True)
  print("Tokenized sentence {}: {}".format(ix, tokenized_sentence))
  tokenized_sentences.append(tokenized_sentence)

Tokenized sentence 0: [101, 1422, 1148, 5650, 102]
Tokenized sentence 1: [101, 1422, 1248, 5650, 102]
Tokenized sentence 2: [101, 1422, 1503, 5650, 102]


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim

hidden_size_encode = 768

# We load the pretrained BERT model
bert = BertModel.from_pretrained('bert-base-cased')

class Model(nn.Module):

  def __init__(self):
    nn.Module.__init__(self)
    
    # This creates the GRU layer that will embed the paragraph. Its input size is 768 
    # because that is the size of embedding outputted by BERT for a sentence
    self.gru_encode = nn.GRU(input_size=768, hidden_size=hidden_size_encode, batch_first=True)

    # This creates the GRU layer that will try and recover the sentence BERT embeddings
    # from the paragraph embedding
    self.gru_decode = nn.GRU(input_size=hidden_size_encode, hidden_size=768, batch_first=True)

  def forward(self, encoded_sentences, verbose=False):
    if verbose: print("Encoded sentences shape ", encoded_sentences.shape)

    # Then we put the sentence embeddings through a GRU to form a paragraph embedding
    encoded_paragraph = self.gru_encode(encoded_sentences.unsqueeze(0))[-1][0, :, :]
    if verbose: print("Paragraph embedding shape ", encoded_paragraph.shape)

    # Then we try and recover the BERT sentence embedddings using the paragraph embedding
    # and another GRU
    decoded_sentences = []
    inputs = encoded_paragraph

    # This loop occurs because we feed in the predictions of the previous timestep as the 
    # input into the next timestep
    for _ in range(x.shape[0]):
      output = self.gru_decode(inputs.unsqueeze(0))[-1][0, :, :]
      decoded_sentences.append(output)
      inputs = output  
    decoded_sentences = torch.cat(decoded_sentences)
    if verbose: print("Decoded sentences shape ", decoded_sentences.shape)
    return encoded_paragraph, decoded_sentences

model = Model()

100%|██████████| 313/313 [00:00<00:00, 196265.08B/s]
100%|██████████| 435779157/435779157 [00:16<00:00, 26115036.96B/s]


In [5]:
# The example below shows the shapes of the data that pass through the network
x = torch.LongTensor(tokenized_sentences) 

# First we embed our sentences using BERT
with torch.no_grad():
  encoded_sentences = bert(x)[1] 

# Then we run the sentences through the model to produce our paragraph embedding
encoded_paragraph, decoded_sentences = model(encoded_sentences, verbose=True)

Encoded sentences shape  torch.Size([3, 768])
Paragraph embedding shape  torch.Size([1, 768])
Decoded sentences shape  torch.Size([3, 768])


In [7]:
# To train the network we run through a training loop like below
for ix in range(10):
  encoded_paragraph, decoded_sentences = model(encoded_sentences)
  optimizer = optim.Adam(model.parameters(), lr=0.001)
  loss = torch.nn.MSELoss()(encoded_sentences, decoded_sentences)
  loss.backward()
  optimizer.step()
  print("Iteration {} -- Loss {}".format(ix+1, loss.item()))

Iteration 1 -- Loss 0.6273453235626221
Iteration 2 -- Loss 0.5458223223686218
Iteration 3 -- Loss 0.47797447443008423
Iteration 4 -- Loss 0.38736626505851746
Iteration 5 -- Loss 0.29094889760017395
Iteration 6 -- Loss 0.22369994223117828
Iteration 7 -- Loss 0.18369582295417786
Iteration 8 -- Loss 0.14752982556819916
Iteration 9 -- Loss 0.11424669623374939
Iteration 10 -- Loss 0.0865730494260788
