In [1]:
import os
import math
import time
import torch
import torch.nn as nn
from torch.nn import functional as F
from dataclasses import dataclass

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality
        # calculate query, key, values for all heads in batch
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        # causal self-attention
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        # output projection
        y = self.c_proj(y)
        return y

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu    = nn.GELU(approximate='tanh')
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # weight sharing between embedding and final layer
        self.transformer.wte.weight = self.lm_head.weight

        # initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"

        # forward the token and position embeddings
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        pos_emb = self.transformer.wpe(pos)
        tok_emb = self.transformer.wte(idx)
        x = tok_emb + pos_emb

        # forward the transformer blocks
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        # final linear layer
        logits = self.lm_head(x)

        # compute loss if targets are provided
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

    def configure_optimizers(self, weight_decay, learning_rate):
        # separate weight decay and non-weight decay parameters
        decay_params = []
        nodecay_params = []
        for name, param in self.named_parameters():
            if param.dim() >= 2:
                decay_params.append(param)
            else:
                nodecay_params.append(param)

        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95))
        return optimizer

# Training setup and helper functions
def train_gpt2(
    model,
    train_data,
    val_data,
    batch_size=32,
    block_size=1024,
    epochs=1,
    learning_rate=3e-4,
    weight_decay=0.1,
    device='cuda' if torch.cuda.is_available() else 'cpu'
):
    model = model.to(device)
    optimizer = model.configure_optimizers(weight_decay=weight_decay, learning_rate=learning_rate)

    def get_batch(split):
        data = train_data if split == 'train' else val_data
        ix = torch.randint(len(data) - block_size, (batch_size,))
        x = torch.stack([data[i:i+block_size] for i in ix])
        y = torch.stack([data[i+1:i+block_size+1] for i in ix])
        x, y = x.to(device), y.to(device)
        return x, y

    # training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        start_time = time.time()

        for iter in range(len(train_data) // (batch_size * block_size)):
            # get batch
            xb, yb = get_batch('train')

            # forward pass
            logits, loss = model(xb, yb)
            total_loss += loss.item()

            # backward pass
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            # logging
            if iter % 100 == 0:
                print(f"epoch {epoch+1} iter {iter}: loss {loss.item():.4f}")

        # validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for _ in range(len(val_data) // (batch_size * block_size) // 10):  # evaluate on subset
                xb, yb = get_batch('val')
                logits, loss = model(xb, yb)
                val_loss += loss.item()

        print(f"Epoch {epoch+1} complete. Train loss: {total_loss:.4f}, Val loss: {val_loss:.4f}")

    return model

In [2]:
!pip install "pyarrow==14.0.1" datasets tiktoken

Collecting pyarrow==14.0.1
  Downloading pyarrow-14.0.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.0 kB)
Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting tiktoken
  Downloading tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
INFO: pip is looking at multiple versions of datasets to determine which version is compatible with other requirements. This could take a while.
Collecting datasets
  Downloading datasets-3.0.2-py3-none-any.whl.metadata (20 kB)
  Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
  Downloading datasets-3.0.0-py3-none-any.whl.metadata (19 kB)
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
  Downloading datasets-2.20.0-py3-none-any.whl.metadata (19 kB)
  Downloading datasets-2.19.2-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from data

In [None]:
# import os
# from datasets import load_dataset
# import tiktoken
# import torch
# import numpy as np
# from tqdm import tqdm
# from torch.utils.data import Dataset, DataLoader
# import gc
# from typing import List, Tuple
# from itertools import islice

# class WikiTextDataset(Dataset):
#     def __init__(self, encodings, block_size):
#         self.encodings = encodings
#         self.block_size = block_size

#     def __len__(self):
#         return len(self.encodings) - self.block_size

#     def __getitem__(self, idx):
#         chunk = self.encodings[idx:idx + self.block_size + 1]
#         x = chunk[:-1]
#         y = chunk[1:]
#         return x, y

# def batch_iterator(iterable, batch_size):
#     """Helper function to create batch iterator"""
#     iterator = iter(iterable)
#     while batch := list(islice(iterator, batch_size)):
#         yield batch

# def process_batch_gpu(
#     texts: List[str],
#     tokenizer,
#     device: torch.device,
#     max_length: int = None
# ) -> List[List[int]]:
#     """Process a batch of texts using GPU"""
#     # Tokenize all texts in the batch
#     tokens_list = []
#     for text in texts:
#         if text.strip():  # Only process non-empty texts
#             tokens = tokenizer.encode(text)
#             if tokens:  # Only add if we got tokens
#                 tokens.append(tokenizer.eot_token)  # Add EOT token
#                 tokens_list.extend(tokens)

