<a href="https://colab.research.google.com/github/whoami-Lory271/NN-project-memorizing-transformers/blob/main/NN_project_Antonelli_DeSantis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn as nn
import numpy as np
from torch.nn import functional as F
from math import sqrt
import matplotlib.pyplot as plt
from torch.autograd import Variable
from pathlib import Path
from filelock import FileLock
import random
import tqdm
import gzip
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.utils import resample
from sklearn.model_selection import train_test_split

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# KNN Memory

In [3]:
!pip install faiss-gpu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2


In [4]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.0


In [5]:
#import per la knn memory
import os
import math
import torch
import faiss
import numpy as np
from pathlib import Path
from functools import wraps

from contextlib import ExitStack, contextmanager

from einops import rearrange, pack, unpack

# multiprocessing

from joblib import Parallel, delayed, cpu_count

In [6]:
FAISS_INDEX_GPU_ID = int(os.getenv('FAISS_INDEX_GPU_ID', 0))

DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY = './.tmp/knn.memories'

# helper functions

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def cast_list(val):
    return val if isinstance(val, list) else [val]

def all_el_unique(arr):
    return len(set(arr)) == len(arr)

@contextmanager
def multi_context(*cms):
    with ExitStack() as stack:
        yield [stack.enter_context(cls) for cls in cms]

def count_intersect(x, y):
    # returns an array that shows how many times an element in x is contained in tensor y
    return np.sum(rearrange(x, 'i -> i 1') == rearrange(y, 'j -> 1 j'), axis = -1)

def check_shape(tensor, pattern, **kwargs):
    return rearrange(tensor, f"{pattern} -> {pattern}", **kwargs)

# a wrapper around faiss IndexIVFFlat
# taking care of expiring old keys automagically

class KNN():
    def __init__(
        self,
        dim,
        max_num_entries,
        cap_num_entries = False,
        M = 15,
        keep_stats = False
    ):
        index = faiss.IndexHNSWFlat(dim, M, faiss.METRIC_INNER_PRODUCT)
        self.index = index
        self.max_num_entries = max_num_entries
        self.cap_num_entries = cap_num_entries
        self.is_trained = False
        self.keep_stats = keep_stats

        self.reset()

    def __del__(self):
        if hasattr(self, 'index'):
            del self.index

    def reset(self):
        self.ids = np.empty((0,), dtype = np.int32)

        if self.keep_stats:
            self.hits = np.empty((0,), dtype = np.int32)
            self.age_num_iterations = np.empty((0,), dtype = np.int32)
            self.ages_since_last_hit = np.empty((0,), dtype = np.int32)

        self.index.reset()
        self.is_trained = False

    def train(self, x):
        self.index.train(x)
        self.is_trained = True

    def add(self, x, ids):
        if not self.is_trained:
            self.train(x)

        self.ids = np.concatenate((ids, self.ids))

        if self.keep_stats:
            self.hits = np.concatenate((np.zeros_like(ids), self.hits))
            self.age_num_iterations = np.concatenate((np.zeros_like(ids), self.age_num_iterations))
            self.ages_since_last_hit = np.concatenate((np.zeros_like(ids), self.ages_since_last_hit))

        if self.cap_num_entries and len(self.ids) > self.max_num_entries:
            self.reset()

        return self.index.add(x)

    def search(
        self,
        x,
        topk,
        nprobe = 8,
        return_distances = False,
        increment_hits = False,
        increment_age = True
    ):
        if not self.is_trained:
            return np.full((x.shape[0], topk), -1)

        distances, indices = self.index.search(x, k = topk)

        if increment_hits and self.keep_stats:
            hits = count_intersect(self.ids, rearrange(indices, '... -> (...)'))
            self.hits += hits

            self.ages_since_last_hit += 1
            self.ages_since_last_hit *= (hits == 0)

        if increment_age and self.keep_stats:
            self.age_num_iterations += 1

        if return_distances:
            return indices, distances

        return indices

# KNN memory layer, where one can store key / value memories
# can automatically take care of a collection of faiss indices (across batch dimension)

class KNNMemory():
    def __init__(
        self,
        dim,
        max_memories = 16000,
        num_indices = 1,
        memmap_filename = './knn.memory.memmap',
        multiprocessing = True
    ):
        self.dim = dim
        self.num_indices = num_indices
        self.scoped_indices = list(range(num_indices))

        self.max_memories = max_memories
        self.shape = (num_indices, max_memories, 2, dim)
        self.db_offsets = np.zeros(num_indices, dtype = np.int32)

        self.db = np.memmap(memmap_filename, mode = 'w+', dtype = np.float32, shape = self.shape)
        self.knns = [KNN(dim = dim, max_num_entries = max_memories, cap_num_entries = True) for _ in range(num_indices)]
    
        self.n_jobs = cpu_count() if multiprocessing else 1

    def set_scoped_indices(self, indices):
        indices = list(indices)
        assert all_el_unique(indices), f'all scoped batch indices must be unique, received: {indices}'
        assert all([0 <= i < self.num_indices for i in indices]), f'each batch index must be between 0 and less than {self.num_indices}: received {indices}'
        self.scoped_indices = indices

    @contextmanager
    def at_batch_indices(self, indices):
        prev_indices = self.scoped_indices
        self.set_scoped_indices(indices)
        yield self
        self.set_scoped_indices(prev_indices)

    def clear(self, batch_indices = None):
        if not exists(batch_indices):
            batch_indices = list(range(self.num_indices))

        batch_indices = cast_list(batch_indices)

        for index in batch_indices:
            knn = self.knns[index]
            knn.reset()

        self.db_offsets[batch_indices] = 0

    def add(self, memories):
        check_shape(memories, 'b n kv d', d = self.dim, kv = 2, b = len(self.scoped_indices))

        memories = memories.detach().cpu().numpy()
        memories = memories[:, -self.max_memories:]
        num_memories = memories.shape[1]

        knn_insert_ids = np.arange(num_memories)

        keys = np.ascontiguousarray(memories[..., 0, :])
        knns = [self.knns[i] for i in self.scoped_indices]
        db_offsets = [self.db_offsets[i] for i in self.scoped_indices]

        # use joblib to insert new key / value memories into faiss index

        @delayed
        def knn_add(knn, key, db_offset):
            knn.add(key, ids = knn_insert_ids + db_offset)
            return knn

        updated_knns = Parallel(n_jobs = self.n_jobs)(knn_add(*args) for args in zip(knns, keys, db_offsets))
        for knn_idx, scoped_idx in enumerate(self.scoped_indices):
            self.knns[scoped_idx] = updated_knns[knn_idx]

        # add the new memories to the memmap "database"

        add_indices = (rearrange(np.arange(num_memories), 'j -> 1 j') + rearrange(self.db_offsets[list(self.scoped_indices)], 'i -> i 1')) % self.max_memories
        self.db[rearrange(np.array(self.scoped_indices), 'i -> i 1'), add_indices] = memories
        self.db.flush()

        self.db_offsets += num_memories

    def search(
        self,
        queries,
        topk,
        nprobe = 8,
        increment_hits = True,
        increment_age = True
    ):
        check_shape(queries, 'b ... d', d = self.dim, b = len(self.scoped_indices))
        queries, ps = pack([queries], 'b * d')

        device = queries.device
        queries = queries.detach().cpu().numpy()

        all_masks = []
        all_key_values = []

        knns = [self.knns[i] for i in self.scoped_indices]

        # parallelize faiss search

        @delayed
        def knn_search(knn, query):
            return knn.search(query, topk, nprobe, increment_hits = increment_hits, increment_age = increment_age)

        fetched_indices = Parallel(n_jobs = self.n_jobs)(knn_search(*args) for args in zip(knns, queries))

        # get all the memory key / values from memmap 'database'
        # todo - remove for loop below

        for batch_index, indices in zip(self.scoped_indices, fetched_indices):
            mask = indices !=  -1
            db_indices = np.where(mask, indices, 0)

            all_masks.append(torch.from_numpy(mask))

            key_values = self.db[batch_index, db_indices % self.max_memories]
            all_key_values.append(torch.from_numpy(key_values))

        all_masks = torch.stack(all_masks)
        all_key_values = torch.stack(all_key_values)
        all_key_values = all_key_values.masked_fill(~rearrange(all_masks, '... -> ... 1 1'), 0.)

        all_key_values, = unpack(all_key_values, ps, 'b * n kv d')
        all_masks, = unpack(all_masks, ps, 'b * n')

        return all_key_values.to(device), all_masks.to(device)

    def __del__(self):
        if hasattr(self, 'knns'):
            for knn in self.knns:
                del knn
        del self.db

# extends list with some extra methods for collections of KNN memories

class KNNMemoryList(list):
    def cleanup(self):
        for memory in self:
            del memory

    @classmethod
    def create_memories(
        self,
        *,
        batch_size,
        num_memory_layers,
        memories_directory = DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY
    ):
        memories_path = Path(memories_directory)
        memories_path.mkdir(exist_ok = True, parents = True)

        def inner(*args, **kwargs):
            return self([KNNMemory(*args, num_indices = batch_size, memmap_filename = str(memories_path / f'knn.memory.layer.{ind + 1}.memmap'), **kwargs) for ind in range(num_memory_layers)])
        return inner

    @contextmanager
    def at_batch_indices(
        self,
        indices
    ):
        knn_batch_indices_contexts = [memory.at_batch_indices(indices) for memory in self]
        with multi_context(*knn_batch_indices_contexts):
            yield

    def clear_memory(
        self,
        batch_indices = None,
        memory_indices = None
    ):
        memory_indices = default(memory_indices, tuple(range(len(self))))

        for memory_index in memory_indices:
            memory = self[memory_index]
            memory.clear(batch_indices)

# Memorizing transformers

In [7]:
def attention(query, key, value, sqrt_q, device):
    t = torch.matmul(query, key.transpose(-2, -1))/sqrt_q
    i, j = t.shape[-2:]
    mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
    return torch.matmul(F.softmax(t.masked_fill_(mask, -1e-9), dim = -1), value)

