# PEER

In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer
import math
from einops import einsum
from tqdm import tqdm 
from einops.layers.torch import Rearrange
from torch.nn.utils import clip_grad_norm_
from torch.cuda.amp import autocast, GradScaler
from torch import distributed as dist
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import time
import numpy as np
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"


In [33]:
from datasets import load_dataset
from torch.utils.data import Dataset

class PileDataset(Dataset):
    def __init__(self, tokenizer, dataset_name='c4', split='train', max_length=512, num_samples=150_000):
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Load dataset depending on name
        if dataset_name == 'c4':
            self.data = load_dataset("c4", "en", split=split, trust_remote_code=True)

        elif dataset_name == 'wikitext':
            self.data = load_dataset("wikitext", "wikitext-103-raw-v1", split=split)
        else:
            raise ValueError(f"Unsupported dataset: {dataset_name}")

        # Filter out empty text entries
        self.data = self.data.filter(lambda x: x and len(x['text']) > 0)

        # Limit the dataset size if needed
        self.data = self.data.select(range(min(num_samples, len(self.data))))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]['text']
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        return input_ids, attention_mask

        

In [34]:
def exists(v):
    return v is not None

def default(v, d):
    return v if exists(v) else d

class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return F.normalize(x, dim=-1) * self.scale * self.gamma

