In [1]:
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pytorch_pretrained_bert import BertModel, BertTokenizer
import string
import re
import sys
import argparse
import torch.nn.utils.rnn as rnn
import torch.nn.functional as F

# import utils
from Interpreter import calculate_regularization, Interpreter

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
def is_word_unfeasible(word):
    def is_ascii(word):
        return all(ord(c) < 128 for c in word)
    return ("unused" in word 
            or "#" in word 
            or not is_ascii(word)
            or len(word) < 3)


class BertDataset(Dataset):
    def __init__(self, device):
        self.device = device
        # get the tokenized words.
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

        # load BERT base model
        self.bert = BertModel.from_pretrained("bert-base-uncased").to(device)
        for param in self.bert.parameters():
            param.requires_grad = False
        self.bert.eval()

        # input characters
        self.CHAR_VOCAB_SIZE = 128
        words = self.tokenizer.vocab.keys()
        self.vocabs = [word for word in words if not is_word_unfeasible(word)]
        self.chars = [torch.LongTensor([ord(c) for c in word])
                      for word in self.vocabs]
        self.chars = rnn.pad_sequence(self.chars).to(self.device).T

        # word embeddings of bert
        ids = torch.LongTensor(
            self.tokenizer.convert_tokens_to_ids(self.vocabs)).to(self.device)
        self.word_embed = self.bert.embeddings.word_embeddings(
            ids.unsqueeze(0)).squeeze(0)

    def __len__(self):
        return len(self.vocabs)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.chars[idx], self.word_embed[idx]


class Conv1dBlockBN(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride, p=0.0):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_channel, out_channel,
                      kernel_size=kernel_size, stride=stride),
            nn.Dropout(p),
            nn.PReLU(),
            nn.BatchNorm1d(out_channel)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class CNN_LM(nn.Module):
    def __init__(self, char_vocab_size, char_len, embed_dim, chan_size, hid_size, bert_hid_size):
        super().__init__()
        self.embedding = nn.Embedding(char_vocab_size, embed_dim)
        convs = []
        for i in range(char_len - 1):
            if i == 0:
                convs.append(Conv1dBlockBN(embed_dim, chan_size, 2, stride=1))
            else:
                convs.append(Conv1dBlockBN(chan_size, chan_size, 2, stride=1))
        self.convs = nn.Sequential(*convs)
        self.fc1 = nn.Linear(chan_size, hid_size)
        self.fc2 = nn.Linear(hid_size, bert_hid_size)

    def forward(self, x):
        # (batch_size, embed_dim, context_width)
        x = self.embedding(x).permute(0, 2, 1)
        x = self.convs(x)  # (batch_size, chan_size, 1)
        x = x.squeeze(2)  # (batch_size, chan_size)
        x = F.relu(self.fc1(x))  # (batch_size, hid_size)
        x = self.fc2(x)  # (batch_size, vocab_size)
        return x

In [3]:
dataset = BertDataset(device)
dataloader = DataLoader(dataset, batch_size=3, shuffle=True)

100%|██████████| 231508/231508 [00:00<00:00, 256904.14B/s]
100%|██████████| 407873900/407873900 [01:05<00:00, 6242065.20B/s]
Token indices sequence length is longer than the specified maximum  sequence length for this BERT model (22205 > 512). Running this sequence through BERT will result in indexing errors


In [8]:
parser = argparse.ArgumentParser()
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--embed_size', type=int, default=8)
parser.add_argument('--hidden_size', type=int, default=256)
parser.add_argument('--channel_size', type=int, default=32)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--learning_rate', type=float, default=0.001)
args = parser.parse_args(["--num_epochs", "10",
                          "--embed_size", "16", "--hidden_size", "256",
                          "--batch_size", "256", 
                          "--learning_rate", "0.001",
                          "--channel_size", "32"])

CHAR_VOCAB_SIZE = 128
BERT_EMBED_DIM = 768
model = CNN_LM(char_vocab_size=CHAR_VOCAB_SIZE, 
        char_len=dataset.chars.shape[1], embed_dim=args.embed_size,
        chan_size=args.channel_size, hid_size=args.hidden_size,
        bert_hid_size=BERT_EMBED_DIM)
model.to(device)
model.load_state_dict(torch.load("data/bert_cnn.ckpt", map_location=torch.device('cpu')))
model.eval()

CNN_LM(
  (embedding): Embedding(128, 16)
  (convs): Sequential(
    (0): Conv1dBlockBN(
      (conv): Sequential(
        (0): Conv1d(16, 32, kernel_size=(2,), stride=(1,))
        (1): Dropout(p=0.0, inplace=False)
        (2): PReLU(num_parameters=1)
        (3): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Conv1dBlockBN(
      (conv): Sequential(
        (0): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
        (1): Dropout(p=0.0, inplace=False)
        (2): PReLU(num_parameters=1)
        (3): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): Conv1dBlockBN(
      (conv): Sequential(
        (0): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
        (1): Dropout(p=0.0, inplace=False)
        (2): PReLU(num_parameters=1)
        (3): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (3): Conv1dBlockBN(
      (conv): Sequentia

# Test

In [27]:
for data in dataloader:
    _input, _target = data
    break
print(_input.shape)
print(_target.shape)

torch.Size([3, 18])
torch.Size([3, 768])


In [30]:
out = model(_input)

In [53]:
dataset.vocabs

['[PAD]',
 '[UNK]',
 '[CLS]',
 '[SEP]',
 '[MASK]',
 'the',
 'and',
 'was',
 'for',
 'with',
 'that',
 'his',
 'from',
 'her',
 'she',
 'you',
 'had',
 'were',
 'but',
 'this',
 'are',
 'not',
 'they',
 'one',
 'which',
 'have',
 'him',
 'first',
 'all',
 'also',
 'their',
 'has',
 'who',
 'out',
 'been',
 'when',
 'after',
 'there',
 'into',
 'new',
 'two',
 'its',
 'time',
 'would',
 'what',
 'about',
 'said',
 'over',
 'then',
 'other',
 'more',
 'can',
 'like',
 'back',
 'them',
 'only',
 'some',
 'could',
 'where',
 'just',
 'during',
 'before',
 'made',
 'school',
 'through',
 'than',
 'now',
 'years',
 'most',
 'world',
 'may',
 'between',
 'down',
 'well',
 'three',
 'year',
 'while',
 'will',
 'later',
 'city',
 'under',
 'around',
 'did',
 'such',
 'being',
 'used',
 'state',
 'people',
 'part',
 'know',
 'against',
 'your',
 'many',
 'second',
 'university',
 'both',
 'national',
 'these',
 'don',
 'known',
 'off',
 'way',
 'until',
 'how',
 'even',
 'get',
 'head',
 '...',
 