def KNNattention(query, key, value, sqrt_q, mask):
    t = torch.einsum('b h i d, b h i j d -> b h i j', query, key)/sqrt_q
    return torch.einsum('b h i j, b h i j d -> b h i d', F.softmax(t.masked_fill_(mask, -1e-9), dim = -1), value)

In [8]:
class MultiHeadAttention(nn.Module):
  def __init__(self, n, d, h, batch_size):
    super(MultiHeadAttention, self).__init__()
    assert d % h == 0
    #assume q = v 
    self.q = d // h
    self.sqrt_q = sqrt(self.q)
    self.h = h
    self.batch_size = batch_size
    self.W_q = nn.Linear(d, d, bias = False) #stack of h matrices of dimension (d, q), one for each head
    self.W_k = nn.Linear(d, d, bias = False)
    self.W_v = nn.Linear(d, d, bias = False)
    self.W_o = nn.Linear(d, d, bias = False)

  def forward(self, x, device):
    query = self.W_q(x).view(self.batch_size, -1, self.h, self.q).transpose(1, 2)
    key = self.W_k(x).view(self.batch_size, -1, self.h, self.q).transpose(1, 2)
    value = self.W_v(x).view(self.batch_size, -1, self.h, self.q).transpose(1, 2)
    new_memories = torch.stack((key, value), dim = -2).detach()
    attention_value = attention(query, key, value, self.sqrt_q, device)
    return self.W_o(attention_value.transpose(1, 2).contiguous().view(self.batch_size, -1, self.h*self.q)), new_memories

In [9]:
class KNNAttention(nn.Module):
   def __init__(self, n, d, h, num_retrieved_memories, batch_size):
      super(KNNAttention, self).__init__()
      assert d % h == 0
      #assume q = v 
      self.q = d // h
      self.sqrt_q = sqrt(self.q)
      self.h = h
      self.W_q = nn.Linear(d, d, bias = False)
      self.W_k = nn.Linear(d, d, bias = False)
      self.W_v = nn.Linear(d, d, bias = False)
      self.W_o = nn.Linear(d, d, bias = False)
      self.b_g = nn.Parameter(torch.randn((h,))) #one for each head
      self.num_retrieved_memories = num_retrieved_memories
      self.batch_size = batch_size

   def forward(self, x, knn_memory, device):
      # calculate local attention 
      query = self.W_q(x).view(self.batch_size, -1, self.h, self.q).transpose(1, 2)
      key = self.W_k(x).view(self.batch_size, -1, self.h, self.q).transpose(1, 2)
      value = self.W_v(x).view(self.batch_size, -1, self.h, self.q).transpose(1, 2)
      local_attention = attention(query, key, value, self.sqrt_q, device)

      # calculate knn attention over memory
      mem_kv, mem_mask = knn_memory.search(query, self.num_retrieved_memories)
      mem_key, mem_value = mem_kv.unbind(dim = -2)
      knn_attention = KNNattention(query, mem_key, mem_value, self.sqrt_q, ~mem_mask)

      # memory to be stored
      new_kv_memories = torch.stack((key, value), dim = -2).view(self.batch_size, -1, 2, self.q).detach()

      # add to knn memory
      if new_kv_memories.numel() > 0:
        knn_memory.add(new_kv_memories)

      # combining local and memory
      g = torch.sigmoid(self.b_g)
      final_attention = torch.einsum('b h n d, h -> b h n d', knn_attention, g) + \
                        torch.einsum('b h n d, h -> b h n d', local_attention, (1 - g))
      
      return self.W_o(final_attention.transpose(1, 2).contiguous().view(self.batch_size, -1, self.h*self.q)), new_kv_memories

In [10]:
class SubLayer(nn.Module):
  def __init__(self, d, dropout, hidden_size):
    super(SubLayer, self).__init__()
    self.norm = nn.LayerNorm(d)
    self.mlp = nn.Sequential(nn.Linear(d, hidden_size, bias = True), 
                             nn.ReLU(),
                             nn.Dropout(dropout),
                             nn.Linear(hidden_size, d, bias = True))

  def forward(self, x):
    return x + self.mlp(self.norm(x)) #residual connection and normalization

In [11]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len=5000):
    super(PositionalEncoding, self).__init__()
    
    # Compute the positional encodings once in log space.
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) *
                          -(math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)
    self.register_buffer('pe', pe)
      
  def forward(self, x):
    return x + Variable(self.pe[:, :x.size(1)], requires_grad=False)

In [12]:
class MemorizingTransformer(nn.Module):
    def __init__(
          self,
          num_tokens,
          d,
          heads = 8,
          depth = 10,
          knn_attn_idx = 7,
          attn_dropout = 0.,
          hidden_size = 1000,
          dropout = 0.3,
          max_knn_memories = 1000,
          num_retrieved_memories = 8,
          knn_memories_directory = DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY,
          knn_memory_multiprocessing = False,
          batch_size = 16
      ):
          # asserts
          assert d % heads == 0
          assert knn_attn_idx < depth

          super(MemorizingTransformer, self).__init__()
          self.token_emb = nn.Embedding(num_tokens, d)
          self.positional_emb = PositionalEncoding(d, max_len = 5000)
          self.dim_head = d // heads
          self.d = d
          self.heads = heads
          self.knn_attn_idx = knn_attn_idx
          self.depth = depth
          self.attn_dropout = attn_dropout
          self.hidden_size = hidden_size
          self.dropout = dropout
          self.max_knn_memories = max_knn_memories
          self.num_retrieved_memories = num_retrieved_memories
          self.knn_memories_directory = knn_memories_directory
          self.knn_memory_multiprocessing =knn_memory_multiprocessing
          self.batch_size = batch_size

          self.layers = nn.ModuleList([])
          for idx in range(depth):
              attn = KNNAttention(num_tokens, d, heads, num_retrieved_memories, self.batch_size) \
                  if idx == knn_attn_idx else MultiHeadAttention(num_tokens, d, heads, self.batch_size)

              self.layers.append(nn.ModuleList([
                  attn,
                  SubLayer(d, dropout, hidden_size)
              ]))

          self.to_out = nn.Sequential(
               nn.LayerNorm(d),
               nn.Linear(d, num_tokens)
          )

          # knn memories init

          self.knn_mem_kwargs = dict(
              dim = self.dim_head,
              max_memories = self.max_knn_memories,
              multiprocessing = knn_memory_multiprocessing
          )
          
    def forward(
        self,
        x,
        knn_memory
    ):
        batch_size, seq_len, *_, device = *x.shape, x.device
        x = self.token_emb(x)
        x = self.positional_emb(x)

        for idx, (attn, sub_l) in enumerate(self.layers):
            
            #attention

            x, mem = attn(x, knn_memory, device) if self.knn_attn_idx == idx else attn(x, device)
      
            # normalization + feedforward + residual connection

            x = sub_l(x)

        return self.to_out(x).transpose(1, 2)

    
    def create_knn_memories(
          self,
          *,
          batch_size
      ):  
          return KNNMemoryList.create_memories(
              batch_size = batch_size,
              num_memory_layers = 1,
              memories_directory = self.knn_memories_directory
          )(**self.knn_mem_kwargs)
      
    @contextmanager
    def knn_memories_context(
        self,
        **kwargs
    ):
        knn_dir = Path(self.knn_memories_directory)
        knn_dir.mkdir(exist_ok = True, parents = True)
        lock = FileLock(str(knn_dir / 'mutex'))

        with lock:
            knn_memories = self.create_knn_memories(**kwargs)
            yield knn_memories
            knn_memories.cleanup()

    def clear_memory(self, x, token_id):
        """ clears the KNN memories based on if the batch row contains the specified token id """
        """ for auto-clearing KNN memories based on start and end of strings """

        clear_memory = (x == token_id).any(dim = -1)
        batch_indices, _ = clear_memory.nonzero(as_tuple = True)
        batch_indices_to_clear = batch_indices.tolist()

        if len(batch_indices_to_clear) == 0:
            return

        knn_memories.clear_memory(batch_indices_to_clear)

# Training

In [13]:
# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 16
SEQ_LEN = 512
SEGMENTS = 5
HEADS = 8
DIM_HEAD = SEQ_LEN // HEADS

LEARNING_RATE = 2e-4
MAX_GRAD_CLIP_NORM = 0.5

EVAL_EVERY = 20
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
CHECKPOINT = 100

In [14]:
model = MemorizingTransformer(
    num_tokens = 256,
    d = SEQ_LEN,
    heads = HEADS,
    batch_size = BATCH_SIZE,
    num_retrieved_memories = 32
).cuda()

memory = KNNMemory(
    dim = DIM_HEAD,                   # dimension of key / values
    max_memories = 1000,       # maximum number of memories to keep (will throw out the oldest memories for now if it overfills)
    num_indices = BATCH_SIZE          # this should be equivalent to batch dimension, as each batch keeps track of its own memories, expiring when it sees a new document
)

# prepare enwik8 data

#Lorenzo
with gzip.open('/content/drive/MyDrive/Secondo Anno/Neural Networks/project/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    print(X.shape)
    # number of samples to take
    n_samples = math.ceil(0.6*X.shape[0])
    # take the set uniformly at random
    data = resample(X, n_samples=n_samples, replace=False)
    trX, vaX = train_test_split(data, test_size=math.ceil(0.2*data.shape[0]))
    # trX, vaX = np.split(X, [int(90e6)])
    print(trX.shape)
    print(vaX.shape)
    # assert False
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
    # data = torch.from_numpy(X)
"""

#Luigi
with gzip.open('/content/drive/MyDrive/Colab Notebooks/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
"""

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# # dataset and dataloader
# dataset = TextSamplerDataset(data, SEQ_LEN)
# # test_dataset = TextSamplerDataset(data_val, SEQ_LEN)

# data_size = dataset.__len__()
# # data_test_size = test_dataset.__len__()

# perc_data = 0.3
# valid_size=0.2
# indices = list(range(data_size))
# np.random.shuffle(indices)
# data_size = int(np.floor(data_size * 0.3))
# print(data_size)
# indices = indices[:data_size]

# split = int(np.floor(valid_size * data_size))
# train_idx, valid_idx = indices[split:], indices[:split]

# # define samplers for obtaining training and validation batches
# train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
# valid_sampler = torch.utils.data.SubsetRandomSampler(valid_idx)

# train_loader  = DataLoader(dataset, batch_size = BATCH_SIZE, sampler = train_sampler, drop_last = True)

# test_loader = DataLoader(dataset, batch_size = BATCH_SIZE, sampler =valid_sampler, drop_last = True)


train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
train_loader  = DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True)
test_dataset = TextSamplerDataset(data_val, SEQ_LEN)
test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE, drop_last = True)

  X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)