# ProductKeyMemory class implements the key lookup used in product-key memory
class ProductKeyMemory(nn.Module):
    def __init__(self, dim, num_keys):
        super().__init__()
        self.dim = dim
        self.num_keys = num_keys
        self.keys = nn.Parameter(torch.randn(num_keys, dim // 2))  # sub-keys of dimensionality d/2

    def forward(self, query):
        query = query.view(query.shape[0], 2, -1)  # split query into two sub-queries
        dots = torch.einsum('bkd,nd->bkn', query, self.keys)  # compute dot product with sub-keys
        return dots.view(query.shape[0], -1)

def topk_schedule(step, warmup, min_k=4, max_k=16):
    if step < warmup:
        return min_k + (max_k - min_k) * step // warmup
    return max_k

def temperature_schedule(step, warmup, start_temp=2.0, end_temp=0.5):
    if step >= warmup:
        return end_temp
    return start_temp - (start_temp - end_temp) * step / warmup


In [35]:
# PEER model module, supports multi-head product-key retrieval and singleton MLP experts
class PEER(nn.Module):
    def __init__(
        self,
        dim,
        *,
        heads=8,
        num_experts=1_000_000, #max
        num_experts_per_head=16,
        activation=nn.GELU,
        dim_key=None,
        product_key_topk=None,
        product_key_topk_schedule=None,
        temperature_schedule=None,
        separate_embed_per_head=False,
        pre_rmsnorm=False,
        dropout=0.
    ):
        super().__init__()

        self.norm = RMSNorm(dim) if pre_rmsnorm else nn.Identity()  # optional input norm
        self.heads = heads
        self.separate_embed_per_head = separate_embed_per_head
        self.num_experts = num_experts

        num_expert_sets = heads if separate_embed_per_head else 1

        # expert weights for down and up projection (singleton MLPs)
        self.weight_down_embed = nn.Embedding(num_experts * num_expert_sets, dim)
        self.weight_up_embed = nn.Embedding(num_experts * num_expert_sets, dim)

        self.activation = activation()

        assert (num_experts ** 0.5).is_integer(), '`num_experts` needs to be a square'
        assert (dim % 2) == 0, 'feature dimension should be divisible by 2'

        dim_key = default(dim_key, dim // 2)
        self.num_keys = int(num_experts ** 0.5)

        # Linear projection to queries for PKM routing (split into two parts)
        self.to_queries = nn.Sequential(
            nn.Linear(dim, dim_key * heads * 2, bias=False),
            Rearrange('b n (p h d) -> p b n h d', p=2, h=heads)
        )

        self.product_key_topk = default(product_key_topk, num_experts_per_head)
        self.num_experts_per_head = num_experts_per_head

        # Key tensor: shape (heads, sqrt(N), 2, d/2)
        self.keys = nn.Parameter(torch.randn(heads, self.num_keys, 2, dim_key))

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('step', torch.tensor(0, dtype=torch.long))  # training step counter

        self.product_key_topk_schedule = product_key_topk_schedule
        self.temperature_schedule = temperature_schedule
        self.expert_norm = nn.LayerNorm(dim)

    def forward(self, x, attention_mask=None):
        x = self.norm(x)  # optional normalization
        queries = self.to_queries(x)  # split into 2 queries, shape: (2, B, N, H, D)

        if attention_mask is not None:
            expanded_mask = attention_mask[:, :, None, None]
            expanded_mask = expanded_mask.permute(2, 0, 1, 3)
            queries = queries * expanded_mask  # apply attention mask to query

        # compute similarity between queries and keys
        sim = einsum(queries, self.keys, 'p b n h d, h k p d -> p b n h k')

        # top-k per subkey dimension
        current_topk = self.product_key_topk_schedule(self.step.item()) if self.product_key_topk_schedule else self.product_key_topk
        (scores_x, scores_y), (indices_x, indices_y) = [s.topk(current_topk, dim=-1) for s in sim]

        # form Cartesian product of topk results
        all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
        all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2)

        # flatten candidate scores and indices
        all_scores = all_scores.view(*all_scores.shape[:-2], -1)
        all_indices = all_indices.view(*all_indices.shape[:-2], -1)

        # select top-k from Cartesian product set
        scores, pk_indices = all_scores.topk(self.num_experts_per_head, dim=-1)
        indices = all_indices.gather(-1, pk_indices)

        if self.separate_embed_per_head:
            head_expert_offsets = torch.arange(self.heads, device=x.device) * self.num_experts
            indices = indices + head_expert_offsets.view(1, 1, -1, 1)

        # retrieve expert weights
        weights_down = self.weight_down_embed(pk_indices)
        weights_up = self.weight_up_embed(pk_indices)

        # apply down projection and activation
        x = einsum(x, weights_down, 'b n d, b n h k d -> b n h k d')
        x = self.activation(x)

        # normalize expert outputs
        x = self.expert_norm(x.reshape(-1, x.shape[-1])).reshape_as(x)
        x = self.dropout(x)

        # temperature annealed softmax for routing scores
        current_temperature = self.temperature_schedule(self.step.item()) if self.temperature_schedule else 1.0
        score_weights = F.softmax(scores / current_temperature, dim=-1)
        score_weights = torch.nan_to_num(score_weights, nan=0.0, posinf=1.0, neginf=0.0)

        x = x * score_weights.unsqueeze(-1)  # apply score weights

        # final up projection
        x = einsum(x, weights_up, 'b n h k d, b n h k d -> b n d')

        self.step += 1
        return x


# TransformerBlock integrates attention + two PEER feedforward modules
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, num_experts, num_experts_per_head, dropout=0.1):
        super(TransformerBlock, self).__init__()

        # Standard multi-head self-attention
        self.attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        # First PEER feedforward layer
        self.peer1 = PEER(dim, heads=num_heads, 
                          num_experts=num_experts, 
                          num_experts_per_head=num_experts_per_head,
                          # product_key_topk_schedule=topk_schedule,
                          # temperature_schedule=temperature_schedule
                         )

        # Second PEER layer applied after GELU (stacked expert processing)
        self.peer2 = PEER(dim, heads=num_heads, 
                          num_experts=num_experts, 
                          num_experts_per_head=num_experts_per_head,
                          # product_key_topk_schedule=topk_schedule,
                          # temperature_schedule=temperature_schedule
                         )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attention_mask=None):
        # Standard attention followed by residual connection and norm
        attn_output, _ = self.attention(x, x, x, key_padding_mask=~attention_mask.bool() if attention_mask is not None else None)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)

        # First PEER expert layer
        peer_output1 = self.peer1(x)
        # Second PEER expert layer with GELU activation in between
        peer_output2 = self.peer2(F.gelu(peer_output1))

        # Residual connection and norm
        x = x + self.dropout(peer_output2)
        x = self.norm2(x)

        return x