#     return tokens_list

# def prepare_wikipedia_data(
#     cache_dir="wiki_cache",
#     block_size=1024,
#     train_val_split=0.95,
#     batch_size=32,
#     processing_batch_size=1000,
#     device='cuda' if torch.cuda.is_available() else 'cpu'
# ):
#     """
#     Prepare Wikipedia dataset for GPT-2 training with GPU acceleration

#     Args:
#         cache_dir: Directory to cache the processed data
#         block_size: Size of text chunks for training
#         train_val_split: Proportion of data to use for training
#         batch_size: Batch size for data loading
#         processing_batch_size: Batch size for GPU processing
#         device: Device to use for processing
#     """

#     # Create cache directory if it doesn't exist
#     os.makedirs(cache_dir, exist_ok=True)

#     # Cache files
#     train_cache = os.path.join(cache_dir, "train_tokens.pt")
#     val_cache = os.path.join(cache_dir, "val_tokens.pt")

#     # Check if processed data already exists
#     if os.path.exists(train_cache) and os.path.exists(val_cache):
#         print("Loading cached data...")
#         train_data = torch.load(train_cache)
#         val_data = torch.load(val_cache)
#         return train_data, val_data

#     print("Loading Wikipedia dataset...")
#     dataset = load_dataset("wikipedia", "20220301.en", split="train")

#     # Initialize tokenizer
#     enc = tiktoken.get_encoding("gpt2")

#     print("Processing dataset in batches...")
#     all_tokens = []
#     total_batches = len(dataset) // processing_batch_size + (1 if len(dataset) % processing_batch_size != 0 else 0)

#     for batch_idx, batch in enumerate(tqdm(
#         batch_iterator(dataset['text'], processing_batch_size),
#         total=total_batches,
#         desc="Processing batches"
#     )):
#         # Process batch using GPU
#         batch_tokens = process_batch_gpu(batch, enc, device)
#         all_tokens.extend(batch_tokens)

#         # Periodically clear GPU memory
#         if (batch_idx + 1) % 10 == 0:
#             torch.cuda.empty_cache()
#             gc.collect()

#     # Convert to tensor
#     print("Converting to tensor...")
#     all_tokens = torch.tensor(all_tokens, dtype=torch.long)

#     # Split into train and validation
#     split_idx = int(len(all_tokens) * train_val_split)
#     train_tokens = all_tokens[:split_idx]
#     val_tokens = all_tokens[split_idx:]

#     print(f"Total tokens: {len(all_tokens):,}")
#     print(f"Train tokens: {len(train_tokens):,}")
#     print(f"Val tokens: {len(val_tokens):,}")

#     # Save processed data
#     print("Saving processed data...")
#     torch.save(train_tokens, train_cache)
#     torch.save(val_tokens, val_cache)

#     # Clear memory
#     del all_tokens
#     torch.cuda.empty_cache()
#     gc.collect()

#     return train_tokens, val_tokens

# def create_dataloaders(
#     train_tokens,
#     val_tokens,
#     block_size=1024,
#     batch_size=32,
#     num_workers=4
# ):
#     """Create DataLoaders for training and validation"""
#     train_dataset = WikiTextDataset(train_tokens, block_size)
#     val_dataset = WikiTextDataset(val_tokens, block_size)

#     train_loader = DataLoader(
#         train_dataset,
#         batch_size=batch_size,
#         shuffle=True,
#         num_workers=num_workers,
#         pin_memory=True
#     )

#     val_loader = DataLoader(
#         val_dataset,
#         batch_size=batch_size,
#         shuffle=False,
#         num_workers=num_workers,
#         pin_memory=True
#     )

#     return train_loader, val_loader

# class TrainingConfig:
#     def __init__(self):
#         self.learning_rate = 3e-4
#         self.weight_decay = 0.1
#         self.gradient_accumulation_steps = 8
#         self.warmup_steps = 1000
#         self.max_steps = None  # Will be set based on dataset size
#         self.batch_size = 32
#         self.block_size = 1024
#         self.epochs = 3
#         self.checkpoint_dir = "checkpoints"
#         self.log_interval = 100

# def train_gpt2_with_dataloader(
#     model,
#     train_loader,
#     val_loader,
#     config: TrainingConfig,
#     device='cuda' if torch.cuda.is_available() else 'cpu'
# ):
#     """Enhanced training function with learning rate scheduling and better logging"""
#     model = model.to(device)
#     optimizer = model.configure_optimizers(
#         weight_decay=config.weight_decay,
#         learning_rate=config.learning_rate
#     )