(95000000,)
(45600000,)
(11400000,)


In [15]:
def print_string(a):
  seq = ""
  for word in a:
    for letter in word:
      seq += chr(letter)
    seq += " "
  return seq

In [16]:
# optimizer

optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
loss = nn.CrossEntropyLoss()

# training

perplexity_list = []

for i, data in enumerate(tqdm.tqdm(train_loader, desc = 'training')):
    model.train()

    train_loss = 0.
    # with model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
        
    seq, labels = data[:, :-1], data[:, 1:] #the labels are the same sequences shifted by one

    out = model(
          seq,
          knn_memory = memory
    )
    #loss_item = torch.exp(loss(out, labels)) #perplexity
    loss_item = loss(out, labels)
    train_loss += loss_item
    loss_item.backward() 

    print(f'training loss: {train_loss}', flush = True)
    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_CLIP_NORM)
    optimizer.step()
    optimizer.zero_grad()

    if i % EVAL_EVERY == 0:
        model.eval()
      
        test_data = None
        for test_data in test_loader:
          break

        test_loss = 0.

        # with torch.no_grad(), model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories: 
        seq, labels = test_data[:, :-1], test_data[:, 1:]
        
        out = model(
          seq,
          knn_memory = memory
        )

        loss_item = loss(out, labels)
        test_loss +=  loss_item
            

        print(f'valid loss: {test_loss}', flush = True)
        print(f'perplexity: {torch.exp(test_loss)}', flush = True)
        perplexity_list.append(torch.exp(test_loss).to('cpu').item())
    
    if i % CHECKPOINT == 0:
      torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
      }, 'model_optimizer2.pt')
      #Lorenzo
      with open('/content/drive/MyDrive/Università/Magistrale/Secondo Anno/Neural Networks/project/perplexity_moreNN.npy', 'wb') as f:
        np.save(f, np.array(perplexity_list))

plt.plot(perplexity_list, label = "Memorizing Transformer Perplexity Plot")
plt.legend()
plt.show()

training:   0%|          | 0/5566 [00:00<?, ?it/s]

training loss: 5.720709800720215
valid loss: 4.837080001831055
perplexity: 126.10060119628906


training:   0%|          | 1/5566 [00:14<21:52:22, 14.15s/it]

training loss: 4.910055637359619


training:   0%|          | 2/5566 [00:17<12:25:24,  8.04s/it]

training loss: 4.385214328765869


training:   0%|          | 3/5566 [00:21<9:20:24,  6.04s/it] 

training loss: 4.145002365112305


training:   0%|          | 4/5566 [00:24<7:15:27,  4.70s/it]

training loss: 4.031399726867676


training:   0%|          | 5/5566 [00:27<6:24:20,  4.15s/it]

training loss: 3.911895751953125


training:   0%|          | 6/5566 [00:30<5:38:38,  3.65s/it]

training loss: 3.833533525466919


training:   0%|          | 7/5566 [00:33<5:30:26,  3.57s/it]

training loss: 3.790104866027832


training:   0%|          | 8/5566 [00:37<5:32:16,  3.59s/it]

training loss: 3.7582852840423584


training:   0%|          | 9/5566 [00:41<5:51:28,  3.79s/it]

training loss: 3.7120096683502197


training:   0%|          | 10/5566 [00:43<5:09:32,  3.34s/it]

training loss: 3.6956899166107178


training:   0%|          | 11/5566 [00:46<4:57:21,  3.21s/it]

training loss: 3.660554885864258


training:   0%|          | 12/5566 [00:48<4:32:50,  2.95s/it]

training loss: 3.634036064147949


training:   0%|          | 13/5566 [00:51<4:30:50,  2.93s/it]

training loss: 3.6276867389678955


training:   0%|          | 14/5566 [00:54<4:13:58,  2.74s/it]

training loss: 3.591150999069214


training:   0%|          | 15/5566 [00:56<4:16:59,  2.78s/it]

training loss: 3.571593761444092


training:   0%|          | 16/5566 [00:59<4:04:57,  2.65s/it]

training loss: 3.5715839862823486


training:   0%|          | 17/5566 [01:02<4:10:45,  2.71s/it]

training loss: 3.5801613330841064


training:   0%|          | 18/5566 [01:04<4:00:38,  2.60s/it]

training loss: 3.5886549949645996


training:   0%|          | 19/5566 [01:07<4:06:38,  2.67s/it]

training loss: 3.602724075317383


training:   0%|          | 20/5566 [01:09<3:57:13,  2.57s/it]

training loss: 3.557241678237915
valid loss: 3.528785467147827
perplexity: 34.082550048828125


training:   0%|          | 21/5566 [01:14<5:01:08,  3.26s/it]

training loss: 3.5490670204162598


training:   0%|          | 22/5566 [01:17<4:47:23,  3.11s/it]

training loss: 3.571744918823242


training:   0%|          | 23/5566 [01:19<4:28:20,  2.90s/it]

training loss: 3.570699691772461


training:   0%|          | 24/5566 [01:22<4:25:57,  2.88s/it]

training loss: 3.5438404083251953


training:   0%|          | 25/5566 [01:25<4:23:17,  2.85s/it]

training loss: 3.5474355220794678


training:   0%|          | 26/5566 [01:28<4:43:43,  3.07s/it]

training loss: 3.5350382328033447


training:   0%|          | 27/5566 [01:31<4:22:39,  2.85s/it]

training loss: 3.5216031074523926


training:   1%|          | 28/5566 [01:35<4:58:38,  3.24s/it]

training loss: 3.5443780422210693


training:   1%|          | 29/5566 [01:38<4:58:24,  3.23s/it]

training loss: 3.5513899326324463


training:   1%|          | 30/5566 [01:41<4:46:44,  3.11s/it]

training loss: 3.5256199836730957


training:   1%|          | 31/5566 [01:43<4:26:47,  2.89s/it]

training loss: 3.550621509552002


training:   1%|          | 32/5566 [01:46<4:24:38,  2.87s/it]

training loss: 3.5123114585876465


training:   1%|          | 33/5566 [01:48<4:09:39,  2.71s/it]

training loss: 3.5502259731292725


training:   1%|          | 34/5566 [01:51<4:15:00,  2.77s/it]

training loss: 3.5253894329071045


training:   1%|          | 35/5566 [01:54<4:14:43,  2.76s/it]

training loss: 3.50290584564209


training:   1%|          | 36/5566 [01:57<4:16:41,  2.79s/it]

training loss: 3.510462760925293


training:   1%|          | 37/5566 [01:59<4:05:47,  2.67s/it]

training loss: 3.4961280822753906


training:   1%|          | 38/5566 [02:02<4:09:58,  2.71s/it]

training loss: 3.5411016941070557


training:   1%|          | 39/5566 [02:05<3:59:47,  2.60s/it]

training loss: 3.5038442611694336


training:   1%|          | 40/5566 [02:07<4:04:22,  2.65s/it]

training loss: 3.5180394649505615
valid loss: 3.548358678817749
perplexity: 34.7562255859375


training:   1%|          | 41/5566 [02:12<5:01:10,  3.27s/it]

training loss: 3.5439422130584717


training:   1%|          | 42/5566 [02:14<4:33:55,  2.98s/it]

training loss: 3.5127997398376465


training:   1%|          | 43/5566 [02:17<4:26:36,  2.90s/it]

training loss: 3.515899181365967


training:   1%|          | 44/5566 [02:20<4:38:46,  3.03s/it]

training loss: 3.538006067276001


training:   1%|          | 45/5566 [02:23<4:41:05,  3.05s/it]

training loss: 3.540938377380371


training:   1%|          | 46/5566 [02:26<4:21:39,  2.84s/it]

training loss: 3.499959707260132


training:   1%|          | 47/5566 [02:29<4:20:44,  2.83s/it]

training loss: 3.5109636783599854


training:   1%|          | 48/5566 [02:31<4:06:30,  2.68s/it]

training loss: 3.5505714416503906


training:   1%|          | 49/5566 [02:34<4:08:54,  2.71s/it]

training loss: 3.520963430404663


training:   1%|          | 50/5566 [02:36<3:59:07,  2.60s/it]

training loss: 3.544567584991455


training:   1%|          | 51/5566 [02:39<4:04:33,  2.66s/it]

training loss: 3.5135903358459473


training:   1%|          | 52/5566 [02:41<3:55:56,  2.57s/it]

training loss: 3.5337581634521484


training:   1%|          | 53/5566 [02:44<4:02:30,  2.64s/it]

training loss: 3.533581018447876


training:   1%|          | 54/5566 [02:46<3:54:47,  2.56s/it]

training loss: 3.516953706741333


training:   1%|          | 55/5566 [02:49<4:00:16,  2.62s/it]

training loss: 3.532263994216919


training:   1%|          | 56/5566 [02:52<4:18:00,  2.81s/it]

training loss: 3.5293478965759277


training:   1%|          | 57/5566 [02:55<4:22:50,  2.86s/it]

training loss: 3.5435898303985596


training:   1%|          | 58/5566 [02:58<4:08:21,  2.71s/it]

training loss: 3.5384974479675293


training:   1%|          | 59/5566 [03:00<4:09:09,  2.71s/it]

training loss: 3.512162685394287


training:   1%|          | 60/5566 [03:03<3:59:01,  2.60s/it]

training loss: 3.5339903831481934
valid loss: 3.5419015884399414
perplexity: 34.53252410888672


training:   1%|          | 61/5566 [03:08<5:01:02,  3.28s/it]

training loss: 3.5238215923309326


training:   1%|          | 62/5566 [03:11<4:48:04,  3.14s/it]

training loss: 3.5256857872009277


training:   1%|          | 63/5566 [03:13<4:26:35,  2.91s/it]