# Full PEER-based language model
class PEERLanguageModel(nn.Module):
    def __init__(self, vocab_size, dim, num_layers, num_heads, num_experts, num_experts_per_head):
        super().__init__()

        # Embedding layers for tokens and positions
        self.token_embedding = nn.Embedding(vocab_size, dim)
        self.position_embedding = nn.Embedding(512, dim)  # fixed positional encoding

        # Stack of TransformerBlocks with PEER
        self.layers = nn.ModuleList([
            TransformerBlock(dim, num_heads, num_experts, num_experts_per_head) for _ in range(num_layers)
        ])

        self.layer_norm = nn.LayerNorm(dim)

        # Final LM head
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)

    def forward(self, x, attention_mask=None):
        b, s = x.shape
        positions = torch.arange(s, device=x.device).unsqueeze(0).expand(b, s)

        # Embed tokens and positions
        x = self.token_embedding(x) + self.position_embedding(positions)

        # Pass through each transformer block
        for layer in self.layers:
            x = layer(x, attention_mask=attention_mask)

        x = self.layer_norm(x)

        # Project to vocabulary logits
        logits = self.lm_head(x)
        return logits


In [37]:

def train(model, train_loader, optimizer, device, max_batches=None, show_progress=True, grad_accum_steps=1):
    model.train()
    total_loss = 0
    batch_losses = []

    scaler = GradScaler()
    loss_fn = torch.nn.CrossEntropyLoss()

    loop = enumerate(train_loader)
    if show_progress:
        loop = tqdm(loop, total=min(len(train_loader), max_batches or len(train_loader)),
                    desc="Training", dynamic_ncols=True, unit="batch")

    optimizer.zero_grad()

    for batch_idx, (batch) in loop:
        if max_batches is not None and batch_idx >= max_batches:
            break
        input_ids, attention_mask = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = input_ids.clone()

        with autocast():  # mixed precision context
            logits = model(input_ids, attention_mask=attention_mask)
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
            loss = loss / grad_accum_steps

        scaler.scale(loss).backward()

        # Step optimizer every grad_accum_steps
        if (batch_idx + 1) % grad_accum_steps == 0 or (max_batches and (batch_idx + 1) == max_batches):
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        batch_loss = loss.item() * grad_accum_steps  # undo scaling for logging
        batch_losses.append(batch_loss)
        total_loss += batch_loss

        if show_progress:
            loop.set_postfix(loss=f"{batch_loss:.4f}")

    avg_loss = total_loss / max(len(batch_losses), 1)
    return avg_loss, batch_losses
    

def validate(model, val_loader, device, max_batches=None, show_progress=True):
    model.eval()
    total_loss = 0
    total_tokens = 0
    batch_losses = []

    loop = val_loader
    if show_progress:
        loop = tqdm(loop, total=min(len(val_loader), max_batches or len(val_loader)),
                    desc="Validation", dynamic_ncols=True, unit="batch")

    with torch.no_grad():
        for batch_idx, (input_ids, attention_mask) in enumerate(loop):
            if max_batches is not None and batch_idx >= max_batches:
                break

            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)

            with autocast():  # AMP context
                outputs = model(input_ids, attention_mask=attention_mask)
                loss = F.cross_entropy(
                    outputs.view(-1, outputs.size(-1)),
                    input_ids.view(-1),
                    ignore_index=0, reduction='sum'
                )

            token_count = (input_ids != 0).sum().item()
            total_loss += loss.item()
            total_tokens += token_count
            batch_losses.append(loss.item() / token_count)

    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    return avg_loss, perplexity, batch_losses


In [38]:
import GPUtil

