<a href="https://colab.research.google.com/github/tae-h-yang/cs229/blob/main/CS_229_Transformers_in_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Remember to make a copy of this colab notebook before you start editing cells!

In [1]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.3.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.1-py3-none-any.whl (484 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m484.9/484.9 kB[0m [31m20.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading

In [2]:
!pip install tqdm



In [3]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-02-20 03:50:31--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-02-20 03:50:31 (47.1 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
# DO NOT MODIFY ANY OF THIS CODE

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

In [5]:
# DO NOT MODIFY ANY OF THIS CODE

# Hyperparameters
batch_sz = 16
context_length = 32
max_iterations = 30000
log_interval = 200
init_lr = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_steps = 200
embedding_dim = 64
num_heads = 4
num_blocks = 4
drop_prob = 0.0

# Load and prepare the data
torch.manual_seed(1337)
with open('input.txt', 'r', encoding='utf-8') as file:
    text_data = file.read()

unique_chars = sorted(set(text_data))
vocab_size = len(unique_chars)
char_to_index = {ch: i for i, ch in enumerate(unique_chars)}
index_to_char = {i: ch for i, ch in enumerate(unique_chars)}

def encode_text(s): return [char_to_index[c] for c in s]
def decode_text(l): return ''.join([index_to_char[i] for i in l])

# Split data for training and validation
data_tensor = torch.tensor(encode_text(text_data), dtype=torch.long)
train_size = int(0.9 * len(data_tensor))
train_data, val_data = data_tensor[:train_size], data_tensor[train_size:]

def generate_batch(split):
    data_src = train_data if split == 'train' else val_data
    indices = torch.randint(0, len(data_src) - context_length, (batch_sz,))
    inputs = torch.stack([data_src[i:i + context_length] for i in indices])
    targets = torch.stack([data_src[i + 1:i + context_length + 1] for i in indices])
    return inputs.to(device), targets.to(device)

@torch.no_grad()
def evaluate_loss():
    model.eval()
    losses = {'train': [], 'val': []}
    for split in ['train', 'val']:
        for _ in range(eval_steps):
            batch_x, batch_y = generate_batch(split)
            _, batch_loss = model(batch_x, batch_y)
            losses[split].append(batch_loss.item())
    model.train()
    return {split: torch.tensor(losses[split]).mean().item() for split in losses}

In [9]:
# YOU WILL CHANGE CODE IN THIS CELL
# Implement a Transformer model with PyTorch. Fill out the provided skeleton.

class SelfAttention(nn.Module):
    def __init__(self, head_dim):
        super().__init__()
        # TODO: initialize key, query, and value as linear layers. Set bias=False
        # self.key_proj = ...
        # self.query_proj = ...
        # self.value_proj = ...
        self.key_proj = nn.Linear(embedding_dim, head_dim, bias=False)
        self.query_proj = nn.Linear(embedding_dim, head_dim, bias=False)
        self.value_proj = nn.Linear(embedding_dim, head_dim, bias=False)

        self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x):
        # B, T, C = ...
        # keys, queries, values = ...
        B, T, C = x.shape
        keys = self.key_proj(x)
        queries = self.query_proj(x)
        values = self.value_proj(x)

        scores = (queries @ keys.transpose(-2, -1)) * (C ** -0.5)
        scores = scores.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        # TODO: apply softmax and dropout
        # attention_weights = ...
        # attention_weights = ...
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        return attention_weights @ values

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, num_heads, head_dim):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttention(head_dim) for _ in range(num_heads)])
        self.output_proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x):
        # TODO: combine multiple attention heads
        # x = ...
        x = torch.cat([head(x) for head in self.heads], dim=-1)

        return self.dropout(self.output_proj(x))

class FeedForwardLayer(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(emb_dim, 4 * emb_dim),
            nn.ReLU(),
            nn.Linear(4 * emb_dim, emb_dim),
            nn.Dropout(drop_prob)
        )

    def forward(self, x):
        return self.layers(x)

