<a href="https://colab.research.google.com/github/whoami-Lory271/NN-project-memorizing-transformers/blob/main/NN_project_Antonelli_DeSantis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [40]:
import torch
from torch import nn as nn
import numpy as np
from torch.nn import functional as F
from math import sqrt
import matplotlib.pyplot as plt
from torch.autograd import Variable
from pathlib import Path
from filelock import FileLock
import random
import tqdm
import gzip
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.utils import resample
from sklearn.model_selection import train_test_split

In [41]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# KNN Memory

In [42]:
!pip install faiss-gpu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [43]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [44]:
#import per la knn memory
import os
import math
import torch
import faiss
import numpy as np
from pathlib import Path
from functools import wraps

from contextlib import ExitStack, contextmanager

from einops import rearrange, pack, unpack

# multiprocessing

from joblib import Parallel, delayed, cpu_count

In [45]:
FAISS_INDEX_GPU_ID = int(os.getenv('FAISS_INDEX_GPU_ID', 0))

DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY = './.tmp/knn.memories'

# helper functions

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def cast_list(val):
    return val if isinstance(val, list) else [val]

def all_el_unique(arr):
    return len(set(arr)) == len(arr)

@contextmanager
def multi_context(*cms):
    with ExitStack() as stack:
        yield [stack.enter_context(cls) for cls in cms]

def count_intersect(x, y):
    # returns an array that shows how many times an element in x is contained in tensor y
    return np.sum(rearrange(x, 'i -> i 1') == rearrange(y, 'j -> 1 j'), axis = -1)

def check_shape(tensor, pattern, **kwargs):
    return rearrange(tensor, f"{pattern} -> {pattern}", **kwargs)

# a wrapper around faiss IndexIVFFlat
# taking care of expiring old keys automagically

class KNN():
    def __init__(
        self,
        dim,
        max_num_entries,
        cap_num_entries = False,
        M = 15,
        keep_stats = False
    ):
        index = faiss.IndexHNSWFlat(dim, M, faiss.METRIC_INNER_PRODUCT)
        self.index = index
        self.max_num_entries = max_num_entries
        self.cap_num_entries = cap_num_entries
        self.is_trained = False
        self.keep_stats = keep_stats

        self.reset()

    def __del__(self):
        if hasattr(self, 'index'):
            del self.index

    def reset(self):
        self.ids = np.empty((0,), dtype = np.int32)

        if self.keep_stats:
            self.hits = np.empty((0,), dtype = np.int32)
            self.age_num_iterations = np.empty((0,), dtype = np.int32)
            self.ages_since_last_hit = np.empty((0,), dtype = np.int32)

        self.index.reset()
        self.is_trained = False

    def train(self, x):
        self.index.train(x)
        self.is_trained = True

    def add(self, x, ids):
        if not self.is_trained:
            self.train(x)

        self.ids = np.concatenate((ids, self.ids))

        if self.keep_stats:
            self.hits = np.concatenate((np.zeros_like(ids), self.hits))
            self.age_num_iterations = np.concatenate((np.zeros_like(ids), self.age_num_iterations))
            self.ages_since_last_hit = np.concatenate((np.zeros_like(ids), self.ages_since_last_hit))

        if self.cap_num_entries and len(self.ids) > self.max_num_entries:
            self.reset()

        return self.index.add(x)

    def search(
        self,
        x,
        topk,
        nprobe = 8,
        return_distances = False,
        increment_hits = False,
        increment_age = True
    ):
        if not self.is_trained:
            return np.full((x.shape[0], topk), -1)

        distances, indices = self.index.search(x, k = topk)

        if increment_hits and self.keep_stats:
            hits = count_intersect(self.ids, rearrange(indices, '... -> (...)'))
            self.hits += hits

            self.ages_since_last_hit += 1
            self.ages_since_last_hit *= (hits == 0)

        if increment_age and self.keep_stats:
            self.age_num_iterations += 1

        if return_distances:
            return indices, distances

        return indices

# KNN memory layer, where one can store key / value memories
# can automatically take care of a collection of faiss indices (across batch dimension)