if torch.cuda.is_available():
    print('GPU is available')
    print('__CUDNN VERSION:', torch.backends.cudnn.version())
    print('__Number CUDA Devices:', torch.cuda.device_count())
    print('__CUDA Device Name:',torch.cuda.get_device_name(0))
    print('__CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)
else:
    print('GPU is not available')

GPUtil.getAvailable()

GPU is available
__CUDNN VERSION: 90100
__Number CUDA Devices: 1
__CUDA Device Name: NVIDIA GeForce RTX 4080
__CUDA Device Total Memory [GB]: 16.747200512


[]

In [39]:
def smooth_curve(data, window=100):
    """Simple moving average."""
    data = np.array(data)
    if len(data) < window:
        return data
    return np.convolve(data, np.ones(window) / window, mode='valid')

def format_k(x, _):
    """Tick format for '10k' etc."""
    return f'{int(x/1000)}k' if x >= 1000 else str(int(x))


def plot_epoch_metrics(train_losses, val_losses, val_perplexities, save_dir, dataset_name='unknown'):
    os.makedirs(save_dir, exist_ok=True)
    epochs = list(range(1, len(train_losses) + 1))

    # Loss Plot
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, train_losses, label='Train Loss', marker='o')
    plt.plot(epochs, val_losses, label='Val Loss', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f"Epoch Loss — {dataset_name}")
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.4)
    plt.savefig(os.path.join(save_dir, f'epoch_loss_{dataset_name}.png'))
    plt.close()

    # Perplexity Plot
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, val_perplexities, label='Val Perplexity', marker='o', color='purple')
    plt.xlabel('Epoch')
    plt.ylabel('Perplexity')
    plt.title(f"Epoch Perplexity — {dataset_name}")
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.4)
    plt.savefig(os.path.join(save_dir, f'epoch_perplexity_{dataset_name}.png'))
    plt.close()

    
def is_distributed_env():
    return (
        dist.is_available()
        and dist.is_nccl_available()
        and "RANK" in os.environ
        and "LOCAL_RANK" in os.environ
        and "WORLD_SIZE" in os.environ
    )

if is_distributed_env():
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
else:
    print("Running in single GPU or notebook mode")
    local_rank = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Running in single GPU or notebook mode


In [44]:
#   Hyperparameters + Configs
vocab_size = 50257
dim = 256
num_layers = 4
num_heads = 4
num_experts = 128*128
num_experts_per_head = 8
batch_size = 4
grad_accum_steps = 4  
num_epochs = 20
learning_rate = 1e-4

enable_debug_mode = True
subset_size = 50000       
max_batches = 5000 if enable_debug_mode else None 

#   Model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = PEERLanguageModel(vocab_size, dim, num_layers, num_heads, num_experts, num_experts_per_head).to(device)

#   Wrap model in DDP (if applicable)
if dist.is_initialized():
    local_rank = int(os.environ["LOCAL_RANK"])
    model = DDP(model, device_ids=[local_rank], output_device=local_rank)
else:
    print("Skipping DDP — running on single GPU or CPU")

# Print info
if local_rank == 0:
    print("Number of parameters:", sum(p.numel() for p in model.parameters()))
    os.makedirs('plots', exist_ok=True)

Skipping DDP — running on single GPU or CPU
Number of parameters: 97178624


