In [1]:
# In AllenNLP we use type annotations for just about everything.
from typing import Iterator, List, Dict

# AllenNLP is built on top of PyTorch, so we use its code freely.
import torch
import torch.optim as optim
import numpy as np

# In AllenNLP we represent each training example as an Instance 
# containing Fields of various types. 
# Here each example will have a TextField containing the sentence, 
# and a SequenceLabelField containing the corresponding part-of-speech tags.

from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField

# Typically to solve a problem like this using AllenNLP, you'll have to implement two classes. The first is a DatasetReader, 
# which contains the logic for reading a file of data and producing a stream of Instances.
from allennlp.data.dataset_readers import DatasetReader

# Frequently we'll want to load datasets or models from URLs. 
# The cached_path helper downloads such files, caches them locally, 
# and returns the local path. It also accepts local file paths (which it just returns as-is).
from allennlp.common.file_utils import cached_path

# There are various ways to represent a word as one or more indices. 
# For example, you might maintain a vocabulary of unique words and give each 
# word a corresponding id. Or you might have one id per character in the word 
# and represent each word as a sequence of ids. 
# AllenNLP uses a has a TokenIndexer abstraction for this representation.
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token

# Whereas a TokenIndexer represents a rule for how to turn a token into indices, 
# a Vocabulary contains the corresponding mappings from strings to integers. 
# For example, your token indexer might specify to represent a token as a sequence of character ids, 
# in which case the Vocabulary would contain the mapping {character -> id}. 
# In this particular example we use a SingleIdTokenIndexer that assigns each 
# token a unique id, and so the Vocabulary will just contain a mapping {token -> id} (as well as the reverse mapping).
from allennlp.data.vocabulary import Vocabulary

# Besides DatasetReader, the other class you'll typically need to implement is Model, 
# which is a PyTorch Module that takes tensor inputs and produces a dict 
# of tensor outputs (including the training loss you want to optimize).
from allennlp.models import Model

# As mentioned above, our model will consist of an embedding layer, followed by a LSTM, 
# then by a feedforward layer. AllenNLP includes abstractions for all of these that smartly 
# handle padding and batching, as well as various utility functions.
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits

# We'll want to track accuracy on the training and validation datasets.
from allennlp.training.metrics import CategoricalAccuracy

# In our training we'll need a DataIterators that can intelligently batch our data.
from allennlp.data.iterators import BucketIterator

# And we'll use AllenNLP's full-featured Trainer.
from allennlp.training.trainer import Trainer

# Finally, we'll want to make predictions on new inputs, more about this below.
from allennlp.predictors import SentenceTaggerPredictor

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)


In [2]:
from allennlp.predictors.predictor import Predictor

In [3]:
predictor = Predictor.from_path("https://s3-us-west-2.amazonaws.com/allennlp/models/coref-model-2018.02.05.tar.gz")
predictor.predict(
  document="The woman reading a newspaper sat on the bench with her dog."
)

_jsonnet not loaded, treating C:\Users\1015203\AppData\Local\Temp\tmpy2w8y00x\config.json as json
  "num_layers={}".format(dropout, num_layers))
Encountered the antecedent_indices key in the model's return dictionary which couldn't be split by the batch size. Key will be ignored.


{'top_spans': [[0, 4], [3, 4], [7, 11], [10, 10], [10, 11]],
 'predicted_antecedents': [-1, -1, -1, 2, -1],
 'document': ['The',
  'woman',
  'reading',
  'a',
  'newspaper',
  'sat',
  'on',
  'the',
  'bench',
  'with',
  'her',
  'dog',
  '.'],
 'clusters': [[[0, 4], [10, 10]]]}

In [4]:
torch.manual_seed(1)

<torch._C.Generator at 0x207fd432530>

In [5]:
# Our first order of business is to implement our DatasetReader subclass.