class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super().__init__()
        # TODO: initialize the multihead self attention, feed forward layer, and two layernorms
        # self.attention = ...
        # self.feed_forward = ...
        # self.norm1 = ...
        # self.norm2 = ...
        self.attention = MultiHeadSelfAttention(num_heads, emb_dim // num_heads)
        self.feed_forward = FeedForwardLayer(emb_dim)
        self.norm1 = nn.LayerNorm(emb_dim)
        self.norm2 = nn.LayerNorm(emb_dim)

    def forward(self, x):
        # TODO: implement the forward logic
        # x = ...
        # x = ...
        x = x + self.attention(self.norm1(x))
        x = x + self.feed_forward(self.norm2(x))

        return x

class TransformerLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # self.token_embeddings = ...
        # self.position_embeddings = ...
        self.token_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.position_embeddings = nn.Embedding(context_length, embedding_dim)

        self.transformer_blocks = nn.Sequential(*[TransformerBlock(embedding_dim, num_heads) for _ in range(num_blocks)])
        self.final_norm = nn.LayerNorm(embedding_dim)
        self.head = nn.Linear(embedding_dim, vocab_size)

    def forward(self, idx, targets=None):
        # B, T = ...
        # tok_emb = ...
        B, T = idx.shape
        tok_emb = self.token_embeddings(idx)

        pos_emb = self.position_embeddings(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.final_norm(self.transformer_blocks(x))
        logits = self.head(x)

        if targets is None:
            return logits, None

        logits = logits.view(B * T, vocab_size)
        # targets = ...
        # loss = ...
        targets = targets.view(B * T)
        loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate_text(self, idx, max_tokens):
        for _ in range(max_tokens):
            idx_cond = idx[:, -context_length:]
            logits, _ = self(idx_cond)
            probs = F.softmax(logits[:, -1, :], dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_token), dim=1)
        return idx

In [10]:
# DO NOT MODIFY ANY OF THIS CODE

# Initialize and train the model
model = TransformerLanguageModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=init_lr)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
best_val_loss = float('inf')
no_progress = 0
max_patience = 10

for step in tqdm(range(max_iterations)):
    if step % log_interval == 0 or step == max_iterations - 1:
        current_losses = evaluate_loss()
        print(f"Step {step}: train loss {current_losses['train']:.4f}, val loss {current_losses['val']:.4f}")
        scheduler.step(current_losses['val'])

        if current_losses['val'] < best_val_loss:
            best_val_loss = current_losses['val']
            no_progress = 0
        else:
            no_progress += 1

    batch_x, batch_y = generate_batch('train')
    _, batch_loss = model(batch_x, batch_y)
    optimizer.zero_grad()
    batch_loss.backward()
    optimizer.step()

start_context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_output = decode_text(model.generate_text(start_context, max_tokens=1000)[0].tolist())

with open("output.txt", "w", encoding="utf-8") as out_file:
    out_file.write(generated_output)

  0%|          | 0/30000 [00:00<?, ?it/s]

Step 0: train loss 4.3551, val loss 4.3576


  1%|          | 206/30000 [00:11<1:31:32,  5.42it/s]

Step 200: train loss 2.5030, val loss 2.4985


  1%|▏         | 405/30000 [00:18<1:29:20,  5.52it/s]

Step 400: train loss 2.3475, val loss 2.3689


  2%|▏         | 605/30000 [00:26<1:29:21,  5.48it/s]

Step 600: train loss 2.2566, val loss 2.2906


  3%|▎         | 805/30000 [00:33<1:46:28,  4.57it/s]

Step 800: train loss 2.1563, val loss 2.1929


  3%|▎         | 1005/30000 [00:40<1:28:04,  5.49it/s]

Step 1000: train loss 2.0935, val loss 2.1200


  4%|▍         | 1206/30000 [00:48<1:17:02,  6.23it/s]

Step 1200: train loss 2.0508, val loss 2.0911


  5%|▍         | 1406/30000 [00:55<1:03:50,  7.46it/s]

Step 1400: train loss 1.9998, val loss 2.0577


  5%|▌         | 1606/30000 [01:03<1:04:25,  7.35it/s]

Step 1600: train loss 1.9576, val loss 2.0349


  6%|▌         | 1806/30000 [01:10<1:15:28,  6.23it/s]

Step 1800: train loss 1.9225, val loss 2.0102


  7%|▋         | 2006/30000 [01:17<1:01:48,  7.55it/s]