In [45]:
#  Finalized Training Loop
def run(enable_debug_mode = True, data = 'wiki'):

    best_val_loss = float('inf')
    patience = 2  # how many bad epochs to wait before stopping
    min_delta = 0.01  # required improvement (e.g., 1% of loss)
    bad_epochs = 0

    #   Load dataset 
    if data == 'c4':
        train_dataset = PileDataset(tokenizer, dataset_name='c4', split='train', num_samples=150_000)
        val_dataset = PileDataset(tokenizer, dataset_name='c4',split='validation', num_samples=10_000)
    elif data == 'wiki':
        train_dataset = PileDataset(tokenizer, dataset_name='wikitext', split='train', num_samples=100_000)
        val_dataset = PileDataset(tokenizer, dataset_name='wikitext',split='validation', num_samples=5_000)
    
    if enable_debug_mode:
        train_dataset = Subset(train_dataset, list(range(min(subset_size, len(train_dataset)))))
    
    #   Create samplers and loaders
    if is_distributed_env():
        train_sampler = DistributedSampler(train_dataset)
        shuffle = False
    else:
        train_sampler = None
        shuffle = True
        
    # Create Dataloader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, shuffle=shuffle)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Compute total and warmup steps
    total_steps = len(train_loader) * num_epochs // grad_accum_steps
    warmup_steps = int(0.2 * total_steps)  # e.g., 20% of training as warmup
    
    if local_rank == 0:
        print(f"Total training steps: {total_steps}, Warmup steps: {warmup_steps}")
        
    # Override top-k and temperature schedules dynamically using warmup
    for layer in model.layers:
        for peer in [layer.peer1, layer.peer2]:
            peer.product_key_topk_schedule = lambda step, warmup=warmup_steps: topk_schedule(step, warmup)
            peer.temperature_schedule = lambda step, warmup=warmup_steps: temperature_schedule(step, warmup)


    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.95), eps=1e-8)

    if enable_debug_mode:
        max_sample = max_batches
    else:
        max_sample = None

    # Logging for plotting
    train_epoch_losses = []
    val_epoch_losses = []
    val_epoch_ppls = []
    
    best_val_loss = float('inf')
    start_time = time.time()
    
    # Start training
    for epoch in range(num_epochs):
        if train_sampler is not None:
            train_sampler.set_epoch(epoch)
    
        if local_rank == 0:
            print(f"\nEpoch {epoch+1}/{num_epochs}")
    
        #   Call updated train() function
        train_loss, train_batch_losses = train(
            model, train_loader, optimizer, device,
            max_batches = max_sample,
            show_progress=True,
            grad_accum_steps=4)
        
        #   Validation
        if local_rank == 0:
            print(f"Epoch {epoch+1}") 
            print(f"Train Loss: {train_loss:.4f}")
            print("Validating...")
    
            with torch.no_grad():
                val_loss, val_perplexity, val_batch_losses = validate(model, val_loader, device)
    
            print(f"Val Loss: {val_loss:.4f}, Perplexity: {val_perplexity:.4f}")
        
        train_epoch_losses.append(train_loss)
        val_epoch_losses.append(val_loss)
        val_epoch_ppls.append(val_perplexity)

        # Check for loss improvement
        if val_loss < best_val_loss - min_delta:
            best_val_loss = val_loss
            bad_epochs = 0  # reset
            torch.save(model.state_dict(), 'best_peer_language_model.pth')
        else:
            bad_epochs += 1
            if local_rank == 0:
                print(f"→ No significant improvement. Bad epochs: {bad_epochs}/{patience}")
        
            if bad_epochs >= patience:
                if local_rank == 0:
                    print("Early stopping triggered.")
                break
        
    end_time = time.time()  # End timer
    elapsed = end_time - start_time
    mins, secs = divmod(elapsed, 60)
    print(f"\n Total Training + Validation Time: {int(mins)} min {int(secs)} sec")

    # plot loss curves
    plot_epoch_metrics(
        train_losses=train_epoch_losses,
        val_losses=val_epoch_losses,
        val_perplexities=val_epoch_ppls,
        save_dir='epoch_plots',
        dataset_name=data
    )
    
    # Final model save
    if local_rank == 0:
        torch.save(model.state_dict(), 'final_peer_language_model.pth')
    
    # Cleanup (only if distributed)
    if is_distributed_env():
        dist.destroy_process_group()


In [46]:
run('c4') # 20 epoch 

Loading dataset shards:   0%|          | 0/1654 [00:00<?, ?it/s]

  scaler = GradScaler()



Epoch 1/20


  with autocast():  # mixed precision context
Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [02:11<00:00,  7.61batch/s, loss=9.8354]


Epoch 1
Train Loss: 10.1845
Validating...


  with autocast():  # AMP context
Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [02:08<00:00, 19.46batch/s]