training loss: 3.5246102809906006


training:   1%|          | 64/5566 [03:16<4:23:05,  2.87s/it]

training loss: 3.544344663619995


training:   1%|          | 65/5566 [03:18<4:08:12,  2.71s/it]

training loss: 3.5157201290130615


training:   1%|          | 66/5566 [03:21<4:11:22,  2.74s/it]

training loss: 3.5273122787475586


training:   1%|          | 67/5566 [03:23<4:01:48,  2.64s/it]

training loss: 3.5078279972076416


training:   1%|          | 68/5566 [03:26<4:06:11,  2.69s/it]

training loss: 3.5220227241516113


training:   1%|          | 69/5566 [03:28<3:59:01,  2.61s/it]

training loss: 3.5421197414398193


training:   1%|▏         | 70/5566 [03:31<4:04:01,  2.66s/it]

training loss: 3.528238296508789


training:   1%|▏         | 71/5566 [03:34<3:55:45,  2.57s/it]

training loss: 3.530125141143799


training:   1%|▏         | 72/5566 [03:36<4:02:03,  2.64s/it]

training loss: 3.512470245361328


training:   1%|▏         | 73/5566 [03:39<3:54:54,  2.57s/it]

training loss: 3.53590989112854


training:   1%|▏         | 74/5566 [03:42<4:02:50,  2.65s/it]

training loss: 3.5292563438415527


training:   1%|▏         | 75/5566 [03:44<3:55:19,  2.57s/it]

training loss: 3.5203263759613037


training:   1%|▏         | 76/5566 [03:47<4:02:29,  2.65s/it]

training loss: 3.5194859504699707


training:   1%|▏         | 77/5566 [03:49<3:53:42,  2.55s/it]

training loss: 3.5161848068237305


training:   1%|▏         | 78/5566 [03:52<4:03:26,  2.66s/it]

training loss: 3.5199434757232666


training:   1%|▏         | 79/5566 [03:54<3:54:41,  2.57s/it]

training loss: 3.524299383163452


training:   1%|▏         | 80/5566 [03:57<4:00:47,  2.63s/it]

training loss: 3.505640745162964
valid loss: 3.5162339210510254
perplexity: 33.657432556152344


training:   1%|▏         | 81/5566 [04:03<5:16:11,  3.46s/it]

training loss: 3.5187454223632812


training:   1%|▏         | 82/5566 [04:05<4:43:55,  3.11s/it]

training loss: 3.5456743240356445


training:   1%|▏         | 83/5566 [04:08<4:32:49,  2.99s/it]

training loss: 3.528416872024536


training:   2%|▏         | 84/5566 [04:10<4:15:55,  2.80s/it]

training loss: 3.5452492237091064


training:   2%|▏         | 85/5566 [04:13<4:17:12,  2.82s/it]

training loss: 3.518573045730591


training:   2%|▏         | 86/5566 [04:15<4:05:07,  2.68s/it]

training loss: 3.531587839126587


training:   2%|▏         | 87/5566 [04:18<4:08:42,  2.72s/it]

training loss: 3.474480390548706


training:   2%|▏         | 88/5566 [04:20<3:57:57,  2.61s/it]

training loss: 3.5349104404449463


training:   2%|▏         | 89/5566 [04:23<4:07:28,  2.71s/it]

training loss: 3.52778697013855


training:   2%|▏         | 90/5566 [04:26<3:58:54,  2.62s/it]

training loss: 3.536407232284546


training:   2%|▏         | 91/5566 [04:29<4:04:48,  2.68s/it]

training loss: 3.4993064403533936


training:   2%|▏         | 92/5566 [04:31<3:55:28,  2.58s/it]

training loss: 3.5082271099090576


training:   2%|▏         | 93/5566 [04:34<4:02:35,  2.66s/it]

training loss: 3.5222575664520264


training:   2%|▏         | 94/5566 [04:36<3:54:25,  2.57s/it]

training loss: 3.5194621086120605


training:   2%|▏         | 95/5566 [04:39<4:00:46,  2.64s/it]

training loss: 3.5073182582855225


training:   2%|▏         | 96/5566 [04:41<3:53:17,  2.56s/it]

training loss: 3.529202461242676


training:   2%|▏         | 97/5566 [04:44<3:59:58,  2.63s/it]

training loss: 3.5119550228118896


training:   2%|▏         | 98/5566 [04:46<3:52:37,  2.55s/it]

training loss: 3.491147041320801


training:   2%|▏         | 99/5566 [04:49<3:58:16,  2.62s/it]

training loss: 3.533970355987549


training:   2%|▏         | 100/5566 [04:52<3:51:04,  2.54s/it]

training loss: 3.5265252590179443
valid loss: 3.5313384532928467
perplexity: 34.16967010498047


training:   2%|▏         | 101/5566 [04:58<5:48:53,  3.83s/it]

training loss: 3.540898084640503


training:   2%|▏         | 102/5566 [05:01<5:23:47,  3.56s/it]

training loss: 3.527052402496338


training:   2%|▏         | 103/5566 [05:04<4:52:32,  3.21s/it]

training loss: 3.515449285507202


training:   2%|▏         | 104/5566 [05:07<4:40:52,  3.09s/it]

training loss: 3.4996683597564697


training:   2%|▏         | 105/5566 [05:09<4:20:23,  2.86s/it]

training loss: 3.5627918243408203


training:   2%|▏         | 106/5566 [05:12<4:17:40,  2.83s/it]

training loss: 3.5148916244506836


training:   2%|▏         | 107/5566 [05:14<4:04:26,  2.69s/it]

training loss: 3.5090038776397705


training:   2%|▏         | 108/5566 [05:17<4:06:26,  2.71s/it]

training loss: 3.5553359985351562


training:   2%|▏         | 109/5566 [05:19<3:56:15,  2.60s/it]

training loss: 3.5236992835998535


training:   2%|▏         | 110/5566 [05:22<3:59:51,  2.64s/it]

training loss: 3.501585006713867


training:   2%|▏         | 111/5566 [05:24<3:53:05,  2.56s/it]

training loss: 3.5327062606811523


training:   2%|▏         | 112/5566 [05:27<3:57:40,  2.61s/it]

training loss: 3.540707588195801


training:   2%|▏         | 113/5566 [05:29<3:50:20,  2.53s/it]

training loss: 3.5290024280548096


training:   2%|▏         | 114/5566 [05:32<3:56:25,  2.60s/it]

training loss: 3.5192201137542725


training:   2%|▏         | 115/5566 [05:34<3:50:19,  2.54s/it]

training loss: 3.5207016468048096


training:   2%|▏         | 116/5566 [05:37<3:56:41,  2.61s/it]

training loss: 3.5437614917755127


training:   2%|▏         | 117/5566 [05:40<3:50:19,  2.54s/it]

training loss: 3.510849952697754


training:   2%|▏         | 118/5566 [05:42<3:56:46,  2.61s/it]

training loss: 3.542234420776367


training:   2%|▏         | 119/5566 [05:45<3:50:03,  2.53s/it]

training loss: 3.522778034210205


training:   2%|▏         | 120/5566 [05:47<3:57:38,  2.62s/it]

training loss: 3.5352301597595215
valid loss: 3.5318331718444824
perplexity: 34.186580657958984


training:   2%|▏         | 121/5566 [05:52<4:55:34,  3.26s/it]

training loss: 3.5123212337493896


training:   2%|▏         | 122/5566 [05:55<4:32:11,  3.00s/it]

training loss: 3.519204616546631


training:   2%|▏         | 123/5566 [05:57<4:24:13,  2.91s/it]

training loss: 3.5341858863830566


training:   2%|▏         | 124/5566 [06:00<4:08:53,  2.74s/it]

training loss: 3.5357844829559326


training:   2%|▏         | 125/5566 [06:02<4:10:21,  2.76s/it]

training loss: 3.507946491241455


training:   2%|▏         | 126/5566 [06:05<3:58:43,  2.63s/it]

training loss: 3.547760486602783


training:   2%|▏         | 127/5566 [06:08<4:01:47,  2.67s/it]

training loss: 3.48419451713562


training:   2%|▏         | 128/5566 [06:10<3:52:56,  2.57s/it]

training loss: 3.526498794555664


training:   2%|▏         | 129/5566 [06:13<3:58:03,  2.63s/it]

training loss: 3.5340495109558105


training:   2%|▏         | 130/5566 [06:15<3:50:39,  2.55s/it]

training loss: 3.5018508434295654


training:   2%|▏         | 131/5566 [06:18<3:57:33,  2.62s/it]

training loss: 3.514620304107666


training:   2%|▏         | 132/5566 [06:20<3:50:23,  2.54s/it]

training loss: 3.5127899646759033


training:   2%|▏         | 133/5566 [06:23<3:57:41,  2.62s/it]

training loss: 3.525838851928711


training:   2%|▏         | 134/5566 [06:25<3:51:04,  2.55s/it]

training loss: 3.5340142250061035


training:   2%|▏         | 135/5566 [06:28<3:57:47,  2.63s/it]

training loss: 3.511503219604492


training:   2%|▏         | 136/5566 [06:31<3:50:32,  2.55s/it]

training loss: 3.5265512466430664


training:   2%|▏         | 137/5566 [06:33<3:54:59,  2.60s/it]

training loss: 3.5374155044555664


training:   2%|▏         | 138/5566 [06:36<3:48:28,  2.53s/it]

training loss: 3.502375602722168


training:   2%|▏         | 139/5566 [06:38<3:56:24,  2.61s/it]

training loss: 3.550943374633789


training:   3%|▎         | 140/5566 [06:41<3:49:03,  2.53s/it]

training loss: 3.511343240737915
valid loss: 3.518869161605835
perplexity: 33.746246337890625


training:   3%|▎         | 141/5566 [06:46<4:51:45,  3.23s/it]

training loss: 3.5420947074890137


training:   3%|▎         | 142/5566 [06:48<4:39:17,  3.09s/it]

training loss: 3.5247724056243896


training:   3%|▎         | 143/5566 [06:51<4:18:43,  2.86s/it]

training loss: 3.5211341381073


training:   3%|▎         | 144/5566 [06:54<4:15:58,  2.83s/it]

