In [None]:
# CELL 1 — Install & Setup (run once)
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 2>/dev/null || pip install torch

# Upload your dataset (simpsons_style_50000.txt)
from google.colab import files
uploaded = files.upload()  # ← Click here and upload your file

# CELL 2 — FULL TRAINING + GENERATION (just run this)
import torch
import torch.nn as nn
import torch.optim as optim
import time
from IPython.display import clear_output

# Load the file you just uploaded
filename = list(uploaded.keys())[0]
with open(filename, 'r', encoding='utf-8') as f:
    text = f.read()
print(f"Loaded {len(text):,} characters")

# Character-level vocab
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join(itos[i] for i in l)

data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data, val_data = data[:n], data[n:]

# Hyperparameters (optimized for Colab free GPU)
batch_size = 128
block_size = 256
embed_dim = 256
num_heads = 8
num_layers = 6
dropout = 0.1
max_iters = 4000      # ~25-35 min on T4
eval_interval = 400
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

# Model (fixed + slightly bigger for better quality)
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(embed_dim, head_size, bias=False)
        self.query = nn.Linear(embed_dim, head_size, bias=False)
        self.value = nn.Linear(embed_dim, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.head_size = head_size # Store head_size as an instance variable

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2,-1) * (self.head_size ** -0.5) # Use self.head_size
        wei = wei.masked_fill(torch.tril(torch.ones(T,T,device=device))==0, float('-inf'))
        wei = torch.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        return wei @ v

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads * head_size, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.dropout(self.proj(out))

class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, 4*embed_dim),
            nn.GELU(),
            nn.Linear(4*embed_dim, embed_dim),
            nn.Dropout(dropout),
        )
    def forward(self, x): return self.net(x)

class Block(nn.Module):
    def __init__(self):
        super().__init__()
        head_size = embed_dim // num_heads
        self.sa = MultiHeadAttention(num_heads, head_size)
        self.ffwd = FeedForward()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class SimpsonsGPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(block_size, embed_dim)
        self.blocks = nn.Sequential(*[Block() for _ in range(num_layers)])
        self.ln_f = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        x = self.token_emb(idx) + self.pos_emb(torch.arange(T, device=device))
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)

        if targets is None:
            loss = None
        else:
            loss = nn.functional.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
        return logits, loss

    def generate(self, idx, max_new_tokens, temperature=0.9, top_k=40):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('inf')
            probs = torch.softmax(logits, dim=-1)
            next_idx = torch.multinomial(probs, 1)
            idx = torch.cat((idx, next_idx), dim=1)
        return idx

# Training
model = SimpsonsGPT().to(device)
optimizer = optim.AdamW(model.parameters(), lr=5e-4, betas=(0.9, 0.95))

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(100)
        for k in range(100):
            xb, yb = get_batch(split)
            _, loss = model(xb, yb)
            losses[k] = loss.item()
        out[split] = losses.mean().item()
    model.train()
    return out

print("Training started...")
start_time = time.time()

for step in range(max_iters + 1):
    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    if step % eval_interval == 0:
        losses = estimate_loss()
        elapsed = (time.time() - start_time) / 60
        clear_output(wait=True)
        print(f"Step {step}/{max_iters}")
        print(f"Train loss: {losses['train']:.4f} | Val loss: {losses['val']:.4f}")
        print(f"Time: {elapsed:.1f} min")

# Save model
torch.save(model.state_dict(), 'simpsons_gpt_colab.pth')
print("Model saved!")

# Generate a sample
model.eval()
context = torch.tensor(encode("Homer:"), dtype=torch.long, device=device).unsqueeze(0)
generated = model.generate(context, max_new_tokens=1000, temperature=0.9, top_k=50)
print("\n" + "="*60)
print("GENERATED SIMPSONS SCRIPT")
print("="*60)
print(decode(generated[0].tolist()))

Step 4000/4000
Train loss: 0.1587 | Val loss: 0.1602
Time: 53.1 min
Model saved!