Val Loss: 9.6746, Perplexity: 15907.8447

Epoch 2/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [02:11<00:00,  7.59batch/s, loss=7.5845]


Epoch 2
Train Loss: 7.9237
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [02:07<00:00, 19.64batch/s]


Val Loss: 6.4196, Perplexity: 613.7863

Epoch 3/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [02:11<00:00,  7.60batch/s, loss=2.9154]


Epoch 3
Train Loss: 6.1679
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [02:04<00:00, 20.05batch/s]


Val Loss: 5.7727, Perplexity: 321.3898

Epoch 4/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.73batch/s, loss=8.0158]


Epoch 4
Train Loss: 5.6084
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.11batch/s]


Val Loss: 5.5577, Perplexity: 259.2278

Epoch 5/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.73batch/s, loss=3.6751]


Epoch 5
Train Loss: 5.5023
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.27batch/s]


Val Loss: 5.4851, Perplexity: 241.0753

Epoch 6/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.73batch/s, loss=6.2475]


Epoch 6
Train Loss: 5.4468
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.11batch/s]


Val Loss: 5.3977, Perplexity: 220.8877

Epoch 7/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.76batch/s, loss=4.1228]


Epoch 7
Train Loss: 5.3423
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.20batch/s]


Val Loss: 5.3275, Perplexity: 205.9266

Epoch 8/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.77batch/s, loss=3.5648]


Epoch 8
Train Loss: 5.2365
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.14batch/s]


Val Loss: 5.2650, Perplexity: 193.4527

Epoch 9/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.74batch/s, loss=5.8049]


Epoch 9
Train Loss: 5.2028
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.12batch/s]


Val Loss: 5.2407, Perplexity: 188.8028

Epoch 10/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.75batch/s, loss=3.8011]


Epoch 10
Train Loss: 5.1743
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.27batch/s]


Val Loss: 5.2121, Perplexity: 183.4799

Epoch 11/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.75batch/s, loss=4.7078]


Epoch 11
Train Loss: 5.0953
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.28batch/s]


Val Loss: 5.1176, Perplexity: 166.9386

Epoch 12/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.73batch/s, loss=4.7742]


Epoch 12
Train Loss: 5.1150
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.13batch/s]


Val Loss: 4.9897, Perplexity: 146.8945

Epoch 13/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.73batch/s, loss=7.6620]


Epoch 13
Train Loss: 4.9248
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.10batch/s]


Val Loss: 4.8869, Perplexity: 132.5430

Epoch 14/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.72batch/s, loss=4.8937]


Epoch 14
Train Loss: 4.8450
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.13batch/s]


Val Loss: 4.7786, Perplexity: 118.9434

Epoch 15/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.75batch/s, loss=4.2224]


Epoch 15
Train Loss: 4.7578
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.13batch/s]


Val Loss: 4.6886, Perplexity: 108.7026

Epoch 16/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.73batch/s, loss=3.3717]


Epoch 16
Train Loss: 4.6749
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.11batch/s]


Val Loss: 4.6043, Perplexity: 99.9104

Epoch 17/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.73batch/s, loss=5.4247]


Epoch 17
Train Loss: 4.5966
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.16batch/s]


Val Loss: 4.5342, Perplexity: 93.1462

Epoch 18/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.74batch/s, loss=2.1768]


Epoch 18
Train Loss: 4.4918
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.16batch/s]


Val Loss: 4.4962, Perplexity: 89.6743

Epoch 19/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.71batch/s, loss=2.2873]


Epoch 19
Train Loss: 4.5032
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.26batch/s]


Val Loss: 4.4579, Perplexity: 86.3080

Epoch 20/20


Training: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.75batch/s, loss=5.3701]


Epoch 20
Train Loss: 4.4345
Validating...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:15<00:00, 33.18batch/s]


Val Loss: 4.4345, Perplexity: 84.3070

 Total Training + Validation Time: 56 min 39 sec