class KNNMemory():
    def __init__(
        self,
        dim,
        max_memories = 16000,
        num_indices = 1,
        memmap_filename = './knn.memory.memmap',
        multiprocessing = True
    ):
        self.dim = dim
        self.num_indices = num_indices
        self.scoped_indices = list(range(num_indices))

        self.max_memories = max_memories
        self.shape = (num_indices, max_memories, 2, dim)
        self.db_offsets = np.zeros(num_indices, dtype = np.int32)

        self.db = np.memmap(memmap_filename, mode = 'w+', dtype = np.float32, shape = self.shape)
        self.knns = [KNN(dim = dim, max_num_entries = max_memories, cap_num_entries = True) for _ in range(num_indices)]
    
        self.n_jobs = cpu_count() if multiprocessing else 1

    def set_scoped_indices(self, indices):
        indices = list(indices)
        assert all_el_unique(indices), f'all scoped batch indices must be unique, received: {indices}'
        assert all([0 <= i < self.num_indices for i in indices]), f'each batch index must be between 0 and less than {self.num_indices}: received {indices}'
        self.scoped_indices = indices

    @contextmanager
    def at_batch_indices(self, indices):
        prev_indices = self.scoped_indices
        self.set_scoped_indices(indices)
        yield self
        self.set_scoped_indices(prev_indices)

    def clear(self, batch_indices = None):
        if not exists(batch_indices):
            batch_indices = list(range(self.num_indices))

        batch_indices = cast_list(batch_indices)

        for index in batch_indices:
            knn = self.knns[index]
            knn.reset()

        self.db_offsets[batch_indices] = 0

    def add(self, memories):
        check_shape(memories, 'b n kv d', d = self.dim, kv = 2, b = len(self.scoped_indices))

        memories = memories.detach().cpu().numpy()
        memories = memories[:, -self.max_memories:]
        num_memories = memories.shape[1]

        knn_insert_ids = np.arange(num_memories)

        keys = np.ascontiguousarray(memories[..., 0, :])
        knns = [self.knns[i] for i in self.scoped_indices]
        db_offsets = [self.db_offsets[i] for i in self.scoped_indices]

        # use joblib to insert new key / value memories into faiss index

        @delayed
        def knn_add(knn, key, db_offset):
            knn.add(key, ids = knn_insert_ids + db_offset)
            return knn

        updated_knns = Parallel(n_jobs = self.n_jobs)(knn_add(*args) for args in zip(knns, keys, db_offsets))
        for knn_idx, scoped_idx in enumerate(self.scoped_indices):
            self.knns[scoped_idx] = updated_knns[knn_idx]

        # add the new memories to the memmap "database"

        add_indices = (rearrange(np.arange(num_memories), 'j -> 1 j') + rearrange(self.db_offsets[list(self.scoped_indices)], 'i -> i 1')) % self.max_memories
        self.db[rearrange(np.array(self.scoped_indices), 'i -> i 1'), add_indices] = memories
        self.db.flush()

        self.db_offsets += num_memories

    def search(
        self,
        queries,
        topk,
        nprobe = 8,
        increment_hits = True,
        increment_age = True
    ):
        check_shape(queries, 'b ... d', d = self.dim, b = len(self.scoped_indices))
        queries, ps = pack([queries], 'b * d')

        device = queries.device
        queries = queries.detach().cpu().numpy()

        all_masks = []
        all_key_values = []

        knns = [self.knns[i] for i in self.scoped_indices]

        # parallelize faiss search

        @delayed
        def knn_search(knn, query):
            return knn.search(query, topk, nprobe, increment_hits = increment_hits, increment_age = increment_age)

        fetched_indices = Parallel(n_jobs = self.n_jobs)(knn_search(*args) for args in zip(knns, queries))

        # get all the memory key / values from memmap 'database'
        # todo - remove for loop below

        for batch_index, indices in zip(self.scoped_indices, fetched_indices):
            mask = indices !=  -1
            db_indices = np.where(mask, indices, 0)

            all_masks.append(torch.from_numpy(mask))

            key_values = self.db[batch_index, db_indices % self.max_memories]
            all_key_values.append(torch.from_numpy(key_values))

        all_masks = torch.stack(all_masks)
        all_key_values = torch.stack(all_key_values)
        all_key_values = all_key_values.masked_fill(~rearrange(all_masks, '... -> ... 1 1'), 0.)

        all_key_values, = unpack(all_key_values, ps, 'b * n kv d')
        all_masks, = unpack(all_masks, ps, 'b * n')

        return all_key_values.to(device), all_masks.to(device)

    def __del__(self):
        if hasattr(self, 'knns'):
            for knn in self.knns:
                del knn
        del self.db

# extends list with some extra methods for collections of KNN memories