training loss: 3.5220305919647217


training:   3%|▎         | 145/5566 [06:56<4:04:41,  2.71s/it]

training loss: 3.54026460647583


training:   3%|▎         | 146/5566 [06:59<4:06:45,  2.73s/it]

training loss: 3.5313520431518555


training:   3%|▎         | 147/5566 [07:01<3:57:29,  2.63s/it]

training loss: 3.5458288192749023


training:   3%|▎         | 148/5566 [07:04<4:00:57,  2.67s/it]

training loss: 3.512416124343872


training:   3%|▎         | 149/5566 [07:06<3:53:13,  2.58s/it]

training loss: 3.5274229049682617


training:   3%|▎         | 150/5566 [07:09<3:59:33,  2.65s/it]

training loss: 3.5381860733032227


training:   3%|▎         | 151/5566 [07:11<3:50:51,  2.56s/it]

training loss: 3.527632713317871


training:   3%|▎         | 152/5566 [07:14<3:57:09,  2.63s/it]

training loss: 3.528085947036743


training:   3%|▎         | 153/5566 [07:17<3:50:24,  2.55s/it]

training loss: 3.5373306274414062


training:   3%|▎         | 154/5566 [07:19<3:56:49,  2.63s/it]

training loss: 3.5439460277557373


training:   3%|▎         | 155/5566 [07:22<3:49:30,  2.54s/it]

training loss: 3.528449773788452


training:   3%|▎         | 156/5566 [07:25<3:57:42,  2.64s/it]

training loss: 3.517338991165161


training:   3%|▎         | 157/5566 [07:27<3:50:05,  2.55s/it]

training loss: 3.5315675735473633


training:   3%|▎         | 158/5566 [07:30<3:55:54,  2.62s/it]

training loss: 3.5029890537261963


training:   3%|▎         | 159/5566 [07:32<3:47:51,  2.53s/it]

training loss: 3.5343704223632812


training:   3%|▎         | 160/5566 [07:36<4:14:04,  2.82s/it]

training loss: 3.526670455932617
valid loss: 3.54310941696167
perplexity: 34.57426071166992


training:   3%|▎         | 161/5566 [07:41<5:13:47,  3.48s/it]

training loss: 3.526201009750366


training:   3%|▎         | 162/5566 [07:43<4:41:04,  3.12s/it]

training loss: 3.5254783630371094


training:   3%|▎         | 163/5566 [07:46<4:29:56,  3.00s/it]

training loss: 3.539426326751709


training:   3%|▎         | 164/5566 [07:48<4:12:21,  2.80s/it]

training loss: 3.5176146030426025


training:   3%|▎         | 165/5566 [07:51<4:11:21,  2.79s/it]

training loss: 3.5165741443634033


training:   3%|▎         | 166/5566 [07:53<3:58:40,  2.65s/it]

training loss: 3.515099048614502


training:   3%|▎         | 167/5566 [07:56<4:02:41,  2.70s/it]

training loss: 3.5260512828826904


training:   3%|▎         | 168/5566 [07:58<3:53:46,  2.60s/it]

training loss: 3.50551438331604


training:   3%|▎         | 169/5566 [08:01<3:58:36,  2.65s/it]

training loss: 3.525132179260254


training:   3%|▎         | 170/5566 [08:03<3:51:14,  2.57s/it]

training loss: 3.5179455280303955


training:   3%|▎         | 171/5566 [08:06<3:56:29,  2.63s/it]

training loss: 3.5295677185058594


training:   3%|▎         | 172/5566 [08:08<3:49:10,  2.55s/it]

training loss: 3.534668445587158


training:   3%|▎         | 173/5566 [08:11<3:56:53,  2.64s/it]

training loss: 3.5051610469818115


training:   3%|▎         | 174/5566 [08:14<3:58:43,  2.66s/it]

training loss: 3.5211894512176514


training:   3%|▎         | 175/5566 [08:17<4:00:42,  2.68s/it]

training loss: 3.4989681243896484


training:   3%|▎         | 176/5566 [08:19<3:51:45,  2.58s/it]

training loss: 3.5203206539154053


training:   3%|▎         | 177/5566 [08:22<3:57:14,  2.64s/it]

training loss: 3.5196213722229004


training:   3%|▎         | 178/5566 [08:24<3:48:52,  2.55s/it]

training loss: 3.543381929397583


training:   3%|▎         | 179/5566 [08:27<3:54:59,  2.62s/it]

training loss: 3.5160303115844727


training:   3%|▎         | 180/5566 [08:29<3:50:09,  2.56s/it]

training loss: 3.5039467811584473
valid loss: 3.516507148742676
perplexity: 33.666629791259766


training:   3%|▎         | 181/5566 [08:34<4:52:03,  3.25s/it]

training loss: 3.531252384185791


training:   3%|▎         | 182/5566 [08:37<4:40:01,  3.12s/it]

training loss: 3.5152649879455566


training:   3%|▎         | 183/5566 [08:39<4:19:29,  2.89s/it]

training loss: 3.530245304107666


training:   3%|▎         | 184/5566 [08:42<4:15:58,  2.85s/it]

training loss: 3.5259900093078613


training:   3%|▎         | 185/5566 [08:45<4:02:05,  2.70s/it]

training loss: 3.52459454536438


training:   3%|▎         | 186/5566 [08:47<4:03:44,  2.72s/it]

training loss: 3.5146069526672363


training:   3%|▎         | 187/5566 [08:50<3:54:06,  2.61s/it]

training loss: 3.526050329208374


training:   3%|▎         | 188/5566 [08:52<3:57:44,  2.65s/it]

training loss: 3.4979944229125977


training:   3%|▎         | 189/5566 [08:55<3:49:39,  2.56s/it]

training loss: 3.545095205307007


training:   3%|▎         | 190/5566 [08:58<3:55:10,  2.62s/it]

training loss: 3.519660472869873


training:   3%|▎         | 191/5566 [09:00<3:49:19,  2.56s/it]

training loss: 3.5273818969726562


training:   3%|▎         | 192/5566 [09:03<3:55:55,  2.63s/it]

training loss: 3.5508480072021484


training:   3%|▎         | 193/5566 [09:05<3:48:59,  2.56s/it]

training loss: 3.523473024368286


training:   3%|▎         | 194/5566 [09:08<3:54:38,  2.62s/it]

training loss: 3.565340518951416


training:   4%|▎         | 195/5566 [09:10<3:48:05,  2.55s/it]

training loss: 3.526916980743408


training:   4%|▎         | 196/5566 [09:13<3:53:30,  2.61s/it]

training loss: 3.525521755218506


training:   4%|▎         | 197/5566 [09:15<3:46:41,  2.53s/it]

training loss: 3.5378098487854004


training:   4%|▎         | 198/5566 [09:18<3:52:53,  2.60s/it]

training loss: 3.5369911193847656


training:   4%|▎         | 199/5566 [09:21<3:46:26,  2.53s/it]

training loss: 3.5374412536621094


training:   4%|▎         | 200/5566 [09:23<3:52:06,  2.60s/it]

training loss: 3.527717113494873
valid loss: 3.505425214767456
perplexity: 33.29560089111328


training:   4%|▎         | 201/5566 [09:29<5:11:40,  3.49s/it]

training loss: 3.5176501274108887


training:   4%|▎         | 202/5566 [09:31<4:40:46,  3.14s/it]

training loss: 3.532334089279175


training:   4%|▎         | 203/5566 [09:34<4:27:38,  2.99s/it]

training loss: 3.531571388244629


training:   4%|▎         | 204/5566 [09:36<4:10:34,  2.80s/it]

training loss: 3.5083959102630615


training:   4%|▎         | 205/5566 [09:39<4:09:26,  2.79s/it]

training loss: 3.50840425491333


training:   4%|▎         | 206/5566 [09:41<3:57:13,  2.66s/it]

training loss: 3.5237607955932617


training:   4%|▎         | 207/5566 [09:44<3:59:11,  2.68s/it]

training loss: 3.5260305404663086


training:   4%|▎         | 208/5566 [09:46<3:50:59,  2.59s/it]

training loss: 3.5287909507751465


training:   4%|▍         | 209/5566 [09:49<3:55:13,  2.63s/it]

training loss: 3.5199601650238037


training:   4%|▍         | 210/5566 [09:51<3:48:15,  2.56s/it]

training loss: 3.543588876724243


training:   4%|▍         | 211/5566 [09:54<3:54:20,  2.63s/it]

training loss: 3.548218250274658


training:   4%|▍         | 212/5566 [09:57<3:47:12,  2.55s/it]

training loss: 3.5394508838653564


training:   4%|▍         | 213/5566 [09:59<3:54:52,  2.63s/it]

training loss: 3.5057411193847656


training:   4%|▍         | 214/5566 [10:02<3:47:18,  2.55s/it]

training loss: 3.5253047943115234


training:   4%|▍         | 215/5566 [10:05<3:53:39,  2.62s/it]

training loss: 3.506970167160034


training:   4%|▍         | 216/5566 [10:07<3:46:30,  2.54s/it]

training loss: 3.5391910076141357


training:   4%|▍         | 217/5566 [10:10<3:52:37,  2.61s/it]

training loss: 3.5303590297698975


training:   4%|▍         | 218/5566 [10:12<3:45:40,  2.53s/it]

training loss: 3.503173351287842


training:   4%|▍         | 219/5566 [10:15<3:52:30,  2.61s/it]

training loss: 3.515212059020996


training:   4%|▍         | 220/5566 [10:18<4:03:48,  2.74s/it]

training loss: 3.525360345840454
valid loss: 3.5137338638305664
perplexity: 33.573394775390625


training:   4%|▍         | 221/5566 [10:23<5:17:33,  3.56s/it]

training loss: 3.5284008979797363


training:   4%|▍         | 222/5566 [10:26<4:57:20,  3.34s/it]

training loss: 3.5380136966705322


training:   4%|▍         | 223/5566 [10:29<4:31:13,  3.05s/it]

training loss: 3.520019292831421


training:   4%|▍         | 224/5566 [10:31<4:25:17,  2.98s/it]

training loss: 3.516718864440918


