In [6]:
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pytorch_pretrained_bert import BertModel, BertTokenizer
from bert_cnn_alphabet import *
import string
import re
import sys
import argparse
import torch.nn.utils.rnn as rnn
import torch.nn.functional as F

# import utils
from Interpreter import calculate_regularization, Interpreter

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
dataset = BertDataset(device)
dataloader = DataLoader(dataset, batch_size=512, shuffle=True)

30522 -> 20361


Token indices sequence length is longer than the specified maximum  sequence length for this BERT model (20361 > 512). Running this sequence through BERT will result in indexing errors


In [3]:
parser = argparse.ArgumentParser()
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--embed_size', type=int, default=16)
parser.add_argument('--hidden_size', type=int, default=256)
parser.add_argument('--channel_size', type=int, default=32)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--learning_rate', type=float, default=0.001)
parser.add_argument('--skip_train', type=bool, default=False)
args = parser.parse_args(["--num_epochs", "10",
                          "--embed_size", "8", "--hidden_size", "256",
                          "--batch_size", "256", 
                          "--learning_rate", "0.001",
                          "--channel_size", "32"])

CHAR_VOCAB_SIZE = 52
BERT_EMBED_DIM = 768
model = CNN_LM(char_vocab_size=CHAR_VOCAB_SIZE, 
        char_len=dataset.chars.shape[1], embed_dim=args.embed_size,
        chan_size=args.channel_size, hid_size=args.hidden_size,
        bert_hid_size=BERT_EMBED_DIM)
model.to(device)
model.load_state_dict(torch.load("data/bert_cnn.ckpt", map_location=torch.device('cpu')))
model.eval()

CNN_LM(
  (embedding): Embedding(52, 8)
  (convs): Sequential(
    (0): Conv1dBlockBN(
      (conv): Sequential(
        (0): Conv1d(8, 32, kernel_size=(2,), stride=(1,))
        (1): Dropout(p=0.0, inplace=False)
        (2): PReLU(num_parameters=1)
        (3): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Conv1dBlockBN(
      (conv): Sequential(
        (0): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
        (1): Dropout(p=0.0, inplace=False)
        (2): PReLU(num_parameters=1)
        (3): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): Conv1dBlockBN(
      (conv): Sequential(
        (0): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
        (1): Dropout(p=0.0, inplace=False)
        (2): PReLU(num_parameters=1)
        (3): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (3): Conv1dBlockBN(
      (conv): Sequential(


# Test

In [4]:
dataset = BertDataset(device)
dataloader = DataLoader(dataset, batch_size=512, shuffle=True)

30522 -> 20361


Token indices sequence length is longer than the specified maximum  sequence length for this BERT model (20361 > 512). Running this sequence through BERT will result in indexing errors


In [7]:
inputs, outputs, targets = run_model_through_dataset(dataloader, model)

print(inputs.shape)
print(outputs.shape)
print(targets.shape)

torch.Size([20361, 18])
torch.Size([20361, 768])
torch.Size([20361, 768])


In [10]:
for i in range(100):
    pred = find_closest_words(i, inputs, outputs, targets)
    print("target word: ", to_word(inputs[i]))
    print("top 3 nearest neighbors: ", pred[:3])
    print()

target word:  med
top 3 nearest neighbors:  ['fcilitte', 'innovtive', 'renowned']

target word:  pretty
top 3 nearest neighbors:  ['fcilitte', 'innovtive', 'fcilittes']

target word:  semifinl
top 3 nearest neighbors:  ['fcilitte', 'fcilittes', 'nnoyed']

target word:  delibertely
top 3 nearest neighbors:  ['fcilitte', 'fcilittes', 'innovtive']

target word:  philhrmonic
top 3 nearest neighbors:  ['fcilitte', 'innovtive', 'fcilittes']

target word:  muir
top 3 nearest neighbors:  ['fcilitte', 'innovtive', 'renowned']

target word:  hoisted
top 3 nearest neighbors:  ['fcilitte', 'nnoyed', 'innovtive']

target word:  dphne
top 3 nearest neighbors:  ['fcilitte', 'innovtive', 'terrifying']

target word:  performed
top 3 nearest neighbors:  ['fcilitte', 'terrifying', 'scowled']

target word:  pencil
top 3 nearest neighbors:  ['fcilitte', 'innovtive', 'terrifying']

target word:  protect
top 3 nearest neighbors:  ['fcilitte', 'innovtive', 'nnoyed']

target word:  leto
top 3 nearest neighbors

target word:  lim
top 3 nearest neighbors:  ['fcilitte', 'innovtive', 'enhnce']

target word:  seprted
top 3 nearest neighbors:  ['fcilitte', 'innovtive', 'terrifying']

target word:  crp
top 3 nearest neighbors:  ['fcilitte', 'innovtive', 'considerble']

target word:  lstly
top 3 nearest neighbors:  ['fcilitte', 'innovtive', 'fcilittes']

target word:  thumbs
top 3 nearest neighbors:  ['fcilitte', 'innovtive', 'fcilittes']



In [9]:
to_word(inputs[0])

'med'