class KNNMemoryList(list):
    def cleanup(self):
        for memory in self:
            del memory

    @classmethod
    def create_memories(
        self,
        *,
        batch_size,
        num_memory_layers,
        memories_directory = DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY
    ):
        memories_path = Path(memories_directory)
        memories_path.mkdir(exist_ok = True, parents = True)

        def inner(*args, **kwargs):
            return self([KNNMemory(*args, num_indices = batch_size, memmap_filename = str(memories_path / f'knn.memory.layer.{ind + 1}.memmap'), **kwargs) for ind in range(num_memory_layers)])
        return inner

    @contextmanager
    def at_batch_indices(
        self,
        indices
    ):
        knn_batch_indices_contexts = [memory.at_batch_indices(indices) for memory in self]
        with multi_context(*knn_batch_indices_contexts):
            yield

    def clear_memory(
        self,
        batch_indices = None,
        memory_indices = None
    ):
        memory_indices = default(memory_indices, tuple(range(len(self))))

        for memory_index in memory_indices:
            memory = self[memory_index]
            memory.clear(batch_indices)

# Memorizing transformers

In [46]:
def attention(query, key, value, sqrt_q, device):
    t = torch.matmul(query, key.transpose(-2, -1))/sqrt_q
    i, j = t.shape[-2:]
    mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
    return torch.matmul(F.softmax(t.masked_fill_(mask, -1e-9), dim = -1), value)

def KNNattention(query, key, value, sqrt_q, mask):
    t = torch.einsum('b h i d, b h i j d -> b h i j', query, key)/sqrt_q
    return torch.einsum('b h i j, b h i j d -> b h i d', F.softmax(t.masked_fill_(mask, -1e-9), dim = -1), value)

In [47]:
class MultiHeadAttention(nn.Module):
  def __init__(self, n, d, h, batch_size):
    super(MultiHeadAttention, self).__init__()
    assert d % h == 0
    #assume q = v 
    self.q = d // h
    self.sqrt_q = sqrt(self.q)
    self.h = h
    self.batch_size = batch_size
    self.W_q = nn.Linear(d, d, bias = False) #stack of h matrices of dimension (d, q), one for each head
    self.W_k = nn.Linear(d, d, bias = False)
    self.W_v = nn.Linear(d, d, bias = False)
    self.W_o = nn.Linear(d, d, bias = False)

  def forward(self, x, device):
    query = self.W_q(x).view(self.batch_size, -1, self.h, self.q).transpose(1, 2)
    key = self.W_k(x).view(self.batch_size, -1, self.h, self.q).transpose(1, 2)
    value = self.W_v(x).view(self.batch_size, -1, self.h, self.q).transpose(1, 2)
    new_memories = torch.stack((key, value), dim = -2).detach()
    attention_value = attention(query, key, value, self.sqrt_q, device)
    return self.W_o(attention_value.transpose(1, 2).contiguous().view(self.batch_size, -1, self.h*self.q)), new_memories

In [48]:
class KNNAttention(nn.Module):
   def __init__(self, n, d, h, num_retrieved_memories, batch_size):
      super(KNNAttention, self).__init__()
      assert d % h == 0
      #assume q = v 
      self.q = d // h
      self.sqrt_q = sqrt(self.q)
      self.h = h
      self.W_q = nn.Linear(d, d, bias = False)
      self.W_k = nn.Linear(d, d, bias = False)
      self.W_v = nn.Linear(d, d, bias = False)
      self.W_o = nn.Linear(d, d, bias = False)
      self.b_g = nn.Parameter(torch.randn((h,))) #one for each head
      self.num_retrieved_memories = num_retrieved_memories
      self.batch_size = batch_size

   def forward(self, x, knn_memory, device):
      # calculate local attention 
      query = self.W_q(x).view(self.batch_size, -1, self.h, self.q).transpose(1, 2)
      key = self.W_k(x).view(self.batch_size, -1, self.h, self.q).transpose(1, 2)
      value = self.W_v(x).view(self.batch_size, -1, self.h, self.q).transpose(1, 2)
      local_attention = attention(query, key, value, self.sqrt_q, device)

      # calculate knn attention over memory
      mem_kv, mem_mask = knn_memory[0].search(query, self.num_retrieved_memories)
      mem_key, mem_value = mem_kv.unbind(dim = -2)
      knn_attention = KNNattention(query, mem_key, mem_value, self.sqrt_q, ~mem_mask)

      # memory to be stored
      new_kv_memories = torch.stack((key, value), dim = -2).view(self.batch_size, -1, 2, self.q).detach()

      # add to knn memory
      if new_kv_memories.numel() > 0:
        knn_memory[0].add(new_kv_memories)

      # combining local and memory
      g = torch.sigmoid(self.b_g)
      final_attention = torch.einsum('b h n d, h -> b h n d', knn_attention, g) + \
                        torch.einsum('b h n d, h -> b h n d', local_attention, (1 - g))
      
      return self.W_o(final_attention.transpose(1, 2).contiguous().view(self.batch_size, -1, self.h*self.q)), new_kv_memories