class PosDatasetReader(DatasetReader):
    """
    DatasetReader for PoS tagging data, one sentence per line, like

        The###DET dog###NN ate###V the###DET apple###NN
    """
    # The only parameter our DatasetReader needs is a dict of TokenIndexers that specify 
    # how to convert tokens into indices. By default we'll just generate a single index 
    # for each token (which we'll call "tokens") that's just a unique id for each distinct 
    # token. (This is just the standard "word to index" mapping you'd use NLP tasks.)
    def __init__(self, token_indexers: Dict[str, TokenIndexer] = None) -> None:
        super().__init__(lazy=False)
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
    
    # DatasetReader.text_to_instance takes the inputs corresponding to a training example 
    # (in this case the tokens of the sentence and the corresponding part-of-speech tags), 
    # instantiates the corresponding Fields (in this case a TextField for the sentence and 
    # a SequenceLabelField for its tags), and returns the Instance containing those fields. 
    # Notice that the tags are optional, since we'd like to be able to create instances 
    # from unlabeled data to make predictions on them.
    def text_to_instance(self, tokens: List[Token], tags: List[str] = None) -> Instance:
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"sentence": sentence_field}

        if tags:
            label_field = SequenceLabelField(labels=tags, sequence_field=sentence_field)
            fields["labels"] = label_field

        return Instance(fields)
    
    # The other piece we have to implement is _read, which takes a filename and produces a stream 
    # of Instances. Most of the work has already been done in text_to_instance.
    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path) as f:
            for line in f:
                pairs = line.strip().split()
                sentence, tags = zip(*(pair.split("###") for pair in pairs))
                yield self.text_to_instance([Token(word) for word in sentence], tags)



In [6]:
# The other class you'll basically always have to implement is Model, 
# which is a subclass of torch.nn.Module. How it works is largely up to you, 
# it mostly just needs a forward method that takes tensor inputs and produces 
# a dict of tensor outputs that includes the loss you'll use to train the model. 
# As mentioned above, our model will consist of an embedding layer, 
# a sequence encoder, and a feedforward network.

class LstmTagger(Model):
    
    # One thing that might seem unusual is that we're going pass in the embedder 
    # and the sequence encoder as constructor parameters. This allows us to experiment 
    # with different embedders and encoders without having to change the model code.
    def __init__(self,
                
                 # The embedding layer is specified as an AllenNLP TextFieldEmbedder 
                 # which represents a general way of turning tokens into tensors. 
                 # (Here we know that we want to represent each unique word with a 
                 # learned tensor, but using the general class allows us to easily 
                 # experiment with different types of embeddings
                 word_embeddings: TextFieldEmbedder,
                
                 # Similarly, the encoder is specified as a general Seq2SeqEncoder 
                 # even though we know we want to use an LSTM. Again, this makes it 
                 # easy to experiment with other sequence encoders, for example a Transformer.
                 encoder: Seq2SeqEncoder,
                
                 # Every AllenNLP model also expects a Vocabulary, which contains the 
                 # namespaced mappings of tokens to indices and labels to indices
                 vocab: Vocabulary) -> None:
        
        # Notice that we have to pass the vocab to the base class constructor
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        
        # The feed forward layer is not passed in as a parameter, but is constructed by us. 
        # Notice that it looks at the encoder to find the correct input dimension and 
        # looks at the vocabulary (and, in particular, at the label -> index mapping) 
        # to find the correct output dimension.
        self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                          out_features=vocab.get_vocab_size('labels'))

        # The last thing to notice is that we also instantiate a CategoricalAccuracy metric, 
        # which we'll use to track accuracy during each training and validation epoch
        self.accuracy = CategoricalAccuracy()

    # Next we need to implement forward, which is where the actual computation happens. 
    # Each Instance in your dataset will get (batched with other instances and) 
    # fed into forward. The forward method expects dicts of tensors as input, and it 
    # expects their names to be the names of the fields in your Instance. 
    # In this case we have a sentence field and (possibly) a labels field, 
    # so we'll construct our forward accordingly:
    def forward(self,
                sentence: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> torch.Tensor:
        
        # AllenNLP is designed to operate on batched inputs, but different input sequences 
        # have different lengths. Behind the scenes AllenNLP is padding the shorter inputs 
        # so that the batch has uniform shape, which means our computations need to use 
        # a mask to exclude the padding. Here we just use the utility function get_text_field_mask, 
        # which returns a tensor of 0s and 1s corresponding to the padded and unpadded locations.
        mask = get_text_field_mask(sentence)
        
        # We start by passing the sentence tensor (each sentence a sequence of token ids) 
        # to the word_embeddings module, which converts each sentence into a sequence of embedded tensors.
        embeddings = self.word_embeddings(sentence)
        
        # We next pass the embedded tensors (and the mask) to the LSTM, 
        # which produces a sequence of encoded outputs.
        encoder_out = self.encoder(embeddings, mask)

        # Finally, we pass each encoded output tensor to the feedforward layer to 
        # produce logits corresponding to the various tags.
        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}
        
        # As before, the labels were optional, as we might want to run this model 
        # to make predictions on unlabeled data. If we do have labels, then we use them 
        # to update our accuracy metric and compute the "loss" that goes in our output.
        if labels is not None:
            self.accuracy(tag_logits, labels, mask)
            output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)

        return output
    
    # We included an accuracy metric that gets updated each forward pass. That means we need to 
    # override a get_metrics method that pulls the data out of it. Behind the scenes, 
    # the CategoricalAccuracy metric is storing the number of predictions and the number of correct predictions, 
    # updating those counts during each call to forward. Each call to get_metric returns the calculated 
    # accuracy and (optionally) resets the counts, which is what allows us to track accuracy anew for each epoch
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}



