In [1]:
import torch
from torch import nn
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import math
import random

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [34]:
SEQ_LEN = 30
BATCH_SIZE = 400
D_MODEL = 256
D_HID = 512
N_HEAD = 4
N_LAYERS = 4
DROPOUT = 0.1

PAD = '<pad>'
START = '<start>'
END = '<end>'
UNK = '<unk>'

In [35]:
with open('lines.txt', encoding='utf-8') as f:
    lines_raw = f.read().splitlines()

lines = {}

for l in lines_raw:
    s = l.split(' +++$+++ ')
    lines[s[0]] = s[-1]

In [13]:
with open('conversations.txt', encoding='utf-8') as f:
    conv_raw = f.read().splitlines()

conversations = []

for conv in conv_raw:
    arr = eval(conv.split(' +++$+++ ')[-1])
    arr = [lines[i] for i in arr[:2]]
    if len(arr) != 2:
        continue
    conversations.append(arr)

In [7]:
conversations[0]

['Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.',
 "Well, I thought we'd start with pronunciation, if that's okay with you."]

In [8]:
conversations = [
    ['hello, how are you?', 'hello, i am fine'],
    ['hello', 'hello'],
    ['how are you?', 'i am fine']
]

In [14]:
tokenizer = get_tokenizer('basic_english')
tokenized_conv = [[tokenizer(line) for line in unit] for unit in tqdm(conversations)]

100%|██████████| 83097/83097 [00:02<00:00, 34141.56it/s]


In [15]:
squeezed_conv = []
for tc in tokenized_conv:
    squeezed_conv += tc

In [16]:
vocab = build_vocab_from_iterator(squeezed_conv, specials=[PAD, START, END, UNK])
vocab.set_default_index(vocab[UNK])

In [17]:
END_PUNCT = list('.?!') + ['...']
INNER_PUNCT = ',:;-"'


def random_amount(max_amount, v, f=0.5):
    return min(random.randint(0, math.ceil(1.25 * v * max_amount ** f)), max_amount)


def random_indices(length, v, f=0.5, reverse=False):
    amount = random_amount(length, v, f)
    try:
        return sorted(random.sample(range(length), amount), reverse=reverse)
    except:
        print(length, amount)
        raise Exception('no. fuck')


def binrand(p):
    return p > random.random()


class Shuffler:
    def __init__(self, vocab):
        self.vocab = vocab

    def swap(self, t, v, max_strength=3):
        strength = math.ceil(max_strength * v)
        for i in random_indices(len(t) - 1, v, f=0.4):
            s_bef = min(strength, i)
            s_aft = min(strength, len(t) - i - 1)
            diff = random.randint(0, s_bef + s_aft) - s_bef
            diff = diff if diff else 1
            t[i], t[i + diff] = t[i + diff], t[i]

    def double(self, t, v):
        for offset, i in enumerate(random_indices(len(t), v)):
            t.insert(offset + i, t[offset + i])

    def add(self, t, v):
        for offset, i in enumerate(random_indices(len(t), v)):
            array = INNER_PUNCT if binrand(0.75) or not self.vocab else self.vocab
            t.insert(offset + i, random.choice(array))

    def add_end(self, t, v):
        if not binrand(v):
            return
        t.append(random.choice(END_PUNCT))

    def shuffle(self, t, v):

        if not t:
            return t

        self.double(t, v)
        self.add(t, v)
        self.add_end(t, v)
        self.swap(t, v)

        return t


def produce_shuffled(tokens, steps, vocab=None, max_shuffle=1):
    shuffler = Shuffler(vocab or [])
    v = torch.linspace(0, max_shuffle, steps)
    result = []

    for tline in tokens:
        result += [shuffler.shuffle(tline.copy(), vi) for vi in v]

    return result

In [18]:
shuffled_conv = []

for q, a in tqdm(tokenized_conv):
    shuffled = produce_shuffled([q], 3)
    shuffled_conv += [[sh, a] for sh in shuffled]

len(tokenized_conv), len(shuffled_conv)

100%|██████████| 83097/83097 [00:26<00:00, 3168.83it/s]


(83097, 249291)

In [19]:
tokenized_conv = shuffled_conv