training:   4%|▍         | 225/5566 [10:34<4:08:25,  2.79s/it]

training loss: 3.5381336212158203


training:   4%|▍         | 226/5566 [10:37<4:08:47,  2.80s/it]

training loss: 3.5349416732788086


training:   4%|▍         | 227/5566 [10:39<3:56:53,  2.66s/it]

training loss: 3.519331455230713


training:   4%|▍         | 228/5566 [10:42<4:00:19,  2.70s/it]

training loss: 3.5476691722869873


training:   4%|▍         | 229/5566 [10:44<3:50:25,  2.59s/it]

training loss: 3.529012441635132


training:   4%|▍         | 230/5566 [10:47<3:55:02,  2.64s/it]

training loss: 3.5208163261413574


training:   4%|▍         | 231/5566 [10:49<3:48:12,  2.57s/it]

training loss: 3.5255584716796875


training:   4%|▍         | 232/5566 [10:52<3:55:56,  2.65s/it]

training loss: 3.515810966491699


training:   4%|▍         | 233/5566 [10:54<3:49:14,  2.58s/it]

training loss: 3.5121564865112305


training:   4%|▍         | 234/5566 [10:57<3:55:10,  2.65s/it]

training loss: 3.511301040649414


training:   4%|▍         | 235/5566 [11:00<3:48:29,  2.57s/it]

training loss: 3.5258750915527344


training:   4%|▍         | 236/5566 [11:03<3:56:55,  2.67s/it]

training loss: 3.5495336055755615


training:   4%|▍         | 237/5566 [11:05<3:48:01,  2.57s/it]

training loss: 3.521681070327759


training:   4%|▍         | 238/5566 [11:08<3:53:58,  2.63s/it]

training loss: 3.5300114154815674


training:   4%|▍         | 239/5566 [11:10<3:47:11,  2.56s/it]

training loss: 3.5167863368988037


training:   4%|▍         | 240/5566 [11:13<3:52:45,  2.62s/it]

training loss: 3.5346460342407227
valid loss: 3.5212645530700684
perplexity: 33.827178955078125


training:   4%|▍         | 241/5566 [11:18<4:48:42,  3.25s/it]

training loss: 3.522596597671509


training:   4%|▍         | 242/5566 [11:20<4:22:32,  2.96s/it]

training loss: 3.5020534992218018


training:   4%|▍         | 243/5566 [11:23<4:15:55,  2.88s/it]

training loss: 3.546557664871216


training:   4%|▍         | 244/5566 [11:25<4:02:45,  2.74s/it]

training loss: 3.5165083408355713


training:   4%|▍         | 245/5566 [11:28<4:03:47,  2.75s/it]

training loss: 3.5244741439819336


training:   4%|▍         | 246/5566 [11:30<3:54:57,  2.65s/it]

training loss: 3.5215401649475098


training:   4%|▍         | 247/5566 [11:33<4:00:14,  2.71s/it]

training loss: 3.504451274871826


training:   4%|▍         | 248/5566 [11:35<3:50:49,  2.60s/it]

training loss: 3.5156428813934326


training:   4%|▍         | 249/5566 [11:38<3:54:57,  2.65s/it]

training loss: 3.5272982120513916


training:   4%|▍         | 250/5566 [11:40<3:47:34,  2.57s/it]

training loss: 3.5305447578430176


training:   5%|▍         | 251/5566 [11:43<3:53:58,  2.64s/it]

training loss: 3.516709327697754


training:   5%|▍         | 252/5566 [11:46<3:47:20,  2.57s/it]

training loss: 3.550971746444702


training:   5%|▍         | 253/5566 [11:48<3:53:58,  2.64s/it]

training loss: 3.5239715576171875


training:   5%|▍         | 254/5566 [11:51<3:46:15,  2.56s/it]

training loss: 3.5255038738250732


training:   5%|▍         | 255/5566 [11:54<3:52:38,  2.63s/it]

training loss: 3.532017946243286


training:   5%|▍         | 256/5566 [11:56<3:45:14,  2.55s/it]

training loss: 3.5296428203582764


training:   5%|▍         | 257/5566 [11:59<3:52:13,  2.62s/it]

training loss: 3.523723602294922


training:   5%|▍         | 258/5566 [12:01<3:47:31,  2.57s/it]

training loss: 3.523699998855591


training:   5%|▍         | 259/5566 [12:04<3:53:59,  2.65s/it]

training loss: 3.52539324760437


training:   5%|▍         | 260/5566 [12:06<3:46:10,  2.56s/it]

training loss: 3.5268146991729736
valid loss: 3.5176188945770264
perplexity: 33.704078674316406


training:   5%|▍         | 261/5566 [12:11<4:47:04,  3.25s/it]

training loss: 3.542511463165283


training:   5%|▍         | 262/5566 [12:14<4:34:46,  3.11s/it]

training loss: 3.5228095054626465


training:   5%|▍         | 263/5566 [12:16<4:15:05,  2.89s/it]

training loss: 3.5128769874572754


training:   5%|▍         | 264/5566 [12:19<4:11:21,  2.84s/it]

training loss: 3.5334041118621826


training:   5%|▍         | 265/5566 [12:22<3:59:18,  2.71s/it]

training loss: 3.506409168243408


training:   5%|▍         | 266/5566 [12:24<3:59:25,  2.71s/it]

training loss: 3.531296730041504


training:   5%|▍         | 267/5566 [12:27<3:50:23,  2.61s/it]

training loss: 3.504136800765991


training:   5%|▍         | 268/5566 [12:29<3:54:46,  2.66s/it]

training loss: 3.5170552730560303


training:   5%|▍         | 269/5566 [12:32<3:49:09,  2.60s/it]

training loss: 3.5241971015930176


training:   5%|▍         | 270/5566 [12:35<3:54:14,  2.65s/it]

training loss: 3.5497701168060303


training:   5%|▍         | 271/5566 [12:37<3:45:32,  2.56s/it]

training loss: 3.5210938453674316


training:   5%|▍         | 272/5566 [12:40<3:51:49,  2.63s/it]

training loss: 3.533083200454712


training:   5%|▍         | 273/5566 [12:42<3:44:00,  2.54s/it]

training loss: 3.516176700592041


training:   5%|▍         | 274/5566 [12:45<3:49:46,  2.61s/it]

training loss: 3.5150794982910156


training:   5%|▍         | 275/5566 [12:47<3:43:17,  2.53s/it]

training loss: 3.538898229598999


training:   5%|▍         | 276/5566 [12:50<3:49:06,  2.60s/it]

training loss: 3.5150363445281982


training:   5%|▍         | 277/5566 [12:52<3:42:44,  2.53s/it]

training loss: 3.5010368824005127


training:   5%|▍         | 278/5566 [12:55<3:50:21,  2.61s/it]

training loss: 3.5129783153533936


training:   5%|▌         | 279/5566 [12:58<3:44:08,  2.54s/it]

training loss: 3.550708055496216


training:   5%|▌         | 280/5566 [13:00<3:50:57,  2.62s/it]

training loss: 3.5250086784362793
valid loss: 3.506404399871826
perplexity: 33.328216552734375


training:   5%|▌         | 281/5566 [13:07<5:25:15,  3.69s/it]

training loss: 3.5012872219085693


training:   5%|▌         | 282/5566 [13:09<4:50:03,  3.29s/it]

training loss: 3.5329582691192627


training:   5%|▌         | 283/5566 [13:12<4:36:25,  3.14s/it]

training loss: 3.5262866020202637


training:   5%|▌         | 284/5566 [13:14<4:16:20,  2.91s/it]

training loss: 3.509516716003418


training:   5%|▌         | 285/5566 [13:17<4:12:44,  2.87s/it]

training loss: 3.5116918087005615


training:   5%|▌         | 286/5566 [13:19<3:59:19,  2.72s/it]

training loss: 3.5314853191375732


training:   5%|▌         | 287/5566 [13:22<4:00:55,  2.74s/it]

training loss: 3.50003981590271


training:   5%|▌         | 288/5566 [13:24<3:50:26,  2.62s/it]

training loss: 3.5216073989868164


training:   5%|▌         | 289/5566 [13:27<3:55:58,  2.68s/it]

training loss: 3.5119504928588867


training:   5%|▌         | 290/5566 [13:30<3:47:56,  2.59s/it]

training loss: 3.525113105773926


training:   5%|▌         | 291/5566 [13:32<3:56:42,  2.69s/it]

training loss: 3.532151699066162


training:   5%|▌         | 292/5566 [13:35<3:48:14,  2.60s/it]

training loss: 3.51466965675354


training:   5%|▌         | 293/5566 [13:38<3:53:42,  2.66s/it]

training loss: 3.5049641132354736


training:   5%|▌         | 294/5566 [13:40<3:45:59,  2.57s/it]

training loss: 3.544494152069092


training:   5%|▌         | 295/5566 [13:43<3:50:57,  2.63s/it]

training loss: 3.519364356994629


training:   5%|▌         | 296/5566 [13:45<3:44:09,  2.55s/it]

training loss: 3.508840322494507


training:   5%|▌         | 297/5566 [13:48<3:50:48,  2.63s/it]

training loss: 3.5179171562194824


training:   5%|▌         | 298/5566 [13:50<3:43:55,  2.55s/it]

training loss: 3.5208308696746826


training:   5%|▌         | 299/5566 [13:53<3:49:46,  2.62s/it]

training loss: 3.511155366897583


training:   5%|▌         | 300/5566 [13:55<3:43:41,  2.55s/it]

training loss: 3.5188236236572266
valid loss: 3.523702383041382
perplexity: 33.90974426269531


training:   5%|▌         | 301/5566 [14:01<5:03:27,  3.46s/it]

training loss: 3.5353031158447266


training:   5%|▌         | 302/5566 [14:04<4:52:17,  3.33s/it]

training loss: 3.539267063140869


training:   5%|▌         | 303/5566 [14:06<4:25:57,  3.03s/it]

training loss: 3.544179677963257


training:   5%|▌         | 304/5566 [14:09<4:19:22,  2.96s/it]

training loss: 3.5299293994903564


training:   5%|▌         | 305/5566 [14:12<4:04:10,  2.78s/it]