GENERATED SIMPSONS SCRIPT
Homer: If this is a test, I’m failing spectacularly. (deadpan)
Lou: There’s no problem that can’t be made worse by my advice. (whispering)
Reverend Lovejoy: That’s not a bug, it’s a feature! (laughing)
Kent Brockman: You had one job, and somehow, you still outdid yourself. (sarcastically)
Krusty: We’ve officially reached peak nonsense. (sighs)
Cletus: This donut tastes like bad decisions. (sighs)
Nelson: You can’t prove it was me… okay, maybe you can. (sarcastically)
Gil: If this is a test, I’m failing spectacularly. (sarcastically)
Grandpa: There’s no problem that can’t be made worse by my advice. (sarcastically)
Hans Moleman: Sometimes silence is the only sane response. (sighs)
Selma: You can’t prove it was me… okay, maybe you can. (shouting)
Gil: Sometimes silence is the only sane response. (laughing)
Patty: I could really use a nap right now. (deadpan)
Gil: That’s not a bug, 

In [None]:
# Install Gradio
!pip install gradio --quiet

In [None]:
import gradio as gr
import torch

# Ensure the model is in evaluation mode and on the correct device
model.eval()
model.to(device)

def generate_simpsons_script_gradio(starting_scene, main_plot, tone, initial_text, max_new_tokens, temperature, top_k, character_to_count, count_dialogues_flag):
    # Replicate prompt construction logic from the notebook
    episode_title = f"The One Where {main_plot.split()[0]} {' '.join(main_plot.split()[1:])}" if not main_plot.startswith("Homer") else main_plot.capitalize()
    prompt = f"Episode Title: {episode_title}\n\n{starting_scene}\n\n*FADE IN:*\n{initial_text}"

    # Filter out characters from the prompt that are not in the vocabulary
    filtered_prompt_chars = [c for c in prompt if c in stoi]
    filtered_prompt = "".join(filtered_prompt_chars)

    context = torch.tensor(encode(filtered_prompt), device=device).unsqueeze(0)

    # Generate the script
    full_drama_tokens = model.generate(
        context,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_k=top_k
    )[0].tolist()

    script = decode(full_drama_tokens)

    # Clean up a bit for readability
    script = script.replace("Homer:", "\nHomer:").replace("Marge:", "\nMarge:").replace("Bart:", "\nBart:").replace("Lisa:", "\nLisa:")
    script = script.replace("Mr. Burns:", "\nMr. Burns:").replace("Smithers:", "\nSmithers:").replace("Chief Wiggum:", "\nChief Wiggum:")

    dialogue_count_message = ""
    if count_dialogues_flag and character_to_count:
        # Count occurrences of the character's name followed by a colon and a newline
        # We add a newline before the character name to avoid counting partial matches within dialogue.
        dialogues = script.count(f"\n{character_to_count}:")
        dialogue_count_message = f"\n\n--- Dialogue Count for {character_to_count}: {dialogues} ---"

    return script + dialogue_count_message

# Gradio Interface
iface = gr.Interface(
    fn=generate_simpsons_script_gradio,
    inputs=[
        gr.Dropdown(
            ["Springfield Nuclear Plant - Day", "The Simpsons living room - Night", "Springfield Elementary", "Moe's Tavern", "Krusty the Clown Show set", "Springfield Town Hall", "The Android's Dungeon comic book store", "Kwik-E-Mart", "Springfield Retirement Castle", "Treehouse of Horror opening"],
            label="Starting Scene",
            value="Springfield Nuclear Plant - Day"
        ),
        gr.Dropdown(
            ["Homer accidentally causes a meltdown", "Lisa discovers corruption in Springfield", "Bart becomes mayor for a day", "Marge runs for school board", "Mr. Burns tries to buy the town", "Apu faces deportation", "Krusty is cancelled", "Grandpa Simpson tells his war stories", "A new teacher changes everything", "Springfield bans something ridiculous"],
            label="Main Plot",
            value="Homer accidentally causes a meltdown"
        ),
        gr.Dropdown(
            ["Dramatic with heart", "Dark comedy", "Pure chaos", "Emotional family drama", "Political satire", "Very silly", "Treehouse of Horror (horror)"],
            label="Tone",
            value="Dramatic with heart"
        ),
        gr.Textbox(
            label="Initial Text (optional, appended to prompt)",
            placeholder="Homer: D'oh!"
        ),
        gr.Slider(minimum=100, maximum=3000, value=800, label="Max New Tokens (length of script)", step=50),
        gr.Slider(minimum=0.1, maximum=1.5, value=0.9, label="Temperature (creativity)", step=0.05),
        gr.Slider(minimum=1, maximum=100, value=50, label="Top-K (diversity)", step=1),
        gr.Dropdown(
            ["Homer", "Marge", "Bart", "Lisa", "Mr. Burns", "Smithers", "Chief Wiggum", "Moe", "Barney", "Ned Flanders", "Principal Skinner", "Apu", "Krusty", "Grandpa", "Milhouse", "Nelson", "Comic Book Guy"],
            label="Character to count dialogues for (optional)",
            value=None, # No default selection
            allow_custom_value=True # Allow users to type other character names
        ),
        gr.Checkbox(label="Count dialogues for selected character", value=False)
    ],
    outputs=gr.Textbox(label="Generated Simpsons Script", lines=20),
    title="SimpsonsGPT Story Generator",
    description="Generate custom Simpsons stories using a fine-tuned GPT model!"
)

