In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import tiktoken
import os
import subprocess

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if device == 'cuda' and torch.cuda.is_bf16_supported():
    dtype = torch.bfloat16
elif device == 'cuda':
    dtype = torch.float16
else:
    dtype = torch.float32

print(f"Using: {device}, dtype: {dtype}")

# detect environment
on_colab = 'COLAB_RELEASE_TAG' in os.environ
on_lambda = os.path.exists('/home/ubuntu') and not on_colab

if on_colab:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    data_dir = '/content/drive/MyDrive'
    print("Running on Colab")
elif on_lambda:
    data_dir = '/home/ubuntu'
    print("Running on Lambda")
else:
    data_dir = '.'
    print("Running locally")

# config
lr = 3e-4
train_iter = 1000000
batch_size = 32
dim = 512
n_heads = 8
head_dim = dim // n_heads  # 64
n_layers = 6
mlp_dim = 1024
ctx_len = 128
plot_every = 5000

# load data — WikiText-103 (Wikipedia articles, ~100M tokens)
# cached as a tensor so it only downloads and tokenizes once
cache_path = os.path.join(data_dir, 'wikitext103_tokens.pt')

if os.path.exists(cache_path):
    print("Loading cached tokens...")
    tokens = torch.load(cache_path, map_location=device).long()
else:
    print("Downloading WikiText-103 and tokenizing (one-time)...")
    import pyarrow.parquet as pq

    # download parquet files from HuggingFace
    base_url = "https://huggingface.co/api/datasets/Salesforce/wikitext/parquet/wikitext-103-raw-v1/train"
    parquet_files = []
    for i in range(2):
        fname = os.path.join(data_dir, f"wikitext_train_{i}.parquet")
        if not os.path.exists(fname):
            url = f"{base_url}/{i}.parquet"
            print(f"  downloading {url}...")
            subprocess.run(["wget", "-q", "-O", fname, url], check=True)
        parquet_files.append(fname)

    # tokenize in chunks to avoid giant Python list (~2.8GB saved)
    enc = tiktoken.get_encoding("cl100k_base")
    chunk_size = 50000
    chunks = []
    for pf in parquet_files:
        table = pq.read_table(pf)
        texts = table.column("text").to_pylist()
        batch = []
        for j, line in enumerate(texts):
            if line.strip():
                batch.extend(enc.encode(line))
            if len(batch) >= chunk_size:
                chunks.append(torch.tensor(batch, dtype=torch.int32))
                batch = []
            if j % 100000 == 0:
                print(f"  tokenized {j}/{len(texts)} rows from {os.path.basename(pf)}...")
        if batch:
            chunks.append(torch.tensor(batch, dtype=torch.int32))
        del texts, table

    tokens = torch.cat(chunks).to(device=device).long()
    del chunks
    torch.save(tokens.cpu().to(torch.int32), cache_path)
    print(f"Saved {len(tokens)} tokens to {cache_path}")

    # clean up parquet files
    for pf in parquet_files:
        os.remove(pf)

enc = tiktoken.get_encoding("cl100k_base")
encode = lambda s: enc.encode(s)
decode = lambda l: enc.decode(l)

vocab_size = tokens.max().item() + 1
print(f"Vocab size: {vocab_size}, dim: {dim}, heads: {n_heads}, layers: {n_layers}")
print(f"Total tokens: {len(tokens)}")

emb = torch.randn(vocab_size, dim, device=device, dtype=dtype) * 0.02
pos = torch.randn(ctx_len, dim, device=device, dtype=dtype) * 0.02

init_scale = 0.1 / (2 ** 0.5)
wq = [torch.randn(dim, dim, device=device, dtype=dtype) * init_scale for _ in range(n_layers)]
wk = [torch.randn(dim, dim, device=device, dtype=dtype) * init_scale for _ in range(n_layers)]
wv = [torch.randn(dim, dim, device=device, dtype=dtype) * init_scale for _ in range(n_layers)]
wo = [torch.randn(dim, dim, device=device, dtype=dtype) * init_scale for _ in range(n_layers)]
w1 = [torch.randn(dim, mlp_dim, device=device, dtype=dtype) * init_scale for _ in range(n_layers)]
w2 = [torch.randn(mlp_dim, dim, device=device, dtype=dtype) * init_scale for _ in range(n_layers)]