In [7]:

# Now that we've implemented a DatasetReader and Model, 
# we're ready to train. We first need an instance of our dataset reader.
reader = PosDatasetReader()

# Which we can use to read in the training data and validation data. 
# Here we read them in from a URL, but you could read them in from local 
# files if your data was local. We use cached_path to cache the files locally 
# (and to hand reader.read the path to the local cached version.)
train_dataset = reader.read(cached_path(
    'https://raw.githubusercontent.com/allenai/allennlp'
    '/master/tutorials/tagger/training.txt'))
validation_dataset = reader.read(cached_path(
    'https://raw.githubusercontent.com/allenai/allennlp'
    '/master/tutorials/tagger/validation.txt'))



2it [00:00, 133.47it/s]
2it [00:00, 181.35it/s]


In [8]:

# Once we've read in the datasets, we use them to create our Vocabulary 
# (that is, the mapping[s] from tokens / labels to ids).
vocab = Vocabulary.from_instances(train_dataset + validation_dataset)

# Now we need to construct the model. We'll choose a size for 
# our embedding layer and for the hidden layer of our LSTM.
EMBEDDING_DIM = 6
HIDDEN_DIM = 6



100%|████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<?, ?it/s]


In [17]:
print(vocab)

Vocabulary with namespaces:
 	Non Padded Namespaces: {'*labels', '*tags'}
 	Namespace: tokens, Size: 11 
 	Namespace: labels, Size: 3 



In [10]:

# For embedding the tokens we'll just use the BasicTextFieldEmbedder which 
# takes a mapping from index names to embeddings. If you go back to where we 
# defined our DatasetReader, the default parameters included a single index 
# called "tokens", so our mapping just needs an embedding corresponding to that index. 
# We use the Vocabulary to find how many embeddings we need and our EMBEDDING_DIM 
# parameter to specify the output dimension. It's also possible to start with pre-trained 
# embeddings (for example, GloVe vectors), but there's no need to do that on this tiny toy dataset
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_DIM)
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})



In [11]:

# We next need to specify the sequence encoder. The need for PytorchSeq2SeqWrapper here 
# is slightly unfortunate (and if you use configuration files you won't need to worry about it) 
# but here it's required to add some extra functionality (and a cleaner interface) to the 
# built-in PyTorch module. In AllenNLP we do everything batch first, so we specify that as well.
lstm = PytorchSeq2SeqWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))

# Finally, we can instantiate the model.
model = LstmTagger(word_embeddings, lstm, vocab)



In [12]:

# Now we're ready to train the model. The first thing we'll need is an optimizer. 
# We can just use PyTorch's stochastic gradient descent.
optimizer = optim.SGD(model.parameters(), lr=0.1)

# And we need a DataIterator that handles batching for our datasets. 
# The BucketIterator sorts instances by the specified fields in order to create batches 
# with similar sequence lengths. Here we indicate that we want to sort the instances 
# by the number of tokens in the sentence field.
iterator = BucketIterator(batch_size=2, sorting_keys=[("sentence", "num_tokens")])

# We also specify that the iterator should make sure its instances are indexed using 
# our vocabulary; that is, that their strings have been converted 
# to integers using the mapping we previously created.
iterator.index_with(vocab)



In [13]:

# Now we instantiate our Trainer and run it. Here we tell it to run for 1000 
# epochs and to stop training early if it ever spends 10 epochs without the 
# validation metric improving. The default validation metric is loss 
# (which improves by getting smaller), but it's also possible to specify a 
# different metric and direction (e.g. accuracy should get bigger).

trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  validation_dataset=validation_dataset,
                  patience=10,
                  num_epochs=1000)



In [14]:

