In [1]:
from bertviz import model_view
from transformers import AutoTokenizer, AutoModel
import pickle

# TODO : torch text should be 0.11.2 // torch = 1.10.2
from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = get_tokenizer('basic_english')
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

In [3]:
train_iter = IMDB(split='train')
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: 0 if x=='neg' else 1

In [25]:
sentence = "this is a book on the table"
processed_text = torch.tensor(text_pipeline(sentence), dtype=torch.int64)
with open('model.pkl', 'rb') as f:
    model = pickle.load(f)


In [26]:
processed_text= processed_text.view(1, -1)

In [27]:
model.train(False)
out, att = model(processed_text)

In [31]:
att = att.reshape(1,8,7,7)
att = (att[None,:])

In [32]:
att

tensor([[[[[0.1576, 0.1489, 0.1401, 0.1156, 0.1551, 0.1540, 0.1286],
           [0.1373, 0.1419, 0.1703, 0.1390, 0.1236, 0.1607, 0.1272],
           [0.1450, 0.1382, 0.1527, 0.1355, 0.1268, 0.1598, 0.1420],
           [0.1162, 0.1598, 0.1598, 0.1388, 0.1315, 0.1394, 0.1545],
           [0.1702, 0.1267, 0.1436, 0.1306, 0.1650, 0.1412, 0.1227],
           [0.1540, 0.1462, 0.1437, 0.1394, 0.1513, 0.1493, 0.1160],
           [0.1346, 0.1383, 0.1418, 0.1578, 0.1284, 0.1570, 0.1421]],

          [[0.1473, 0.1440, 0.1706, 0.1506, 0.1274, 0.1268, 0.1333],
           [0.1337, 0.1425, 0.1466, 0.1210, 0.1511, 0.1538, 0.1513],
           [0.1552, 0.1507, 0.1614, 0.1250, 0.1073, 0.1465, 0.1539],
           [0.1290, 0.1353, 0.1768, 0.1327, 0.1440, 0.1346, 0.1475],
           [0.1253, 0.1408, 0.1371, 0.1611, 0.1482, 0.1332, 0.1543],
           [0.1338, 0.1243, 0.1726, 0.1369, 0.1376, 0.1398, 0.1550],
           [0.1358, 0.1219, 0.1581, 0.1411, 0.1381, 0.1420, 0.1630]],

          [[0.1590, 0.1356, 0.

In [30]:
model_view(att, processed_text, sentence)

TypeError: slice indices must be integers or None or have an __index__ method

In [15]:
import former
from former import util
from former.util import d, here

import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F

from torchtext import data, datasets, vocab
# from torchtext.legacy import data, datasets, vocab

import numpy as np

from argparse import ArgumentParser, Namespace
from torch.utils.tensorboard import SummaryWriter

import random, tqdm, sys, math, gzip
import os
from copy import deepcopy
from datetime import datetime
# Used for converting between nats and bits
LOG2E = math.log2(math.e)
TEXT = data.Field(lower=True, include_lengths=True, batch_first=True)
LABEL = data.Field(sequential=False)
NUM_CLS = 2
MODEL_OUTPUT_PATH = './model_output'



def go(arg):
    """
    Creates and trains a basic transformer for the IMDB sentiment classification task.
    """
    tbw = SummaryWriter(log_dir=arg.tb_dir) # Tensorboard logging

    # load the IMDB data
    if arg.final:
        train, test = datasets.IMDB.splits(TEXT, LABEL)


        TEXT.build_vocab(train, max_size=arg.vocab_size - 2)
        LABEL.build_vocab(train)

        train_iter, test_iter = data.BucketIterator.splits((train, test), batch_size=arg.batch_size, device=util.d())
    else:
        tdata, _ = datasets.IMDB.splits(TEXT, LABEL)
        print(f'tdata ======> ',tdata.split())
        train, test = tdata.split(split_ratio=0.8)

        TEXT.build_vocab(train, max_size=arg.vocab_size - 2) # - 2 to make space for <unk> and <pad>
        LABEL.build_vocab(train)

        train_iter, test_iter = data.BucketIterator.splits((train, test), batch_size=arg.batch_size, device=util.d())

    print(f'- nr. of training examples {len(train_iter)}')
    print(f'- nr. of {"test" if arg.final else "validation"} examples {len(test_iter)}')

    if arg.max_length < 0:
        mx = max([input.text[0].size(1) for input in train_iter])
        mx = mx * 2
        print(f'- maximum sequence length: {mx}')
    else:
        mx = arg.max_length

    # create the model
    model = former.CTransformer(emb=arg.embedding_size, heads=arg.num_heads, depth=arg.depth, seq_length=mx, num_tokens=arg.vocab_size, num_classes=NUM_CLS, max_pool=arg.max_pool)
    if torch.cuda.is_available():
        model.cuda()

    model = nn.DataParallel(model, device_ids=[0, 1, 2])

    opt = torch.optim.SGD(
        lr=arg.lr, params=model.parameters(), momentum=arg.momentum)

    # opt = torch.optim.Adam(lr=arg.lr, params=model.parameters())
    sch = torch.optim.lr_scheduler.LambdaLR(opt, lambda i: min(i / (arg.lr_warmup / arg.batch_size), 1.0))

    # best output to save model
    best_accuracy = 0
    best_model_state = None
    # training loop
    seen = 0
    for e in range(arg.num_epochs):

        print(f'\n epoch {e}')
        model.train(True)

        for batch in tqdm.tqdm(train_iter):

            opt.zero_grad()

            input = batch.text[0]
            label = batch.label - 1

            if input.size(1) > mx:
                input = input[:, :mx]
            out, att = model(input)
            print()
            loss = F.nll_loss(out, label)

            loss.backward()

            # clip gradients
            # - If the total gradient vector has a length > 1, we clip it back down to 1.
            if arg.gradient_clipping > 0.0:
                nn.utils.clip_grad_norm_(model.parameters(), arg.gradient_clipping)

            opt.step()
            sch.step()

            seen += input.size(0)
            tbw.add_scalar('classification/train-loss', float(loss.item()), seen)

        with torch.no_grad():

            model.train(False)
            tot, cor= 0.0, 0.0

            for batch in test_iter:

                input = batch.text[0]
                label = batch.label - 1

                if input.size(1) > mx:
                    input = input[:, :mx]
                out = model(input).argmax(dim=1)

                tot += float(input.size(0))
                cor += float((label == out).sum().item())

            acc = cor / tot
            print(f'-- {"test" if arg.final else "validation"} accuracy {acc:.3}')
            tbw.add_scalar('classification/test-loss', float(loss.item()), e)
            # if acc > best_accuracy:
            #     best_accuracy, best_model_state = acc, deepcopy(
            #         model.state_dict())

    return model

    # if not os.path.exists(MODEL_OUTPUT_PATH):
    #     os.mkdir(MODEL_OUTPUT_PATH)
    # time_stamp = datetime.now()
    # torch.save(best_model_state, MODEL_OUTPUT_PATH +
    #            f'{time_stamp}_SGD_momentum_{arg.momentum}_acc_{best_accuracy:3f}')



In [16]:

parser = ArgumentParser()


parser.add_argument("-e", "--num-epochs",
                        dest="num_epochs",
                        help="Number of epochs.",
                        default=80, type=int)

parser.add_argument("-b", "--batch-size",
                        dest="batch_size",
                        help="The batch size.",
                        default=4, type=int)

parser.add_argument("-l", "--learn-rate",
                        dest="lr",
                        help="Learning rate",
                        default=0.0001, type=float)

parser.add_argument("-T", "--tb_dir", dest="tb_dir",
                        help="Tensorboard logging directory",
                        default='./runs')

parser.add_argument("-f", "--final", dest="final",
                        help="Whether to run on the real test set (if not included, the validation set is used).",
                        action="store_true")

parser.add_argument("--max-pool", dest="max_pool",
                    help="Use max pooling in the final classification layer.",
                    action="store_true")

parser.add_argument("-E", "--embedding", dest="embedding_size",
                    help="Size of the character embeddings.",
                    default=128, type=int)

parser.add_argument("-V", "--vocab-size", dest="vocab_size",
                    help="Number of words in the vocabulary.",
                
                    default=50_000, type=int)

parser.add_argument("-M", "--max", dest="max_length",
                    help="Max sequence length. Longer sequences are clipped (-1 for no limit).",
                    default=512, type=int)

parser.add_argument("-H", "--heads", dest="num_heads",
                    help="Number of attention heads.",
                    default=8, type=int)

parser.add_argument("-d", "--depth", dest="depth",
                    help="Depth of the network (nr. of self-attention layers)",
                    default=6, type=int)

parser.add_argument("-r", "--random-seed",
                    dest="seed",
                    help="RNG seed. Negative for random",
                    default=1, type=int)

parser.add_argument("--lr-warmup",
                    dest="lr_warmup",
                    help="Learning rate warmup.",
                    default=10_000, type=int)

parser.add_argument("--gradient-clipping",
                    dest="gradient_clipping",
                    help="Gradient clipping.",
                    default=1.0, type=float)

parser.add_argument("--momentum",
                    dest="momentum",
                    help="momentum for SGD",
                    default=0.9, type=float)

# args, unknown = parser.parse_known_args()


args = Namespace(
    num_epochs=80,
    batch_size=4,
    lr=0.0001,
    tb_dir='./runs',
    final=False,
    max_pool=False,
    embedding_size=128,
    vocab_size=50_000,
    max_length=512,
    num_heads=8,
    depth=6,
    seed=1,
    lr_warmup=10_000,
    gradient_clipping=1.0,
    momentum=0.9
)

model = go(arg=args)

aclImdb_v1.tar.gz:   4%|▎         | 3.08M/84.1M [00:00<00:02, 30.8MB/s]

downloading aclImdb_v1.tar.gz


aclImdb_v1.tar.gz: 100%|██████████| 84.1M/84.1M [00:00<00:00, 90.2MB/s]


- nr. of training examples 5000
- nr. of validation examples 1250


  0%|          | 0/5000 [00:00<?, ?it/s]


 epoch 0


  0%|          | 4/5000 [00:02<1:51:50,  1.34s/it]








  0%|          | 8/5000 [00:02<1:18:56,  1.05it/s]









  0%|          | 15/5000 [00:02<40:03,  2.07it/s] 











  0%|          | 23/5000 [00:02<20:52,  3.97it/s]









  1%|          | 31/5000 [00:02<11:29,  7.21it/s]











  1%|          | 39/5000 [00:03<06:44, 12.26it/s]











  1%|          | 47/5000 [00:03<04:34, 18.04it/s]










  1%|          | 51/5000 [00:03<03:58, 20.74it/s]









  1%|          | 59/5000 [00:03<03:18, 24.91it/s]










  1%|▏         | 67/5000 [00:03<02:49, 29.11it/s]











  2%|▏         | 75/5000 [00:04<02:36, 31.45it/s]










  2%|▏         | 79/5000 [00:04<02:37, 31.20it/s]









  2%|▏         | 87/5000 [00:04<02:49, 29.00it/s]









  2%|▏         | 95/5000 [00:04<02:38, 31.02it/s]











  2%|▏         | 99/5000 [00:04<02:41, 30.36it/s]









  2%|▏         | 107/5000 [00:05<02:40, 30.43it/s]










  2%|▏         | 115/5000 [00:05<02:36, 31.29it/s]










  2%|▏         | 123/5000 [00:05<02:30, 32.49it/s]










  3%|▎         | 127/5000 [00:05<02:32, 31.99it/s]










  3%|▎         | 135/5000 [00:06<02:31, 32.02it/s]










  3%|▎         | 143/5000 [00:06<02:27, 32.83it/s]










  3%|▎         | 151/5000 [00:06<02:22, 33.95it/s]











  3%|▎         | 155/5000 [00:06<02:26, 33.06it/s]









  3%|▎         | 163/5000 [00:06<02:32, 31.72it/s]










  3%|▎         | 171/5000 [00:07<02:39, 30.36it/s]









  4%|▎         | 175/5000 [00:07<02:36, 30.87it/s]









  4%|▎         | 183/5000 [00:07<02:43, 29.46it/s]










  4%|▍         | 191/5000 [00:07<02:32, 31.47it/s]











  4%|▍         | 199/5000 [00:08<02:29, 32.01it/s]










  4%|▍         | 204/5000 [00:08<02:24, 33.25it/s]











  4%|▍         | 212/5000 [00:08<02:31, 31.66it/s]









  4%|▍         | 216/5000 [00:08<02:33, 31.11it/s]










  4%|▍         | 224/5000 [00:08<02:33, 31.15it/s]










  5%|▍         | 232/5000 [00:09<02:44, 28.97it/s]









  5%|▍         | 238/5000 [00:09<02:47, 28.45it/s]









  5%|▍         | 244/5000 [00:09<02:48, 28.26it/s]









  5%|▍         | 248/5000 [00:09<02:40, 29.52it/s]











  5%|▌         | 257/5000 [00:09<02:28, 31.89it/s]










  5%|▌         | 261/5000 [00:10<02:27, 32.04it/s]







  5%|▌         | 268/5000 [00:10<03:31, 22.40it/s]









  6%|▌         | 277/5000 [00:10<02:46, 28.39it/s]











  6%|▌         | 281/5000 [00:10<02:45, 28.53it/s]









  6%|▌         | 289/5000 [00:11<02:44, 28.57it/s]










  6%|▌         | 297/5000 [00:11<02:25, 32.23it/s]










  6%|▌         | 305/5000 [00:11<02:18, 33.87it/s]











  6%|▌         | 310/5000 [00:11<02:09, 36.33it/s]











  6%|▋         | 318/5000 [00:12<02:22, 32.97it/s]










  7%|▋         | 326/5000 [00:12<02:27, 31.60it/s]









  7%|▋         | 330/5000 [00:12<02:33, 30.42it/s]









  7%|▋         | 338/5000 [00:12<02:34, 30.19it/s]










  7%|▋         | 346/5000 [00:12<02:26, 31.80it/s]











  7%|▋         | 350/5000 [00:13<02:26, 31.72it/s]









  7%|▋         | 357/5000 [00:13<02:40, 28.87it/s]










  7%|▋         | 364/5000 [00:13<02:39, 28.99it/s]










  7%|▋         | 372/5000 [00:13<02:32, 30.32it/s]











  8%|▊         | 380/5000 [00:14<02:19, 33.16it/s]










  8%|▊         | 388/5000 [00:14<02:17, 33.57it/s]










  8%|▊         | 397/5000 [00:14<02:14, 34.14it/s]










  8%|▊         | 401/5000 [00:14<02:18, 33.21it/s]











  8%|▊         | 410/5000 [00:14<02:21, 32.46it/s]









  8%|▊         | 414/5000 [00:15<02:21, 32.50it/s]










  9%|▊         | 426/5000 [00:15<02:10, 35.11it/s]











  9%|▊         | 430/5000 [00:15<02:23, 31.84it/s]









  9%|▉         | 438/5000 [00:15<02:26, 31.15it/s]










  9%|▉         | 446/5000 [00:16<02:13, 34.05it/s]











  9%|▉         | 450/5000 [00:16<02:19, 32.60it/s]










  9%|▉         | 458/5000 [00:16<02:16, 33.26it/s]










  9%|▉         | 466/5000 [00:16<02:28, 30.49it/s]









  9%|▉         | 470/5000 [00:16<02:36, 29.01it/s]










 10%|▉         | 478/5000 [00:17<02:31, 29.82it/s]










 10%|▉         | 486/5000 [00:17<02:30, 30.08it/s]










 10%|▉         | 490/5000 [00:17<02:26, 30.78it/s]









 10%|▉         | 498/5000 [00:17<02:24, 31.07it/s]










 10%|█         | 506/5000 [00:18<02:20, 31.98it/s]










 10%|█         | 514/5000 [00:18<02:17, 32.63it/s]










 10%|█         | 522/5000 [00:18<02:08, 34.79it/s]











 11%|█         | 526/5000 [00:18<02:11, 33.99it/s]










 11%|█         | 534/5000 [00:18<02:13, 33.35it/s]










 11%|█         | 542/5000 [00:19<02:20, 31.64it/s]










 11%|█         | 551/5000 [00:19<02:14, 33.06it/s]











 11%|█         | 559/5000 [00:19<02:10, 33.91it/s]











 11%|█▏        | 563/5000 [00:19<02:14, 32.90it/s]









 11%|█▏        | 571/5000 [00:19<02:12, 33.34it/s]











 12%|█▏        | 579/5000 [00:20<02:27, 29.96it/s]









 12%|█▏        | 586/5000 [00:20<02:29, 29.54it/s]









 12%|█▏        | 591/5000 [00:20<02:16, 32.30it/s]











 12%|█▏        | 599/5000 [00:20<02:20, 31.24it/s]










 12%|█▏        | 607/5000 [00:21<02:21, 31.09it/s]










 12%|█▏        | 611/5000 [00:21<02:17, 31.90it/s]










 12%|█▏        | 619/5000 [00:21<02:20, 31.10it/s]









 13%|█▎        | 627/5000 [00:21<02:20, 31.07it/s]










 13%|█▎        | 631/5000 [00:21<02:18, 31.54it/s]










 13%|█▎        | 639/5000 [00:22<02:16, 32.05it/s]










 13%|█▎        | 647/5000 [00:22<02:23, 30.37it/s]









 13%|█▎        | 651/5000 [00:22<02:30, 28.85it/s]









 13%|█▎        | 657/5000 [00:22<02:33, 28.25it/s]










 13%|█▎        | 665/5000 [00:23<02:25, 29.72it/s]









 13%|█▎        | 673/5000 [00:23<02:23, 30.13it/s]










 14%|█▎        | 682/5000 [00:23<02:05, 34.39it/s]












 14%|█▎        | 686/5000 [00:23<02:16, 31.52it/s]









 14%|█▍        | 694/5000 [00:23<02:17, 31.28it/s]









 14%|█▍        | 698/5000 [00:24<02:18, 31.11it/s]










 14%|█▍        | 706/5000 [00:24<02:24, 29.80it/s]









 14%|█▍        | 714/5000 [00:24<02:17, 31.13it/s]










 14%|█▍        | 718/5000 [00:24<02:14, 31.86it/s]










 15%|█▍        | 726/5000 [00:24<02:07, 33.59it/s]










 15%|█▍        | 734/5000 [00:25<02:16, 31.23it/s]









 15%|█▍        | 738/5000 [00:25<02:22, 29.83it/s]









 15%|█▍        | 746/5000 [00:25<02:16, 31.16it/s]











 15%|█▌        | 754/5000 [00:25<02:16, 31.01it/s]









 15%|█▌        | 758/5000 [00:26<02:28, 28.64it/s]









 15%|█▌        | 766/5000 [00:26<02:18, 30.59it/s]










 15%|█▌        | 774/5000 [00:26<02:08, 32.88it/s]











 16%|█▌        | 778/5000 [00:26<02:20, 30.14it/s]









 16%|█▌        | 786/5000 [00:26<02:12, 31.88it/s]










 16%|█▌        | 794/5000 [00:27<02:09, 32.36it/s]










 16%|█▌        | 802/5000 [00:27<02:14, 31.22it/s]









 16%|█▌        | 806/5000 [00:27<02:14, 31.16it/s]










 16%|█▋        | 814/5000 [00:27<02:13, 31.33it/s]










 16%|█▋        | 819/5000 [00:27<02:02, 34.24it/s]










 17%|█▋        | 827/5000 [00:28<02:13, 31.20it/s]










 17%|█▋        | 835/5000 [00:28<02:07, 32.57it/s]










 17%|█▋        | 843/5000 [00:28<02:12, 31.28it/s]










 17%|█▋        | 847/5000 [00:28<02:13, 31.01it/s]










 17%|█▋        | 855/5000 [00:29<02:15, 30.60it/s]










 17%|█▋        | 863/5000 [00:29<02:08, 32.22it/s]











 17%|█▋        | 871/5000 [00:29<02:01, 33.91it/s]











 18%|█▊        | 880/5000 [00:29<02:04, 33.17it/s]










 18%|█▊        | 884/5000 [00:29<02:12, 31.07it/s]









 18%|█▊        | 892/5000 [00:30<02:12, 30.98it/s]










 18%|█▊        | 900/5000 [00:30<02:08, 32.00it/s]










 18%|█▊        | 904/5000 [00:30<02:12, 31.03it/s]










 18%|█▊        | 912/5000 [00:30<02:15, 30.19it/s]









 18%|█▊        | 919/5000 [00:31<02:25, 27.99it/s]









 19%|█▊        | 927/5000 [00:31<02:07, 32.05it/s]











 19%|█▊        | 935/5000 [00:31<02:00, 33.68it/s]










 19%|█▉        | 939/5000 [00:31<02:02, 33.23it/s]










 19%|█▉        | 947/5000 [00:31<01:55, 35.04it/s]











 19%|█▉        | 955/5000 [00:32<02:04, 32.57it/s]









 19%|█▉        | 963/5000 [00:32<01:58, 34.07it/s]











 19%|█▉        | 967/5000 [00:32<02:05, 32.25it/s]










 20%|█▉        | 975/5000 [00:32<02:01, 33.21it/s]










 20%|█▉        | 983/5000 [00:32<02:03, 32.61it/s]










 20%|█▉        | 992/5000 [00:33<01:56, 34.31it/s]











 20%|█▉        | 996/5000 [00:33<01:53, 35.37it/s]









 20%|██        | 1004/5000 [00:33<02:05, 31.81it/s]










 20%|██        | 1012/5000 [00:33<02:04, 32.07it/s]










 20%|██        | 1016/5000 [00:34<02:04, 32.09it/s]









 20%|██        | 1024/5000 [00:34<02:04, 32.06it/s]










 21%|██        | 1032/5000 [00:34<01:53, 34.95it/s]











 21%|██        | 1040/5000 [00:34<02:02, 32.22it/s]










 21%|██        | 1044/5000 [00:34<02:04, 31.70it/s]









 21%|██        | 1052/5000 [00:35<02:10, 30.17it/s]









 21%|██        | 1056/5000 [00:35<02:20, 28.09it/s]









 21%|██▏       | 1063/5000 [00:35<02:14, 29.23it/s]










 21%|██▏       | 1071/5000 [00:35<02:05, 31.23it/s]










 22%|██▏       | 1079/5000 [00:36<01:59, 32.72it/s]











 22%|██▏       | 1087/5000 [00:36<01:57, 33.31it/s]











 22%|██▏       | 1095/5000 [00:36<01:50, 35.32it/s]










 22%|██▏       | 1099/5000 [00:36<02:05, 31.06it/s]









 22%|██▏       | 1107/5000 [00:36<01:57, 33.02it/s]











 22%|██▏       | 1116/5000 [00:37<01:56, 33.42it/s]










 22%|██▏       | 1120/5000 [00:37<01:57, 32.91it/s]









 23%|██▎       | 1128/5000 [00:37<02:12, 29.30it/s]









 23%|██▎       | 1132/5000 [00:37<02:08, 30.21it/s]










 23%|██▎       | 1140/5000 [00:37<02:06, 30.61it/s]











 23%|██▎       | 1148/5000 [00:38<02:08, 29.97it/s]









 23%|██▎       | 1152/5000 [00:38<02:08, 30.03it/s]









 23%|██▎       | 1162/5000 [00:38<02:12, 28.95it/s]









 23%|██▎       | 1166/5000 [00:38<02:05, 30.45it/s]










 23%|██▎       | 1174/5000 [00:39<02:06, 30.26it/s]










 24%|██▎       | 1182/5000 [00:39<02:02, 31.13it/s]










 24%|██▍       | 1190/5000 [00:39<01:56, 32.77it/s]











 24%|██▍       | 1194/5000 [00:39<01:56, 32.68it/s]









 24%|██▍       | 1202/5000 [00:39<01:58, 32.12it/s]










 24%|██▍       | 1210/5000 [00:40<01:59, 31.78it/s]










 24%|██▍       | 1218/5000 [00:40<01:56, 32.58it/s]










 24%|██▍       | 1222/5000 [00:40<01:54, 32.87it/s]










 25%|██▍       | 1230/5000 [00:40<01:57, 32.16it/s]











 25%|██▍       | 1238/5000 [00:41<01:54, 32.79it/s]










 25%|██▍       | 1246/5000 [00:41<02:03, 30.48it/s]









 25%|██▌       | 1250/5000 [00:41<02:03, 30.28it/s]










 25%|██▌       | 1259/5000 [00:41<02:02, 30.56it/s]










 25%|██▌       | 1267/5000 [00:41<01:54, 32.66it/s]











 25%|██▌       | 1271/5000 [00:42<01:49, 33.94it/s]










 26%|██▌       | 1279/5000 [00:42<01:59, 31.17it/s]









 26%|██▌       | 1283/5000 [00:42<02:04, 29.82it/s]









 26%|██▌       | 1292/5000 [00:42<02:02, 30.21it/s]











 26%|██▌       | 1300/5000 [00:43<01:59, 30.92it/s]










 26%|██▌       | 1308/5000 [00:43<01:55, 31.84it/s]










 26%|██▋       | 1316/5000 [00:43<01:48, 34.09it/s]











 26%|██▋       | 1320/5000 [00:43<01:52, 32.83it/s]









 27%|██▋       | 1328/5000 [00:43<01:58, 30.88it/s]









 27%|██▋       | 1332/5000 [00:44<02:02, 29.90it/s]







 27%|██▋       | 1336/5000 [00:44<02:57, 20.67it/s]










 27%|██▋       | 1344/5000 [00:44<02:26, 24.91it/s]










 27%|██▋       | 1351/5000 [00:44<02:13, 27.31it/s]









 27%|██▋       | 1357/5000 [00:45<02:15, 26.91it/s]









 27%|██▋       | 1364/5000 [00:45<02:07, 28.46it/s]










 27%|██▋       | 1373/5000 [00:45<01:53, 32.05it/s]











 28%|██▊       | 1381/5000 [00:45<01:46, 33.97it/s]











 28%|██▊       | 1385/5000 [00:45<01:48, 33.28it/s]









 28%|██▊       | 1393/5000 [00:46<01:48, 33.21it/s]










 28%|██▊       | 1397/5000 [00:46<01:56, 30.86it/s]









 28%|██▊       | 1405/5000 [00:46<01:56, 30.80it/s]











 28%|██▊       | 1413/5000 [00:46<01:52, 31.99it/s]











 28%|██▊       | 1421/5000 [00:47<01:46, 33.51it/s]









 29%|██▊       | 1429/5000 [00:47<01:46, 33.68it/s]










 29%|██▊       | 1433/5000 [00:47<01:50, 32.38it/s]










 29%|██▉       | 1441/5000 [00:47<01:55, 30.75it/s]










 29%|██▉       | 1446/5000 [00:47<01:47, 32.94it/s]









 29%|██▉       | 1454/5000 [00:48<01:52, 31.40it/s]











 29%|██▉       | 1462/5000 [00:48<01:48, 32.59it/s]









 29%|██▉       | 1470/5000 [00:48<01:52, 31.34it/s]










 29%|██▉       | 1474/5000 [00:48<01:52, 31.39it/s]









 30%|██▉       | 1482/5000 [00:48<01:50, 31.94it/s]










 30%|██▉       | 1490/5000 [00:49<01:47, 32.58it/s]










 30%|██▉       | 1494/5000 [00:49<01:50, 31.85it/s]









 30%|███       | 1502/5000 [00:49<01:52, 31.12it/s]










 30%|███       | 1506/5000 [00:49<01:53, 30.87it/s]









 30%|███       | 1514/5000 [00:49<01:51, 31.34it/s]










 30%|███       | 1522/5000 [00:50<01:53, 30.72it/s]









 31%|███       | 1526/5000 [00:50<01:54, 30.41it/s]











 31%|███       | 1535/5000 [00:50<01:53, 30.50it/s]









 31%|███       | 1543/5000 [00:50<01:55, 29.81it/s]









 31%|███       | 1547/5000 [00:51<01:59, 28.85it/s]









 31%|███       | 1555/5000 [00:51<01:51, 30.91it/s]










 31%|███       | 1559/5000 [00:51<01:50, 31.11it/s]










 31%|███▏      | 1567/5000 [00:51<01:51, 30.67it/s]










 32%|███▏      | 1575/5000 [00:51<01:50, 31.12it/s]










 32%|███▏      | 1584/5000 [00:52<01:44, 32.83it/s]











 32%|███▏      | 1588/5000 [00:52<01:52, 30.25it/s]










 32%|███▏      | 1596/5000 [00:52<01:48, 31.48it/s]










 32%|███▏      | 1604/5000 [00:52<01:42, 33.29it/s]










 32%|███▏      | 1612/5000 [00:53<01:44, 32.56it/s]










 32%|███▏      | 1616/5000 [00:53<01:44, 32.25it/s]










 32%|███▏      | 1624/5000 [00:53<01:51, 30.40it/s]









 33%|███▎      | 1632/5000 [00:53<01:51, 30.33it/s]










 33%|███▎      | 1640/5000 [00:54<01:45, 31.88it/s]











 33%|███▎      | 1648/5000 [00:54<01:39, 33.60it/s]











 33%|███▎      | 1652/5000 [00:54<01:42, 32.55it/s]










 33%|███▎      | 1660/5000 [00:54<01:40, 33.20it/s]












 33%|███▎      | 1669/5000 [00:54<01:37, 34.09it/s]










 34%|███▎      | 1677/5000 [00:55<01:39, 33.33it/s]










 34%|███▎      | 1685/5000 [00:55<01:35, 34.55it/s]











 34%|███▍      | 1693/5000 [00:55<01:43, 31.82it/s]









 34%|███▍      | 1697/5000 [00:55<01:46, 31.14it/s]









 34%|███▍      | 1705/5000 [00:55<01:47, 30.75it/s]










 34%|███▍      | 1709/5000 [00:56<01:47, 30.67it/s]










 34%|███▍      | 1717/5000 [00:56<01:38, 33.17it/s]











 34%|███▍      | 1725/5000 [00:56<01:41, 32.40it/s]










 35%|███▍      | 1733/5000 [00:56<01:38, 33.21it/s]










 35%|███▍      | 1741/5000 [00:57<01:43, 31.35it/s]









 35%|███▍      | 1745/5000 [00:57<01:41, 32.13it/s]










 35%|███▌      | 1753/5000 [00:57<01:45, 30.82it/s]










 35%|███▌      | 1761/5000 [00:57<01:36, 33.52it/s]










 35%|███▌      | 1769/5000 [00:57<01:37, 33.03it/s]










 35%|███▌      | 1773/5000 [00:58<01:36, 33.54it/s]










 36%|███▌      | 1781/5000 [00:58<01:41, 31.85it/s]










 36%|███▌      | 1789/5000 [00:58<01:43, 30.88it/s]










 36%|███▌      | 1793/5000 [00:58<01:45, 30.31it/s]









 36%|███▌      | 1801/5000 [00:58<01:41, 31.54it/s]










 36%|███▌      | 1805/5000 [00:59<01:43, 30.97it/s]








KeyboardInterrupt: 

 36%|███▌      | 1805/5000 [01:10<01:43, 30.97it/s]

In [None]:
if __name__ == "__main__":

    parser = ArgumentParser()

    parser.add_argument("-e", "--num-epochs",
                        dest="num_epochs",
                        help="Number of epochs.",
                        default=80, type=int)

    parser.add_argument("-b", "--batch-size",
                        dest="batch_size",
                        help="The batch size.",
                        default=4, type=int)

    parser.add_argument("-l", "--learn-rate",
                        dest="lr",
                        help="Learning rate",
                        default=0.0001, type=float)

    parser.add_argument("-T", "--tb_dir", dest="tb_dir",
                        help="Tensorboard logging directory",
                        default='./runs')

    parser.add_argument("-f", "--final", dest="final",
                        help="Whether to run on the real test set (if not included, the validation set is used).",
                        action="store_true")

    parser.add_argument("--max-pool", dest="max_pool",
                        help="Use max pooling in the final classification layer.",
                        action="store_true")

    parser.add_argument("-E", "--embedding", dest="embedding_size",
                        help="Size of the character embeddings.",
                        default=128, type=int)

    parser.add_argument("-V", "--vocab-size", dest="vocab_size",
                        help="Number of words in the vocabulary.",
                        default=50_000, type=int)

    parser.add_argument("-M", "--max", dest="max_length",
                        help="Max sequence length. Longer sequences are clipped (-1 for no limit).",
                        default=512, type=int)

    parser.add_argument("-H", "--heads", dest="num_heads",
                        help="Number of attention heads.",
                        default=8, type=int)

    parser.add_argument("-d", "--depth", dest="depth",
                        help="Depth of the network (nr. of self-attention layers)",
                        default=6, type=int)

    parser.add_argument("-r", "--random-seed",
                        dest="seed",
                        help="RNG seed. Negative for random",
                        default=1, type=int)

    parser.add_argument("--lr-warmup",
                        dest="lr_warmup",
                        help="Learning rate warmup.",
                        default=10_000, type=int)

    parser.add_argument("--gradient-clipping",
                        dest="gradient_clipping",
                        help="Gradient clipping.",
                        default=1.0, type=float)

    parser.add_argument("--momentum",
                        dest="momentum",
                        help="momentum for SGD",
                        default=0.9, type=float)

    options = parser.parse_args()

    print('OPTIONS ', options)

    model = go(options)