Step 2000: train loss 1.8955, val loss 1.9902


  7%|▋         | 2209/30000 [01:25<1:04:10,  7.22it/s]

Step 2200: train loss 1.8752, val loss 1.9816


  8%|▊         | 2404/30000 [01:32<1:30:00,  5.11it/s]

Step 2400: train loss 1.8328, val loss 1.9607


  9%|▊         | 2605/30000 [01:39<1:22:41,  5.52it/s]

Step 2600: train loss 1.8207, val loss 1.9434


  9%|▉         | 2808/30000 [01:47<1:15:36,  5.99it/s]

Step 2800: train loss 1.7919, val loss 1.9287


 10%|█         | 3008/30000 [01:54<1:00:30,  7.43it/s]

Step 3000: train loss 1.7750, val loss 1.9227


 11%|█         | 3209/30000 [02:01<59:28,  7.51it/s]  

Step 3200: train loss 1.7715, val loss 1.9120


 11%|█▏        | 3404/30000 [02:09<1:33:32,  4.74it/s]

Step 3400: train loss 1.7499, val loss 1.9202


 12%|█▏        | 3606/30000 [02:16<59:04,  7.45it/s]  

Step 3600: train loss 1.7497, val loss 1.8880


 13%|█▎        | 3806/30000 [02:23<1:13:38,  5.93it/s]

Step 3800: train loss 1.7309, val loss 1.8781


 13%|█▎        | 4006/30000 [02:30<58:01,  7.47it/s]  

Step 4000: train loss 1.7252, val loss 1.8656


 14%|█▍        | 4206/30000 [02:38<57:18,  7.50it/s]  

Step 4200: train loss 1.7075, val loss 1.8692


 15%|█▍        | 4406/30000 [02:46<1:08:19,  6.24it/s]

Step 4400: train loss 1.6996, val loss 1.8517


 15%|█▌        | 4606/30000 [02:53<56:23,  7.50it/s]  

Step 4600: train loss 1.6912, val loss 1.8533


 16%|█▌        | 4806/30000 [03:00<59:59,  7.00it/s]  

Step 4800: train loss 1.6825, val loss 1.8515


 17%|█▋        | 5005/30000 [03:07<1:01:13,  6.80it/s]

Step 5000: train loss 1.6706, val loss 1.8429


 17%|█▋        | 5205/30000 [03:15<1:15:32,  5.47it/s]

Step 5200: train loss 1.6643, val loss 1.8291


 18%|█▊        | 5405/30000 [03:22<1:28:29,  4.63it/s]

Step 5400: train loss 1.6550, val loss 1.8325


 19%|█▊        | 5606/30000 [03:29<53:45,  7.56it/s]  

Step 5600: train loss 1.6547, val loss 1.8318


 19%|█▉        | 5807/30000 [03:37<54:43,  7.37it/s]  

Step 5800: train loss 1.6476, val loss 1.8131


 20%|██        | 6006/30000 [03:44<1:01:45,  6.47it/s]

Step 6000: train loss 1.6391, val loss 1.8007


 21%|██        | 6208/30000 [03:51<53:04,  7.47it/s]  

Step 6200: train loss 1.6267, val loss 1.7968


 21%|██▏       | 6408/30000 [04:00<1:15:29,  5.21it/s]

Step 6400: train loss 1.6264, val loss 1.7943


 22%|██▏       | 6608/30000 [04:07<52:58,  7.36it/s]  

Step 6600: train loss 1.6127, val loss 1.7926


 23%|██▎       | 6808/30000 [04:14<51:12,  7.55it/s]  

Step 6800: train loss 1.6233, val loss 1.7933


 23%|██▎       | 7008/30000 [04:22<1:01:11,  6.26it/s]

Step 7000: train loss 1.6157, val loss 1.7951


 24%|██▍       | 7208/30000 [04:29<51:12,  7.42it/s]  

Step 7200: train loss 1.6074, val loss 1.7882


 25%|██▍       | 7407/30000 [04:36<57:15,  6.58it/s]  

Step 7400: train loss 1.6057, val loss 1.7899


 25%|██▌       | 7607/30000 [04:43<49:59,  7.46it/s]  

Step 7600: train loss 1.6064, val loss 1.7746


 26%|██▌       | 7805/30000 [04:51<1:07:42,  5.46it/s]