# When we launch it it will print a progress bar for each epoch that includes 
# both the "loss" and the "accuracy" metric. If our model is good, the loss 
# should go down and the accuracy up as we train
trainer.train()



accuracy: 0.3333, loss: 1.1685 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 111.22it/s]
accuracy: 0.3333, loss: 1.1592 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.82it/s]
accuracy: 0.3333, loss: 1.1604 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 0.3333, loss: 1.1516 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 166.51it/s]
accuracy: 0.3333, loss: 1.1529 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 124.89it/s]
accuracy: 0.3333, loss: 1.1445 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 0.3333, loss: 1.1458 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 111.03it/s]
accuracy: 0.3333, loss: 1.1379 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 166.49it/s]
accuracy: 0.3333, loss: 1.1391 ||: 100%|

accuracy: 0.4444, loss: 1.0612 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 111.01it/s]
accuracy: 0.4444, loss: 1.0590 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 166.67it/s]
accuracy: 0.4444, loss: 1.0605 ||: 100%|█████████████████████████████████████████████████| 1/1 [00:00<00:00, 99.90it/s]
accuracy: 0.4444, loss: 1.0584 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.85it/s]
accuracy: 0.4444, loss: 1.0598 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 142.75it/s]
accuracy: 0.4444, loss: 1.0578 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.68it/s]
accuracy: 0.4444, loss: 1.0592 ||: 100%|█████████████████████████████████████████████████| 1/1 [00:00<00:00, 99.95it/s]
accuracy: 0.4444, loss: 1.0572 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 166.65it/s]
accuracy: 0.4444, loss: 1.0586 ||: 100%|

accuracy: 0.4444, loss: 1.0479 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 0.4444, loss: 1.0463 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.12it/s]
accuracy: 0.4444, loss: 1.0477 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.90it/s]
accuracy: 0.4444, loss: 1.0460 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 499.62it/s]
accuracy: 0.4444, loss: 1.0474 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 0.4444, loss: 1.0457 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.12it/s]
accuracy: 0.4444, loss: 1.0471 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 0.4444, loss: 1.0454 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.83it/s]
accuracy: 0.4444, loss: 1.0468 ||: 100%|

accuracy: 0.4444, loss: 1.0362 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 166.52it/s]
accuracy: 0.4444, loss: 1.0345 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 0.4444, loss: 1.0357 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.95it/s]
accuracy: 0.4444, loss: 1.0341 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.09it/s]
accuracy: 0.4444, loss: 1.0353 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.78it/s]
accuracy: 0.4444, loss: 1.0336 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 166.55it/s]
accuracy: 0.4444, loss: 1.0349 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.80it/s]
accuracy: 0.4444, loss: 1.0332 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 0.4444, loss: 1.0344 ||: 100%|

accuracy: 0.4444, loss: 1.0173 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.82it/s]
accuracy: 0.4444, loss: 1.0156 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.75it/s]
accuracy: 0.4444, loss: 1.0166 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.85it/s]
accuracy: 0.4444, loss: 1.0149 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 499.44it/s]
accuracy: 0.4444, loss: 1.0159 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.79it/s]
accuracy: 0.4444, loss: 1.0141 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 0.4444, loss: 1.0151 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.86it/s]
accuracy: 0.4444, loss: 1.0134 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.09it/s]
accuracy: 0.4444, loss: 1.0144 ||: 100%|

accuracy: 0.4444, loss: 0.9846 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.85it/s]
accuracy: 0.4444, loss: 0.9826 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.78it/s]
accuracy: 0.4444, loss: 0.9834 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.85it/s]
accuracy: 0.4444, loss: 0.9813 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.78it/s]
accuracy: 0.4444, loss: 0.9821 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 0.4444, loss: 0.9800 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.01it/s]
accuracy: 0.4444, loss: 0.9807 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.82it/s]
accuracy: 0.4444, loss: 0.9787 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.12it/s]
accuracy: 0.4444, loss: 0.9794 ||: 100%|

accuracy: 0.4444, loss: 0.9270 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.85it/s]
accuracy: 0.4444, loss: 0.9242 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.20it/s]
accuracy: 0.5556, loss: 0.9248 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.78it/s]
accuracy: 0.4444, loss: 0.9220 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 0.5556, loss: 0.9225 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.86it/s]
accuracy: 0.4444, loss: 0.9197 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.15it/s]
accuracy: 0.5556, loss: 0.9203 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.83it/s]
accuracy: 0.4444, loss: 0.9175 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.01it/s]
accuracy: 0.5556, loss: 0.9180 ||: 100%|