In [49]:
class SubLayer(nn.Module):
  def __init__(self, d, dropout, hidden_size):
    super(SubLayer, self).__init__()
    self.norm = nn.LayerNorm(d)
    self.mlp = nn.Sequential(nn.Linear(d, hidden_size, bias = True), 
                             nn.ReLU(),
                             nn.Dropout(dropout),
                             nn.Linear(hidden_size, d, bias = True))

  def forward(self, x):
    return x + self.mlp(self.norm(x)) #residual connection and normalization

In [50]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len=5000):
    super(PositionalEncoding, self).__init__()
    
    # Compute the positional encodings once in log space.
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) *
                          -(math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)
    self.register_buffer('pe', pe)
      
  def forward(self, x):
    return x + Variable(self.pe[:, :x.size(1)], requires_grad=False)

In [51]:
class MemorizingTransformer(nn.Module):
    def __init__(
          self,
          num_tokens,
          d,
          heads = 8,
          depth = 10,
          knn_attn_idx = 2,
          attn_dropout = 0.,
          hidden_size = 1000,
          dropout = 0.3,
          max_knn_memories = 1000,
          num_retrieved_memories = 8,
          knn_memories_directory = DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY,
          knn_memory_multiprocessing = False,
          batch_size = 16
      ):
          # asserts
          assert d % heads == 0
          assert knn_attn_idx < depth

          super(MemorizingTransformer, self).__init__()
          self.token_emb = nn.Embedding(num_tokens, d)
          self.positional_emb = PositionalEncoding(d, max_len = 5000)
          self.dim_head = d // heads
          self.d = d
          self.heads = heads
          self.knn_attn_idx = knn_attn_idx
          self.depth = depth
          self.attn_dropout = attn_dropout
          self.hidden_size = hidden_size
          self.dropout = dropout
          self.max_knn_memories = max_knn_memories
          self.num_retrieved_memories = num_retrieved_memories
          self.knn_memories_directory = knn_memories_directory
          self.knn_memory_multiprocessing =knn_memory_multiprocessing
          self.batch_size = batch_size

          self.layers = nn.ModuleList([])
          for idx in range(depth):
              attn = KNNAttention(num_tokens, d, heads, num_retrieved_memories, self.batch_size) \
                  if idx == knn_attn_idx else MultiHeadAttention(num_tokens, d, heads, self.batch_size)

              self.layers.append(nn.ModuleList([
                  attn,
                  SubLayer(d, dropout, hidden_size)
              ]))

          self.to_out = nn.Sequential(
               nn.LayerNorm(d),
               nn.Linear(d, num_tokens)
          )

          # knn memories init

          self.knn_mem_kwargs = dict(
              dim = self.dim_head,
              max_memories = self.max_knn_memories,
              multiprocessing = knn_memory_multiprocessing
          )
          
    def forward(
        self,
        x,
        knn_memory
    ):
        batch_size, seq_len, *_, device = *x.shape, x.device
        x = self.token_emb(x)
        x = self.positional_emb(x)

        for idx, (attn, sub_l) in enumerate(self.layers):
            
            #attention

            x, mem = attn(x, knn_memory, device) if self.knn_attn_idx == idx else attn(x, device)
      
            # normalization + feedforward + residual connection

            x = sub_l(x)

        return self.to_out(x).transpose(1, 2)

    
    def create_knn_memories(
          self,
          *,
          batch_size
      ):  
          return KNNMemoryList.create_memories(
              batch_size = batch_size,
              num_memory_layers = 1,
              memories_directory = self.knn_memories_directory
          )(**self.knn_mem_kwargs)
      
    @contextmanager
    def knn_memories_context(
        self,
        **kwargs
    ):
        knn_dir = Path(self.knn_memories_directory)
        knn_dir.mkdir(exist_ok = True, parents = True)
        lock = FileLock(str(knn_dir / 'mutex'))

        with lock:
            knn_memories = self.create_knn_memories(**kwargs)
            yield knn_memories
            knn_memories.cleanup()

    def clear_memory(self, x, token_id):
        """ clears the KNN memories based on if the batch row contains the specified token id """
        """ for auto-clearing KNN memories based on start and end of strings """

        clear_memory = (x == token_id).any(dim = -1)
        batch_indices, _ = clear_memory.nonzero(as_tuple = True)
        batch_indices_to_clear = batch_indices.tolist()

        if len(batch_indices_to_clear) == 0:
            return

        knn_memories.clear_memory(batch_indices_to_clear)