Step 7800: train loss 1.5955, val loss 1.7804


 27%|██▋       | 8005/30000 [04:58<1:20:15,  4.57it/s]

Step 8000: train loss 1.6026, val loss 1.7516


 27%|██▋       | 8206/30000 [05:07<1:01:45,  5.88it/s]

Step 8200: train loss 1.5900, val loss 1.7678


 28%|██▊       | 8406/30000 [05:14<48:19,  7.45it/s]  

Step 8400: train loss 1.5890, val loss 1.7792


 29%|██▊       | 8606/30000 [05:22<56:20,  6.33it/s]  

Step 8600: train loss 1.5800, val loss 1.7557


 29%|██▉       | 8806/30000 [05:29<48:13,  7.32it/s]  

Step 8800: train loss 1.5882, val loss 1.7651


 30%|███       | 9006/30000 [05:36<53:00,  6.60it/s]  

Step 9000: train loss 1.5870, val loss 1.7448


 31%|███       | 9206/30000 [05:43<46:41,  7.42it/s]  

Step 9200: train loss 1.5716, val loss 1.7471


 31%|███▏      | 9404/30000 [05:51<1:10:26,  4.87it/s]

Step 9400: train loss 1.5737, val loss 1.7366


 32%|███▏      | 9607/30000 [05:59<54:27,  6.24it/s]  

Step 9600: train loss 1.5664, val loss 1.7475


 33%|███▎      | 9807/30000 [06:06<45:27,  7.40it/s]  

Step 9800: train loss 1.5630, val loss 1.7653


 33%|███▎      | 10008/30000 [06:13<44:15,  7.53it/s]  

Step 10000: train loss 1.5668, val loss 1.7462


 34%|███▍      | 10203/30000 [06:21<1:07:51,  4.86it/s]

Step 10200: train loss 1.5688, val loss 1.7350


 35%|███▍      | 10408/30000 [06:28<44:05,  7.40it/s]

Step 10400: train loss 1.5556, val loss 1.7291


 35%|███▌      | 10609/30000 [06:36<54:36,  5.92it/s]  

Step 10600: train loss 1.5591, val loss 1.7334


 36%|███▌      | 10804/30000 [06:42<58:11,  5.50it/s]

Step 10800: train loss 1.5543, val loss 1.7232


 37%|███▋      | 11005/30000 [06:50<57:42,  5.49it/s]

Step 11000: train loss 1.5555, val loss 1.7467


 37%|███▋      | 11205/30000 [06:58<1:08:27,  4.58it/s]

Step 11200: train loss 1.5537, val loss 1.7481


 38%|███▊      | 11405/30000 [07:05<57:10,  5.42it/s]

Step 11400: train loss 1.5478, val loss 1.7221


 39%|███▊      | 11606/30000 [07:12<46:12,  6.64it/s]  

Step 11600: train loss 1.5485, val loss 1.7302


 39%|███▉      | 11806/30000 [07:19<41:07,  7.37it/s]

Step 11800: train loss 1.5386, val loss 1.7330


 40%|████      | 12004/30000 [07:27<56:36,  5.30it/s]

Step 12000: train loss 1.5359, val loss 1.7279


 41%|████      | 12209/30000 [07:35<47:30,  6.24it/s]  

Step 12200: train loss 1.5378, val loss 1.7351


 41%|████▏     | 12409/30000 [07:42<39:26,  7.43it/s]

Step 12400: train loss 1.5289, val loss 1.7226


 42%|████▏     | 12605/30000 [07:49<52:43,  5.50it/s]

Step 12600: train loss 1.5356, val loss 1.7321


 43%|████▎     | 12805/30000 [07:56<57:25,  4.99it/s]

Step 12800: train loss 1.5082, val loss 1.6958


 43%|████▎     | 13007/30000 [08:04<38:04,  7.44it/s]

Step 13000: train loss 1.5118, val loss 1.6931


 44%|████▍     | 13208/30000 [08:11<47:22,  5.91it/s]  

Step 13200: train loss 1.4995, val loss 1.6846


 45%|████▍     | 13408/30000 [08:18<36:40,  7.54it/s]

Step 13400: train loss 1.4985, val loss 1.6866


 45%|████▌     | 13609/30000 [08:26<37:09,  7.35it/s]