params = [emb, pos] + wq + wk + wv + wo + w1 + w2
for p in params:
    p.requires_grad = True

def rmsnorm(x, eps=1e-5):
    return x / ((x ** 2).mean(dim=-1, keepdim=True).sqrt() + eps)

def forward(x):
    B, T = x.shape
    x = emb[x] + pos[:T]

    for layer in range(n_layers):
        nx = rmsnorm(x)
        q = (nx @ wq[layer]).view(B, T, n_heads, head_dim).transpose(1, 2)
        k = (nx @ wk[layer]).view(B, T, n_heads, head_dim).transpose(1, 2)
        v = (nx @ wv[layer]).view(B, T, n_heads, head_dim).transpose(1, 2)
        attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True).transpose(1, 2).reshape(B, T, dim)
        x = x + attn_out @ wo[layer]

        nx = rmsnorm(x)
        out = (nx @ w1[layer]).relu() @ w2[layer]
        x = x + out
        
    x = rmsnorm(x)
    x = x @ emb.T
    return x

if hasattr(torch, 'compile'):
    forward = torch.compile(forward)
    print("Using torch.compile")

n_tokens = len(tokens)


Using: cpu, dtype: torch.float32
Running locally
Loading cached tokens...
Vocab size: 100256, dim: 512, heads: 8, layers: 6
Total tokens: 120196189
Using torch.compile


In [4]:
# load saved weights to resume training (skip this cell to train from scratch)
weights_path = os.path.join(data_dir, 'weights_tiktoken.pt')
ckpt = torch.load(weights_path, map_location=device)
emb.data = ckpt['emb'].to(dtype=dtype)
pos.data = ckpt['pos'].to(dtype=dtype)
for i in range(n_layers):
    wq[i].data = ckpt['wq'][i].to(dtype=dtype)
    wk[i].data = ckpt['wk'][i].to(dtype=dtype)
    wv[i].data = ckpt['wv'][i].to(dtype=dtype)
    # wo[i].data = ckpt['wo'][i].to(dtype=dtype)
    w1[i].data = ckpt['w1'][i].to(dtype=dtype)
    w2[i].data = ckpt['w2'][i].to(dtype=dtype)
print(f"Loaded weights from {weights_path}")

Loaded weights from ./weights_tiktoken.pt


In [None]:
# train
opt = torch.optim.Adam(params, lr=lr, fused=True)
offsets = torch.arange(ctx_len, device=device)

for i in range(train_iter):
    idx = torch.randint(0, n_tokens - ctx_len, (batch_size,), device=device)
    seqs = tokens[idx.unsqueeze(1) + offsets]
    x = seqs[:, :-1]
    y = seqs[:, 1:]
    logits = forward(x)
    loss = F.cross_entropy(logits.reshape(-1, vocab_size), y.reshape(-1))
    opt.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(params, 1.0)
    opt.step()
    if i % 100 == 0:
        print(f"{i}: loss={loss.item():.2f}")


In [6]:
# generate
ctx = "most important"
tokens = encode(ctx)
print(tokens)

for _ in range(10):
    x = torch.tensor([tokens[-(ctx_len-1):]], device=device)
    logits = forward(x)
    probs = F.softmax(logits[0, -1] / 0.8, dim=-1)
    next_token = torch.multinomial(probs, 1).item()
    tokens.append(next_token)

print(decode(tokens))

[3646, 3062]
most important 추가Jo.ST Reading.series і grandson'M Bearingilling


In [None]:
# save weights (persists across restarts)
weights_path = os.path.join(data_dir, 'weights_tiktoken2.pt')
torch.save({
    'emb': emb.data, 'pos': pos.data,
    'wq': [w.data for w in wq], 'wk': [w.data for w in wk],
    'wv': [w.data for w in wv], 'wo': [w.data for w in wo],
    'w1': [w.data for w in w1], 'w2': [w.data for w in w2],
}, weights_path)

print(f"Saved weights to {weights_path}")