# Training

In [52]:
# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 16
SEQ_LEN = 512
SEGMENTS = 5
HEADS = 8
DIM_HEAD = SEQ_LEN // HEADS

LEARNING_RATE = 2e-4
MAX_GRAD_CLIP_NORM = 0.5

EVAL_EVERY = 20
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
CHECKPOINT = 10

In [53]:
model = MemorizingTransformer(
    num_tokens = 256,
    d = SEQ_LEN,
    heads = HEADS,
    batch_size = BATCH_SIZE
).cuda()

# prepare enwik8 data

#Lorenzo
with gzip.open('/content/drive/MyDrive/Secondo Anno/Neural Networks/project/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    print(X.shape)
    # number of samples to take
    n_samples = math.ceil(0.6*X.shape[0])
    # take the set uniformly at random
    data = resample(X, n_samples=n_samples, replace=False)
    trX, vaX = train_test_split(data, test_size=math.ceil(0.2*data.shape[0]))
    # trX, vaX = np.split(X, [int(90e6)])
    print(trX.shape)
    print(vaX.shape)
    # assert False
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
    # data = torch.from_numpy(X)
"""

#Luigi
with gzip.open('/content/drive/MyDrive/Colab Notebooks/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
"""

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# # dataset and dataloader
# dataset = TextSamplerDataset(data, SEQ_LEN)
# # test_dataset = TextSamplerDataset(data_val, SEQ_LEN)

# data_size = dataset.__len__()
# # data_test_size = test_dataset.__len__()

# perc_data = 0.3
# valid_size=0.2
# indices = list(range(data_size))
# np.random.shuffle(indices)
# data_size = int(np.floor(data_size * 0.3))
# print(data_size)
# indices = indices[:data_size]

# split = int(np.floor(valid_size * data_size))
# train_idx, valid_idx = indices[split:], indices[:split]

# # define samplers for obtaining training and validation batches
# train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
# valid_sampler = torch.utils.data.SubsetRandomSampler(valid_idx)

# train_loader  = DataLoader(dataset, batch_size = BATCH_SIZE, sampler = train_sampler, drop_last = True)

# test_loader = DataLoader(dataset, batch_size = BATCH_SIZE, sampler =valid_sampler, drop_last = True)


train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
train_loader  = DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True)
test_dataset = TextSamplerDataset(data_val, SEQ_LEN)
test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE, drop_last = True)

  X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)


(95000000,)
(45600000,)
(11400000,)


In [54]:
def print_string(a):
  seq = ""
  for word in a:
    for letter in word:
      seq += chr(letter)
    seq += " "
  return seq

In [55]:
# optimizer

optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
loss = nn.CrossEntropyLoss()

# training

perplexity_list = []
for i, data in enumerate(tqdm.tqdm(train_loader, desc = 'training')):
    model.train()

    train_loss = 0.
    with model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
        
        seq, labels = data[:, :-1], data[:, 1:] #the labels are the same sequences shifted by one

        out = model(
              seq,
              knn_memory = knn_memories
        )
        #loss_item = torch.exp(loss(out, labels)) #perplexity
        loss_item = loss(out, labels)
        train_loss += loss_item
        loss_item.backward() 

    print(f'training loss: {train_loss}', flush = True)
    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_CLIP_NORM)
    optimizer.step()
    optimizer.zero_grad()

    if i % EVAL_EVERY == 0:
        model.eval()
      
        test_data = None
        for test_data in test_loader:
          break

        test_loss = 0.

        with torch.no_grad(), model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories: 
            seq, labels = data[:, :-1], data[:, 1:]
            
            out = model(
              seq,
              knn_memory = knn_memories
            )

            loss_item = loss(out, labels)
            test_loss +=  loss_item
            

        print(f'valid loss: {test_loss}', flush = True)
        print(f'perplexity: {torch.exp(test_loss)}', flush = True)
        perplexity_list.append(torch.exp(test_loss).cpu())
        #Lorenzo
        with open('/content/drive/MyDrive/Università/Magistrale/Secondo Anno/Neural Networks/project/perplexity.npy', 'wb') as f:
          np.save(f, np.array(perplexity_list))
    
    if i % CHECKPOINT == 0:
      torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
      }, 'model_optimizer2.pt')
      #Lorenzo
      with open('/content/drive/MyDrive/Università/Magistrale/Secondo Anno/Neural Networks/project/perplexity.npy', 'wb') as f:
        np.save(f, np.array(perplexity_list))