#     # Create checkpoint directory
#     os.makedirs(config.checkpoint_dir, exist_ok=True)

#     # Calculate total steps if not provided
#     if config.max_steps is None:
#         config.max_steps = len(train_loader) * config.epochs

#     # Learning rate scheduler
#     def get_lr(step):
#         if step < config.warmup_steps:
#             return config.learning_rate * step / config.warmup_steps
#         return config.learning_rate

#     # Training loop
#     global_step = 0
#     best_val_loss = float('inf')

#     for epoch in range(config.epochs):
#         model.train()
#         total_loss = 0

#         progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

#         for step, (x, y) in enumerate(progress_bar):
#             x, y = x.to(device), y.to(device)

#             # Forward pass
#             logits, loss = model(x, y)
#             loss = loss / config.gradient_accumulation_steps
#             total_loss += loss.item()

#             # Backward pass
#             loss.backward()

#             # Update weights
#             if (step + 1) % config.gradient_accumulation_steps == 0:
#                 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

#                 # Update learning rate
#                 lr = get_lr(global_step)
#                 for param_group in optimizer.param_groups:
#                     param_group['lr'] = lr

#                 optimizer.step()
#                 optimizer.zero_grad()
#                 global_step += 1

#             # Update progress bar
#             current_loss = total_loss / (step + 1)
#             progress_bar.set_postfix({
#                 'loss': f"{current_loss:.4f}",
#                 'perplexity': f"{torch.exp(torch.tensor(current_loss)):.2f}",
#                 'lr': f"{lr:.2e}"
#             })

#             # Log training progress
#             if step % config.log_interval == 0:
#                 print(f"\nStep {global_step}: loss {current_loss:.4f}, "
#                       f"perplexity {torch.exp(torch.tensor(current_loss)):.2f}, "
#                       f"lr {lr:.2e}")

#         # Validation
#         model.eval()
#         val_loss = 0
#         with torch.no_grad():
#             for x, y in tqdm(val_loader, desc="Validation"):
#                 x, y = x.to(device), y.to(device)
#                 logits, loss = model(x, y)
#                 val_loss += loss.item()

#         val_loss /= len(val_loader)

#         # Save checkpoint if best validation loss
#         if val_loss < best_val_loss:
#             best_val_loss = val_loss
#             checkpoint_path = os.path.join(config.checkpoint_dir, f'best_model.pt')
#             torch.save({
#                 'epoch': epoch + 1,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'val_loss': val_loss,
#                 'global_step': global_step,
#             }, checkpoint_path)

#         # Save regular checkpoint
#         checkpoint_path = os.path.join(config.checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pt')
#         torch.save({
#             'epoch': epoch + 1,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'val_loss': val_loss,
#             'global_step': global_step,
#         }, checkpoint_path)

#         print(f"\nEpoch {epoch+1} complete. "
#               f"Train loss: {total_loss/len(train_loader):.4f}, "
#               f"Val loss: {val_loss:.4f}, "
#               f"Train perplexity: {torch.exp(torch.tensor(total_loss/len(train_loader))):.2f}, "
#               f"Val perplexity: {torch.exp(torch.tensor(val_loss)):.2f}")

#     return model

# # Example usage
# if __name__ == "__main__":
#     # Set device
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     print(f"Using device: {device}")

#     # Training configuration
#     config = TrainingConfig()

#     # Prepare data
#     print("Preparing Wikipedia dataset...")
#     train_tokens, val_tokens = prepare_wikipedia_data(
#         cache_dir="wiki_cache",
#         block_size=config.block_size,
#         train_val_split=0.95,
#         batch_size=config.batch_size,
#         device=device
#     )

#     # Create dataloaders
#     print("Creating DataLoaders...")
#     train_loader, val_loader = create_dataloaders(
#         train_tokens,
#         val_tokens,
#         block_size=config.block_size,
#         batch_size=config.batch_size
#     )

#     # Initialize model
#     print("Initializing model...")
#     gpt_config = GPTConfig()
#     model = GPT(gpt_config)

#     # Train model
#     print("Starting training...")
#     model = train_gpt2_with_dataloader(
#         model,
#         train_loader,
#         val_loader,
#         config,
#         device=device
#     )

Using device: cuda
Preparing Wikipedia dataset...
Loading Wikipedia dataset...


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

