# ByteNet Character Prediction Model & Experiment
This task uses the Hutter Prize version of  the Wikipedia dataset to test

In [1]:
!pip install datasets



In [2]:
# Define Imports
import json

import requests
import torch
import pickle
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from datasets import load_dataset
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from transformers import AutoTokenizer


In [3]:
from datasets import load_dataset
dataset = load_dataset('enwik8')['train']

In [4]:
from bs4 import BeautifulSoup

def remove_xml_tags(example):
    soup = BeautifulSoup(example['text'], 'lxml')
    text_content = soup.get_text()
    return {'text': text_content}

def clean_html(example):
    import re
    from bs4 import BeautifulSoup
    text = example['text'] 
    soup = BeautifulSoup(text, "html.parser")
    text = soup.get_text(separator=" ")
    text = re.sub(r"http\S+|www\S+|https\S+", '', text, flags=re.MULTILINE)
    text = re.sub(r'\[\[|\]\]|\'\'', '', text)
    return {'text': text}


In [5]:
filtered_dataset = dataset.map(remove_xml_tags)
filtered_dataset = filtered_dataset.map(clean_html)

In [6]:
print(filtered_dataset[:220])

{'text': ['', '', 'Wikipedia', '', 'MediaWiki 1.6alpha', 'first-letter', '', 'Media', 'Special', '', 'Talk', 'User', 'User talk', 'Wikipedia', 'Wikipedia talk', 'Image', 'Image talk', 'MediaWiki', 'MediaWiki talk', 'Template', 'Template talk', 'Help', 'Help talk', 'Category', 'Category talk', 'Portal', 'Portal talk', '', '', '', 'AaA', '1', '', '32899315', '2005-12-27T18:46:47Z', '', 'Jsmethers', '614213', '', '#REDIRECT AAA', '', '', '', 'AlgeriA', '5', '', '18063769', '2005-07-03T11:13:13Z', '', 'Docu', '8029', '', '', 'adding cur_id=5: {{R from CamelCase}}', '#REDIRECT Algeria{{R from CamelCase}}', '', '', '', 'AmericanSamoa', '6', '', '18063795', '2005-07-03T11:14:17Z', '', 'Docu', '8029', '', '', 'adding to cur_id=6  {{R from CamelCase}}', '#REDIRECT American Samoa{{R from CamelCase}}', '', '', '', 'AppliedEthics', '8', '', '15898943', '2002-02-25T15:43:11Z', '', 'Conversion script', '', '', 'Automated conversion', '#REDIRECT Applied ethics', '', '', '', '', 'AccessibleComputing',

In [7]:
def join_to_length(data, target_length):
    joined_data = []
    temp_str = ''
    for entry in data:
        temp_str += entry
        if len(temp_str) >= target_length:
            joined_data.append(temp_str[:target_length])
            temp_str = temp_str[target_length:]
    if temp_str:  # add remaining string if it's not empty
        joined_data.append(temp_str)
    return joined_data

In [8]:
train_dataset = join_to_length(filtered_dataset['text'],500)

In [9]:
print(train_dataset[:150])

['WikipediaMediaWiki 1.6alphafirst-letterMediaSpecialTalkUserUser talkWikipediaWikipedia talkImageImage talkMediaWikiMediaWiki talkTemplateTemplate talkHelpHelp talkCategoryCategory talkPortalPortal talkAaA1328993152005-12-27T18:46:47ZJsmethers614213#REDIRECT AAAAlgeriA5180637692005-07-03T11:13:13ZDocu8029adding cur_id=5: {{R from CamelCase}}#REDIRECT Algeria{{R from CamelCase}}AmericanSamoa6180637952005-07-03T11:14:17ZDocu8029adding to cur_id=6  {{R from CamelCase}}#REDIRECT American Samoa{{R fro', "m CamelCase}}AppliedEthics8158989432002-02-25T15:43:11ZConversion scriptAutomated conversion#REDIRECT Applied ethicsAccessibleComputing10158989452003-04-25T22:18:38ZAms807543Fixing redirect#REDIRECT Accessible_computingAdA11158989462002-09-22T16:02:58ZAndre Engels300#REDIRECT Ada programming languageAnarchism12421368312006-03-04T01:41:25ZCJames745832382/* Anarchist Communism */  too many brackets{{Anarchism}}'Anarchism' originated as a term of abuse first used against early working class r

In [10]:
print(len(train_dataset))

170085


In [11]:
train_split = int(0.9 * len(train_dataset))
validation_split = int(0.95 * len(train_dataset))
test_split = len(train_dataset)

In [12]:
# Takes too long for full dataset
train_data = train_dataset[:int(0.4*train_split)]
validation_data = train_dataset[train_split:validation_split]
test_data = train_dataset[validation_split:test_split]

In [13]:
print(len(train_data),len(validation_data),len(test_data))

61230 8504 8505


In [14]:
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")

In [15]:
def tokenize_item(item):
    src = tokenizer(item[:100], max_length=100, padding="max_length", truncation=True, return_tensors="pt")
    trgt = tokenizer(item[100:], max_length=400, padding="max_length", truncation=True, return_tensors="pt")
    return src['input_ids'],trgt['input_ids']

In [16]:
tokenize_item(train_data[1])

(tensor([[112,  35,  70, 100, 112, 104, 111,  70, 100, 118, 104, 128, 128,  68,
          115, 115, 111, 108, 104, 103,  72, 119, 107, 108, 102, 118,  59,  52,
           56,  59,  60,  59,  60,  55,  54,  53,  51,  51,  53,  48,  51,  53,
           48,  53,  56,  87,  52,  56,  61,  55,  54,  61,  52,  52,  93,  70,
          114, 113, 121, 104, 117, 118, 108, 114, 113,  35, 118, 102, 117, 108,
          115, 119,  68, 120, 119, 114, 112, 100, 119, 104, 103,  35, 102, 114,
          113, 121, 104, 117, 118, 108, 114, 113,  38,  85,  72,  71,  76,  85,
           72,   1]]),
 tensor([[ 87,  35,  68, 115, 115, 111, 108, 104, 103,  35, 104, 119, 107, 108,
          102, 118,  68, 102, 102, 104, 118, 118, 108, 101, 111, 104,  70, 114,
          112, 115, 120, 119, 108, 113, 106,  52,  51,  52,  56,  59,  60,  59,
           60,  55,  56,  53,  51,  51,  54,  48,  51,  55,  48,  53,  56,  87,
           53,  53,  61,  52,  59,  61,  54,  59,  93,  68, 112, 118,  59,  51,
           58,  5

In [17]:
# Imports for ByteNet
import torch
from torch.utils.data import Dataset
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

In [18]:
class DynamicLayerNorm(nn.Module):
    def __init__(self):
        super(DynamicLayerNorm, self).__init__()
        self.norm = None

    def forward(self, x):
        if self.norm is None or self.norm.normalized_shape[0] != x.shape[2]:
            self.norm = nn.LayerNorm(x.shape[2]).to(x.device)
        return self.norm(x)

In [19]:
class MultiplicativeUnit(nn.Module):
    def __init__(self, d, dilation, k, masked = True):
        super(MultiplicativeUnit, self).__init__()
        self.receptive_field = (k-1)*dilation
        self.masked = masked
        self.sig_conv_1 = nn.Conv1d(d, d, k, dilation = dilation)
        self.sig_conv_2 = nn.Conv1d(d, d, k, dilation = dilation)
        self.sig_conv_3 = nn.Conv1d(d, d, k, dilation = dilation)
        self.tanh_conv_1 = nn.Conv1d(d, d, k, dilation = dilation)

        # 100 input tokens
        self.layer_norm1 = DynamicLayerNorm()
        self.layer_norm2 = DynamicLayerNorm()
        self.layer_norm3 = DynamicLayerNorm()
        self.layer_norm4 = DynamicLayerNorm()
        
    def forward(self,x):
        residual = x
        # Multiplicative Unit Block
        if self.receptive_field > 0 and self.masked:
            x = F.pad(x, (self.receptive_field, 0))
        sig_1 = torch.sigmoid(self.layer_norm1(self.sig_conv_1(x)))
        sig_2 = torch.sigmoid(self.layer_norm2(self.sig_conv_2(x)))
        sig_3 = torch.sigmoid(self.layer_norm3(self.sig_conv_3(x)))
        tanh_1 = torch.tanh(self.layer_norm4(self.tanh_conv_1(x)))
        sig_tanh = sig_2 * tanh_1
        residual_sig = sig_3 * residual
        mu_out = sig_1 * torch.tanh(sig_tanh+residual_sig)
        # Multiplicative unit block end
        # Add back the residual
        mu_out = mu_out * residual
        return mu_out


In [20]:
class ResidualBlockMu(nn.Module):
    """
    Implementation of residual Layer for Bytenet machine translation task.

    :param d: The number of input features.
    :param dilation: The initial dilation rate for the convolution layers.
    """

    def __init__(self, d, dilation, k):
        super(ResidualBlockMu, self).__init__()
        self.layer_norm_in = DynamicLayerNorm()
        self.conv_in_1 = nn.Conv1d(2*d,d,1)
        self.layer_norm_in_2 = DynamicLayerNorm()
        # Multiplicative block
        # d -> d
        self.masked_mu = MultiplicativeUnit(d,dilation,k)
        self.mu = MultiplicativeUnit(d,1,1, masked=False)
        self.conv_out = nn.Conv1d(d,2*d,1)

    def forward(self, x):
        residual = x
        x = self.layer_norm_in(x)
        x = torch.nn.ReLU()(x)
        x = self.conv_in_1(x)
        x = self.layer_norm_in_2(x)
        x = torch.nn.ReLU()(x)
        x = self.masked_mu(x)
        x = self.mu(x)
        x = self.conv_out(x)
        x = x + residual
        return x


In [21]:
class HutterDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        src, trgt = tokenize_item(item)
        return src, trgt


In [22]:
train_dataset = HutterDataset(train_data)
validation_dataset = HutterDataset(validation_data)
test_dataset = HutterDataset(test_data)

In [23]:
batch_size = 8
shuffle = True

In [24]:
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=shuffle)
validation_loader = DataLoader(validation_dataset,batch_size=batch_size,shuffle=shuffle)
test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=shuffle)

In [25]:
class ByteNetDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size=1024, d=512, n_sets = 4, set_size = 5, masked_kernel_size=3, max_dilation_rate=16):
        super(ByteNetDecoder, self).__init__()

        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.layers = nn.Sequential()

        for _ in range(n_sets):
            dilation_rate = 1
            for _ in range(set_size):
                self.layers.append(ResidualBlockMu(d,dilation_rate if dilation_rate <= max_dilation_rate else max_dilation_rate, masked_kernel_size))
                dilation_rate = dilation_rate * 2
        self.layers.append(nn.Conv1d(d * 2, d, 1))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.Conv1d(d, vocab_size, 1))
        self.layers.append(nn.Dropout(p=0.1))
    def forward(self,x):
        embed_x = self.embedding(x).permute(0, 2, 1)
        x = self.layers(embed_x)
        return x

In [26]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
decoder = ByteNetDecoder(len(tokenizer.get_vocab())).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(decoder.parameters(),lr = 0.0003, weight_decay=0.0001)

In [27]:
num_epochs = 2

In [28]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

2024-06-26 15:35:07.423308: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-26 15:35:07.470022: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-26 15:35:07.470088: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-26 15:35:07.472729: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-26 15:35:07.485689: I tensorflow/core/platform/cpu_feature_guar

In [29]:
def unfold(inp, decoder):
    x = inp
    out = None
    # 100 -> 200 -> 400 output characters
    for i in range(4):
        # 100 as input -> 200 as input -> 400 as input
        out = decoder(x)
        next_gen = torch.argmax(out,1)
        # 200 characters  ( 100 predicted, 100 given as input) -> (300 predicted, 100 input)
        x =  torch.cat((x,next_gen), dim = 1)
    return out[:, :, out.size(2)//2:]

In [None]:
for epoch in range(num_epochs):
    for i, (ip,trgt) in tqdm(enumerate(train_loader), total=len(train_loader)):
        decoder.train()
        ip, trgt = ip.squeeze(1).to(device), trgt.squeeze(1).to(device)
        out = unfold(ip,decoder)
        loss = criterion(out,trgt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        step = (i + 1)/len(train_loader)
        if i % 25 == 0:
            writer.add_scalar("Loss/train",loss,step)
            tqdm.write(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item()}')
        if i % 2000 == 0 and i > 0:
            decoder.eval()
            total_val_loss = 0
            with torch.no_grad():
                for i, (ip, trgt) in enumerate(validation_loader):
                    ip, trgt = ip.squeeze(1).to(device), trgt.squeeze(1).to(device)
                    out = unfold(ip,decoder)
                    val_loss = criterion(out, trgt)
                    total_val_loss += val_loss.item()
        
            avg_val_loss = total_val_loss / len(validation_loader)
            writer.add_scalar("Loss/val", avg_val_loss, step)
            tqdm.write(f'Epoch [{epoch + 1}/{num_epochs}], , Step [{i + 1}/{len(train_loader)}], Val Loss: {avg_val_loss}')

  0%|          | 1/7654 [00:02<6:11:05,  2.91s/it]

Epoch [1/2], Step [1/7654], Loss: 5.932135105133057


  0%|          | 26/7654 [00:50<4:20:50,  2.05s/it]

Epoch [1/2], Step [26/7654], Loss: 3.7533555030822754


  1%|          | 51/7654 [01:39<4:20:57,  2.06s/it]

Epoch [1/2], Step [51/7654], Loss: 3.7070529460906982


  1%|          | 76/7654 [02:27<4:20:43,  2.06s/it]

Epoch [1/2], Step [76/7654], Loss: 3.659820556640625


  1%|▏         | 101/7654 [03:15<4:20:24,  2.07s/it]

Epoch [1/2], Step [101/7654], Loss: 3.6643054485321045


  2%|▏         | 126/7654 [04:04<4:19:31,  2.07s/it]

Epoch [1/2], Step [126/7654], Loss: 3.638408899307251


  2%|▏         | 151/7654 [04:52<4:18:41,  2.07s/it]

Epoch [1/2], Step [151/7654], Loss: 3.5787863731384277


  2%|▏         | 176/7654 [05:40<4:18:53,  2.08s/it]

Epoch [1/2], Step [176/7654], Loss: 3.7078351974487305


  3%|▎         | 201/7654 [06:30<4:22:07,  2.11s/it]

Epoch [1/2], Step [201/7654], Loss: 3.601872444152832


  3%|▎         | 226/7654 [07:19<4:24:03,  2.13s/it]

Epoch [1/2], Step [226/7654], Loss: 3.6512341499328613


  3%|▎         | 251/7654 [08:09<4:22:18,  2.13s/it]

Epoch [1/2], Step [251/7654], Loss: 3.7560064792633057


  4%|▎         | 276/7654 [08:59<4:20:45,  2.12s/it]

Epoch [1/2], Step [276/7654], Loss: 3.7709479331970215


  4%|▍         | 301/7654 [09:48<4:17:32,  2.10s/it]

Epoch [1/2], Step [301/7654], Loss: 3.6606128215789795


  4%|▍         | 326/7654 [10:37<4:16:08,  2.10s/it]

Epoch [1/2], Step [326/7654], Loss: 3.739375114440918


  5%|▍         | 351/7654 [11:26<4:15:19,  2.10s/it]

Epoch [1/2], Step [351/7654], Loss: 3.5430426597595215


  5%|▍         | 376/7654 [12:15<4:13:03,  2.09s/it]

Epoch [1/2], Step [376/7654], Loss: 3.852949857711792


  5%|▌         | 401/7654 [13:04<4:11:51,  2.08s/it]

Epoch [1/2], Step [401/7654], Loss: 3.788132429122925


  6%|▌         | 426/7654 [13:53<4:14:00,  2.11s/it]

Epoch [1/2], Step [426/7654], Loss: 3.61944580078125


  6%|▌         | 451/7654 [14:41<4:09:49,  2.08s/it]

Epoch [1/2], Step [451/7654], Loss: 3.7573187351226807


  6%|▌         | 476/7654 [15:30<4:09:24,  2.08s/it]

Epoch [1/2], Step [476/7654], Loss: 3.508157968521118


  7%|▋         | 501/7654 [16:19<4:13:07,  2.12s/it]

Epoch [1/2], Step [501/7654], Loss: 3.946018695831299


  7%|▋         | 526/7654 [17:08<4:07:54,  2.09s/it]

Epoch [1/2], Step [526/7654], Loss: 3.67533540725708


  7%|▋         | 551/7654 [17:56<4:06:13,  2.08s/it]

Epoch [1/2], Step [551/7654], Loss: 3.57808256149292


  8%|▊         | 576/7654 [18:45<4:14:24,  2.16s/it]

Epoch [1/2], Step [576/7654], Loss: 3.5392262935638428


  8%|▊         | 601/7654 [19:34<4:04:23,  2.08s/it]

Epoch [1/2], Step [601/7654], Loss: 3.5827796459198


  8%|▊         | 626/7654 [20:22<4:03:29,  2.08s/it]

Epoch [1/2], Step [626/7654], Loss: 3.5759973526000977


  9%|▊         | 651/7654 [21:11<4:02:13,  2.08s/it]

Epoch [1/2], Step [651/7654], Loss: 3.7197155952453613


  9%|▉         | 676/7654 [22:00<4:01:35,  2.08s/it]

Epoch [1/2], Step [676/7654], Loss: 3.589160203933716


  9%|▉         | 701/7654 [22:48<4:00:43,  2.08s/it]

Epoch [1/2], Step [701/7654], Loss: 3.510446786880493


  9%|▉         | 726/7654 [23:37<3:59:54,  2.08s/it]

Epoch [1/2], Step [726/7654], Loss: 3.4395503997802734


 10%|▉         | 751/7654 [24:26<3:58:58,  2.08s/it]

Epoch [1/2], Step [751/7654], Loss: 3.5004770755767822


 10%|█         | 776/7654 [25:14<3:58:28,  2.08s/it]

Epoch [1/2], Step [776/7654], Loss: 3.5394339561462402


 10%|█         | 801/7654 [26:03<3:57:35,  2.08s/it]

Epoch [1/2], Step [801/7654], Loss: 3.5731465816497803


 11%|█         | 826/7654 [26:52<3:55:59,  2.07s/it]

Epoch [1/2], Step [826/7654], Loss: 3.841878890991211


 11%|█         | 851/7654 [27:40<3:55:56,  2.08s/it]

Epoch [1/2], Step [851/7654], Loss: 3.5975098609924316


 11%|█▏        | 876/7654 [28:29<3:54:43,  2.08s/it]

Epoch [1/2], Step [876/7654], Loss: 3.713879346847534


 12%|█▏        | 901/7654 [29:18<3:53:53,  2.08s/it]

Epoch [1/2], Step [901/7654], Loss: 3.5459067821502686


 12%|█▏        | 926/7654 [30:06<3:52:50,  2.08s/it]

Epoch [1/2], Step [926/7654], Loss: 3.657768964767456


 12%|█▏        | 951/7654 [30:55<3:52:32,  2.08s/it]

Epoch [1/2], Step [951/7654], Loss: 3.697028875350952


 13%|█▎        | 976/7654 [31:44<3:51:19,  2.08s/it]

Epoch [1/2], Step [976/7654], Loss: 3.7389965057373047


 13%|█▎        | 1001/7654 [32:32<3:50:20,  2.08s/it]

Epoch [1/2], Step [1001/7654], Loss: 3.4974594116210938


 13%|█▎        | 1026/7654 [33:21<3:49:29,  2.08s/it]

Epoch [1/2], Step [1026/7654], Loss: 3.602097749710083


 14%|█▎        | 1051/7654 [34:09<3:48:37,  2.08s/it]

Epoch [1/2], Step [1051/7654], Loss: 3.525632619857788


 14%|█▍        | 1076/7654 [34:58<3:47:46,  2.08s/it]

Epoch [1/2], Step [1076/7654], Loss: 3.7521305084228516


 14%|█▍        | 1101/7654 [35:46<3:46:43,  2.08s/it]

Epoch [1/2], Step [1101/7654], Loss: 3.638521194458008


 15%|█▍        | 1126/7654 [36:35<3:46:08,  2.08s/it]

Epoch [1/2], Step [1126/7654], Loss: 3.4474751949310303


 15%|█▌        | 1151/7654 [37:24<3:45:03,  2.08s/it]

Epoch [1/2], Step [1151/7654], Loss: 3.6853103637695312


 15%|█▌        | 1176/7654 [38:12<3:44:15,  2.08s/it]

Epoch [1/2], Step [1176/7654], Loss: 3.5008344650268555


 16%|█▌        | 1201/7654 [39:01<3:43:29,  2.08s/it]

Epoch [1/2], Step [1201/7654], Loss: 3.7366018295288086


 16%|█▌        | 1226/7654 [39:50<3:42:43,  2.08s/it]

Epoch [1/2], Step [1226/7654], Loss: 3.557490348815918


 16%|█▋        | 1251/7654 [40:38<3:41:25,  2.07s/it]

Epoch [1/2], Step [1251/7654], Loss: 3.5267717838287354


 17%|█▋        | 1276/7654 [41:27<3:41:01,  2.08s/it]

Epoch [1/2], Step [1276/7654], Loss: 3.5204246044158936


 17%|█▋        | 1301/7654 [42:15<3:39:30,  2.07s/it]

Epoch [1/2], Step [1301/7654], Loss: 3.8157970905303955


 17%|█▋        | 1326/7654 [43:04<3:38:50,  2.07s/it]

Epoch [1/2], Step [1326/7654], Loss: 3.7049224376678467


 18%|█▊        | 1351/7654 [43:52<3:39:27,  2.09s/it]

Epoch [1/2], Step [1351/7654], Loss: 3.518401861190796


 18%|█▊        | 1376/7654 [44:41<3:37:44,  2.08s/it]

Epoch [1/2], Step [1376/7654], Loss: 3.593738317489624


 18%|█▊        | 1401/7654 [45:29<3:35:56,  2.07s/it]

Epoch [1/2], Step [1401/7654], Loss: 3.5179080963134766


 19%|█▊        | 1426/7654 [46:18<3:37:59,  2.10s/it]

Epoch [1/2], Step [1426/7654], Loss: 3.8184409141540527


 19%|█▉        | 1451/7654 [47:06<3:34:15,  2.07s/it]

Epoch [1/2], Step [1451/7654], Loss: 3.798609733581543


 19%|█▉        | 1476/7654 [47:55<3:33:34,  2.07s/it]

Epoch [1/2], Step [1476/7654], Loss: 3.4416539669036865


 20%|█▉        | 1501/7654 [48:44<3:37:56,  2.13s/it]

Epoch [1/2], Step [1501/7654], Loss: 3.6953978538513184


 20%|█▉        | 1526/7654 [49:32<3:31:35,  2.07s/it]

Epoch [1/2], Step [1526/7654], Loss: 3.6603734493255615


 20%|██        | 1551/7654 [50:20<3:30:53,  2.07s/it]

Epoch [1/2], Step [1551/7654], Loss: 3.786489248275757


 21%|██        | 1576/7654 [51:09<3:41:22,  2.19s/it]

Epoch [1/2], Step [1576/7654], Loss: 3.5681729316711426


 21%|██        | 1601/7654 [51:57<3:28:30,  2.07s/it]

Epoch [1/2], Step [1601/7654], Loss: 3.566453218460083


 21%|██        | 1626/7654 [52:46<3:28:08,  2.07s/it]

Epoch [1/2], Step [1626/7654], Loss: 3.6062841415405273


 22%|██▏       | 1651/7654 [53:34<3:27:05,  2.07s/it]

Epoch [1/2], Step [1651/7654], Loss: 3.5497817993164062


 22%|██▏       | 1676/7654 [54:23<3:26:01,  2.07s/it]

Epoch [1/2], Step [1676/7654], Loss: 3.540512800216675


 22%|██▏       | 1701/7654 [55:11<3:25:31,  2.07s/it]

Epoch [1/2], Step [1701/7654], Loss: 3.480771064758301


 23%|██▎       | 1726/7654 [56:00<3:24:34,  2.07s/it]

Epoch [1/2], Step [1726/7654], Loss: 3.6797497272491455


 23%|██▎       | 1751/7654 [56:48<3:23:54,  2.07s/it]

Epoch [1/2], Step [1751/7654], Loss: 3.5935850143432617


 23%|██▎       | 1776/7654 [57:37<3:23:05,  2.07s/it]

Epoch [1/2], Step [1776/7654], Loss: 3.551464319229126


 24%|██▎       | 1801/7654 [58:25<3:22:05,  2.07s/it]

Epoch [1/2], Step [1801/7654], Loss: 3.517752170562744


 24%|██▍       | 1826/7654 [59:14<3:20:43,  2.07s/it]

Epoch [1/2], Step [1826/7654], Loss: 3.665013313293457


 24%|██▍       | 1851/7654 [1:00:02<3:20:32,  2.07s/it]

Epoch [1/2], Step [1851/7654], Loss: 3.762255907058716


 25%|██▍       | 1876/7654 [1:00:50<3:19:06,  2.07s/it]

Epoch [1/2], Step [1876/7654], Loss: 3.521975040435791


 25%|██▍       | 1901/7654 [1:01:39<3:18:20,  2.07s/it]

Epoch [1/2], Step [1901/7654], Loss: 3.580029249191284


 25%|██▌       | 1926/7654 [1:02:27<3:17:22,  2.07s/it]

Epoch [1/2], Step [1926/7654], Loss: 3.776118516921997


 25%|██▌       | 1951/7654 [1:03:16<3:16:47,  2.07s/it]

Epoch [1/2], Step [1951/7654], Loss: 3.4332435131073


 26%|██▌       | 1976/7654 [1:04:04<3:16:07,  2.07s/it]

Epoch [1/2], Step [1976/7654], Loss: 3.6815521717071533


 26%|██▌       | 2000/7654 [1:04:53<3:01:44,  1.93s/it]

Epoch [1/2], Step [2001/7654], Loss: 3.9431769847869873


 26%|██▌       | 2001/7654 [1:21:07<462:10:23, 294.33s/it]

Epoch [1/2], , Step [1063/7654], Val Loss: 3.3702664433620164


 26%|██▋       | 2026/7654 [1:21:55<3:17:50,  2.11s/it]   

Epoch [1/2], Step [2026/7654], Loss: 3.64711856842041


 27%|██▋       | 2051/7654 [1:22:44<3:13:31,  2.07s/it]

Epoch [1/2], Step [2051/7654], Loss: 3.683884382247925


 27%|██▋       | 2076/7654 [1:23:32<3:12:32,  2.07s/it]

Epoch [1/2], Step [2076/7654], Loss: 3.7578213214874268


 27%|██▋       | 2101/7654 [1:24:21<3:11:39,  2.07s/it]

Epoch [1/2], Step [2101/7654], Loss: 3.392252206802368


 28%|██▊       | 2126/7654 [1:25:09<3:10:56,  2.07s/it]

Epoch [1/2], Step [2126/7654], Loss: 3.5032544136047363


 28%|██▊       | 2151/7654 [1:25:58<3:09:54,  2.07s/it]

Epoch [1/2], Step [2151/7654], Loss: 3.6550564765930176


 28%|██▊       | 2176/7654 [1:26:46<3:09:06,  2.07s/it]

Epoch [1/2], Step [2176/7654], Loss: 3.581892967224121


 29%|██▉       | 2201/7654 [1:27:35<3:08:42,  2.08s/it]

Epoch [1/2], Step [2201/7654], Loss: 3.5744190216064453


 29%|██▉       | 2226/7654 [1:28:23<3:07:13,  2.07s/it]

Epoch [1/2], Step [2226/7654], Loss: 3.641709089279175


 29%|██▉       | 2251/7654 [1:29:12<3:06:03,  2.07s/it]

Epoch [1/2], Step [2251/7654], Loss: 3.549344062805176


 30%|██▉       | 2276/7654 [1:30:00<3:06:05,  2.08s/it]

Epoch [1/2], Step [2276/7654], Loss: 3.6602747440338135


 30%|███       | 2301/7654 [1:30:49<3:04:36,  2.07s/it]

Epoch [1/2], Step [2301/7654], Loss: 3.634732961654663


 30%|███       | 2326/7654 [1:31:37<3:03:53,  2.07s/it]

Epoch [1/2], Step [2326/7654], Loss: 3.6550798416137695


 31%|███       | 2351/7654 [1:32:26<3:04:48,  2.09s/it]

Epoch [1/2], Step [2351/7654], Loss: 3.8313000202178955


 31%|███       | 2376/7654 [1:33:14<3:02:13,  2.07s/it]

Epoch [1/2], Step [2376/7654], Loss: 3.588810920715332


 31%|███▏      | 2401/7654 [1:34:02<3:01:07,  2.07s/it]

Epoch [1/2], Step [2401/7654], Loss: 3.5711135864257812


 32%|███▏      | 2426/7654 [1:34:51<3:04:30,  2.12s/it]

Epoch [1/2], Step [2426/7654], Loss: 3.5667765140533447


 32%|███▏      | 2451/7654 [1:35:40<2:59:23,  2.07s/it]

Epoch [1/2], Step [2451/7654], Loss: 3.484030246734619


 32%|███▏      | 2476/7654 [1:36:28<2:58:33,  2.07s/it]

Epoch [1/2], Step [2476/7654], Loss: 3.457796573638916


 33%|███▎      | 2501/7654 [1:37:17<3:04:28,  2.15s/it]

Epoch [1/2], Step [2501/7654], Loss: 3.7230377197265625


 33%|███▎      | 2526/7654 [1:38:05<2:56:32,  2.07s/it]

Epoch [1/2], Step [2526/7654], Loss: 3.5808963775634766


 33%|███▎      | 2551/7654 [1:38:53<2:55:56,  2.07s/it]

Epoch [1/2], Step [2551/7654], Loss: 3.7430579662323


 34%|███▎      | 2576/7654 [1:39:41<2:55:05,  2.07s/it]

Epoch [1/2], Step [2576/7654], Loss: 3.6157970428466797


 34%|███▍      | 2601/7654 [1:40:30<2:53:56,  2.07s/it]

Epoch [1/2], Step [2601/7654], Loss: 3.7458603382110596


 34%|███▍      | 2626/7654 [1:41:18<2:53:15,  2.07s/it]

Epoch [1/2], Step [2626/7654], Loss: 3.7040939331054688


 35%|███▍      | 2651/7654 [1:42:07<2:52:34,  2.07s/it]

Epoch [1/2], Step [2651/7654], Loss: 3.554974317550659


 35%|███▍      | 2676/7654 [1:42:55<2:51:31,  2.07s/it]

Epoch [1/2], Step [2676/7654], Loss: 3.597804546356201


 35%|███▌      | 2701/7654 [1:43:44<2:50:37,  2.07s/it]

Epoch [1/2], Step [2701/7654], Loss: 3.7460150718688965


 36%|███▌      | 2726/7654 [1:44:32<2:49:50,  2.07s/it]

Epoch [1/2], Step [2726/7654], Loss: 3.4327995777130127


 36%|███▌      | 2751/7654 [1:45:21<2:48:36,  2.06s/it]

Epoch [1/2], Step [2751/7654], Loss: 3.476978063583374


 36%|███▋      | 2776/7654 [1:46:09<2:48:21,  2.07s/it]

Epoch [1/2], Step [2776/7654], Loss: 3.6380536556243896


 37%|███▋      | 2801/7654 [1:46:57<2:47:34,  2.07s/it]

Epoch [1/2], Step [2801/7654], Loss: 3.6405038833618164


 37%|███▋      | 2826/7654 [1:47:46<2:46:41,  2.07s/it]

Epoch [1/2], Step [2826/7654], Loss: 3.7385709285736084


 37%|███▋      | 2851/7654 [1:48:34<2:45:28,  2.07s/it]

Epoch [1/2], Step [2851/7654], Loss: 3.6369552612304688


 38%|███▊      | 2876/7654 [1:49:22<2:44:47,  2.07s/it]

Epoch [1/2], Step [2876/7654], Loss: 3.699734687805176


 38%|███▊      | 2901/7654 [1:50:11<2:44:03,  2.07s/it]

Epoch [1/2], Step [2901/7654], Loss: 3.650113105773926


 38%|███▊      | 2926/7654 [1:50:59<2:42:53,  2.07s/it]

Epoch [1/2], Step [2926/7654], Loss: 3.5730254650115967


 39%|███▊      | 2951/7654 [1:51:48<2:42:15,  2.07s/it]

Epoch [1/2], Step [2951/7654], Loss: 3.4884274005889893


 39%|███▉      | 2976/7654 [1:52:37<2:41:28,  2.07s/it]

Epoch [1/2], Step [2976/7654], Loss: 3.680771589279175


 39%|███▉      | 3001/7654 [1:53:25<2:41:08,  2.08s/it]

Epoch [1/2], Step [3001/7654], Loss: 3.5326333045959473


 40%|███▉      | 3026/7654 [1:54:13<2:39:13,  2.06s/it]

Epoch [1/2], Step [3026/7654], Loss: 3.585087299346924


 40%|███▉      | 3051/7654 [1:55:02<2:38:41,  2.07s/it]

Epoch [1/2], Step [3051/7654], Loss: 3.8108696937561035


 40%|████      | 3076/7654 [1:55:50<2:37:38,  2.07s/it]

Epoch [1/2], Step [3076/7654], Loss: 3.5229666233062744


 41%|████      | 3101/7654 [1:56:38<2:36:20,  2.06s/it]

Epoch [1/2], Step [3101/7654], Loss: 3.583500385284424


 41%|████      | 3126/7654 [1:57:27<2:36:39,  2.08s/it]

Epoch [1/2], Step [3126/7654], Loss: 3.709407329559326


 41%|████      | 3151/7654 [1:58:15<2:35:03,  2.07s/it]

Epoch [1/2], Step [3151/7654], Loss: 3.615478754043579


 41%|████▏     | 3176/7654 [1:59:03<2:34:16,  2.07s/it]

Epoch [1/2], Step [3176/7654], Loss: 3.6107654571533203


 42%|████▏     | 3201/7654 [1:59:52<2:33:50,  2.07s/it]

Epoch [1/2], Step [3201/7654], Loss: 3.654075860977173


 42%|████▏     | 3226/7654 [2:00:40<2:32:42,  2.07s/it]

Epoch [1/2], Step [3226/7654], Loss: 3.642332077026367


 42%|████▏     | 3251/7654 [2:01:29<2:32:21,  2.08s/it]

Epoch [1/2], Step [3251/7654], Loss: 3.715812921524048


 43%|████▎     | 3276/7654 [2:02:17<2:31:25,  2.08s/it]

Epoch [1/2], Step [3276/7654], Loss: 3.5943970680236816


 43%|████▎     | 3301/7654 [2:03:06<2:30:03,  2.07s/it]

Epoch [1/2], Step [3301/7654], Loss: 3.692753791809082


 43%|████▎     | 3326/7654 [2:03:54<2:29:09,  2.07s/it]

Epoch [1/2], Step [3326/7654], Loss: 3.652444839477539


 44%|████▍     | 3351/7654 [2:04:43<2:30:17,  2.10s/it]

Epoch [1/2], Step [3351/7654], Loss: 3.4484405517578125


 44%|████▍     | 3376/7654 [2:05:31<2:27:58,  2.08s/it]

Epoch [1/2], Step [3376/7654], Loss: 3.4691333770751953


 44%|████▍     | 3401/7654 [2:06:19<2:26:52,  2.07s/it]

Epoch [1/2], Step [3401/7654], Loss: 3.6426258087158203


 45%|████▍     | 3426/7654 [2:07:08<2:29:52,  2.13s/it]

Epoch [1/2], Step [3426/7654], Loss: 3.578839063644409


 45%|████▌     | 3451/7654 [2:07:57<2:25:07,  2.07s/it]

Epoch [1/2], Step [3451/7654], Loss: 3.6163530349731445


 45%|████▌     | 3476/7654 [2:08:45<2:24:34,  2.08s/it]

Epoch [1/2], Step [3476/7654], Loss: 3.579988479614258


 46%|████▌     | 3501/7654 [2:09:34<2:31:33,  2.19s/it]

Epoch [1/2], Step [3501/7654], Loss: 3.5897836685180664


 46%|████▌     | 3526/7654 [2:10:22<2:22:55,  2.08s/it]

Epoch [1/2], Step [3526/7654], Loss: 3.4959588050842285


 46%|████▋     | 3551/7654 [2:11:11<2:21:33,  2.07s/it]

Epoch [1/2], Step [3551/7654], Loss: 3.729555368423462


 47%|████▋     | 3576/7654 [2:11:59<2:20:51,  2.07s/it]

Epoch [1/2], Step [3576/7654], Loss: 3.6102712154388428


 47%|████▋     | 3601/7654 [2:12:48<2:19:45,  2.07s/it]

Epoch [1/2], Step [3601/7654], Loss: 3.6181600093841553


 47%|████▋     | 3626/7654 [2:13:36<2:19:20,  2.08s/it]

Epoch [1/2], Step [3626/7654], Loss: 3.6468374729156494


 48%|████▊     | 3651/7654 [2:14:24<2:18:15,  2.07s/it]

Epoch [1/2], Step [3651/7654], Loss: 3.658576250076294


 48%|████▊     | 3676/7654 [2:15:13<2:17:25,  2.07s/it]

Epoch [1/2], Step [3676/7654], Loss: 3.6709325313568115


 48%|████▊     | 3701/7654 [2:16:02<2:16:35,  2.07s/it]

Epoch [1/2], Step [3701/7654], Loss: 3.6272478103637695


 49%|████▊     | 3726/7654 [2:16:50<2:16:13,  2.08s/it]

Epoch [1/2], Step [3726/7654], Loss: 3.6940784454345703


 49%|████▉     | 3751/7654 [2:17:39<2:15:00,  2.08s/it]

Epoch [1/2], Step [3751/7654], Loss: 3.5735855102539062


 49%|████▉     | 3776/7654 [2:18:27<2:14:23,  2.08s/it]

Epoch [1/2], Step [3776/7654], Loss: 3.5428130626678467


 50%|████▉     | 3801/7654 [2:19:16<2:13:30,  2.08s/it]

Epoch [1/2], Step [3801/7654], Loss: 3.5466513633728027


 50%|████▉     | 3826/7654 [2:20:05<2:12:31,  2.08s/it]

Epoch [1/2], Step [3826/7654], Loss: 3.7237839698791504


 50%|█████     | 3851/7654 [2:20:53<2:11:18,  2.07s/it]

Epoch [1/2], Step [3851/7654], Loss: 3.624598503112793


 51%|█████     | 3876/7654 [2:21:42<2:10:56,  2.08s/it]

Epoch [1/2], Step [3876/7654], Loss: 3.705472469329834


 51%|█████     | 3901/7654 [2:22:31<2:09:53,  2.08s/it]

Epoch [1/2], Step [3901/7654], Loss: 3.4595706462860107


 51%|█████▏    | 3926/7654 [2:23:19<2:08:55,  2.07s/it]

Epoch [1/2], Step [3926/7654], Loss: 3.99259090423584


 52%|█████▏    | 3951/7654 [2:24:07<2:07:56,  2.07s/it]

Epoch [1/2], Step [3951/7654], Loss: 3.5244369506835938


 52%|█████▏    | 3976/7654 [2:24:56<2:07:23,  2.08s/it]

Epoch [1/2], Step [3976/7654], Loss: 3.538147211074829


 52%|█████▏    | 4000/7654 [2:25:45<1:58:07,  1.94s/it]

Epoch [1/2], Step [4001/7654], Loss: 3.7510364055633545


 52%|█████▏    | 4001/7654 [2:41:53<296:49:12, 292.51s/it]

Epoch [1/2], , Step [1063/7654], Val Loss: 3.377793480513349


 53%|█████▎    | 4026/7654 [2:42:41<2:06:57,  2.10s/it]   

Epoch [1/2], Step [4026/7654], Loss: 3.6991465091705322


 53%|█████▎    | 4051/7654 [2:43:30<2:04:20,  2.07s/it]

Epoch [1/2], Step [4051/7654], Loss: 3.6264476776123047


 53%|█████▎    | 4076/7654 [2:44:18<2:02:55,  2.06s/it]

Epoch [1/2], Step [4076/7654], Loss: 3.7156670093536377


 54%|█████▎    | 4101/7654 [2:45:06<2:01:59,  2.06s/it]

Epoch [1/2], Step [4101/7654], Loss: 3.5791125297546387


 54%|█████▍    | 4126/7654 [2:45:55<2:01:33,  2.07s/it]

Epoch [1/2], Step [4126/7654], Loss: 3.6624457836151123


 54%|█████▍    | 4151/7654 [2:46:43<2:00:30,  2.06s/it]

Epoch [1/2], Step [4151/7654], Loss: 3.770266532897949


 55%|█████▍    | 4176/7654 [2:47:31<1:59:33,  2.06s/it]

Epoch [1/2], Step [4176/7654], Loss: 3.5543596744537354


 55%|█████▍    | 4201/7654 [2:48:19<1:59:09,  2.07s/it]

Epoch [1/2], Step [4201/7654], Loss: 3.533764600753784


 55%|█████▌    | 4218/7654 [2:48:52<1:50:07,  1.92s/it]

In [None]:
!tensorboard --logdir=runs

In [None]:
torch.save(decoder.state_dict(), 'model_state.pth')

In [None]:
writer.close()

In [None]:
decoder = ByteNetDecoder(len(tokenizer.get_vocab())).to(device)
decoder.load_state_dict(torch.load('model_state.pth'))

In [None]:
def predict(inp, model,  tokenizer):
    model.eval()
    with torch.inference_mode():
        out = model(inp)
    print(out.shape)
    out = torch.argmax(out.squeeze(0), dim=0)

    out = tokenizer.decode(out.tolist())
    out = out.replace("<pad>", "")
    print(f"Predicted text:  {out}")

In [None]:
batch = next(iter(validation_loader))
batch =  next(iter(train_loader))

In [None]:
src = batch[0][0][0]
trgt = batch[1][0][0]
i =  tokenizer.decode(src.tolist())
i = i.replace("<pad>", "")
print(f"Input : {i}")
print(f"Ground  Truth : {tokenizer.decode(trgt.tolist())}")


In [None]:
predict(batch[0][0].to(device),decoder,tokenizer)