In this notebook we look at raitonalized language models. 
We first consider the one that looks at the models as a black box. 

In [1]:
#First we fix the relative imports
import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)


In [2]:
from modules.LanguageModels.LstmLanguageModel import LSTMLanguageModel
from modules.RationalExtractors.PolicyBasedRationalExtractor import PolicyBasedRationalExtractor
from utils.utils import encode, decode

In [3]:
from daily_dialog.DialogTokenizer import get_daily_dialog_tokenizer
my_tokenizer = get_daily_dialog_tokenizer(tokenizer_location='../daily_dialog/tokenizer.json', )

#Find the rmask token
encode(my_tokenizer, '[RMASK]')

tensor([4])

In [4]:
location = "../models/small_lm_test.pt"
trained_language_model = LSTMLanguageModel.load(location)
rational_extractor = PolicyBasedRationalExtractor(my_tokenizer.get_vocab_size(), mask_token=4)


In [5]:

context = "Hi how are you today? [SEP]"
ids = encode(my_tokenizer, context).view(-1, 1) #Batch second

In [6]:
#The policy based RE can mask the tokens directly:
RE_out = rational_extractor.forward(ids)
RE_out

{'policy_logits': tensor([[[ 0.0587, -0.0585]],
 
         [[ 0.1070,  0.0356]],
 
         [[ 0.0642,  0.0869]],
 
         [[-0.0495, -0.1706]],
 
         [[ 0.1459, -0.2259]],
 
         [[ 0.1012, -0.0150]],
 
         [[ 0.0846, -0.0090]]], grad_fn=<AddBackward0>),
 'policy': tensor([[[0.5293, 0.4707]],
 
         [[0.5178, 0.4822]],
 
         [[0.4943, 0.5057]],
 
         [[0.5302, 0.4698]],
 
         [[0.5919, 0.4081]],
 
         [[0.5290, 0.4710]],
 
         [[0.5234, 0.4766]]], grad_fn=<SoftmaxBackward>),
 'chosen_policy': tensor([[[0.5293]],
 
         [[0.5178]],
 
         [[0.5057]],
 
         [[0.5302]],
 
         [[0.5919]],
 
         [[0.4710]],
 
         [[0.4766]]], grad_fn=<GatherBackward>),
 'mask': tensor([[False],
         [False],
         [ True],
         [False],
         [False],
         [ True],
         [ True]]),
 'masked_input': tensor([[427],
         [219],
         [  4],
         [127],
         [484],
         [  4],
         [  4]])}

In [7]:
# The masked input becomes:
print(decode(my_tokenizer, RE_out["masked_input"].flatten()))

hi how [RMASK] you today [RMASK] [RMASK]


In [8]:
# We then can forward the masked_input directly to the LM
completed_dialogue = trained_language_model.complete_dialogue(RE_out["masked_input"].flatten())
print(completed_dialogue)
print(decode(my_tokenizer, completed_dialogue))

tensor([ 427,  219,    4,  127,  484,    4,    4,  157,   33,  127,  224,   73,
          58,  171,  871,  296,  130,  159,  128, 1081,   16,  208,   47,  239,
         127,   73,  206,  286, 1899,  234,   18,  127,  254,   16,  128,  296,
          16,   47,  291, 1845,    5,  176, 2072,  159,  127,  209, 1707,   33,
           1,  229,  699,   18,  152,   73,  302,  385,  724,  148,  128,  717,
          18,  155,  152,   73,  206, 1051,  148,  157,  989,   18,    1,  252,
          16,   47,  224,   11,   58,  254,  219,   47,   11,   51,   39,  402,
         442,  148,  128, 1170,   18,  157,   11,   57,  990,   16,   47,  526,
          11,   58,  288,  713])
hi how [RMASK] you today [RMASK] [RMASK] that ? you don ’ t have enough time to do the movie , but i think you ’ re very careful about . you know , the time , i am kidding ! what classes do you like movies ? [SEP] not bad . we ’ ve been working in the office . and we ’ re interested in that country . [SEP] well , i don ' t kn