In [177]:
# These are all the modules we'll be using later. Make sure you can import them
# before proceeding further.
%matplotlib inline
import collections
import math
import numpy as np
import pandas as pd
import os
import random
import torch
import torch.nn as nn
import zipfile
from matplotlib import pylab
from six.moves import range
from six.moves.urllib.request import urlretrieve
from torch.nn.utils.rnn import pad_sequence
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from torchtext.data.utils import get_tokenizer, ngrams_iterator
from torchtext.datasets import DATASETS
from torchtext.utils import download_from_url
from torchtext.vocab import build_vocab_from_iterator
import torch.nn as nn
from torchtext.data.utils import get_tokenizer
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torchtext.vocab import FastText, CharNGram
from itertools import chain

seed = 54321

### Download the data

In [178]:
url = 'https://github.com/ZihanWangKi/CrossWeigh/raw/master/data/'
dir_name = 'data'
#https://github.com/ZihanWangKi/CrossWeigh/raw/master/data/conllpp_train.txt
def download_data(url, filename, download_dir, expected_bytes):
    """Download a file if not present, and make sure it's the right size."""
      
    # Create directories if doesn't exist
    os.makedirs(download_dir, exist_ok=True)
    
    # If file doesn't exist download
    if not os.path.exists(os.path.join(download_dir,filename)):
        filepath, _ = urlretrieve(url + filename, os.path.join(download_dir,filename))
    else:
        filepath = os.path.join(download_dir, filename)
    
    # Check the file size
    statinfo = os.stat(filepath)
    if statinfo.st_size == expected_bytes:
        print('Found and verified %s' % filepath)
    else:
        print(statinfo.st_size)
        raise Exception(
          'Failed to verify ' + filepath + '. Can you get to it with a browser?')
        
    return filepath

# Filepaths to train/valid/test data
train_filepath = download_data(url, 'conllpp_train.txt', dir_name, 3283420)
dev_filepath = download_data(url, 'conllpp_dev.txt', dir_name, 827443)
test_filepath = download_data(url, 'conllpp_test.txt', dir_name, 748737)

Found and verified data/conllpp_train.txt
Found and verified data/conllpp_dev.txt
Found and verified data/conllpp_test.txt


In [179]:
!head data/conllpp_train.txt

-DOCSTART- -X- -X- O

EU NNP B-NP B-ORG
rejects VBZ B-VP O
German JJ B-NP B-MISC
call NN I-NP O
to TO B-VP O
boycott VB I-VP O
British JJ B-NP B-MISC
lamb NN I-NP O


### Read the data

In [180]:
def read_data(filename):
    '''
    Read data from a file with given filename
    Returns a list of sentences (each sentence a string), 
    and list of ner labels for each string
    '''

    print("Reading data ...")
    # master lists - Holds sentences (list of tokens), ner_labels (for each token an NER label)
    sentences, ner_labels = [], [] 
    
    # Open the file
    with open(filename,'r',encoding='latin-1') as f:        
        # Read each line
        is_sos = True # We record at each line if we are seeing the beginning of a sentence
        
        # Tokens and labels of a single sentence, flushed when encountered a new one
        sentence_tokens = []
        sentence_labels = []
        i = 0
        for row in f:
            # If we are seeing an empty line or -DOCSTART- that's a new line
            if len(row.strip()) == 0 or row.split(' ')[0] == '-DOCSTART-':
                is_sos = False
            # Otherwise keep capturing tokens and labels
            else:
                is_sos = True
                token, _, _, ner_tag = row.split(' ')
                sentence_tokens.append(token)
                sentence_labels.append(ner_tag.strip())
            
            # When we reach the end / or reach the beginning of next
            # add the data to the master lists, flush the temporary one
            if not is_sos and len(sentence_tokens)>0:
                sentences.append(' '.join(sentence_tokens))
                ner_labels.append(sentence_labels)
                sentence_tokens, sentence_labels = [], []
    
    print('\tDone')
    return sentences, ner_labels

# Train data
train_sentences, train_labels = read_data(train_filepath) 
# Validation data
valid_sentences, valid_labels = read_data(dev_filepath) 
# Test data
test_sentences, test_labels = read_data(test_filepath) 

