In [1]:
import sys
import os
import torch 
from pathlib import Path

def get_project_info() -> Path:
  current = Path.cwd().resolve()
  root = current
  for parent in [current, *current.parents]:
    if (parent / "toy_transformers").exists():
      root = parent
      break
  return root, current

if 'ROOT_DIR' not in globals():
	ROOT_DIR, EXPERIMENT_DIR = get_project_info()
	if str(ROOT_DIR) not in sys.path:
		sys.path.append(str(ROOT_DIR))
	if Path.cwd() != ROOT_DIR:
		os.chdir(ROOT_DIR)

from toy_transformers.models import gptv1
from toy_transformers import tokenization

In [2]:
VOCAB_SIZE = 256
BATCH_SIZE = 16
MODE = tokenization.TokenizationMode.STR
DEVICE = "mps"

config = gptv1.GPTv1Config(
	vocab_size=VOCAB_SIZE,
	block_size=256,
)

In [3]:
vocab_path = EXPERIMENT_DIR / f"vocab_{VOCAB_SIZE}.json"
raw_data_path = ROOT_DIR / "data/gutenberg/freud-interpretation-of-dreams.txt"

if not vocab_path.exists():
	raw_data = open(raw_data_path, "r")
	vocab = tokenization.create_bpe(
		raw_data, 
		VOCAB_SIZE, MODE
	)
	vocab.save(vocab_path)
else:
	vocab = tokenization.Vocabulary.load(vocab_path)

In [4]:
data = torch.tensor(
	vocab.encode(open(raw_data_path, "r").read()),
	dtype=torch.long
).to(device=DEVICE)

n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

block_size, batch_size = config.block_size, BATCH_SIZE
def get_batch(split):
  data = train_data if split == 'train' else val_data
  idxs = torch.randint(len(data) - block_size, (batch_size,), device=DEVICE)
  x = torch.stack([data[i:i+block_size] for i in idxs])
  y = torch.stack([data[i+1:i+block_size+1] for i in idxs])
  return x, y

@torch.no_grad()
def estimate_val_loss(model):
  model.eval()
  X, Y = get_batch("val")
  _, loss = model(X, Y)
  model.train()
  return loss.item()

In [5]:
torch.set_float32_matmul_precision("medium")
m = gptv1.LanguageModel(config).to(device=DEVICE)
m.compile()

optimizer = torch.optim.AdamW(m.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
  optimizer,
  mode='min',
  factor=0.1,
  patience=10
)

import time
from torch.amp import autocast
from tqdm import tqdm

In [6]:
for steps in range(4000):
	xb, yb = get_batch('train')
	with autocast(device_type=DEVICE, dtype=torch.float16):
		logits, loss = m(xb, yb)
	optimizer.zero_grad(set_to_none=True)
	loss.backward()
	optimizer.step()
	train_loss, val_loss = loss.item(), None
	if steps % 50 == 0:
		val_loss = estimate_val_loss(m)
		scheduler.step(val_loss)
	
	if steps % 25 == 0:
		print(steps, train_loss, val_loss)

W0217 19:07:53.402000 1495 torch/_inductor/utils.py:1436] [0/0] Not enough SMs to use max_autotune_gemm mode