training:   0%|          | 0/5566 [00:00<?, ?it/s]

training loss: 5.583859920501709
valid loss: 4.685384750366211
perplexity: 108.3519515991211


training:   0%|          | 1/5566 [00:03<4:43:07,  3.05s/it]

training loss: 4.761495113372803


training:   0%|          | 2/5566 [00:04<3:17:26,  2.13s/it]

training loss: 4.342834949493408


training:   0%|          | 3/5566 [00:05<2:37:04,  1.69s/it]

training loss: 4.1457061767578125


training:   0%|          | 4/5566 [00:06<2:17:44,  1.49s/it]

training loss: 3.997325897216797


training:   0%|          | 5/5566 [00:08<2:21:39,  1.53s/it]

training loss: 3.9089677333831787


training:   0%|          | 6/5566 [00:10<2:32:55,  1.65s/it]

training loss: 3.848611354827881


training:   0%|          | 7/5566 [00:12<2:35:19,  1.68s/it]

training loss: 3.8183159828186035


training:   0%|          | 8/5566 [00:14<3:06:42,  2.02s/it]

training loss: 3.7277345657348633


training:   0%|          | 9/5566 [00:16<3:10:09,  2.05s/it]

training loss: 3.7073957920074463


training:   0%|          | 10/5566 [00:18<3:00:42,  1.95s/it]

training loss: 3.6881093978881836


training:   0%|          | 11/5566 [00:20<2:48:18,  1.82s/it]

training loss: 3.66084885597229


training:   0%|          | 12/5566 [00:21<2:31:35,  1.64s/it]

training loss: 3.6330738067626953


training:   0%|          | 13/5566 [00:22<2:19:32,  1.51s/it]

training loss: 3.6036295890808105


training:   0%|          | 14/5566 [00:23<2:10:23,  1.41s/it]

training loss: 3.6230766773223877


training:   0%|          | 15/5566 [00:25<2:04:02,  1.34s/it]

training loss: 3.6110241413116455


training:   0%|          | 16/5566 [00:26<1:59:33,  1.29s/it]

training loss: 3.5530009269714355


training:   0%|          | 17/5566 [00:27<1:56:26,  1.26s/it]

training loss: 3.57905912399292


training:   0%|          | 18/5566 [00:28<1:53:55,  1.23s/it]

training loss: 3.5554442405700684


training:   0%|          | 19/5566 [00:29<1:52:16,  1.21s/it]

training loss: 3.558539867401123


training:   0%|          | 20/5566 [00:30<1:51:22,  1.20s/it]

training loss: 3.572993755340576
valid loss: 3.5608463287353516
perplexity: 35.19296646118164


training:   0%|          | 21/5566 [00:33<2:33:57,  1.67s/it]

training loss: 3.575486898422241


training:   0%|          | 22/5566 [00:34<2:21:15,  1.53s/it]

training loss: 3.546755313873291


training:   0%|          | 23/5566 [00:36<2:11:38,  1.42s/it]

training loss: 3.582848310470581


training:   0%|          | 24/5566 [00:37<2:06:41,  1.37s/it]

training loss: 3.5360851287841797


training:   0%|          | 25/5566 [00:38<2:02:25,  1.33s/it]

training loss: 3.5344812870025635


training:   0%|          | 26/5566 [00:39<2:00:04,  1.30s/it]

training loss: 3.5317583084106445


training:   0%|          | 27/5566 [00:40<1:56:46,  1.26s/it]

training loss: 3.51768159866333


training:   1%|          | 28/5566 [00:42<1:53:58,  1.23s/it]

training loss: 3.529911994934082


training:   1%|          | 29/5566 [00:43<1:52:15,  1.22s/it]

training loss: 3.5371429920196533


training:   1%|          | 30/5566 [00:44<1:51:06,  1.20s/it]

training loss: 3.532968521118164


training:   1%|          | 31/5566 [00:45<1:56:06,  1.26s/it]

training loss: 3.5428149700164795


training:   1%|          | 32/5566 [00:47<1:54:15,  1.24s/it]

training loss: 3.5460057258605957


training:   1%|          | 33/5566 [00:48<1:52:41,  1.22s/it]

training loss: 3.5042781829833984


training:   1%|          | 34/5566 [00:49<1:52:15,  1.22s/it]

training loss: 3.5222489833831787


training:   1%|          | 35/5566 [00:50<1:51:29,  1.21s/it]

training loss: 3.543365001678467


training:   1%|          | 36/5566 [00:51<1:51:15,  1.21s/it]

training loss: 3.516097068786621


training:   1%|          | 37/5566 [00:52<1:50:47,  1.20s/it]