In [None]:
import os
from datasets import load_dataset
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import gc
from typing import List, Tuple
from itertools import islice
from transformers import GPT2Tokenizer, GPT2TokenizerFast
from accelerate import Accelerator

class WikiTextDataset(Dataset):
    def __init__(self, encodings, block_size):
        self.encodings = encodings
        self.block_size = block_size

    def __len__(self):
        return len(self.encodings) - self.block_size

    def __getitem__(self, idx):
        chunk = self.encodings[idx:idx + self.block_size + 1]
        x = chunk[:-1]
        y = chunk[1:]
        return x, y

def batch_iterator(iterable, batch_size):
    """Helper function to create batch iterator"""
    iterator = iter(iterable)
    while batch := list(islice(iterator, batch_size)):
        yield batch

def process_batch_gpu(
    texts: List[str],
    tokenizer: GPT2TokenizerFast,
    device: torch.device,
    max_length: int = None
) -> torch.Tensor:
    """Process a batch of texts using GPU-accelerated tokenizer"""
    # Tokenize all texts in the batch
    encoded = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors='pt',
        return_attention_mask=True
    ).to(device)

    # Get input ids and convert to list
    tokens = encoded.input_ids
    attention_mask = encoded.attention_mask

    # Remove padding and flatten
    tokens_list = []
    for seq, mask in zip(tokens, attention_mask):
        # Only keep non-padding tokens
        valid_tokens = seq[mask.bool()]
        if len(valid_tokens) > 0:
            tokens_list.append(valid_tokens)

    # Concatenate all tokens
    if tokens_list:
        return torch.cat(tokens_list)
    return torch.tensor([], device=device)

def prepare_wikipedia_data(
    cache_dir="wiki_cache",
    block_size=1024,
    train_val_split=0.95,
    batch_size=32,
    processing_batch_size=1000,
    device='cuda' if torch.cuda.is_available() else 'cpu'
):
    """
    Prepare Wikipedia dataset for GPT-2 training with GPU acceleration
    """
    # Create cache directory if it doesn't exist
    os.makedirs(cache_dir, exist_ok=True)

    # Cache files
    train_cache = os.path.join(cache_dir, "train_tokens.pt")
    val_cache = os.path.join(cache_dir, "val_tokens.pt")

    # Check if processed data already exists
    if os.path.exists(train_cache) and os.path.exists(val_cache):
        print("Loading cached data...")
        train_data = torch.load(train_cache)
        val_data = torch.load(val_cache)
        return train_data, val_data

    print("Loading Wikipedia dataset...")
    dataset = load_dataset("wikipedia", "20220301.en", split="train")

    # Initialize tokenizer with GPU acceleration
    print("Initializing GPU-accelerated tokenizer...")
    tokenizer = GPT2TokenizerFast.from_pretrained(
        'gpt2',
        model_max_length=block_size
    )

    # Add special tokens if needed
    special_tokens = {
        'pad_token': '<|pad|>',
        'eos_token': '<|endoftext|>'
    }
    tokenizer.add_special_tokens(special_tokens)

    print("Processing dataset in batches...")
    all_tokens = []
    total_batches = len(dataset) // processing_batch_size + (1 if len(dataset) % processing_batch_size != 0 else 0)

    try:
        for batch_idx, batch in enumerate(tqdm(
            batch_iterator(dataset['text'], processing_batch_size),
            total=total_batches,
            desc="Processing batches"
        )):
            # Filter out empty strings and very short texts
            batch = [text for text in batch if len(text.strip()) > 50]

            if not batch:
                continue

            # Process batch using GPU
            batch_tokens = process_batch_gpu(batch, tokenizer, device)

            if len(batch_tokens) > 0:
                # Move to CPU to save GPU memory
                all_tokens.append(batch_tokens.cpu())

            # Periodically clear GPU memory
            if (batch_idx + 1) % 10 == 0:
                torch.cuda.empty_cache()
                gc.collect()

            # Periodically save progress
            if (batch_idx + 1) % 100 == 0:
                print(f"\nProcessed {batch_idx + 1}/{total_batches} batches")
                print(f"Current total tokens: {sum(len(t) for t in all_tokens):,}")

    except Exception as e:
        print(f"Error during processing: {str(e)}")
        # Save what we have so far
        print("Saving partial progress...")

    finally:
        # Concatenate all tokens
        print("Concatenating tokens...")
        all_tokens = torch.cat(all_tokens)

        # Split into train and validation
        split_idx = int(len(all_tokens) * train_val_split)
        train_tokens = all_tokens[:split_idx]
        val_tokens = all_tokens[split_idx:]

        print(f"Total tokens: {len(all_tokens):,}")
        print(f"Train tokens: {len(train_tokens):,}")
        print(f"Val tokens: {len(val_tokens):,}")

        # Save processed data
        print("Saving processed data...")
        torch.save(train_tokens, train_cache)
        torch.save(val_tokens, val_cache)

        # Clear memory
        del all_tokens
        torch.cuda.empty_cache()
        gc.collect()

        return train_tokens, val_tokens