# Print some stats
print(f"Train size: {len(train_labels)}")
print(f"Valid size: {len(valid_labels)}")
print(f"Test size: {len(test_labels)}")

# Print some data
print('\nSample data\n')
for v_sent, v_labels in zip(valid_sentences[:5], valid_labels[:5]):
    print(f"Sentence: {v_sent}")
    print(f"Labels: {v_labels}")
    assert(len(v_sent.split(' ')) == len(v_labels))
    print('\n')

Reading data ...
	Done
Reading data ...
	Done
Reading data ...
	Done
Train size: 14041
Valid size: 3250
Test size: 3452

Sample data

Sentence: CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .
Labels: ['O', 'O', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


Sentence: LONDON 1996-08-30
Labels: ['B-LOC', 'O']


Sentence: West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39 runs in two days to take over at the head of the county championship .
Labels: ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


Sentence: Their stay on top , though , may be short-lived as title rivals Essex , Derbyshire and Surrey all closed in on victory while Kent made up for lost time in their rain-affected match against Nottinghamshire .
Labels: ['O', 'O', 'O', 'O', 'O', 'O', '

In [181]:
assert(len(train_labels) == 14041)
assert(len(valid_labels) == 3250)
assert(len(test_labels) == 3452)

In [182]:
# We build these since the basic english tokenizer does get rid of some tokens that are useful
# Lowercase everything to make this simplier
class SentenceTokenizer():
    def __call__(self, sentence):
        return sentence.lower().split(' ')
    
class WordTokenizer():
    def __call__(self, word):
        return [c for c in word.lower()]

In [183]:
SENTENCE_TOKENIZER = SentenceTokenizer()
WORD_TOKENIZER = WordTokenizer()

In [184]:
assert(len(WORD_TOKENIZER("this is a sentence")) == 18)
assert(len(SENTENCE_TOKENIZER("this is a sentence")) == 4)  

In [185]:
sentences = train_sentences + test_sentences + valid_sentences
labels = train_labels + test_labels + valid_labels

def yield_word_tokens(sentences):
    for sentence in sentences:
        word_tokens = SENTENCE_TOKENIZER(sentence)
        # A list of word tokens.
        yield word_tokens
        
def yield_char_tokens(sentences):
    for word_tokens in yield_word_tokens(sentences):
        for word_token in word_tokens:
            char_tokens = WORD_TOKENIZER(word_token)
            yield char_tokens

In [186]:
WORD_VOCAB = build_vocab_from_iterator(
    yield_word_tokens(sentences),
    specials=('<pad>', '<unk>')
)

CHAR_VOCAB = build_vocab_from_iterator(
    yield_char_tokens(sentences),
    specials=('<pad>', '<unk>')
)

In [187]:
# Example: You should see 4 integer tokens below.
WORD_VOCAB(SENTENCE_TOKENIZER("this is a sentence"))

[64, 31, 8, 1780]

In [188]:
# Example: You should see 4 integer tokens below.
CHAR_VOCAB(WORD_TOKENIZER("Xhis"))

[42, 12, 6, 8]

In [189]:
# Get the word to idx and idx to char dictionaries
wtoi = WORD_VOCAB.get_stoi()
itow = WORD_VOCAB.get_itos()
# Get the char to idx and idx to char dictionaries
ctoi = CHAR_VOCAB.get_stoi()
itoc = CHAR_VOCAB.get_itos()

In [190]:
ctoi

{'#': 60,
 '`': 59,
 'y': 21,
 'b': 20,
 'w': 19,
 'g': 18,
 '!': 58,
 'f': 17,
 'm': 15,
 'c': 13,
 '<unk>': 1,
 'o': 7,
 'v': 27,
 '&': 51,
 'p': 16,
 'l': 10,
 '+': 48,
 ':': 45,
 'e': 2,
 'a': 3,
 '5': 33,
 '-': 24,
 '<pad>': 0,
 '/': 46,
 'i': 6,
 'k': 26,
 '=': 52,
 'n': 5,
 'd': 11,
 '"': 40,
 '1': 23,
 ',': 25,
 'u': 14,
 '*': 50,
 '6': 31,
 't': 4,
 'r': 9,
 '3': 32,
 '2': 29,
 '9': 30,
 '4': 34,
 '8': 35,
 ']': 56,
 '7': 36,
 '0': 28,
 '(': 37,
 'h': 12,
 's': 8,
 'j': 39,
 "'": 41,
 '$': 47,
 '.': 22,
 'x': 42,
 'z': 43,
 'q': 44,
 ';': 49,
 '%': 53,
 '?': 54,
 ')': 38,
 '[': 55,
 '@': 57}

In [191]:
assert(len(wtoi) == 26871)
assert(len(ctoi) == 61)

In [192]:
# You should see 0 and 0 below
WORD_VOCAB['<pad>'], CHAR_VOCAB['<pad>']

(0, 0)

In [193]:
# You should see 1 and 1 below
WORD_VOCAB['<unk>'], CHAR_VOCAB['<unk>']

(1, 1)

In [194]:
# We need to carefully weight all the classes 
# We use w(c) = min(freq(l)) / freq(c); lower frequency classes 
# So a low class gets a weight that's higher, a higher class a lower weight
def get_label_id_map(labels):
    # Get the unique list of labels
    unique_labels = pd.Series(chain(*labels)).unique()
    # Create a dictionary label to idx, starting with idx 0
    ltoi = dict(zip(unique_labels, np.arange(unique_labels.shape[0])))
    # Make a map from idx to label
    itol = {i : label for label, i in ltoi.items()}
    pd.Series(chain(*labels))
    
    itolw = {}
    
    label_to_count = pd.Series(chain(*labels)).value_counts()
    
    for label, count in label_to_count.items():
        itolw[ltoi[label]] = label_to_count.min() / count
    
    # Return (ltoi, itol, itolw)
    return ltoi, itol, itolw

In [195]:
assert(len(pd.Series(chain(*train_labels)).unique()) == 9)

In [196]:
ltoi, itol, itolw = get_label_id_map(train_labels)

In [197]:
for l, idx in ltoi.items():
  assert(l == itol[idx])
  assert(idx in itolw)

In [198]:
# Look at the weights per tag
itolw

{1: 0.006811025015037328,
 5: 0.16176470588235295,
 3: 0.175,
 0: 0.18272425249169436,
 4: 0.25507950530035334,
 6: 0.31182505399568033,
 2: 0.33595113438045376,
 8: 0.9982713915298185,
 7: 1.0}

In [199]:
ltoi

{'B-ORG': 0,
 'O': 1,
 'B-MISC': 2,
 'B-PER': 3,
 'I-PER': 4,
 'B-LOC': 5,
 'I-ORG': 6,
 'I-MISC': 7,
 'I-LOC': 8}

In [200]:
assert(min(itolw.values()) == 0.006811025015037328)

In [201]:
# Get the weights per class as a tensor
weights = torch.zeros(len(itolw))
for i, lw in itolw.items():
    weights[i] = lw

In [202]:
labels = pd.Series(chain(*train_labels))

In [203]:
print(labels)

0          B-ORG
1              O
2         B-MISC
3              O
4              O
           ...  
203616         O
203617     B-ORG
203618         O
203619     B-ORG
203620         O
Length: 203621, dtype: object


In [204]:
for k, v in labels.value_counts().items():
    print(k, v)

O 169578
B-LOC 7140
B-PER 6600
B-ORG 6321
I-PER 4528
I-ORG 3704
B-MISC 3438
I-LOC 1157
I-MISC 1155


In [205]:
assert(labels.value_counts().min() == 1155)

### Check for class balance

In [206]:
# Print the value count for each label
print("Training data label counts")
print(pd.Series(chain(*train_labels)).value_counts())

print("\nValidation data label counts")
print(pd.Series(chain(*valid_labels)).value_counts())

print("\nTest data label counts")
print(pd.Series(chain(*test_labels)).value_counts())

Training data label counts
O         169578
B-LOC       7140
B-PER       6600
B-ORG       6321
I-PER       4528
I-ORG       3704
B-MISC      3438
I-LOC       1157
I-MISC      1155
dtype: int64

Validation data label counts
O         42759
B-PER      1842
B-LOC      1837
B-ORG      1341
I-PER      1307
B-MISC      922
I-ORG       751
I-MISC      346
I-LOC       257
dtype: int64

Test data label counts
O         38143
B-ORG      1714
B-LOC      1645
B-PER      1617
I-PER      1161
I-ORG       881
B-MISC      722
I-LOC       259
I-MISC      252
dtype: int64


### Series length.

In [207]:
# Display the mean sentence length for the training samples
# You should get around 15 mean ...  What about median, 95%, etc?
# .describe applied to a certain series is a good idea ...
pd.Series(train_sentences).str.split().str.len().describe(percentiles=[0.05, 0.95])

count    14041.000000
mean        14.501887
std         11.602756
min          1.000000
5%           2.000000
50%         10.000000
95%         37.000000
max        113.000000
dtype: float64

### Parameters

In [208]:
# Size of token embeddings
d_model = 300

# Number of hidden units in the GRU layer
d_hidden = 64

# Number of hidden units in the GRU layer
d_char = 32

# Number of output nodes in the last layer
num_classes = len(itol)

# Number of samples in a batch
BATCH_SIZE = 128

# Number of training epochs.
EPOCHS = 25

# FastText embeddings
FAST_TEXT = FastText("simple")

# Learning rate
LR = 1.0

# Get the weights per class
weight = weights

# Maximum word length; critical for convolutions
MAX_WORD_LENGTH = 12

# The device to run on
# Change this to 'mps' if you are on a mac with MPS
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [209]:
assert(len(train_sentences) // BATCH_SIZE == 109)

In [210]:
def collate_batch(batch):
    label_list, sentence_list, sentence_lengths = [], [], []
    word_list = []
    # The sentence below is already transformed to int tokens.
    for sentence, words, labels in batch:
        sentence_list.append(torch.tensor(
            sentence, 
            dtype=torch.int64
        ))
        sentence_lengths.append(len(sentence))
        label_list.append(
            torch.tensor(
                labels, 
                dtype=torch.int64
            )
        )
        word_list.append(
           torch.tensor(
            words, 
            dtype=torch.int64
        ))
            
    return (
        # (N, L_sentence)
        nn.utils.rnn.pad_sequence(
            sentence_list,
            batch_first=True
        ).to(DEVICE),
        nn.utils.rnn.pad_sequence(
            label_list,
            batch_first=True,
            padding_value=-1 # This is not like the vocabulary, this will be ignored by the loss in particular
        ).to(DEVICE),
        torch.tensor(sentence_lengths).to(DEVICE),
        # (N, L_sentence, L_word) where L_word (max) = 12
        # This is padded at the word level, but not sentence level
        nn.utils.rnn.pad_sequence( 
            word_list,
            batch_first=True
        ).to(DEVICE)
    )

In [229]:
def get_dl(sentences, labels):
    
    # Maybe sort by the sentences by length so batches have roughly the same data?
    
    data = []
    
    # Note that we need to do our own 
    for sentence, labels in zip(sentences, labels):
        word_tokens = SENTENCE_TOKENIZER(sentence)
        int_sentence = WORD_VOCAB(word_tokens)
        int_words = []
        for word_token in word_tokens:
            int_words.append(
                CHAR_VOCAB(
                    WORD_TOKENIZER(word_token[:MAX_WORD_LENGTH]) + max(0, MAX_WORD_LENGTH - len(word_token)) * ['<pad>']
                )
            )
                    
        labels = [ltoi[label] for label in labels]
        assert(len(int_sentence) == len(labels))
        data.append([int_sentence, int_words, labels])
        
    return DataLoader(
        data,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_batch
    )    

train_dl = get_dl(train_sentences, train_labels)
valid_dl = get_dl(valid_sentences, valid_labels)
test_dl = get_dl(test_sentences, test_labels)

In [219]:
assert(len(train_dl) == 110)

In [232]:
class GRUNERModel(nn.Module):
    def __init__(
        self,
        num_class,
        d_model, 
        d_hidden,
        initialize = True,
        fine_tune_embeddings = True,
        use_conv_embeddings = True,
    ):
        
        super(GRUNERModel, self).__init__()
        self.vocab_size = len(WORD_VOCAB)
        self.d_model = d_model
        self.d_hidden = d_hidden
        self.d_char = 32
        self.kernel = 5
        self.max_word_length = MAX_WORD_LENGTH
        self.use_conv_embeddings = use_conv_embeddings
        
        if self.use_conv_embeddings:
            # 12 - 5 + 1 = 8
            # Input data will be (N * L_sentence, D_char, L_word = 12)
            self.conv = nn.Conv1d(self.d_char, self.d_char, self.kernel)
            # Will results in (N * L_sentence, D_char, 8) data.
            # D_char is 32.
            # Will result is (32, 1) vector for each word.
            self.max_pool = nn.MaxPool1d(self.max_word_length - self.kernel + 1)
            
        self.embedding = nn.Embedding(
            len(WORD_VOCAB),
            d_model if not initialize else 300, # The FastText embeddings we use below have dimension 300
            padding_idx = 0
        )
        
        self.char_embedding = nn.Embedding(
            len(CHAR_VOCAB),
            self.d_char, 
            padding_idx = 0
        )
        
        if initialize:
            self.embedding.weight.requires_grad = False
            for i in range(len(WORD_VOCAB)):
                token = WORD_VOCAB.lookup_token(i)
                self.embedding.weight[i, :] = FAST_TEXT.get_vecs_by_tokens(
                    token, 
                    lower_case_backup=True
                )
            self.embedding.weight.requires_grad = True
        else:
            self.init_weights()
                
        if not fine_tune_embeddings:
            self.embedding.weight.requires_grad = False
        
        self.rnn = nn.GRU(
            self.d_model + self.d_char, 
            self.d_hidden,
            batch_first=True,
            bidirectional=True
        )

        # Bidirectional; so we go from 2 * d_hidden to num_class
        self.fc = nn.Linear(2 * self.d_hidden, num_class)

        # Note: for drop out + ReLu, order does not matters
        # Use 0.3 for the dropout probability
        self.dropout = nn.Dropout(0.3)
        
    def init_weights(self):
        # Initialize the word embedding layer with uniform random variables between (-initrange, initrange)
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        if self.use_conv_embeddings:
          self.char_embedding.weight.data.uniform_(-initrange, initrange)

    # N = batch_size,
    # L_sentence = sequence length
    # D_word = word embedding length
    # D_char = char embedding length
    # Hout = hidden dimenson from bidirectional GRU
    # C = number of classes
    def forward(self, sentences, lengths, words):
        # (N, L_sentence, D_word)
        embedded_sentences = self.embedding(sentences.int()) 
        
        if self.use_conv_embeddings:                        
            # (N, L_sentence, L_word, D_char)
            embedded_words = self.char_embedding(words.int())
                                                
            N, L_sentence, L_word, D_char = embedded_words.shape
            
            # (N * L_sentence, L_word, D_char)
            embedded_words = embedded_words.view(N * L_sentence, L_word, -1)

            # (N * L_sentence, D_char, L_word)                        
            embedded_words = torch.swapaxes(embedded_words, 2, 1)
                        
            # 12 - 4, since kernel size is 5
            # (N * L_sentence, D_char, L_word - kernel_size + 1 )
            embedded_words = self.conv(embedded_words)
                        
            # (N * L_sentence, D_char, 1)
            embedded_words = self.max_pool(embedded_words).squeeze()
                        
            # (N, L_sentence, D_char)
            embedded_words = embedded_words.view(N, L_sentence, -1)

            #  (N, L_sentence, D_char + D_word)           
            embedded_sentences = torch.cat([embedded_sentences, embedded_words], axis=-1)
            
        # This is a key for efficient computation. 
        # Pack the padded embeddings. Magic.
        embedded_sentences = nn.utils.rnn.pack_padded_sequence(
            embedded_sentences,
            lengths.cpu().numpy(),
            enforce_sorted=False,
            batch_first=True
        )
        
        # (N * L_sentence sort of, Hout)
        logits, _ = self.rnn(embedded_sentences)
        
         # (N, L_sentence, Hout) 
        logits, _ = nn.utils.rnn.pad_packed_sequence(logits, batch_first=True) # (N, L, Hout)

        # (N, L, C)
        logits = self.fc(logits)
        
        return logits

In [233]:
# Used so we do not include padding indices.
# Also, give different weights to different classes to account for class imbalance.
criterion = torch.nn.CrossEntropyLoss(weight = weights, ignore_index=-1).to(DEVICE)

model = GRUNERModel(
    num_classes,
    d_model,
    d_hidden,
    initialize=True,
    fine_tune_embeddings=True,
    use_conv_embeddings=True,
).to(DEVICE)

optimizer = torch.optim.SGD(model.parameters(), lr=LR)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

In [234]:
from re import escape
def train(dl, model, optimizer, criterion, epoch):
    model.train()
    total_acc, total_count = 0, 0
    total_loss, total_batches = 0.0, 0.0
    log_interval = 50

    for idx, (sentences, labels, lengths, words) in enumerate(dl):
        optimizer.zero_grad()
                        
        logits = model(sentences, lengths, words)
                           
        # Get the loss
        N, L, _ = logits.shape
        logits = logits.view(N * L, -1)
        labels = labels.view(N * L)
        loss = criterion(input=logits, target=labels)
        
        total_loss += loss.item()
        total_batches += 1
        
        # Do back propagation
        loss.backward()
        
        # Clip the gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        
        # Do an optimization step
        optimizer.step()
        model.eval()

        # Get the mask and then find out the predictions
        masks = (labels != -1)
        total_acc += (logits.argmax(-1) == labels)[masks].sum().item()
        total_count += masks.sum()

        model.train()
        if idx % log_interval == 0 and idx > 0:
            print(
                "| epoch {:3d} | {:5d}/{:5d} batches "
                "| accuracy {:8.3f} "
                "| loss {:8.3f}".format(
                    epoch,
                    idx,
                    len(dl),
                    total_acc / total_count,
                    total_loss / total_batches
                )
            )
            total_acc, total_count = 0, 0
            total_loss, total_batches  = 0.0, 0.0

In [235]:
def evaluate(dl, model):
    model.eval()
    total_acc, total_count = 0, 0
    total_loss, total_batches = 0.0, 0.0

    with torch.no_grad():
        for idx, (sentences, labels, lengths, words) in enumerate(dl):
            logits = model(sentences, lengths, words)
            N, L, _ = logits.shape
            logits = logits.view(N * L, -1)
            labels = labels.view(N * L)
            
            total_loss += criterion(input=logits, target=labels)
            total_batches += 1
            
            masks = (labels != -1)
            total_acc += (logits.argmax(-1) == labels)[masks].sum().item()
            total_count += masks.sum()
            
    return total_acc / total_count, total_loss / total_batches

In [236]:
from time import time
import time

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dl, model, optimizer, criterion, epoch)
    accu_val, loss_val = evaluate(valid_dl, model)
    scheduler.step()
    print("-" * 59)
    print(
        "| end of epoch {:3d} | time: {:5.2f}s "
        "| valid accuracy {:8.3f} "
        "| valid loss {:8.3f} ".format(
            epoch,
            time.time() - epoch_start_time,
            accu_val,
            loss_val
        )
    )
    print("-" * 59)

print("Checking the results of test dataset.")
accu_test, loss_test = evaluate(test_dl, model)
print("test accuracy {:8.3f} | test loss {:8.3f}".format(accu_test, loss_test))

| epoch   1 |    50/  110 batches | accuracy    0.442 | loss    1.946
| epoch   1 |   100/  110 batches | accuracy    0.736 | loss    1.468
-----------------------------------------------------------
| end of epoch   1 | time: 26.13s | valid accuracy    0.799 | valid loss    1.190 
-----------------------------------------------------------
| epoch   2 |    50/  110 batches | accuracy    0.778 | loss    1.183
| epoch   2 |   100/  110 batches | accuracy    0.778 | loss    1.161
-----------------------------------------------------------
| end of epoch   2 | time: 25.50s | valid accuracy    0.791 | valid loss    1.139 
-----------------------------------------------------------
| epoch   3 |    50/  110 batches | accuracy    0.782 | loss    1.134
| epoch   3 |   100/  110 batches | accuracy    0.782 | loss    1.145
-----------------------------------------------------------
| end of epoch   3 | time: 25.20s | valid accuracy    0.789 | valid loss    1.129 
-------------------------------