accuracy: 0.6667, loss: 0.8340 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 0.6667, loss: 0.8306 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 0.6667, loss: 0.8307 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 0.6667, loss: 0.8273 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 0.6667, loss: 0.8274 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.79it/s]
accuracy: 0.6667, loss: 0.8241 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 0.6667, loss: 0.8242 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.79it/s]
accuracy: 0.6667, loss: 0.8208 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 0.6667, loss: 0.8208 ||: 100%|

accuracy: 0.6667, loss: 0.7148 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.81it/s]
accuracy: 0.6667, loss: 0.7119 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 0.6667, loss: 0.7112 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.82it/s]
accuracy: 0.6667, loss: 0.7082 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 0.6667, loss: 0.7075 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.87it/s]
accuracy: 0.6667, loss: 0.7046 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 0.6667, loss: 0.7039 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 142.74it/s]
accuracy: 0.6667, loss: 0.7010 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 499.44it/s]
accuracy: 0.6667, loss: 0.7002 ||: 100%|

accuracy: 0.7778, loss: 0.5912 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 166.52it/s]
accuracy: 0.7778, loss: 0.5895 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.87it/s]
accuracy: 0.7778, loss: 0.5876 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.87it/s]
accuracy: 0.7778, loss: 0.5859 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.79it/s]
accuracy: 0.7778, loss: 0.5840 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.79it/s]
accuracy: 0.7778, loss: 0.5823 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 499.80it/s]
accuracy: 0.7778, loss: 0.5804 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 166.53it/s]
accuracy: 0.7778, loss: 0.5787 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 0.7778, loss: 0.5768 ||: 100%|

accuracy: 1.0000, loss: 0.4695 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 124.88it/s]
accuracy: 1.0000, loss: 0.4682 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 1.0000, loss: 0.4660 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 1.0000, loss: 0.4647 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.4625 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.81it/s]
accuracy: 1.0000, loss: 0.4612 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.12it/s]
accuracy: 1.0000, loss: 0.4591 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.76it/s]
accuracy: 1.0000, loss: 0.4578 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.4556 ||: 100%|

accuracy: 1.0000, loss: 0.3577 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.82it/s]
accuracy: 1.0000, loss: 0.3561 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.09it/s]
accuracy: 1.0000, loss: 0.3547 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.83it/s]
accuracy: 1.0000, loss: 0.3530 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.3517 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.82it/s]
accuracy: 1.0000, loss: 0.3500 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.3487 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 1.0000, loss: 0.3470 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.79it/s]
accuracy: 1.0000, loss: 0.3457 ||: 100%|

accuracy: 1.0000, loss: 0.2661 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.79it/s]
accuracy: 1.0000, loss: 0.2644 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 1.0000, loss: 0.2638 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.78it/s]
accuracy: 1.0000, loss: 0.2620 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.2614 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.83it/s]
accuracy: 1.0000, loss: 0.2597 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.2591 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 1.0000, loss: 0.2574 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.09it/s]
accuracy: 1.0000, loss: 0.2569 ||: 100%|

accuracy: 1.0000, loss: 0.1976 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.83it/s]
accuracy: 1.0000, loss: 0.1961 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.1959 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.83it/s]
accuracy: 1.0000, loss: 0.1944 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.1943 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 1.0000, loss: 0.1927 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.1926 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.85it/s]
accuracy: 1.0000, loss: 0.1911 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.82it/s]
accuracy: 1.0000, loss: 0.1910 ||: 100%|

accuracy: 1.0000, loss: 0.1490 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.85it/s]
accuracy: 1.0000, loss: 0.1478 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.78it/s]
accuracy: 1.0000, loss: 0.1478 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.81it/s]
accuracy: 1.0000, loss: 0.1466 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.78it/s]
accuracy: 1.0000, loss: 0.1466 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.82it/s]
accuracy: 1.0000, loss: 0.1454 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.12it/s]
accuracy: 1.0000, loss: 0.1455 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 1.0000, loss: 0.1443 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.09it/s]
accuracy: 1.0000, loss: 0.1443 ||: 100%|

