In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from typing import Optional, Tuple
from dataclasses import dataclass
torch.set_printoptions(sci_mode=False, precision=4)

%matplotlib inline

In [2]:
@dataclass
class ModelArgs:
    n_embd: int = 32
    n_layer: int = 2
    n_head: int = 2
    vocab_size: int = -1  # defined later by tokenizer

    bias: bool = False
    dropout: float = 0.1

    batch_size: int = 32
    max_seq_len: int = 2048
    dtype: torch.dtype = torch.float32

    rotary: bool = True
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'

    def __init__(self):
        if torch.cuda.is_available():
            self.dtype = torch.bfloat16
            self.n_embd = 128
            self.n_head = 4
            self.n_layer = 6

config = ModelArgs()

In [7]:
# https://github.com/meta-llama/llama/blob/llama_v2/llama/model.py

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.
    The returned tensor contains complex values in complex64 data type.

    Args:
        dim (int): Dimension of the frequency tensor.
        end (int): End index for precomputing frequencies.
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.

    Returns:
        torch.Tensor: Precomputed frequency tensor with complex exponentials.
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    """
    Reshape frequency tensor for broadcasting it with another tensor.

    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
    for the purpose of broadcasting the frequency tensor during element-wise operations.

    Args:
        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
        x (torch.Tensor): Target tensor for broadcasting compatibility.

    Returns:
        torch.Tensor: Reshaped frequency tensor.

    Raises:
        AssertionError: If the frequency tensor doesn't match the expected shape.
        AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
    """
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
    returned as real tensors.

    Args:
        xq (torch.Tensor): Query tensor to apply rotary embeddings.
        xk (torch.Tensor): Key tensor to apply rotary embeddings.
        freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
        
    """
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=config.dtype, device=config.device)
        # output projection
        # self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=config.dtype, device=config.device)
        # regularization
        # self.attn_dropout = nn.Dropout(config.dropout)
        # self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        self.is_causal = False

    def forward(self, x, freqs_cis):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head)
        q = q.view(B, T, self.n_head, C // self.n_head)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        if freqs_cis is not None:
            q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)

        q = q.transpose(1, 2)  # (B, nh, T, hs)
        k = k.transpose(1, 2)  # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        # efficient attention using Flash Attention CUDA kernels
        y = F.scaled_dot_product_attention(q, k, v, attn_mask=None,
                                           dropout_p=self.dropout if self.training else 0, is_causal=self.is_causal)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side

        # output projection
        # y = self.resid_dropout(self.c_proj(y))
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias, dtype=config.dtype, device=config.device)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias, dtype=config.dtype, device=config.device)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.rmsn_1 = nn.RMSNorm(config.n_embd, device=config.device, dtype=config.dtype)
        self.attn = CausalSelfAttention(config)
        self.rmsn_2 = nn.RMSNorm(config.n_embd, device=config.device, dtype=config.dtype)
        self.mlp = MLP(config)

    def forward(self, x, freqs_cis):
        x = x + self.attn(self.rmsn_1(x), freqs_cis)
        x = x + self.mlp(self.rmsn_2(x))
        return x

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.wte=nn.Embedding(config.vocab_size, config.n_embd, dtype=config.dtype, device=config.device)
        self.h =nn.ModuleList([Block(config) for _ in range(config.n_layer)])         
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=config.bias, dtype=config.dtype, device=config.device)
        if config.rotary:
            self.freqs_cis = precompute_freqs_cis(config.n_embd // config.n_head, config.max_seq_len * 2).to(config.device)
        else:
            self.wpe = nn.Embedding(256, config.n_embd, device=config.device, dtype=config.dtype)
            self.pos_idx = torch.arange(0, 1024, dtype=torch.long, device=config.device)
            self.freqs_cis = None

        # self.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
        self.rmsn=nn.RMSNorm(config.n_embd, device=config.device, dtype=config.dtype)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        _bsz, seqlen = idx.shape

        x = self.wte(idx)  # token embeddings of shape (b, t, n_embd)
        if self.freqs_cis is None:
            x += self.wpe(self.pos_idx[:seqlen])
            freqs_cis = None
        else:
            freqs_cis = self.freqs_cis[:seqlen]
        for block in self.h:
            x = block(x, freqs_cis)
        x = self.rmsn(x)
        logits = self.lm_head(x) # B, T, vocab_size
        logits = logits[:, -1, :]  # last one      
        # if we are given some desired targets also calculate the loss
        loss = None
        if targets is not None:
            # loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
            loss = F.cross_entropy(logits, targets)

        return logits, loss

In [5]:
torch.manual_seed(13)
x = torch.randn(2, 3, 4)
x

tensor([[[ 0.4372,  0.3701,  1.5816, -0.1556],
         [ 0.1511, -1.3495, -0.7089, -0.2434],
         [-2.1701,  0.4161,  1.0159, -0.2136]],

        [[-0.9229, -0.2264, -0.0301,  0.4067],
         [-1.7423, -0.7882, -0.1232,  1.5986],
         [ 0.5884,  0.6930, -0.2881,  0.6280]]])

In [6]:
x[:,-1,:]

tensor([[-2.1701,  0.4161,  1.0159, -0.2136],
        [ 0.5884,  0.6930, -0.2881,  0.6280]])

In [70]:
x = torch.randint(1, 10, (2, 4))
logits, loss = gpt(x)
logits

tensor([[[ 0.0596,  0.0568,  0.2460,  0.0706,  0.0052,  0.0270, -0.1398,
           0.1677,  0.0867, -0.1406,  0.0380,  0.0434, -0.1751, -0.1319,
           0.1504, -0.1324,  0.0828,  0.0459, -0.2859, -0.0450, -0.0359,
           0.2662, -0.1456, -0.0038, -0.1007,  0.0747,  0.0604, -0.0895,
           0.2200, -0.1594,  0.2052,  0.0279,  0.0978, -0.1103,  0.0741,
           0.0199,  0.1176,  0.0953,  0.0800, -0.0324, -0.0179, -0.0270,
           0.0537, -0.0115,  0.0294, -0.0514, -0.1358, -0.1006, -0.0119,
           0.2503, -0.0704, -0.0960,  0.1120,  0.0210, -0.0770,  0.1577,
          -0.0145, -0.0904,  0.0924, -0.1659,  0.0047,  0.0839, -0.0555,
           0.0069],
         [-0.1262, -0.0099, -0.0071, -0.0194, -0.0142,  0.0764,  0.0564,
           0.0544, -0.1410, -0.0199, -0.0729, -0.1143,  0.0503,  0.1272,
           0.1048,  0.0610, -0.0093,  0.0141, -0.1129, -0.1118, -0.1458,
           0.0377, -0.1448,  0.1341,  0.1015,  0.0174,  0.0548, -0.1718,
          -0.0134,  0.0361,  0.

In [19]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars) + 1
# create a mapping from characters to integers
stoi = { ch:i + 1 for i,ch in enumerate(chars) }
stoi['<>'] = 0
itos = { i:ch for ch,i in stoi.items() }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
data = torch.tensor(encode(text), dtype=torch.long)

# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

def get_batch(split, batch_size, block_size):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.tensor([data[i + block_size] for i in ix])
    # y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [20]:
get_batch('train', 4, 8)

(tensor([[53, 51, 64,  2, 41, 44,  1, 17],
         [43,  2, 52, 40, 50, 44,  2, 59],
         [44, 59,  2, 22, 58, 40, 41, 44],
         [10,  2, 24, 22, 27, 20,  2, 21]]),
 tensor([44, 57, 51, 18]))

In [69]:
block_size = 16

class Flatten(nn.Module):
    def __init__(self, n):
        super(Flatten, self).__init__() 
        self.n = n

    def __call__(self, x):
        # print(x.shape)
        B, T, C = x.shape
        if T == self.n:
            return  x.view(B, C * self.n)
        return x.view(B, T//2, C * self.n)
        out =  x.view(B, -1, C * self.n)
        print(out.shape)
        return out


    def parameters(self):
        return []

class MLPModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.wte=nn.Embedding(config.vocab_size, config.n_embd, dtype=config.dtype, device=config.device)
        self.flatten = nn.Flatten(1)
        hidden = config.n_embd * block_size // 2
        self.mlp1 = nn.Linear(config.n_embd * block_size, hidden)    
        self.mlp2 = nn.Linear(hidden, hidden)    
        self.lm_head = nn.Linear(hidden, vocab_size)

    def forward(self, idx, targets=None):
        x = self.wte(idx)  # token embeddings of shape (b, t, n_embd)
        x = self.flatten(x)
        x = self.mlp1(x)
        x = self.mlp2(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            # loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
            loss = F.cross_entropy(logits, targets)

        return logits, loss        
    

class MLPModel2(nn.Module):
    def __init__(self, config):
        super().__init__()

        n_embd = config.n_embd
        self.model = nn.Sequential(
                    nn.Embedding(config.vocab_size, n_embd),
                    Flatten(2), nn.Linear(n_embd * 2, n_embd * 2), nn.Tanh(),
                    Flatten(2), nn.Linear(n_embd * 4, n_embd * 2), nn.Tanh(),
                    Flatten(2), nn.Linear(n_embd * 4, n_embd * 2), nn.Tanh(),
                    Flatten(2), nn.Linear(n_embd * 4, n_embd * 2), nn.Tanh(),
                    nn.Linear(n_embd * 2, n_embd * 2),
                    nn.Tanh(),
                    nn.Linear(n_embd * 2, vocab_size))

        # self.wte=nn.Embedding(config.vocab_size, config.n_embd, dtype=config.dtype, device=config.device)
        # self.flatten = nn.Flatten(1)
        # hidden = config.n_embd * block_size // 2
        # self.mlp1 = nn.Linear(config.n_embd * block_size, hidden)    
        # self.mlp2 = nn.Linear(hidden, hidden)    
        # self.lm_head = nn.Linear(hidden, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.model(idx)
        # x = self.wte(idx)  # token embeddings of shape (b, t, n_embd)
        # x = self.flatten(x)
        # x = self.mlp1(x)
        # x = self.mlp2(x)
        # logits = self.lm_head(x)
        loss = None
        if targets is not None:
            # loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
            loss = F.cross_entropy(logits, targets)

        return logits, loss            

In [73]:
# 26000 23.94s loss: 1.59636, train: 1.69068, val: 1.84234
# parameter: 117360
config = ModelArgs()
config.vocab_size = vocab_size
config.n_embd = 32
config.n_layer = 4
config.n_head = 4
learning_rate = 1e-3
batch_size = 16
block_size = 16

#  28000 53.24s loss: 1.59238, train: 1.60147, val: 1.75652
#  304192
# config.n_embd = 64
# config.n_layer = 6

#  26000 156.38s loss: 1.53036, train: 1.34517, val: 1.55875
#  304192
# block_size = 64
model = GPT(config)



@torch.no_grad()
def eval_mode(model, split):
    model.eval()
    xb, yb = get_batch(split, 1024 * 2, block_size)
    logits, loss = model(xb, yb)
    model.train()
    return loss.item()

print(sum(e.numel() for e in model.parameters()))

#      0 0.97s loss: 4.14254, train: 4.12606, val: 4.12622
#   2000 15.41s loss: 2.43846, train: 2.50690, val: 2.50940
#   4000 15.87s loss: 2.45041, train: 2.38253, val: 2.45002
#   6000 15.50s loss: 1.87670, train: 2.33288, val: 2.29257
#   8000 16.25s loss: 2.40117, train: 2.25745, val: 2.21347
#  10000 15.27s loss: 1.96440, train: 2.14363, val: 2.20197
#  12000 15.40s loss: 2.25821, train: 2.12671, val: 2.15990
#  14000 15.18s loss: 1.41092, train: 2.08953, val: 2.13023
#  16000 15.21s loss: 1.61853, train: 2.00345, val: 2.10484
#  18000 15.52s loss: 1.47404, train: 2.06662, val: 2.08309
#  20000 15.36s loss: 2.22606, train: 2.09605, val: 2.07258
#  22000 16.15s loss: 1.57381, train: 2.02398, val: 2.02660
#  24000 16.62s loss: 2.39012, train: 2.01963, val: 2.09455
#  26000 16.92s loss: 2.43945, train: 2.03690, val: 2.10079
#  28000 15.43s loss: 1.67814, train: 2.00688, val: 2.08757
#  30000 15.56s loss: 2.29863, train: 1.99688, val: 2.02595

# above is attention. bellow is mlp

#      0 0.01s loss: 1.86977, train: 2.19873, val: 2.28758
#   2000 1.50s loss: 2.21103, train: 2.33470, val: 2.37905
#   4000 1.52s loss: 1.96309, train: 2.25205, val: 2.36831
#   6000 1.48s loss: 1.89688, train: 2.36085, val: 2.31879
#   8000 1.46s loss: 2.42841, train: 2.29361, val: 2.38357
#  10000 1.40s loss: 2.24522, train: 2.24002, val: 2.27659
#  12000 1.51s loss: 2.24394, train: 2.17114, val: 2.24774
#  14000 1.45s loss: 2.39240, train: 2.07881, val: 2.29540
#  16000 1.48s loss: 2.45173, train: 2.23459, val: 2.23733
#  18000 1.49s loss: 2.32054, train: 2.09277, val: 2.22561
#  20000 1.86s loss: 2.58712, train: 2.12597, val: 2.26913
#  22000 1.66s loss: 2.91778, train: 2.04940, val: 2.28270
#  24000 1.64s loss: 1.76361, train: 2.06376, val: 2.24237
#  26000 1.66s loss: 1.69099, train: 2.10229, val: 2.25317
#  28000 1.50s loss: 1.59344, train: 2.10322, val: 2.25436
#  30000 1.64s loss: 2.25858, train: 2.10492, val: 2.21290

# # above wavenet
#      0 0.01s loss: 4.23850, train: 4.17033, val: 4.17057
#   2000 1.47s loss: 2.26062, train: 2.46871, val: 2.51698
#   4000 1.39s loss: 2.17279, train: 2.34896, val: 2.32339
#   6000 1.40s loss: 1.75408, train: 2.28589, val: 2.28735
#   8000 1.42s loss: 2.64250, train: 2.17526, val: 2.21579
#  10000 1.41s loss: 2.19711, train: 2.15487, val: 2.17959
#  12000 1.42s loss: 2.06995, train: 2.06976, val: 2.08725
#  14000 1.41s loss: 1.37480, train: 2.03400, val: 2.07191
#  16000 1.44s loss: 2.39997, train: 2.06084, val: 2.00076
#  18000 1.41s loss: 1.99309, train: 2.02796, val: 2.02401
#  20000 1.41s loss: 1.52541, train: 2.01666, val: 2.08847
#  22000 1.38s loss: 1.57991, train: 1.98687, val: 2.06349
#  24000 1.39s loss: 1.96693, train: 1.97387, val: 1.96752
#  26000 1.39s loss: 2.00294, train: 1.96320, val: 1.91582
#  28000 1.41s loss: 2.25782, train: 1.94247, val: 2.00578
#  30000 1.43s loss: 2.71035, train: 1.93993, val: 2.00118

49568


In [70]:
model = MLPModel2(config)
print(sum(e.numel() for e in model.parameters()))
xb, yb = get_batch('train', batch_size, block_size)
model(xb, yb)

39490


(tensor([[ 0.0130, -0.0244,  0.0519,  ..., -0.1142,  0.0245, -0.0126],
         [-0.0093, -0.0133,  0.0367,  ..., -0.1408,  0.0179, -0.0367],
         [ 0.0059, -0.0176,  0.1086,  ..., -0.1233,  0.0098, -0.0457],
         ...,
         [ 0.0048, -0.0396,  0.0417,  ..., -0.1158,  0.0273, -0.0483],
         [ 0.0119, -0.0101,  0.0295,  ..., -0.1450, -0.0306,  0.0049],
         [ 0.0098, -0.0077,  0.0573,  ..., -0.0993, -0.0086,  0.0459]],
        grad_fn=<AddmmBackward0>),
 tensor(4.1754, grad_fn=<NllLossBackward0>))

In [71]:
import time
t = time.time()
opti = torch.optim.AdamW(model.parameters(), learning_rate, fused=True)
for step in range(30001):
    for g in opti.param_groups:
        if step > 20000: g['lr'] = learning_rate / 5
        elif step > 10000: g['lr'] = learning_rate / 3

    xb, yb = get_batch('train', batch_size, block_size)
    logits, loss = model(xb, yb)
    opti.zero_grad(set_to_none=True)
    loss.backward()
    opti.step()
    if step % 2000 == 0:
        n = time.time()
        train_time = n - t
        print(f'{step:6d} {train_time:.2f}s loss: {loss.item():.5f}, train: {eval_mode(model, "train"):.5f}, val: {eval_mode(model, "val"):.5f}')
        t = n

     0 0.01s loss: 4.23850, train: 4.17033, val: 4.17057
  2000 1.47s loss: 2.26062, train: 2.46871, val: 2.51698
  4000 1.39s loss: 2.17279, train: 2.34896, val: 2.32339
  6000 1.40s loss: 1.75408, train: 2.28589, val: 2.28735
  8000 1.42s loss: 2.64250, train: 2.17526, val: 2.21579
 10000 1.41s loss: 2.19711, train: 2.15487, val: 2.17959
 12000 1.42s loss: 2.06995, train: 2.06976, val: 2.08725
 14000 1.41s loss: 1.37480, train: 2.03400, val: 2.07191
 16000 1.44s loss: 2.39997, train: 2.06084, val: 2.00076
 18000 1.41s loss: 1.99309, train: 2.02796, val: 2.02401
 20000 1.41s loss: 1.52541, train: 2.01666, val: 2.08847
 22000 1.38s loss: 1.57991, train: 1.98687, val: 2.06349
 24000 1.39s loss: 1.96693, train: 1.97387, val: 1.96752
 26000 1.39s loss: 2.00294, train: 1.96320, val: 1.91582
 28000 1.41s loss: 2.25782, train: 1.94247, val: 2.00578
 30000 1.43s loss: 2.71035, train: 1.93993, val: 2.00118


In [240]:
# 26000 23.94s loss: 1.59636, train: 1.69068, val: 1.84234
# parameter: 117360
config = ModelArgs()
config.vocab_size = vocab_size
config.n_embd = 64
config.n_layer = 4
config.n_head = 4
config.dropout = 0.1
config.rotary = True
learning_rate = 1e-3
batch_size = 64
block_size = 128

#  28000 53.24s loss: 1.59238, train: 1.60147, val: 1.75652
#  304192
# config.n_embd = 64
# config.n_layer = 6

#  26000 156.38s loss: 1.53036, train: 1.34517, val: 1.55875
#  304192
# block_size = 64
model = GPT(config)
# model = torch.compile(model)

@torch.no_grad()
def eval_mode(model, split):
    model.eval()
    xb, yb = get_batch(split, 1024 * 2, block_size)
    logits, loss = model(xb, yb)
    model.train()
    return loss.item()

print(sum(e.numel() for e in model.parameters()) / 1024 / 1024)

0.22320556640625


In [None]:
import time
t = time.time()
opti = torch.optim.AdamW(model.parameters(), learning_rate)
for step in range(10001):
    for g in opti.param_groups:
        if step > 20000: g['lr'] = learning_rate / 5
        elif step > 10000: g['lr'] = learning_rate / 3

    xb, yb = get_wk_batch(batch_size, block_size, config.device)
    logits, loss = model(xb, yb)
    opti.zero_grad(set_to_none=True)
    loss.backward()
    opti.step()
    if step % 2000 == 0:
        n = time.time()
        train_time = n - t
        print(f'{step:6d} {train_time:.2f}s loss: {loss.item():.5f}, train: {eval_mode(model, "train"):.5f}, val: {eval_mode(model, "val"):.5f}')
        t = n

In [None]:
import time
t = time.time()
opti = torch.optim.AdamW(model.parameters(), learning_rate * 0.5)
for step in range(30001):
    for g in opti.param_groups:
        if step > 20000: g['lr'] = learning_rate / 5
        elif step > 10000: g['lr'] = learning_rate / 3

    xb, yb = get_batch('train', batch_size, block_size)
    logits, loss = model(xb, yb)
    opti.zero_grad(set_to_none=True)
    loss.backward()
    opti.step()
    if step % 2000 == 0:
        n = time.time()
        train_time = n - t
        print(f'{step:6d} {train_time:.2f}s loss: {loss.item():.5f}, train: {eval_mode(model, "train"):.5f}, val: {eval_mode(model, "val"):.5f}')
        t = n

In [136]:
#  28000 27.25s loss: 1.66619, train: 1.67056, val: 1.81868

import time
t = time.time()
opti = torch.optim.AdamW(model.parameters(), learning_rate)
for step in range(30001):
    for g in opti.param_groups:
        if step > 20000: g['lr'] = learning_rate / 5
        elif step > 10000: g['lr'] = learning_rate / 3

    xb, yb = get_batch('train', batch_size, block_size)
    logits, loss = model(xb, yb)
    opti.zero_grad(set_to_none=True)
    loss.backward()
    opti.step()
    if step % 2000 == 0:
        n = time.time()
        train_time = n - t
        print(f'{step:6d} {train_time:.2f}s loss: {loss.item():.5f}, train: {eval_mode(model, "train"):.5f}, val: {eval_mode(model, "val"):.5f}')
        t = n

     0 0.12s loss: 4.20371, train: 4.10815, val: 4.10997
  2000 22.33s loss: 2.16811, train: 2.12599, val: 2.16987
  4000 19.42s loss: 2.17855, train: 1.98829, val: 2.06462
  6000 21.77s loss: 1.99858, train: 1.90744, val: 2.01406
  8000 20.54s loss: 2.08288, train: 1.87279, val: 1.99812
 10000 19.68s loss: 1.82278, train: 1.84337, val: 1.95235
 12000 19.88s loss: 1.85650, train: 1.77525, val: 1.92027
 14000 23.85s loss: 1.95274, train: 1.76716, val: 1.90403
 16000 20.85s loss: 1.89785, train: 1.75859, val: 1.91004
 18000 19.18s loss: 1.95584, train: 1.74462, val: 1.88499
 20000 18.84s loss: 1.91002, train: 1.74278, val: 1.88153
 22000 22.03s loss: 1.66406, train: 1.71974, val: 1.86783
 24000 18.64s loss: 1.92112, train: 1.70893, val: 1.87611
 26000 18.82s loss: 1.77871, train: 1.70853, val: 1.85082
 28000 18.50s loss: 1.59352, train: 1.69870, val: 1.86800
 30000 18.73s loss: 1.91003, train: 1.70339, val: 1.84661


In [140]:
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None):
    """
    Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
    the sequence max_new_tokens times, feeding the predictions back into the model each time.
    Most likely you'll want to make sure to be in model.eval() mode of operation for this.
    """
    for _ in range(max_new_tokens):
        # if the sequence context is growing too long we must crop it at block_size
        idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
        # forward the model to get the logits for the index in the sequence
        logits, _ = model(idx_cond)
        # pluck the logits at the final step and scale by desired temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = F.softmax(logits, dim=-1)
        # sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1)
        # append sampled index to the running sequence and continue
        idx = torch.cat((idx, idx_next), dim=1)
    return idx

In [142]:
print(decode(generate(model, idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


<>he my comfore of ous sweet not much should of tuken a stoold.

Fisters then have sulf minince so cha


In [None]:
    #  with rotary
     0 0.03s loss: 4.24567, train: 4.09009, val: 4.09109
  1000 25.91s loss: 1.68743, train: 1.63288, val: 1.78919
  2000 25.88s loss: 1.53839, train: 1.46762, val: 1.64870
  3000 25.79s loss: 1.47733, train: 1.40543, val: 1.59762
  4000 26.56s loss: 1.47561, train: 1.37177, val: 1.56717
  5000 26.36s loss: 1.42775, train: 1.35312, val: 1.55640
  6000 25.83s loss: 1.41559, train: 1.32865, val: 1.54277
  7000 25.83s loss: 1.39557, train: 1.32097, val: 1.52121
  8000 25.54s loss: 1.39905, train: 1.30683, val: 1.53706
  9000 26.02s loss: 1.40830, train: 1.30955, val: 1.51410
 10000 25.89s loss: 1.40223, train: 1.29421, val: 1.52741
 11000 25.19s loss: 1.40270, train: 1.29250, val: 1.51724
 12000 26.01s loss: 1.33441, train: 1.29159, val: 1.52723
 13000 26.52s loss: 1.35442, train: 1.25828, val: 1.49880
 14000 26.40s loss: 1.32233, train: 1.25585, val: 1.50987
 15000 26.66s loss: 1.35553, train: 1.25042, val: 1.49551
 16000 27.02s loss: 1.33569, train: 1.24887, val: 1.49686
 17000 26.28s loss: 1.33931, train: 1.25006, val: 1.49716
 18000 25.27s loss: 1.36694, train: 1.25032, val: 1.49718
 19000 26.49s loss: 1.31716, train: 1.24690, val: 1.49094
 20000 26.55s loss: 1.37013, train: 1.25256, val: 1.48834

# 26000 23.94s loss: 1.59636, train: 1.69068, val: 1.84234
# parameter: 117360
config = ModelArgs()
config.vocab_size = vocab_size
config.n_embd = 64
config.n_layer = 4
config.n_head = 4
config.dropout = 0.1
config.rotary = False
learning_rate = 1e-3
batch_size = 64
block_size = 128

#  28000 53.24s loss: 1.59238, train: 1.60147, val: 1.75652
#  304192
# config.n_embd = 64
# config.n_layer = 6

#  26000 156.38s loss: 1.53036, train: 1.34517, val: 1.55875
#  304192
# block_size = 64
model = GPT(config)
# model = torch.compile(model)

@torch.no_grad()
def eval_mode(model, split):
    model.eval()
    xb, yb = get_batch(split, 1024 * 2, block_size)
    logits, loss = model(xb, yb)
    model.train()
    return loss.item()

print(sum(e.numel() for e in model.parameters()) / 1024 / 1024)


# no rotary
     0 0.04s loss: 4.20762, train: 4.03342, val: 4.03646
  1000 23.65s loss: 1.88716, train: 1.79410, val: 1.92691
  2000 23.82s loss: 1.66935, train: 1.56973, val: 1.76726
  3000 24.01s loss: 1.59561, train: 1.48337, val: 1.67860
  4000 24.11s loss: 1.54473, train: 1.43455, val: 1.62868
  5000 23.89s loss: 1.49493, train: 1.40338, val: 1.59355
  6000 23.99s loss: 1.46876, train: 1.37792, val: 1.58696
  7000 23.51s loss: 1.46747, train: 1.37471, val: 1.58601
  8000 23.38s loss: 1.47176, train: 1.35034, val: 1.56929
  9000 22.87s loss: 1.44583, train: 1.34326, val: 1.57503
 10000 23.65s loss: 1.44932, train: 1.33792, val: 1.55838
 11000 23.95s loss: 1.40971, train: 1.32840, val: 1.55533
 12000 24.01s loss: 1.42878, train: 1.31682, val: 1.54599
 13000 23.26s loss: 1.40214, train: 1.29594, val: 1.53102
 14000 23.75s loss: 1.37013, train: 1.28887, val: 1.53773
 15000 23.31s loss: 1.38122, train: 1.28983, val: 1.52035
 16000 23.83s loss: 1.34553, train: 1.28327, val: 1.52100
 17000 23.41s loss: 1.39873, train: 1.27728, val: 1.52637
 18000 23.08s loss: 1.40726, train: 1.27376, val: 1.52361
 19000 22.95s loss: 1.35267, train: 1.28464, val: 1.51774
 20000 23.35s loss: 1.36085, train: 1.26649, val: 1.53216

In [148]:
!pip install pandas pyarrow


Looking in indexes: http://mirrors.aliyun.com/pypi/simple/
Collecting pandas
  Downloading http://mirrors.aliyun.com/pypi/packages/e1/0c/ad295fd74bfac85358fd579e271cded3ac969de81f62dd0142c426b9da91/pandas-2.2.3-cp312-cp312-macosx_11_0_arm64.whl (11.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.4/11.4 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting pyarrow
  Downloading http://mirrors.aliyun.com/pypi/packages/6a/50/12829e7111b932581e51dda51d5cb39207a056c30fe31ef43f14c63c4d7e/pyarrow-18.1.0-cp312-cp312-macosx_12_0_arm64.whl (29.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.5/29.5 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting pytz>=2020.1 (from pandas)
  Downloading http://mirrors.aliyun.com/pypi/packages/11/c3/005fcca25ce078d2cc29fd559379817424e94885510568bc1bc53d7d5846/pytz-2024.2-py2.py3-none-any.whl (508 kB)
Collecting tzdata>=2022.7 (from pandas)
  Downloading http://mi

In [237]:
import pandas as pd
import pyarrow.parquet as pq

file =  '/Users/admin/workspace/nn2/train-00000-of-00001.parquet'
table = pq.read_table(file)

wk_chars = set()
all_txts = []
for i in table.to_pylist():
    txt = i['text'].strip()
    if len(txt) > 130:
        cs = set(txt)
        wk_chars.update(cs)
        all_txts.append(txt)

wk_chars.add('\n')
u_chars = sorted(list(wk_chars))
import math
vocab_size = math.ceil((len(u_chars) + 1) / 32) * 32 #
# create a mapping from characters to integers
stoi = { ch:i + 1 for i,ch in enumerate(u_chars) }
stoi['<>'] = 0
itos = { i:ch for ch,i in stoi.items() }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

device = 'cuda' if torch.cuda.is_available() else 'cpu'

wk_data = []
for txt in all_txts:
    data = torch.tensor(encode(txt), dtype=torch.long)
    wk_data.append(data)

print('all pretraining: ', sum(len(d) for d in wk_data))

import random
def get_wk_batch(batch_size, block_size, device):
    # generate a small batch of data of inputs x and targets y
    ix = torch.randint(len(wk_data), (batch_size,))

    xs, ys = [], []
    for i in ix:
        d = wk_data[i]
        idx = random.randint(0, len(d) - block_size - 1)
        xs.append(d[idx:idx+block_size])
        ys.append(d[idx+1:idx+1+block_size])
    return torch.stack(xs).to(device), torch.stack(ys).to(device)

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
data = torch.tensor(encode(text), dtype=torch.long).to(device)
print('all shake data:  ', len(data))

# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

def get_batch(split, batch_size, block_size):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

all pretraining:  10393085
all shake data:   1115394


288

In [230]:
x, y = get_wk_batch(12012, 128, 'cpu')
x.shape, y.shape, len(wk_data)


(torch.Size([12012, 128]), torch.Size([12012, 128]), 14947)

In [65]:
n_embd = 32
model = nn.Sequential(
            nn.Embedding(config.vocab_size, n_embd),
            Flatten(2), nn.Linear(n_embd * 2, n_embd * 2), nn.Tanh(),
            Flatten(2), nn.Linear(n_embd * 4, n_embd * 2), nn.Tanh(),
            Flatten(2), nn.Linear(n_embd * 4, n_embd * 2), nn.Tanh(),
            Flatten(2), nn.Linear(n_embd * 4, n_embd * 2), nn.Tanh(),
            nn.Linear(n_embd * 2, n_embd * 2),
            nn.Tanh(),
            nn.Linear(n_embd * 2, vocab_size)
        )

x, y = get_batch('train', batch_size, block_size)

# model = nn.Embedding(config.vocab_size, config.n_embd)

x = model(x)
x.shape

sum(p.numel() for p in model.parameters())
# f = Flatten(2)
# f2 = nn.Linear(n_embd * 2, n_embd * 2)
# f2(f(x))

39490