# Example usage with better error handling and GPU memory management
def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Set up accelerator for distributed training if needed
    accelerator = Accelerator()

    try:
        # Prepare data with larger batch size for GPU processing
        train_tokens, val_tokens = prepare_wikipedia_data(
            cache_dir="wiki_cache",
            block_size=1024,
            train_val_split=0.95,
            processing_batch_size=2000,  # Increased batch size for GPU
            device=device
        )

        # Create dataloaders (rest of the code remains the same)
        train_loader, val_loader = create_dataloaders(
            train_tokens,
            val_tokens,
            block_size=1024,
            batch_size=32
        )

        print("Data preparation completed successfully!")
        return train_loader, val_loader

    except Exception as e:
        print(f"Error during data preparation: {str(e)}")
        raise

    finally:
        # Clean up GPU memory
        torch.cuda.empty_cache()
        gc.collect()

if __name__ == "__main__":
    main()

In [6]:
from transformers import GPT2TokenizerFast
import torch
import torch.nn.functional as F

class TextGenerator:
    def __init__(self, model, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model.to(device)
        self.device = device
        self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
        if not self.tokenizer.pad_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model.eval()

    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 100,
        temperature: float = 0.8,
        top_k: int = 50,
        top_p: float = 0.9,
        num_return_sequences: int = 1,
    ):
        """Generate text based on a prompt"""
        # Tokenize the prompt
        encoded = self.tokenizer(prompt, return_tensors='pt', truncation=True)
        input_ids = encoded['input_ids'].to(self.device)
        input_ids = input_ids.repeat(num_return_sequences, 1)

        with torch.no_grad():
            for _ in range(max_new_tokens):
                # Get model predictions
                logits, _ = self.model(input_ids)
                next_token_logits = logits[:, -1, :] / temperature

                # Apply top-k filtering
                if top_k > 0:
                    indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
                    next_token_logits[indices_to_remove] = float('-inf')

                # Sample from the filtered distribution
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)

                # Stop if all sequences have generated EOS token
                if (next_token == self.tokenizer.eos_token_id).all():
                    break

                input_ids = torch.cat([input_ids, next_token], dim=-1)

        # Decode generated sequences
        generated_sequences = []
        for seq in input_ids:
            text = self.tokenizer.decode(seq, skip_special_tokens=True)
            generated_sequences.append(text)

        return generated_sequences

def generate_text(model_path, prompt):
    """Simple function to load model and generate text"""
    # Load the checkpoint
    checkpoint = torch.load(model_path)

    # Initialize model with the saved config
    model = GPT(checkpoint['config'])
    model.load_state_dict(checkpoint['model'])

    # Create generator
    generator = TextGenerator(model)

    # Generate text with different temperatures
    print("\nGenerating with different temperatures:")
    temperatures = [0.5, 0.8, 1.0]

    for temp in temperatures:
        print(f"\nTemperature: {temp}")
        print("-" * 50)

        generated_texts = generator.generate(
            prompt=prompt,
            max_new_tokens=100,
            temperature=temp,
            num_return_sequences=1
        )

        print(generated_texts[0])

    return generator  # Return generator for further use if needed

# Example usage in Colab
model_path = 'checkpoints/best_model.pt'  # Update this path
prompt = "Once upon a time"

# Generate text
generator = generate_text(model_path, prompt)

# For interactive use, you can now use the generator directly:
while True:
    user_prompt = input("\nEnter prompt (or 'quit' to exit): ")
    if user_prompt.lower() == 'quit':
        break

    temp = float(input("Enter temperature (0.1-2.0, default 0.8): ") or 0.8)

    generated = generator.generate(
        prompt=user_prompt,
        temperature=temp,
        max_new_tokens=100
    )

    print("\nGenerated text:")
    print("-" * 50)
    print(generated[0])

  checkpoint = torch.load(model_path)


FileNotFoundError: [Errno 2] No such file or directory: 'checkpoints/best_model.pt'