In [8]:
import tiktoken
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
encoding = tiktoken.get_encoding("cl100k_base")

In [7]:
encoding.encode("test")

[1985]

In [144]:
encoding._special_tokens

{'<|endoftext|>': 100257,
 '<|fim_prefix|>': 100258,
 '<|fim_middle|>': 100259,
 '<|fim_suffix|>': 100260,
 '<|endofprompt|>': 100276}

In [73]:
raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
raw_dataset = raw_dataset.filter(lambda x: len(x["text"]) > 0)

Found cached dataset wikitext (/Users/sidbaskaran/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Loading cached processed dataset at /Users/sidbaskaran/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-82e63c6c3fde2311.arrow


In [74]:
raw_dataset = raw_dataset.map(lambda x: {"ids": encoding.encode(x["text"])})

Loading cached processed dataset at /Users/sidbaskaran/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-45ce7928739699e8.arrow


In [75]:
raw_dataset = raw_dataset.remove_columns(["text"])

In [133]:
tokens = raw_dataset.to_dict()["ids"]

In [134]:
max_sequence_length = 256

In [145]:
encoding._special_tokens["<|endofprompt|>"]

100276

In [135]:
import torch
from torch.nn import functional as F

In [136]:
for i in range(len(tokens)):
    if len(tokens[i]) > max_sequence_length:
        extra_seq = tokens[i][max_sequence_length:]
        tokens[i] = torch.tensor(tokens[i][:max_sequence_length])
        while len(extra_seq) > max_sequence_length:
            padded_tokens = extra_seq[:max_sequence_length]
            padded_tokens = F.pad(
                torch.tensor(padded_tokens),
                (0, max_sequence_length - len(padded_tokens)),
                value=-1,
            )
            tokens.append(padded_tokens)
            extra_seq = extra_seq[max_sequence_length:]
    else:
        tokens[i] = F.pad(
            torch.tensor(tokens[i]), (0, max_sequence_length - len(tokens[i])), value=-1
        )

In [141]:
data = torch.stack(tokens)

In [142]:
data.shape

torch.Size([23800, 256])

In [160]:
x = torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0]])

In [161]:
x

tensor([[1, 1, 1, 1, 1, 0, 0, 0]])

In [162]:
x = x.expand(x.size(0), 8, 8)

In [167]:
torch.tril(x).squeeze()

tensor([[1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0]])

In [168]:
class DecoderBlock(nn.Module):
    def __init__(
        self,
        num_heads: int,
        embed_dim: int,
        block_size: int,
        attn_dropout: float,
        layer_dropout: float,
        return_attn_weights: bool = False,
    ) -> None:
        super().__init__()
        self.block_size = block_size
        self.multihead_attention = nn.MultiheadAttention(
            embed_dim, num_heads, dropout=attn_dropout, batch_first=True
        )
        self.ln1 = nn.LayerNorm(embed_dim)
        # positionwise FFN
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),  # as in original GPT
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(layer_dropout),
        )
        self.ln2 = nn.LayerNorm(embed_dim)
        self.return_attn_weights = return_attn_weights

    def forward(
        self, x: torch.Tensor, attn_mask: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # self attention and residual, -1 denotes padding
        causal_mask = torch.tril(attn_mask).squeeze().to(x.device)
        attn_out, attn_weights = self.multihead_attention(
            x, x, x, attn_mask=causal_mask
        )
        x = self.ln1(x + attn_out)
        # feed forward and residual
        x = self.ln2(x + self.feed_forward(x))

        if self.return_attn_weights:
            return x, attn_weights
        else:
            return x

NameError: name 'nn' is not defined