Step 13600: train loss 1.4968, val loss 1.6873


 46%|████▌     | 13804/30000 [08:33<58:21,  4.63it/s]

Step 13800: train loss 1.4986, val loss 1.6738


 47%|████▋     | 14009/30000 [08:40<35:48,  7.44it/s]

Step 14000: train loss 1.4873, val loss 1.6740


 47%|████▋     | 14207/30000 [08:48<42:52,  6.14it/s]  

Step 14200: train loss 1.4801, val loss 1.6753


 48%|████▊     | 14407/30000 [08:55<34:53,  7.45it/s]

Step 14400: train loss 1.4886, val loss 1.6801


 49%|████▊     | 14606/30000 [09:02<34:16,  7.49it/s]

Step 14600: train loss 1.4800, val loss 1.6730


 49%|████▉     | 14806/30000 [09:10<40:36,  6.24it/s]

Step 14800: train loss 1.4893, val loss 1.6610


 50%|█████     | 15006/30000 [09:17<33:20,  7.49it/s]

Step 15000: train loss 1.4836, val loss 1.6766


 51%|█████     | 15207/30000 [09:25<33:53,  7.28it/s]

Step 15200: train loss 1.4794, val loss 1.6861


 51%|█████▏    | 15406/30000 [09:32<36:05,  6.74it/s]

Step 15400: train loss 1.4740, val loss 1.6677


 52%|█████▏    | 15608/30000 [09:39<32:24,  7.40it/s]

Step 15600: train loss 1.4796, val loss 1.6558


 53%|█████▎    | 15808/30000 [09:47<37:32,  6.30it/s]

Step 15800: train loss 1.4781, val loss 1.6762


 53%|█████▎    | 16008/30000 [09:54<31:14,  7.46it/s]

Step 16000: train loss 1.4789, val loss 1.6815


 54%|█████▍    | 16208/30000 [10:01<31:21,  7.33it/s]

Step 16200: train loss 1.4831, val loss 1.6865


 55%|█████▍    | 16403/30000 [10:09<46:44,  4.85it/s]

Step 16400: train loss 1.4725, val loss 1.6627


 55%|█████▌    | 16607/30000 [10:16<29:31,  7.56it/s]

Step 16600: train loss 1.4824, val loss 1.6654


 56%|█████▌    | 16807/30000 [10:23<36:32,  6.02it/s]

Step 16800: train loss 1.4783, val loss 1.6717


 57%|█████▋    | 17007/30000 [10:31<29:39,  7.30it/s]

Step 17000: train loss 1.4507, val loss 1.6584


 57%|█████▋    | 17207/30000 [10:38<28:47,  7.41it/s]

Step 17200: train loss 1.4640, val loss 1.6449


 58%|█████▊    | 17407/30000 [10:46<33:17,  6.30it/s]

Step 17400: train loss 1.4561, val loss 1.6437


 59%|█████▊    | 17607/30000 [10:53<28:03,  7.36it/s]

Step 17600: train loss 1.4589, val loss 1.6537


 59%|█████▉    | 17806/30000 [11:00<30:50,  6.59it/s]

Step 17800: train loss 1.4523, val loss 1.6624


 60%|██████    | 18005/30000 [11:07<28:39,  6.98it/s]

Step 18000: train loss 1.4433, val loss 1.6526


 61%|██████    | 18209/30000 [11:15<26:13,  7.49it/s]

Step 18200: train loss 1.4383, val loss 1.6464


 61%|██████▏   | 18409/30000 [11:23<30:45,  6.28it/s]

Step 18400: train loss 1.4429, val loss 1.6587


 62%|██████▏   | 18609/30000 [11:30<25:31,  7.44it/s]

Step 18600: train loss 1.4497, val loss 1.6502


 63%|██████▎   | 18809/30000 [11:37<25:12,  7.40it/s]

Step 18800: train loss 1.4486, val loss 1.6454


 63%|██████▎   | 19004/30000 [11:45<36:54,  4.97it/s]

Step 19000: train loss 1.4418, val loss 1.6458


 64%|██████▍   | 19207/30000 [11:52<24:21,  7.39it/s]

Step 19200: train loss 1.4379, val loss 1.6560


 65%|██████▍   | 19406/30000 [12:00<30:09,  5.86it/s]