training loss: 3.5152857303619385


training:   5%|▌         | 306/5566 [14:14<4:03:33,  2.78s/it]

training loss: 3.5099072456359863


training:   6%|▌         | 307/5566 [14:17<3:52:11,  2.65s/it]

training loss: 3.5152902603149414


training:   6%|▌         | 308/5566 [14:20<3:56:34,  2.70s/it]

training loss: 3.528768539428711


training:   6%|▌         | 309/5566 [14:22<3:47:26,  2.60s/it]

training loss: 3.5112979412078857


training:   6%|▌         | 310/5566 [14:25<3:52:46,  2.66s/it]

training loss: 3.5126914978027344


training:   6%|▌         | 311/5566 [14:27<3:45:09,  2.57s/it]

training loss: 3.5255184173583984


training:   6%|▌         | 312/5566 [14:30<3:52:02,  2.65s/it]

training loss: 3.546659231185913


training:   6%|▌         | 313/5566 [14:32<3:44:28,  2.56s/it]

training loss: 3.5384507179260254


training:   6%|▌         | 314/5566 [14:35<3:51:03,  2.64s/it]

training loss: 3.5591635704040527


training:   6%|▌         | 315/5566 [14:37<3:43:48,  2.56s/it]

training loss: 3.5242233276367188


training:   6%|▌         | 316/5566 [14:40<3:50:21,  2.63s/it]

training loss: 3.490894317626953


training:   6%|▌         | 317/5566 [14:43<3:42:48,  2.55s/it]

training loss: 3.5109617710113525


training:   6%|▌         | 318/5566 [14:45<3:49:33,  2.62s/it]

training loss: 3.503485918045044


training:   6%|▌         | 319/5566 [14:48<3:42:59,  2.55s/it]

training loss: 3.536330223083496


training:   6%|▌         | 320/5566 [14:51<3:49:09,  2.62s/it]

training loss: 3.5219717025756836
valid loss: 3.5135085582733154
perplexity: 33.56583023071289


training:   6%|▌         | 321/5566 [14:55<4:46:31,  3.28s/it]

training loss: 3.539515972137451


training:   6%|▌         | 322/5566 [14:58<4:21:09,  2.99s/it]

training loss: 3.5227673053741455


training:   6%|▌         | 323/5566 [15:00<4:14:47,  2.92s/it]

training loss: 3.522395610809326


training:   6%|▌         | 324/5566 [15:03<4:01:07,  2.76s/it]

training loss: 3.5201499462127686


training:   6%|▌         | 325/5566 [15:06<4:01:50,  2.77s/it]

training loss: 3.520883560180664


training:   6%|▌         | 326/5566 [15:08<3:50:54,  2.64s/it]

training loss: 3.493387222290039


training:   6%|▌         | 327/5566 [15:11<3:54:10,  2.68s/it]

training loss: 3.51017689704895


training:   6%|▌         | 328/5566 [15:13<3:45:32,  2.58s/it]

training loss: 3.5241849422454834


training:   6%|▌         | 329/5566 [15:16<3:51:37,  2.65s/it]

training loss: 3.5252203941345215


training:   6%|▌         | 330/5566 [15:18<3:43:49,  2.56s/it]

training loss: 3.514580726623535


training:   6%|▌         | 331/5566 [15:21<3:50:21,  2.64s/it]

training loss: 3.549947738647461


training:   6%|▌         | 332/5566 [15:23<3:43:22,  2.56s/it]

training loss: 3.51802396774292


training:   6%|▌         | 333/5566 [15:26<3:48:35,  2.62s/it]

training loss: 3.554748296737671


training:   6%|▌         | 334/5566 [15:29<3:42:07,  2.55s/it]

training loss: 3.5256357192993164


training:   6%|▌         | 335/5566 [15:31<3:48:51,  2.63s/it]

training loss: 3.544724464416504


training:   6%|▌         | 336/5566 [15:34<3:44:37,  2.58s/it]

training loss: 3.5410726070404053


training:   6%|▌         | 337/5566 [15:37<3:51:24,  2.66s/it]

training loss: 3.5327744483947754


training:   6%|▌         | 338/5566 [15:39<3:44:29,  2.58s/it]

training loss: 3.5254015922546387


training:   6%|▌         | 339/5566 [15:42<3:50:22,  2.64s/it]

training loss: 3.5202465057373047


training:   6%|▌         | 340/5566 [15:44<3:42:09,  2.55s/it]

training loss: 3.5309739112854004
valid loss: 3.5082292556762695
perplexity: 33.38909149169922


training:   6%|▌         | 341/5566 [15:50<5:13:48,  3.60s/it]

training loss: 3.5406980514526367


training:   6%|▌         | 342/5566 [15:53<4:52:44,  3.36s/it]

training loss: 3.5260279178619385


training:   6%|▌         | 343/5566 [15:55<4:27:41,  3.08s/it]

training loss: 3.524021625518799


training:   6%|▌         | 344/5566 [15:58<4:21:44,  3.01s/it]

training loss: 3.5203983783721924


training:   6%|▌         | 345/5566 [16:01<4:04:50,  2.81s/it]

training loss: 3.53242564201355


training:   6%|▌         | 346/5566 [16:04<4:04:26,  2.81s/it]

training loss: 3.5227763652801514


training:   6%|▌         | 347/5566 [16:06<3:53:04,  2.68s/it]

training loss: 3.515712261199951


training:   6%|▋         | 348/5566 [16:09<3:56:02,  2.71s/it]

training loss: 3.5455269813537598


training:   6%|▋         | 349/5566 [16:11<3:46:50,  2.61s/it]

training loss: 3.5343129634857178


training:   6%|▋         | 350/5566 [16:14<3:51:34,  2.66s/it]

training loss: 3.516239881515503


training:   6%|▋         | 351/5566 [16:16<3:44:02,  2.58s/it]

training loss: 3.544826030731201


training:   6%|▋         | 352/5566 [16:19<3:49:13,  2.64s/it]

training loss: 3.5464115142822266


training:   6%|▋         | 353/5566 [16:21<3:41:46,  2.55s/it]

training loss: 3.5302393436431885


training:   6%|▋         | 354/5566 [16:24<3:48:57,  2.64s/it]

training loss: 3.5237510204315186


training:   6%|▋         | 355/5566 [16:27<3:41:24,  2.55s/it]

training loss: 3.5215680599212646


training:   6%|▋         | 356/5566 [16:29<3:47:08,  2.62s/it]

training loss: 3.535623550415039


training:   6%|▋         | 357/5566 [16:32<3:40:54,  2.54s/it]

training loss: 3.507612705230713


training:   6%|▋         | 358/5566 [16:35<3:49:09,  2.64s/it]

training loss: 3.5177695751190186


training:   6%|▋         | 359/5566 [16:37<3:42:58,  2.57s/it]

training loss: 3.5201869010925293


training:   6%|▋         | 360/5566 [16:40<3:48:34,  2.63s/it]

training loss: 3.5180511474609375
valid loss: 3.5132319927215576
perplexity: 33.556549072265625


training:   6%|▋         | 361/5566 [16:44<4:43:05,  3.26s/it]

training loss: 3.5421571731567383


training:   7%|▋         | 362/5566 [16:47<4:17:03,  2.96s/it]

training loss: 3.5290355682373047


training:   7%|▋         | 363/5566 [16:49<4:10:33,  2.89s/it]

training loss: 3.5268194675445557


training:   7%|▋         | 364/5566 [16:52<3:57:40,  2.74s/it]

training loss: 3.507648229598999


training:   7%|▋         | 365/5566 [16:55<4:01:04,  2.78s/it]

training loss: 3.495201826095581


training:   7%|▋         | 366/5566 [16:57<3:49:53,  2.65s/it]

training loss: 3.5187482833862305


training:   7%|▋         | 367/5566 [17:00<3:54:35,  2.71s/it]

training loss: 3.5293006896972656


training:   7%|▋         | 368/5566 [17:02<3:45:21,  2.60s/it]

training loss: 3.5435149669647217


training:   7%|▋         | 369/5566 [17:05<3:54:17,  2.70s/it]

training loss: 3.5260732173919678


training:   7%|▋         | 370/5566 [17:08<3:46:03,  2.61s/it]

training loss: 3.5357329845428467


training:   7%|▋         | 371/5566 [17:10<3:51:49,  2.68s/it]

training loss: 3.527036666870117


training:   7%|▋         | 372/5566 [17:13<3:44:27,  2.59s/it]

training loss: 3.504105567932129


training:   7%|▋         | 373/5566 [17:16<3:49:56,  2.66s/it]

training loss: 3.522382974624634


training:   7%|▋         | 374/5566 [17:18<3:42:05,  2.57s/it]

training loss: 3.508206844329834


training:   7%|▋         | 375/5566 [17:21<3:48:12,  2.64s/it]

training loss: 3.5337679386138916


training:   7%|▋         | 376/5566 [17:23<3:40:58,  2.55s/it]

training loss: 3.524688959121704


training:   7%|▋         | 377/5566 [17:26<3:46:17,  2.62s/it]

training loss: 3.522331476211548


training:   7%|▋         | 378/5566 [17:28<3:39:44,  2.54s/it]

training loss: 3.5247066020965576


training:   7%|▋         | 379/5566 [17:31<3:46:02,  2.61s/it]

training loss: 3.534961700439453


training:   7%|▋         | 380/5566 [17:33<3:39:19,  2.54s/it]

training loss: 3.5304372310638428
valid loss: 3.5407190322875977
perplexity: 34.4917106628418


training:   7%|▋         | 381/5566 [17:38<4:39:42,  3.24s/it]

training loss: 3.5493874549865723


training:   7%|▋         | 382/5566 [17:41<4:28:46,  3.11s/it]

training loss: 3.495544672012329


training:   7%|▋         | 383/5566 [17:43<4:08:56,  2.88s/it]

training loss: 3.5200915336608887


training:   7%|▋         | 384/5566 [17:46<4:06:35,  2.86s/it]

training loss: 3.52955961227417


training:   7%|▋         | 385/5566 [17:49<3:53:39,  2.71s/it]

training loss: 3.5387303829193115