training loss: 3.5341320037841797


training:   1%|          | 38/5566 [00:54<1:50:26,  1.20s/it]

training loss: 3.54929518699646


training:   1%|          | 39/5566 [00:55<1:50:02,  1.19s/it]

training loss: 3.550807476043701


training:   1%|          | 40/5566 [00:56<1:49:56,  1.19s/it]

training loss: 3.542736530303955
valid loss: 3.538811445236206
perplexity: 34.42597961425781


training:   1%|          | 41/5566 [00:59<2:25:47,  1.58s/it]

training loss: 3.5230836868286133


training:   1%|          | 42/5566 [01:00<2:22:29,  1.55s/it]

training loss: 3.5214385986328125


training:   1%|          | 43/5566 [01:01<2:12:54,  1.44s/it]

training loss: 3.5257785320281982


training:   1%|          | 44/5566 [01:02<2:05:25,  1.36s/it]

training loss: 3.5382394790649414


training:   1%|          | 45/5566 [01:04<2:00:12,  1.31s/it]

training loss: 3.5578837394714355


training:   1%|          | 46/5566 [01:05<1:56:48,  1.27s/it]

training loss: 3.5271456241607666


training:   1%|          | 47/5566 [01:06<1:54:22,  1.24s/it]

training loss: 3.5218896865844727


training:   1%|          | 48/5566 [01:07<1:52:29,  1.22s/it]

training loss: 3.5510761737823486


training:   1%|          | 49/5566 [01:08<1:51:48,  1.22s/it]

training loss: 3.498243808746338


training:   1%|          | 50/5566 [01:10<1:51:40,  1.21s/it]

training loss: 3.511568546295166


training:   1%|          | 51/5566 [01:11<1:56:50,  1.27s/it]

training loss: 3.5203421115875244


training:   1%|          | 52/5566 [01:12<1:55:07,  1.25s/it]

training loss: 3.518507957458496


training:   1%|          | 53/5566 [01:13<1:53:13,  1.23s/it]

training loss: 3.5314698219299316


training:   1%|          | 54/5566 [01:15<1:52:12,  1.22s/it]

training loss: 3.4933252334594727


training:   1%|          | 55/5566 [01:16<1:55:10,  1.25s/it]

training loss: 3.5192105770111084


training:   1%|          | 56/5566 [01:17<1:55:11,  1.25s/it]

training loss: 3.519406318664551


training:   1%|          | 57/5566 [01:18<1:53:56,  1.24s/it]

training loss: 3.5255396366119385


training:   1%|          | 58/5566 [01:20<1:54:42,  1.25s/it]

training loss: 3.525297164916992


training:   1%|          | 59/5566 [01:22<2:14:13,  1.46s/it]

training loss: 3.5337727069854736


training:   1%|          | 60/5566 [01:24<2:34:33,  1.68s/it]

training loss: 3.522975444793701
valid loss: 3.5206260681152344
perplexity: 33.80558395385742


training:   1%|          | 61/5566 [01:27<3:19:02,  2.17s/it]

training loss: 3.5432262420654297


training:   1%|          | 62/5566 [01:28<2:53:55,  1.90s/it]

training loss: 3.543565273284912


training:   1%|          | 63/5566 [01:29<2:34:16,  1.68s/it]

training loss: 3.5243663787841797


training:   1%|          | 64/5566 [01:31<2:21:00,  1.54s/it]

training loss: 3.5358896255493164


training:   1%|          | 65/5566 [01:32<2:22:20,  1.55s/it]

training loss: 3.5102109909057617


training:   1%|          | 66/5566 [01:34<2:17:56,  1.50s/it]

training loss: 3.4920716285705566


training:   1%|          | 67/5566 [01:35<2:08:50,  1.41s/it]

training loss: 3.5212621688842773


training:   1%|          | 68/5566 [01:36<2:02:13,  1.33s/it]

training loss: 3.551018476486206


training:   1%|          | 69/5566 [01:37<1:58:05,  1.29s/it]

training loss: 3.5422005653381348


training:   1%|▏         | 70/5566 [01:38<1:55:26,  1.26s/it]

training loss: 3.5328776836395264


training:   1%|▏         | 71/5566 [01:40<2:00:53,  1.32s/it]

training loss: 3.4938557147979736


training:   1%|▏         | 72/5566 [01:41<1:57:27,  1.28s/it]

training loss: 3.527597427368164


training:   1%|▏         | 73/5566 [01:42<1:54:26,  1.25s/it]

training loss: 3.5164458751678467