Step 19400: train loss 1.4362, val loss 1.6439


 65%|██████▌   | 19606/30000 [12:07<23:28,  7.38it/s]

Step 19600: train loss 1.4344, val loss 1.6305


 66%|██████▌   | 19807/30000 [12:14<23:08,  7.34it/s]

Step 19800: train loss 1.4298, val loss 1.6395


 67%|██████▋   | 20007/30000 [12:22<26:13,  6.35it/s]

Step 20000: train loss 1.4405, val loss 1.6328


 67%|██████▋   | 20207/30000 [12:29<22:22,  7.30it/s]

Step 20200: train loss 1.4254, val loss 1.6459


 68%|██████▊   | 20406/30000 [12:37<24:31,  6.52it/s]

Step 20400: train loss 1.4241, val loss 1.6467


 69%|██████▊   | 20605/30000 [12:44<22:38,  6.92it/s]

Step 20600: train loss 1.4351, val loss 1.6401


 69%|██████▉   | 20809/30000 [12:51<20:43,  7.39it/s]

Step 20800: train loss 1.4294, val loss 1.6329


 70%|███████   | 21009/30000 [12:59<23:52,  6.28it/s]

Step 21000: train loss 1.4148, val loss 1.6514


 71%|███████   | 21209/30000 [13:06<19:48,  7.39it/s]

Step 21200: train loss 1.4237, val loss 1.6457


 71%|███████▏  | 21405/30000 [13:14<26:19,  5.44it/s]

Step 21400: train loss 1.4194, val loss 1.6478


 72%|███████▏  | 21605/30000 [13:21<28:22,  4.93it/s]

Step 21600: train loss 1.4257, val loss 1.6454


 73%|███████▎  | 21808/30000 [13:28<18:31,  7.37it/s]

Step 21800: train loss 1.4205, val loss 1.6337


 73%|███████▎  | 22008/30000 [13:36<22:59,  5.79it/s]

Step 22000: train loss 1.4316, val loss 1.6310


 74%|███████▍  | 22208/30000 [13:43<17:32,  7.40it/s]

Step 22200: train loss 1.4217, val loss 1.6427


 75%|███████▍  | 22408/30000 [13:51<16:52,  7.50it/s]

Step 22400: train loss 1.4211, val loss 1.6351


 75%|███████▌  | 22608/30000 [13:58<19:34,  6.29it/s]

Step 22600: train loss 1.4281, val loss 1.6331


 76%|███████▌  | 22808/30000 [14:05<16:04,  7.46it/s]

Step 22800: train loss 1.4194, val loss 1.6375


 77%|███████▋  | 23008/30000 [14:13<18:35,  6.27it/s]

Step 23000: train loss 1.4168, val loss 1.6405


 77%|███████▋  | 23208/30000 [14:20<15:20,  7.38it/s]

Step 23200: train loss 1.4170, val loss 1.6457


 78%|███████▊  | 23405/30000 [14:27<15:49,  6.95it/s]

Step 23400: train loss 1.4193, val loss 1.6427


 79%|███████▊  | 23605/30000 [14:35<23:18,  4.57it/s]

Step 23600: train loss 1.4198, val loss 1.6279


 79%|███████▉  | 23805/30000 [14:42<18:46,  5.50it/s]

Step 23800: train loss 1.4212, val loss 1.6445


 80%|████████  | 24005/30000 [14:50<18:34,  5.38it/s]

Step 24000: train loss 1.4133, val loss 1.6431


 81%|████████  | 24205/30000 [14:57<19:08,  5.05it/s]

Step 24200: train loss 1.4227, val loss 1.6367


 81%|████████▏ | 24409/30000 [15:04<12:42,  7.34it/s]

Step 24400: train loss 1.4229, val loss 1.6229


 82%|████████▏ | 24606/30000 [15:12<15:23,  5.84it/s]

Step 24600: train loss 1.4221, val loss 1.6424


 83%|████████▎ | 24806/30000 [15:19<11:42,  7.40it/s]

Step 24800: train loss 1.4205, val loss 1.6127


 83%|████████▎ | 25006/30000 [15:26<11:11,  7.44it/s]