training:   7%|▋         | 386/5566 [17:51<3:55:47,  2.73s/it]

training loss: 3.539379119873047


training:   7%|▋         | 387/5566 [17:54<3:45:51,  2.62s/it]

training loss: 3.5206804275512695


training:   7%|▋         | 388/5566 [17:56<3:49:35,  2.66s/it]

training loss: 3.541267156600952


training:   7%|▋         | 389/5566 [17:59<3:42:08,  2.57s/it]

training loss: 3.5350565910339355


training:   7%|▋         | 390/5566 [18:02<3:47:34,  2.64s/it]

training loss: 3.5393500328063965


training:   7%|▋         | 391/5566 [18:04<3:40:33,  2.56s/it]

training loss: 3.491471767425537


training:   7%|▋         | 392/5566 [18:07<3:47:22,  2.64s/it]

training loss: 3.528920888900757


training:   7%|▋         | 393/5566 [18:09<3:40:27,  2.56s/it]

training loss: 3.5263373851776123


training:   7%|▋         | 394/5566 [18:12<3:46:26,  2.63s/it]

training loss: 3.5314905643463135


training:   7%|▋         | 395/5566 [18:15<3:43:45,  2.60s/it]

training loss: 3.5123612880706787


training:   7%|▋         | 396/5566 [18:17<3:48:07,  2.65s/it]

training loss: 3.5228431224823


training:   7%|▋         | 397/5566 [18:20<3:40:33,  2.56s/it]

training loss: 3.5203683376312256


training:   7%|▋         | 398/5566 [18:22<3:46:28,  2.63s/it]

training loss: 3.5508170127868652


training:   7%|▋         | 399/5566 [18:25<3:38:51,  2.54s/it]

training loss: 3.5240015983581543


training:   7%|▋         | 400/5566 [18:28<3:45:26,  2.62s/it]

training loss: 3.538647413253784
valid loss: 3.5178110599517822
perplexity: 33.7105598449707


training:   7%|▋         | 401/5566 [18:34<5:23:03,  3.75s/it]

training loss: 3.5247855186462402


training:   7%|▋         | 402/5566 [18:37<4:51:27,  3.39s/it]

training loss: 3.511456251144409


training:   7%|▋         | 403/5566 [18:39<4:35:38,  3.20s/it]

training loss: 3.516627788543701


training:   7%|▋         | 404/5566 [18:42<4:15:05,  2.97s/it]

training loss: 3.5109076499938965


training:   7%|▋         | 405/5566 [18:44<4:10:45,  2.92s/it]

training loss: 3.5294907093048096


training:   7%|▋         | 406/5566 [18:47<3:55:55,  2.74s/it]

training loss: 3.50346040725708


training:   7%|▋         | 407/5566 [18:50<3:55:19,  2.74s/it]

training loss: 3.546952724456787


training:   7%|▋         | 408/5566 [18:52<3:44:35,  2.61s/it]

training loss: 3.5253078937530518


training:   7%|▋         | 409/5566 [18:55<3:49:26,  2.67s/it]

training loss: 3.530890941619873


training:   7%|▋         | 410/5566 [18:57<3:41:05,  2.57s/it]

training loss: 3.5127007961273193


training:   7%|▋         | 411/5566 [19:00<3:47:14,  2.64s/it]

training loss: 3.520591974258423


training:   7%|▋         | 412/5566 [19:02<3:39:18,  2.55s/it]

training loss: 3.5168466567993164


training:   7%|▋         | 413/5566 [19:05<3:45:28,  2.63s/it]

training loss: 3.526437997817993


training:   7%|▋         | 414/5566 [19:07<3:39:58,  2.56s/it]

training loss: 3.530978202819824


training:   7%|▋         | 415/5566 [19:10<3:49:15,  2.67s/it]

training loss: 3.5078659057617188


training:   7%|▋         | 416/5566 [19:13<3:41:50,  2.58s/it]

training loss: 3.535092830657959


training:   7%|▋         | 417/5566 [19:15<3:46:59,  2.64s/it]

training loss: 3.5045313835144043


training:   8%|▊         | 418/5566 [19:18<3:39:38,  2.56s/it]

training loss: 3.5041885375976562


training:   8%|▊         | 419/5566 [19:21<3:45:29,  2.63s/it]

training loss: 3.529899835586548


training:   8%|▊         | 420/5566 [19:23<3:38:23,  2.55s/it]

training loss: 3.545961380004883
valid loss: 3.532926082611084
perplexity: 34.22396469116211


training:   8%|▊         | 421/5566 [19:28<4:38:33,  3.25s/it]

training loss: 3.5114493370056152


training:   8%|▊         | 422/5566 [19:31<4:26:36,  3.11s/it]

training loss: 3.515712261199951


training:   8%|▊         | 423/5566 [19:33<4:07:22,  2.89s/it]

training loss: 3.553297519683838


training:   8%|▊         | 424/5566 [19:36<4:05:26,  2.86s/it]

training loss: 3.515012741088867


training:   8%|▊         | 425/5566 [19:38<3:53:26,  2.72s/it]

training loss: 3.512606143951416


training:   8%|▊         | 426/5566 [19:41<3:58:42,  2.79s/it]

training loss: 3.533738136291504


training:   8%|▊         | 427/5566 [19:44<3:47:50,  2.66s/it]

training loss: 3.540376901626587


training:   8%|▊         | 428/5566 [19:46<3:52:03,  2.71s/it]

training loss: 3.528101921081543


training:   8%|▊         | 429/5566 [19:49<3:42:36,  2.60s/it]

training loss: 3.53572678565979


training:   8%|▊         | 430/5566 [19:52<3:47:47,  2.66s/it]

training loss: 3.5472238063812256


training:   8%|▊         | 431/5566 [19:54<3:39:56,  2.57s/it]

training loss: 3.5149717330932617


training:   8%|▊         | 432/5566 [19:57<3:46:28,  2.65s/it]

training loss: 3.5031447410583496


training:   8%|▊         | 433/5566 [19:59<3:39:24,  2.56s/it]

training loss: 3.5334837436676025


training:   8%|▊         | 434/5566 [20:02<3:44:16,  2.62s/it]

training loss: 3.5249624252319336


training:   8%|▊         | 435/5566 [20:04<3:37:49,  2.55s/it]

training loss: 3.516166925430298


training:   8%|▊         | 436/5566 [20:07<3:43:51,  2.62s/it]

training loss: 3.540454626083374


training:   8%|▊         | 437/5566 [20:09<3:38:26,  2.56s/it]

training loss: 3.5027971267700195


training:   8%|▊         | 438/5566 [20:12<3:45:11,  2.63s/it]

training loss: 3.5399529933929443


training:   8%|▊         | 439/5566 [20:15<3:38:47,  2.56s/it]

training loss: 3.5607237815856934


training:   8%|▊         | 440/5566 [20:17<3:44:12,  2.62s/it]

training loss: 3.5363683700561523
valid loss: 3.5150179862976074
perplexity: 33.61653137207031


training:   8%|▊         | 441/5566 [20:22<4:39:03,  3.27s/it]

training loss: 3.5038866996765137


training:   8%|▊         | 442/5566 [20:24<4:13:47,  2.97s/it]

training loss: 3.5131168365478516


training:   8%|▊         | 443/5566 [20:27<4:07:19,  2.90s/it]

training loss: 3.4993391036987305


training:   8%|▊         | 444/5566 [20:30<3:53:48,  2.74s/it]

training loss: 3.507577896118164


training:   8%|▊         | 445/5566 [20:32<3:54:26,  2.75s/it]

training loss: 3.54167103767395


training:   8%|▊         | 446/5566 [20:35<3:44:14,  2.63s/it]

training loss: 3.535661220550537


training:   8%|▊         | 447/5566 [20:37<3:49:01,  2.68s/it]

training loss: 3.517934799194336


training:   8%|▊         | 448/5566 [20:40<3:40:11,  2.58s/it]

training loss: 3.513162136077881


training:   8%|▊         | 449/5566 [20:43<3:46:17,  2.65s/it]

training loss: 3.5410304069519043


training:   8%|▊         | 450/5566 [20:45<3:39:18,  2.57s/it]

training loss: 3.513511896133423


training:   8%|▊         | 451/5566 [20:48<3:44:01,  2.63s/it]

training loss: 3.496908664703369


training:   8%|▊         | 452/5566 [20:50<3:37:27,  2.55s/it]

training loss: 3.540076494216919


training:   8%|▊         | 453/5566 [20:53<3:43:35,  2.62s/it]

training loss: 3.5061349868774414


training:   8%|▊         | 454/5566 [20:55<3:37:21,  2.55s/it]

training loss: 3.5225765705108643


training:   8%|▊         | 455/5566 [20:59<3:55:49,  2.77s/it]


KeyboardInterrupt: ignored

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(
          self,
          num_tokens,
          d,
          heads = 8,
          depth = 4,
          hidden_size = 1000,
          dropout = 0.3,
          batch_size = 16
      ):
          # asserts
          assert d % heads == 0

          super(TransformerDecoder, self).__init__()
          self.token_emb = nn.Embedding(num_tokens, d)
          self.positional_emb = PositionalEncoding(d, max_len = 5000)
          self.dim_head = d // heads
          self.d = d
          self.heads = heads
          self.depth = depth
          self.hidden_size = hidden_size
          self.dropout = dropout
          self.batch_size = batch_size

          self.layers = nn.ModuleList([])
          for idx in range(depth):
              attn = MultiHeadAttention(num_tokens, d, heads, self.batch_size)

              self.layers.append(nn.ModuleList([
                  attn,
                  SubLayer(d, dropout, hidden_size)
              ]))

          self.to_out = nn.Sequential(
               nn.LayerNorm(d),
               nn.Linear(d, num_tokens)
          )
          
    def forward(
        self,
        x
    ):
        batch_size, seq_len, *_, device = *x.shape, x.device
        x = self.token_emb(x)
        x = self.positional_emb(x)

        for idx, (attn, sub_l) in enumerate(self.layers):
            
            #attention
            x, mem = attn(x, device)
      
            # normalization + feedforward + residual connection
            x = sub_l(x)

        return self.to_out(x).transpose(1, 2)