In [20]:
num_conv = [[vocab(line) for line in conv] for conv in tokenized_conv]

In [21]:
def make_qna_sequence(q, a):
    s = [vocab[START]] + q + [vocab[END], vocab[START]]
    s = [vocab[PAD]] * (SEQ_LEN - len(s)) + s[-SEQ_LEN:] + a + [vocab[END]]
    return s

In [22]:
qna = []

for conv in tqdm(num_conv):
    seq = make_qna_sequence(conv[0], conv[1])
    qna.append(seq)

len(qna)

100%|██████████| 249291/249291 [00:02<00:00, 114952.40it/s]


249291

In [23]:
class MyDataset(Dataset):
  def __init__(self, data):
    self.data = data
    self.borders = []
    self.length = 0

    for qna in data:
      self.borders.append(self.length)
      self.length += len(qna) - SEQ_LEN

  def _bin_search(self, w):
      l, r = 0, len(self.borders)
      while r - l > 1 and w != self.borders[l]:
          i = (l + r) // 2
          if w < self.borders[i]:
              r = i
          else:
              l = i
      return l

  def __len__(self):
    return self.length

  def __getitem__(self, i):
    arr_id = self._bin_search(i)
    j = i - self.borders[arr_id]
    qna = self.data[arr_id]
    return torch.tensor(qna[j:j + SEQ_LEN]), qna[j + SEQ_LEN]

In [24]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[0, :x.size(1)]
        return self.dropout(x)

In [25]:
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, device='cpu'):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(D_MODEL, SEQ_LEN, DROPOUT).to(device)
        encoder_layers = nn.TransformerEncoderLayer(D_MODEL, N_HEAD, D_HID, DROPOUT, batch_first=True).to(device)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, N_LAYERS).to(device)
        self.embedding = nn.Embedding(vocab_size, D_MODEL).to(device)
        self.linear = nn.Linear(D_MODEL * SEQ_LEN, vocab_size).to(device)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src):
        src = self.embedding(src) * math.sqrt(D_MODEL)
        #src = self.pos_encoder(src)
        output = self.transformer_encoder(src)
        output = output.reshape(-1, SEQ_LEN * D_MODEL)
        output = self.linear(output)
        return output

In [52]:
train_loader = DataLoader(MyDataset(qna[:50_000]), batch_size=BATCH_SIZE, shuffle=True)

In [27]:
model = TransformerModel(len(vocab), device=device)

In [255]:
#model.load_state_dict(torch.load('model.pt'))

In [28]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

In [37]:
num_epochs = 10

In [53]:
model.train()

for epoch in range(num_epochs):
    print('-' * 6, epoch, '-' * 6)

    train_loss = 0

    for (x, y) in tqdm(train_loader):
        x = x.to(device)
        y = y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)

    print(f'Train loss: {train_loss:.3f}')

------ 0 ------


100%|██████████| 1784/1784 [17:24<00:00,  1.71it/s]


Train loss: 2.878
------ 1 ------


100%|██████████| 1784/1784 [17:24<00:00,  1.71it/s]


Train loss: 0.871
------ 2 ------


100%|██████████| 1784/1784 [17:24<00:00,  1.71it/s]


Train loss: 0.453
------ 3 ------


100%|██████████| 1784/1784 [17:24<00:00,  1.71it/s]


Train loss: 0.318
------ 4 ------


 25%|██▍       | 443/1784 [04:19<13:05,  1.71it/s]


KeyboardInterrupt: ignored

In [54]:
torch.save(model.state_dict(), 'model.pt')

In [55]:
def ask(q, max_tokens=40):
    model.eval()
    tokens = vocab(tokenizer(q))
    qna = make_qna_sequence(tokens, [])[:SEQ_LEN]
    for i in range(max_tokens):
        input_ = torch.tensor(qna[-SEQ_LEN:], device=device).unsqueeze(0)
        with torch.no_grad():
            next_token = torch.argmax(model(input_)[0]).item()
        qna.append(next_token)
        if next_token == vocab[END]:
            break
    return ' '.join(vocab.lookup_tokens(qna[SEQ_LEN:]))

In [122]:
q = 'shut the fuck up'

In [123]:
print(ask(q))

erik . ! i ' m not scared . <end>