InductorError: SyntaxError: failed to compile 
    #include <c10/metal/utils.h>
    #include <c10/metal/reduction_utils.h>
    [[max_total_threads_per_threadgroup(288)]]
    kernel void generated_kernel(
        device float* out_ptr2,
        device float* out_ptr5,
        device half* out_ptr6,
        constant half* in_ptr0,
        constant half* in_ptr1,
        constant half* in_ptr2,
        constant half* in_ptr3,
        constant half* in_ptr4,
        constant half* in_ptr5,
        constant half* in_ptr6,
        constant half* in_ptr7,
        constant half* in_ptr8,
        constant half* in_ptr9,
        constant half* in_ptr10,
        constant half* in_ptr11,
        constant half* in_ptr12,
        constant half* in_ptr13,
        constant half* in_ptr14,
        constant half* in_ptr15,
        constant half* in_ptr16,
        constant half* in_ptr17,
        constant half* in_ptr18,
        constant half* in_ptr19,
        constant half* in_ptr20,
        constant half* in_ptr21,
        constant half* in_ptr22,
        constant half* in_ptr23,
        constant float* in_ptr24,
        constant float* in_ptr25,
        constant float* in_ptr26,
        constant float* in_ptr27,
        constant bool* in_ptr28,
        uint2 thread_pos [[thread_position_in_grid]],
        uint2 group_pos [[thread_position_in_threadgroup]]
    ) {
        auto xindex = thread_pos.x;
        auto r0_index = thread_pos.y;
        int r0_1 = r0_index;
        int x0 = xindex;
        threadgroup float tmp_acc_0[9];
        threadgroup float tmp_acc_1[9];
        auto tmp0 = static_cast<float>(in_ptr0[r0_1 + 288*x0]);
        auto tmp2 = static_cast<float>(in_ptr1[r0_1 + 288*x0]);
        auto tmp5 = static_cast<float>(in_ptr2[r0_1 + 288*x0]);
        auto tmp8 = static_cast<float>(in_ptr3[r0_1 + 288*x0]);
        auto tmp11 = static_cast<float>(in_ptr4[r0_1 + 288*x0]);
        auto tmp14 = static_cast<float>(in_ptr5[r0_1 + 288*x0]);
        auto tmp17 = static_cast<float>(in_ptr6[r0_1 + 288*x0]);
        auto tmp20 = static_cast<float>(in_ptr7[r0_1 + 288*x0]);
        auto tmp23 = static_cast<float>(in_ptr8[r0_1 + 288*x0]);
        auto tmp26 = static_cast<float>(in_ptr9[r0_1 + 288*x0]);
        auto tmp29 = static_cast<float>(in_ptr10[r0_1 + 288*x0]);
        auto tmp32 = static_cast<float>(in_ptr11[r0_1 + 288*x0]);
        auto tmp35 = static_cast<float>(in_ptr12[r0_1 + 288*x0]);
        auto tmp38 = static_cast<float>(in_ptr13[r0_1 + 288*x0]);
        auto tmp41 = static_cast<float>(in_ptr14[r0_1 + 288*x0]);
        auto tmp44 = static_cast<float>(in_ptr15[r0_1 + 288*x0]);
        auto tmp47 = static_cast<float>(in_ptr16[r0_1 + 288*x0]);
        auto tmp50 = static_cast<float>(in_ptr17[r0_1 + 288*x0]);
        auto tmp53 = static_cast<float>(in_ptr18[r0_1 + 288*x0]);
        auto tmp56 = static_cast<float>(in_ptr19[r0_1 + 288*x0]);
        auto tmp59 = static_cast<float>(in_ptr20[r0_1 + 288*x0]);
        auto tmp62 = static_cast<float>(in_ptr21[r0_1 + 288*x0]);
        auto tmp65 = static_cast<float>(in_ptr22[r0_1 + 288*x0]);
        auto tmp68 = static_cast<float>(in_ptr23[r0_1 + 288*x0]);
        auto tmp71 = in_ptr24[r0_1];
        auto tmp74 = in_ptr25[r0_1 + 288*x0];
        auto tmp1 = static_cast<float>(tmp0);
        auto tmp3 = static_cast<float>(tmp2);
        auto tmp4 = tmp1 + tmp3;
        auto tmp6 = static_cast<float>(tmp5);
        auto tmp7 = tmp4 + tmp6;
        auto tmp9 = static_cast<float>(tmp8);
        auto tmp10 = tmp7 + tmp9;
        auto tmp12 = static_cast<float>(tmp11);
        auto tmp13 = tmp10 + tmp12;
        auto tmp15 = static_cast<float>(tmp14);
        auto tmp16 = tmp13 + tmp15;
        auto tmp18 = static_cast<float>(tmp17);
        auto tmp19 = tmp16 + tmp18;
        auto tmp21 = static_cast<float>(tmp20);
        auto tmp22 = tmp19 + tmp21;
        auto tmp24 = static_cast<float>(tmp23);
        auto tmp25 = tmp22 + tmp24;
        auto tmp27 = static_cast<float>(tmp26);
        auto tmp28 = tmp25 + tmp27;
        auto tmp30 = static_cast<float>(tmp29);
        auto tmp31 = tmp28 + tmp30;
        auto tmp33 = static_cast<float>(tmp32);
        auto tmp34 = tmp31 + tmp33;
        auto tmp36 = static_cast<float>(tmp35);
        auto tmp37 = tmp34 + tmp36;
        auto tmp39 = static_cast<float>(tmp38);
        auto tmp40 = tmp37 + tmp39;
        auto tmp42 = static_cast<float>(tmp41);
        auto tmp43 = tmp40 + tmp42;
        auto tmp45 = static_cast<float>(tmp44);
        auto tmp46 = tmp43 + tmp45;
        auto tmp48 = static_cast<float>(tmp47);
        auto tmp49 = tmp46 + tmp48;
        auto tmp51 = static_cast<float>(tmp50);
        auto tmp52 = tmp49 + tmp51;
        auto tmp54 = static_cast<float>(tmp53);
        auto tmp55 = tmp52 + tmp54;
        auto tmp57 = static_cast<float>(tmp56);
        auto tmp58 = tmp55 + tmp57;
        auto tmp60 = static_cast<float>(tmp59);
        auto tmp61 = tmp58 + tmp60;
        auto tmp63 = static_cast<float>(tmp62);
        auto tmp64 = tmp61 + tmp63;
        auto tmp66 = static_cast<float>(tmp65);
        auto tmp67 = tmp64 + tmp66;
        auto tmp69 = static_cast<float>(tmp68);
        auto tmp70 = tmp67 + tmp69;
        out_ptr2[r0_1 + 288*x0] = static_cast<float>(tmp70);
        auto tmp72 = tmp70 * tmp71;
        auto tmp75 = tmp72 * tmp74;
        auto tmp73 = c10::metal::threadgroup_sum(tmp_acc_0, tmp72, r0_index * 1, 288);
        auto tmp76 = c10::metal::threadgroup_sum(tmp_acc_1, tmp75, r0_index * 1, 288);
        auto tmp77 = in_ptr26[r0_1 + 288*x0];
        auto tmp78 = in_ptr27[x0];
        auto tmp87 = in_ptr28[r0_1 + 288*x0];
        auto tmp79 = 288.0;
        auto tmp80 = tmp72 * tmp79;
        auto tmp81 = tmp80 - tmp73;
        auto tmp82 = tmp74 * tmp76;
        auto tmp83 = tmp81 - tmp82;
        auto tmp84 = tmp78 * tmp83;
        auto tmp85 = tmp77 + tmp84;
        out_ptr5[r0_1 + 288*x0] = static_cast<float>(tmp85);
        auto tmp86 = static_cast<half>(tmp85);
        auto tmp88 = static_cast<half>(tmp87);
        auto tmp89 = 1.25;
        auto tmp90 = tmp88 * tmp89;
        auto tmp91 = tmp86 * tmp90;
        out_ptr6[r0_1 + 288*x0] = static_cast<half>(tmp91);
    }
 with program_source:589:24: error: no 'buffer' resource location available for 'in_ptr28'
        constant bool* in_ptr28,
                       ^


In [None]:
idx = torch.tensor([vocab.encode("The mind ")], dtype=torch.long, device=DEVICE)
print(idx)
print("The mind ", end="", flush=True)
for token in m.generate(idx, max_new_tokens=200):
	print(vocab.decode([token.item()])[0], end="", flush=True)
print()

tensor([[ 36, 126, 145, 129,  47]], device='mps:0')
The mind upill Inauentally noses which arawary noking a clf acou reccrfLef are implapodreams po with repecturness

Iftersompar

KeyboardInterrupt: 