accuracy: 1.0000, loss: 0.1151 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.1142 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.1143 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 1.0000, loss: 0.1134 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.01it/s]
accuracy: 1.0000, loss: 0.1135 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.1126 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 499.50it/s]
accuracy: 1.0000, loss: 0.1127 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 1.0000, loss: 0.1118 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.12it/s]
accuracy: 1.0000, loss: 0.1119 ||: 100%|

accuracy: 1.0000, loss: 0.0914 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 1.0000, loss: 0.0907 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.85it/s]
accuracy: 1.0000, loss: 0.0908 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 1.0000, loss: 0.0901 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.09it/s]
accuracy: 1.0000, loss: 0.0902 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.82it/s]
accuracy: 1.0000, loss: 0.0895 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 1.0000, loss: 0.0897 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.78it/s]
accuracy: 1.0000, loss: 0.0890 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 499.50it/s]
accuracy: 1.0000, loss: 0.0891 ||: 100%|

accuracy: 1.0000, loss: 0.0744 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 1.0000, loss: 0.0739 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.0740 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 1.0000, loss: 0.0735 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.09it/s]
accuracy: 1.0000, loss: 0.0736 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 1.0000, loss: 0.0731 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.78it/s]
accuracy: 1.0000, loss: 0.0732 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.83it/s]
accuracy: 1.0000, loss: 0.0726 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.01it/s]
accuracy: 1.0000, loss: 0.0728 ||: 100%|

accuracy: 1.0000, loss: 0.0620 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0616 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.0617 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0613 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.17it/s]
accuracy: 1.0000, loss: 0.0614 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 1.0000, loss: 0.0610 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.09it/s]
accuracy: 1.0000, loss: 0.0611 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.78it/s]
accuracy: 1.0000, loss: 0.0607 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.77it/s]
accuracy: 1.0000, loss: 0.0608 ||: 100%|

accuracy: 1.0000, loss: 0.0527 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 1.0000, loss: 0.0523 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.20it/s]
accuracy: 1.0000, loss: 0.0524 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 1.0000, loss: 0.0521 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.77it/s]
accuracy: 1.0000, loss: 0.0522 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0518 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0519 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.85it/s]
accuracy: 1.0000, loss: 0.0516 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.75it/s]
accuracy: 1.0000, loss: 0.0517 ||: 100%|

accuracy: 1.0000, loss: 0.0454 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.85it/s]
accuracy: 1.0000, loss: 0.0452 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.09it/s]
accuracy: 1.0000, loss: 0.0453 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 166.51it/s]
accuracy: 1.0000, loss: 0.0450 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0451 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.83it/s]
accuracy: 1.0000, loss: 0.0448 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.01it/s]
accuracy: 1.0000, loss: 0.0449 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.85it/s]
accuracy: 1.0000, loss: 0.0446 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.12it/s]
accuracy: 1.0000, loss: 0.0447 ||: 100%|

accuracy: 1.0000, loss: 0.0398 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.83it/s]
accuracy: 1.0000, loss: 0.0395 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.82it/s]
accuracy: 1.0000, loss: 0.0396 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.82it/s]
accuracy: 1.0000, loss: 0.0394 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.12it/s]
accuracy: 1.0000, loss: 0.0395 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.88it/s]
accuracy: 1.0000, loss: 0.0392 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.0393 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 166.53it/s]
accuracy: 1.0000, loss: 0.0391 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.17it/s]
accuracy: 1.0000, loss: 0.0392 ||: 100%|

accuracy: 1.0000, loss: 0.0352 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 1.0000, loss: 0.0350 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 1.0000, loss: 0.0351 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.87it/s]
accuracy: 1.0000, loss: 0.0349 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0349 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0348 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0348 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0347 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.78it/s]
accuracy: 1.0000, loss: 0.0347 ||: 100%|

accuracy: 1.0000, loss: 0.0315 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0313 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.87it/s]
accuracy: 1.0000, loss: 0.0314 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.85it/s]
accuracy: 1.0000, loss: 0.0312 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 499.56it/s]
accuracy: 1.0000, loss: 0.0313 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.78it/s]
accuracy: 1.0000, loss: 0.0311 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.01it/s]
accuracy: 1.0000, loss: 0.0312 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0310 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0311 ||: 100%|

accuracy: 1.0000, loss: 0.0284 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.83it/s]
accuracy: 1.0000, loss: 0.0282 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.04it/s]
accuracy: 1.0000, loss: 0.0283 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0282 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.15it/s]
accuracy: 1.0000, loss: 0.0282 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0281 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.15it/s]
accuracy: 1.0000, loss: 0.0281 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.83it/s]
accuracy: 1.0000, loss: 0.0280 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 499.44it/s]
accuracy: 1.0000, loss: 0.0280 ||: 100%|