iface.launch(debug=True)

It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://5254918e7fca75f02a.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


In [None]:
#@title LONG SIMPSONS DRAMA GENERATOR (500–800 lines) — Click play!
# Change any of these to direct the drama:
starting_scene = "Springfield Nuclear Plant - Day" #@param ["Springfield Nuclear Plant - Day", "The Simpsons living room - Night", "Springfield Elementary", "Moe's Tavern", "Krusty the Clown Show set", "Springfield Town Hall", "The Android's Dungeon comic book store", "Kwik-E-Mart", "Springfield Retirement Castle", "Treehouse of Horror opening"]
main_plot = "Homer accidentally causes a meltdown" #@param ["Homer accidentally causes a meltdown", "Lisa discovers corruption in Springfield", "Bart becomes mayor for a day", "Marge runs for school board", "Mr. Burns tries to buy the town", "Apu faces deportation", "Krusty is cancelled", "Grandpa Simpson tells his war stories", "A new teacher changes everything", "Springfield bans something ridiculous"]
tone = "Dramatic with heart" #@param ["Dramatic with heart", "Dark comedy", "Pure chaos", "Emotional family drama", "Political satire", "Very silly", "Treehouse of Horror (horror)"]

# Build the prompt that forces long, dialogue-heavy output
if "Treehouse of Horror" in tone:
    temperature = 1.0
    top_k = 60
else:
    temperature = 0.9
    top_k = 40

prompt = f"""Episode Title: {"The One Where " + main_plot.split()[0] + " " + " ".join(main_plot.split()[1:]) if not main_plot.startswith("Homer") else main_plot.capitalize()}

{starting_scene}

*FADE IN:*"""

model.eval()
# Filter out characters from the prompt that are not in the vocabulary
filtered_prompt_chars = [c for c in prompt if c in stoi]
filtered_prompt = "".join(filtered_prompt_chars)
context = torch.tensor(encode(filtered_prompt), device=device).unsqueeze(0)

print("Generating your full-length dramatic episode… (this takes 15–30 seconds)")

full_drama = model.generate(
    context,
    max_new_tokens=2800,      # ← ~500–800 dialogue lines
    temperature=temperature,
    top_k=top_k
)[0].tolist()

script = decode(full_drama)

# Clean up a bit for readability
script = script.replace("Homer:", "\nHomer:").replace("Marge:", "\nMarge:").replace("Bart:", "\nBart:").replace("Lisa:", "\nLisa:")
script = script.replace("Mr. Burns:", "\nMr. Burns:").replace("Smithers:", "\nSmithers:").replace("Chief Wiggum:", "\nChief Wiggum:")

print("\n" + "═"*80)
print(f"SIMPSONS DRAMA: {main_plot.upper()}")
print("═"*80 + "\n")
print(script)
print("\n" + "═"*80)
print("END OF EPISODE")

Generating your full-length dramatic episode… (this takes 15–30 seconds)

════════════════════════════════════════════════════════════════════════════════
SIMPSONS DRAMA: HOMER ACCIDENTALLY CAUSES A MELTDOWN
════════════════════════════════════════════════════════════════════════════════

Episode Title: Homer accidentally causes a meltdown

Springfield Nuclear Plant  Day

FADE IN: D'oh! Not again! (shouting)
Professor Frink: You had one job, and somehow, you still outdid yourself. (shouting)
Nelson: Why is it always me? (laughing)
Grandpa: You call that a plan? I call that chaos. (sarcastically)
Troy McClure: That’s not how gravity works, but go on. (confused)
Lenny: You can’t prove it was me… okay, maybe you can.
Willie: There’s nothing like mild panic to start the day. (muttering)
Edna: D'oh! Not again! (excitedly)
Maggie: You call that a plan? I call that chaos. (confused)

Homer: You can’t prove it was me… okay, maybe you can. (confused)
Brandine: You had one job, and somehow, you 