In [7]:
import tiktoken
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F

# ----------------------------------------------------------

class DataLoaderLite:
	def __init__(self, B, T):
		self.B = B
		self.T = T

		# at init load tokens from disk and store them in memory
		with open('input.txt', 'r') as f:
			text = f.read()
		enc = tiktoken.get_encoding('gpt2')
		tokens = enc.encode(text)
		self.tokens = torch.tensor(tokens)
		print(f"loaded {len(self.tokens)} tokens")
		print(f"1 epoch = {len(self.tokens) // (B * T)} batches")

		# state
		self.current_position = 0

	def next_batch(self):
		B, T = self.B, self.T
		buf = self.tokens[self.current_position:self.current_position + B * T + 1]
		x = (buf[:-1]).view(B, T) # inputs
		y = (buf[1:]).view(B, T) # targets

		self.current_position += B * T

		if self.current_position + (B * T + 1) > len(self.tokens):
			self.current_position = 0
		return x, y

class CausalSelfAttention(nn.Module):

	def __init__(self, config):
		super().__init__()
		assert config.n_embd % config.n_head == 0

		self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) # key, query, value projections for all heads but in a batch. Saves you from three separate instantiations of nn.Linear
		self.c_proj = nn.Linear(config.n_embd, config.n_embd) # output projection

		self.n_head = config.n_head
		self.n_embd = config.n_embd

		self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))

	def forward(self, x):

		B, T, C = x.size() # batch size, sequence length, embedding dimension (n_embd)

		# Calculate query, key, value for all heads in batch, move head forward in the shape to be a batch dim alongside B
		# nh is "number of heads", hs is "head size", and C is number of channels (nh * hs)
		# e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs = 768 channels in the Transformer

		qkv = self.c_attn(x)
		q, k, v = qkv.split(self.n_embd, dim=2)
		k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
		q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
		v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

		# attention materializes the large (T, T) matrix for all queries and keys
		att = q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))) # --> (B, nh, T, T)
		att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
		att = F.softmax(att, dim=-1)

		y = att @ v # (B, nh, T, T) x (B, nh, T, hs) --> (B, nh, T, hs)
		y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

		# output project
		y = self.c_proj(y)
		return y

    def test():

IndentationError: unindent does not match any outer indentation level (<string>, line 79)