accuracy: 1.0000, loss: 0.0258 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.85it/s]
accuracy: 1.0000, loss: 0.0257 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0257 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.80it/s]
accuracy: 1.0000, loss: 0.0256 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.0256 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.82it/s]
accuracy: 1.0000, loss: 0.0255 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.78it/s]
accuracy: 1.0000, loss: 0.0256 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.85it/s]
accuracy: 1.0000, loss: 0.0255 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.0255 ||: 100%|

accuracy: 1.0000, loss: 0.0236 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.85it/s]
accuracy: 1.0000, loss: 0.0235 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.0235 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.78it/s]
accuracy: 1.0000, loss: 0.0234 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 499.56it/s]
accuracy: 1.0000, loss: 0.0234 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.83it/s]
accuracy: 1.0000, loss: 0.0234 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.09it/s]
accuracy: 1.0000, loss: 0.0234 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 1.0000, loss: 0.0233 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.09it/s]
accuracy: 1.0000, loss: 0.0233 ||: 100%|

accuracy: 1.0000, loss: 0.0217 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 166.54it/s]
accuracy: 1.0000, loss: 0.0216 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.0216 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.78it/s]
accuracy: 1.0000, loss: 0.0215 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.0216 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.87it/s]
accuracy: 1.0000, loss: 0.0215 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.0215 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 1.0000, loss: 0.0214 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.0215 ||: 100%|

accuracy: 1.0000, loss: 0.0200 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.84it/s]
accuracy: 1.0000, loss: 0.0200 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.15it/s]
accuracy: 1.0000, loss: 0.0200 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.79it/s]
accuracy: 1.0000, loss: 0.0199 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 499.62it/s]
accuracy: 1.0000, loss: 0.0199 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.81it/s]
accuracy: 1.0000, loss: 0.0199 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.07it/s]
accuracy: 1.0000, loss: 0.0199 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 166.53it/s]
accuracy: 1.0000, loss: 0.0198 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.01it/s]
accuracy: 1.0000, loss: 0.0198 ||: 100%|

accuracy: 1.0000, loss: 0.0186 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.83it/s]
accuracy: 1.0000, loss: 0.0185 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.15it/s]
accuracy: 1.0000, loss: 0.0186 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.89it/s]
accuracy: 1.0000, loss: 0.0185 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.12it/s]
accuracy: 1.0000, loss: 0.0185 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 166.52it/s]
accuracy: 1.0000, loss: 0.0185 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.01it/s]
accuracy: 1.0000, loss: 0.0185 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 199.86it/s]
accuracy: 1.0000, loss: 0.0184 ||: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.84it/s]
accuracy: 1.0000, loss: 0.0184 ||: 100%|

{'training_duration': '00:00:19',
 'training_start_epoch': 0,
 'training_epochs': 999,
 'epoch': 999,
 'training_accuracy': 1.0,
 'training_loss': 0.01809467002749443,
 'validation_accuracy': 1.0,
 'validation_loss': 0.01804003119468689,
 'best_epoch': 999,
 'best_validation_accuracy': 1.0,
 'best_validation_loss': 0.01804003119468689}

In [15]:

# As in the original PyTorch tutorial, we'd like to look at the predictions our model generates. 
# AllenNLP contains a Predictor abstraction that takes inputs, converts them to instances, 
# feeds them through your model, and returns JSON-serializable results. Often you'd need to 
# implement your own Predictor, but AllenNLP already has a SentenceTaggerPredictor that 
# works perfectly here, so we can use it. It requires our model (for making predictions) 
# and a dataset reader (for creating instances).
predictor = SentenceTaggerPredictor(model, dataset_reader=reader)

# It has a predict method that just needs a sentence and returns (a JSON-serializable version of) 
# the output dict from forward. Here tag_logits will be a (5, 3) array of logits, 
# corresponding to the 3 possible tags for each of the 5 words.
tag_logits = predictor.predict("The dog ate the apple")['tag_logits']

# To get the actual "predictions" we can just take the argmax
tag_ids = np.argmax(tag_logits, axis=-1)

# And then use our vocabulary to find the predicted tags.
print([model.vocab.get_token_from_index(i, 'labels') for i in tag_ids])



['DET', 'NN', 'V', 'DET', 'NN']