training:   1%|▏         | 74/5566 [01:43<1:52:37,  1.23s/it]

training loss: 3.5559284687042236


training:   1%|▏         | 75/5566 [01:45<1:50:56,  1.21s/it]

training loss: 3.516507625579834


training:   1%|▏         | 76/5566 [01:46<1:49:57,  1.20s/it]

training loss: 3.5478367805480957


training:   1%|▏         | 77/5566 [01:47<1:49:08,  1.19s/it]

training loss: 3.5155773162841797


training:   1%|▏         | 78/5566 [01:48<1:49:00,  1.19s/it]

training loss: 3.513805866241455


training:   1%|▏         | 79/5566 [01:49<1:48:40,  1.19s/it]

training loss: 3.5315678119659424


training:   1%|▏         | 80/5566 [01:50<1:48:21,  1.19s/it]

training loss: 3.5278818607330322
valid loss: 3.523977041244507
perplexity: 33.91905975341797


training:   1%|▏         | 81/5566 [01:53<2:23:33,  1.57s/it]

training loss: 3.520968198776245


training:   1%|▏         | 82/5566 [01:54<2:14:16,  1.47s/it]

training loss: 3.541003942489624


training:   1%|▏         | 83/5566 [01:55<2:06:20,  1.38s/it]

training loss: 3.5459842681884766


training:   2%|▏         | 84/5566 [01:57<2:00:47,  1.32s/it]

training loss: 3.4964165687561035


training:   2%|▏         | 85/5566 [01:58<1:56:53,  1.28s/it]

training loss: 3.50215744972229


training:   2%|▏         | 86/5566 [01:59<1:54:14,  1.25s/it]

training loss: 3.531989336013794


training:   2%|▏         | 87/5566 [02:00<1:52:23,  1.23s/it]

training loss: 3.5176868438720703


training:   2%|▏         | 88/5566 [02:01<1:51:10,  1.22s/it]

training loss: 3.5254030227661133


training:   2%|▏         | 89/5566 [02:02<1:50:10,  1.21s/it]

training loss: 3.5072922706604004


training:   2%|▏         | 90/5566 [02:04<1:49:18,  1.20s/it]

training loss: 3.505013942718506


training:   2%|▏         | 91/5566 [02:05<1:55:03,  1.26s/it]

training loss: 3.4944331645965576


training:   2%|▏         | 92/5566 [02:06<1:53:20,  1.24s/it]

training loss: 3.538832187652588


training:   2%|▏         | 93/5566 [02:07<1:51:36,  1.22s/it]

training loss: 3.5165345668792725


training:   2%|▏         | 94/5566 [02:09<1:50:50,  1.22s/it]

training loss: 3.521580934524536


training:   2%|▏         | 95/5566 [02:10<1:50:21,  1.21s/it]

training loss: 3.533860683441162


training:   2%|▏         | 96/5566 [02:11<1:49:46,  1.20s/it]

training loss: 3.520343065261841


training:   2%|▏         | 97/5566 [02:12<1:49:51,  1.21s/it]

training loss: 3.5520544052124023


training:   2%|▏         | 98/5566 [02:14<2:05:10,  1.37s/it]


KeyboardInterrupt: ignored

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(
          self,
          num_tokens,
          d,
          heads = 8,
          depth = 4,
          hidden_size = 1000,
          dropout = 0.3,
          batch_size = 16
      ):
          # asserts
          assert d % heads == 0

          super(TransformerDecoder, self).__init__()
          self.token_emb = nn.Embedding(num_tokens, d)
          self.positional_emb = PositionalEncoding(d, max_len = 5000)
          self.dim_head = d // heads
          self.d = d
          self.heads = heads
          self.depth = depth
          self.hidden_size = hidden_size
          self.dropout = dropout
          self.batch_size = batch_size

          self.layers = nn.ModuleList([])
          for idx in range(depth):
              attn = MultiHeadAttention(num_tokens, d, heads, self.batch_size)

              self.layers.append(nn.ModuleList([
                  attn,
                  SubLayer(d, dropout, hidden_size)
              ]))

          self.to_out = nn.Sequential(
               nn.LayerNorm(d),
               nn.Linear(d, num_tokens)
          )
          
    def forward(
        self,
        x
    ):
        batch_size, seq_len, *_, device = *x.shape, x.device
        x = self.token_emb(x)
        x = self.positional_emb(x)

        for idx, (attn, sub_l) in enumerate(self.layers):
            
            #attention
            x, mem = attn(x, device)
      
            # normalization + feedforward + residual connection
            x = sub_l(x)

        return self.to_out(x).transpose(1, 2)