# ByteNet Character Prediction Model & Experiment
This task uses the Hutter Prize version of  the Wikipedia dataset to test

In [1]:
!pip install datasets



In [2]:
# Define Imports
import json

import requests
import torch
import pickle
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from datasets import load_dataset
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from transformers import AutoTokenizer


In [3]:
from datasets import load_dataset
dataset = load_dataset('enwik8')['train']

In [4]:
from bs4 import BeautifulSoup

def remove_xml_tags(example):
    soup = BeautifulSoup(example['text'], 'lxml')
    text_content = soup.get_text()
    return {'text': text_content}


In [5]:
filtered_dataset = dataset.map(remove_xml_tags)


In [6]:
print(filtered_dataset[:220])

{'text': ['', '', 'Wikipedia', 'http://en.wikipedia.org/wiki/Main_Page', 'MediaWiki 1.6alpha', 'first-letter', '', 'Media', 'Special', '', 'Talk', 'User', 'User talk', 'Wikipedia', 'Wikipedia talk', 'Image', 'Image talk', 'MediaWiki', 'MediaWiki talk', 'Template', 'Template talk', 'Help', 'Help talk', 'Category', 'Category talk', 'Portal', 'Portal talk', '', '', '', 'AaA', '1', '', '32899315', '2005-12-27T18:46:47Z', '', 'Jsmethers', '614213', '', '#REDIRECT [[AAA]]', '', '', '', 'AlgeriA', '5', '', '18063769', '2005-07-03T11:13:13Z', '', 'Docu', '8029', '', '', 'adding cur_id=5: {{R from CamelCase}}', '#REDIRECT [[Algeria]]{{R from CamelCase}}', '', '', '', 'AmericanSamoa', '6', '', '18063795', '2005-07-03T11:14:17Z', '', 'Docu', '8029', '', '', 'adding to cur_id=6  {{R from CamelCase}}', '#REDIRECT [[American Samoa]]{{R from CamelCase}}', '', '', '', 'AppliedEthics', '8', '', '15898943', '2002-02-25T15:43:11Z', '', 'Conversion script', '', '', 'Automated conversion', '#REDIRECT [[App

In [7]:
def join_to_length(data, target_length):
    joined_data = []
    temp_str = ''
    for entry in data:
        temp_str += entry
        if len(temp_str) >= target_length:
            joined_data.append(temp_str[:target_length])
            temp_str = temp_str[target_length:]
    if temp_str:  # add remaining string if it's not empty
        joined_data.append(temp_str)
    return joined_data

In [8]:
train_dataset = join_to_length(filtered_dataset['text'],500)

In [9]:
print(len(train_dataset))

186380


In [10]:
train_split = int(0.9 * len(train_dataset))
validation_split = int(0.95 * len(train_dataset))
test_split = len(train_dataset)

In [11]:
train_data = train_dataset[:train_split]
validation_data = train_dataset[train_split:validation_split]
test_data = train_dataset[validation_split:test_split]

In [12]:
print(len(train_data),len(validation_data),len(test_data))

167742 9319 9319


In [13]:
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")

In [14]:
def tokenize_item(item):
    src = tokenizer(item[:100], max_length=400, padding="max_length", truncation=True, return_tensors="pt")
    trgt = tokenizer(item[100:], max_length=400, padding="max_length", truncation=True, return_tensors="pt")
    return src['input_ids'],trgt['input_ids']

In [15]:
tokenize_item(train_data[1])

(tensor([[117, 114, 112,  35,  70, 100, 112, 104, 111,  70, 100, 118, 104, 128,
          128,  38,  85,  72,  71,  76,  85,  72,  70,  87,  35,  94,  94,  68,
          112, 104, 117, 108, 102, 100, 113,  35,  86, 100, 112, 114, 100,  96,
           96, 126, 126,  85,  35, 105, 117, 114, 112,  35,  70, 100, 112, 104,
          111,  70, 100, 118, 104, 128, 128,  68, 115, 115, 111, 108, 104, 103,
           72, 119, 107, 108, 102, 118,  59,  52,  56,  59,  60,  59,  60,  55,
           54,  53,  51,  51,  53,  48,  51,  53,  48,  53,  56,  87,  52,  56,
           61,  55,   1,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
            0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
            0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
            0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
            0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
            0,   0,   0,   0,   0,   0, 

In [16]:
# Imports for ByteNet
import torch
from torch.utils.data import Dataset
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

In [17]:
class MultiplicativeUnit(nn.Module):
    def __init__(self, d, dilation, k, masked = True):
        super(MultiplicativeUnit, self).__init__()
        self.receptive_field = (k-1)*dilation
        self.masked = masked
        self.sig_conv_1 = nn.Conv1d(d, d, k, dilation = dilation)
        self.sig_conv_2 = nn.Conv1d(d, d, k, dilation = dilation)
        self.sig_conv_3 = nn.Conv1d(d, d, k, dilation = dilation)
        self.tanh_conv_1 = nn.Conv1d(d, d, k, dilation = dilation)

        # 100 input tokens
        self.layer_norm1 = nn.LayerNorm(400)
        self.layer_norm2 = nn.LayerNorm(400)
        self.layer_norm3 = nn.LayerNorm(400)
        self.layer_norm4 = nn.LayerNorm(400)
        
    def forward(self,x):
        residual = x
        # Multiplicative Unit Block
        if self.receptive_field > 0 and self.masked:
            x = F.pad(x, (self.receptive_field, 0))
        sig_1 = torch.sigmoid(self.layer_norm1(self.sig_conv_1(x)))
        sig_2 = torch.sigmoid(self.layer_norm2(self.sig_conv_2(x)))
        sig_3 = torch.sigmoid(self.layer_norm3(self.sig_conv_3(x)))
        tanh_1 = torch.tanh(self.layer_norm4(self.tanh_conv_1(x)))
        sig_tanh = sig_2 * tanh_1
        residual_sig = sig_3 * residual
        mu_out = sig_1 * torch.tanh(sig_tanh+residual_sig)
        # Multiplicative unit block end
        # Add back the residual
        mu_out = mu_out * residual
        return mu_out


In [18]:
class ResidualBlockMu(nn.Module):
    """
    Implementation of residual Layer for Bytenet machine translation task.

    :param d: The number of input features.
    :param dilation: The initial dilation rate for the convolution layers.
    """

    def __init__(self, d, dilation, k):
        super(ResidualBlockMu, self).__init__()
        self.layer_norm_in = nn.LayerNorm(400)
        self.conv_in_1 = nn.Conv1d(2*d,d,1)
        self.layer_norm_in_2 = nn.LayerNorm(400)
        # Multiplicative block
        # d -> d
        self.masked_mu = MultiplicativeUnit(d,dilation,k)
        self.mu = MultiplicativeUnit(d,1,1, masked=False)
        self.conv_out = nn.Conv1d(d,2*d,1)

    def forward(self, x):
        residual = x
        x = self.layer_norm_in(x)
        x = torch.nn.ReLU()(x)
        x = self.conv_in_1(x)
        x = self.layer_norm_in_2(x)
        x = torch.nn.ReLU()(x)
        x = self.masked_mu(x)
        x = self.mu(x)
        x = self.conv_out(x)
        return x


In [19]:
class HutterDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        src, trgt = tokenize_item(item)
        return src, trgt


In [20]:
train_dataset = HutterDataset(train_data)
validation_dataset = HutterDataset(validation_data)
test_dataset = HutterDataset(test_data)

In [21]:
batch_size = 16
shuffle = True

In [22]:
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=shuffle)
validation_loader = DataLoader(validation_dataset,batch_size=batch_size,shuffle=shuffle)
test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=shuffle)

In [23]:
class ByteNetDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size=1024, d=512, n_sets = 6, set_size = 5, masked_kernel_size=3, max_dilation_rate=16):
        super(ByteNetDecoder, self).__init__()

        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.layers = nn.Sequential()

        for _ in range(n_sets):
            dilation_rate = 1
            for _ in range(set_size):
                self.layers.append(ResidualBlockMu(d,dilation_rate if dilation_rate <= max_dilation_rate else max_dilation_rate, masked_kernel_size))
                dilation_rate = dilation_rate * 2
        self.layers.append(nn.Conv1d(d * 2, d, 1))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.Conv1d(d, vocab_size, 1))
        self.layers.append(nn.Dropout(p=0.1))
    def forward(self,x):
        embed_x = self.embedding(x).permute(0, 2, 1)
        x = self.layers(embed_x)
        return x

In [24]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
decoder = ByteNetDecoder(len(tokenizer.get_vocab())).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(decoder.parameters(),lr = 0.0003, weight_decay=0.0001)

In [25]:
num_epochs = 1

In [None]:
for epoch in range(num_epochs):
    for i, (ip,trgt) in tqdm(enumerate(train_loader), total=len(train_loader)):
        decoder.train()
        ip, trgt = ip.squeeze(1).to(device), trgt.squeeze(1).to(device)
        out = decoder(ip)
        loss = criterion(out,trgt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 25 == 0:
            tqdm.write(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item()}')


  0%|          | 1/10484 [00:02<8:23:18,  2.88s/it]

Epoch [1/1], Step [1/10484], Loss: 5.955323696136475


  0%|          | 26/10484 [01:00<6:27:04,  2.22s/it]

Epoch [1/1], Step [26/10484], Loss: 5.3454437255859375


  0%|          | 48/10484 [01:45<6:06:14,  2.11s/it]

In [None]:
torch.cuda.empty_cache()
!nvidia-smi