Step 25000: train loss 1.4173, val loss 1.6269


 84%|████████▍ | 25202/30000 [15:34<16:56,  4.72it/s]

Step 25200: train loss 1.4102, val loss 1.6376


 85%|████████▍ | 25405/30000 [15:41<13:50,  5.53it/s]

Step 25400: train loss 1.4178, val loss 1.6277


 85%|████████▌ | 25609/30000 [15:49<12:04,  6.06it/s]

Step 25600: train loss 1.4217, val loss 1.6287


 86%|████████▌ | 25809/30000 [15:56<09:25,  7.41it/s]

Step 25800: train loss 1.4205, val loss 1.6355


 87%|████████▋ | 26008/30000 [16:03<09:02,  7.35it/s]

Step 26000: train loss 1.4181, val loss 1.6354


 87%|████████▋ | 26208/30000 [16:11<10:09,  6.22it/s]

Step 26200: train loss 1.4170, val loss 1.6355


 88%|████████▊ | 26408/30000 [16:18<08:00,  7.48it/s]

Step 26400: train loss 1.4194, val loss 1.6231


 89%|████████▊ | 26609/30000 [16:26<07:48,  7.24it/s]

Step 26600: train loss 1.4099, val loss 1.6353


 89%|████████▉ | 26804/30000 [16:33<09:55,  5.37it/s]

Step 26800: train loss 1.4279, val loss 1.6302


 90%|█████████ | 27005/30000 [16:40<09:16,  5.39it/s]

Step 27000: train loss 1.4115, val loss 1.6313


 91%|█████████ | 27205/30000 [16:48<10:07,  4.60it/s]

Step 27200: train loss 1.4136, val loss 1.6295


 91%|█████████▏| 27405/30000 [16:55<07:54,  5.46it/s]

Step 27400: train loss 1.4159, val loss 1.6304


 92%|█████████▏| 27605/30000 [17:02<07:18,  5.46it/s]

Step 27600: train loss 1.4170, val loss 1.6283


 93%|█████████▎| 27805/30000 [17:10<07:34,  4.83it/s]

Step 27800: train loss 1.4066, val loss 1.6370


 93%|█████████▎| 28006/30000 [17:17<04:23,  7.56it/s]

Step 28000: train loss 1.4204, val loss 1.6364


 94%|█████████▍| 28208/30000 [17:25<05:01,  5.94it/s]

Step 28200: train loss 1.4227, val loss 1.6295


 95%|█████████▍| 28408/30000 [17:32<03:33,  7.47it/s]

Step 28400: train loss 1.4189, val loss 1.6426


 95%|█████████▌| 28608/30000 [17:39<03:08,  7.37it/s]

Step 28600: train loss 1.4190, val loss 1.6357


 96%|█████████▌| 28808/30000 [17:47<03:10,  6.26it/s]

Step 28800: train loss 1.4226, val loss 1.6304


 97%|█████████▋| 29008/30000 [17:54<02:13,  7.44it/s]

Step 29000: train loss 1.4181, val loss 1.6276


 97%|█████████▋| 29209/30000 [18:02<01:53,  6.96it/s]

Step 29200: train loss 1.4228, val loss 1.6282


 98%|█████████▊| 29404/30000 [18:09<01:51,  5.37it/s]

Step 29400: train loss 1.4152, val loss 1.6397


 99%|█████████▊| 29609/30000 [18:16<00:52,  7.48it/s]

Step 29600: train loss 1.4248, val loss 1.6084


 99%|█████████▉| 29809/30000 [18:24<00:30,  6.30it/s]

Step 29800: train loss 1.4231, val loss 1.6386


100%|██████████| 30000/30000 [18:31<00:00, 27.00it/s]

Step 29999: train loss 1.4201, val loss 1.6320





In [18]:
# DO NOT MODIFY ANY OF THIS CODE

# Generate from the model
with torch.no_grad():
    context = torch.tensor(encode_text("JULIET: "), dtype=torch.long).unsqueeze(0).to(device)
    generated_text = decode_text(model.generate_text(context, max_tokens=200)[0].tolist())
    print(generated_text)

JULIET: whole burness, on me dear litts of his pain:
If it not, God, horsek sleep not teeds will full man of your call lam.
Your contempt starve their give and mildly.
Is note, but this? Do nobly he wind,
It 
