In [None]:
# Cell 1: Core Model Definition and Helper Functions
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import random
from IPython.display import clear_output

# Set global parameters
batch_size = 64
block_size = 256
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
learning_rate = 1e-3
max_iters = 5000

# Define improved language model
class ImprovedLanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=512):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=2, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, idx, targets=None):
        embeddings = self.embedding(idx)  
        lstm_out, _ = self.lstm(embeddings)
        logits = self.fc(lstm_out)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
        
    def generate(self, idx, max_new_tokens, temperature=0.8):
        for _ in range(max_new_tokens):
            logits, _ = self(idx)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

# Function to get a batch of data
def get_batch(split):
    # Select appropriate data
    data = train_data if split == 'train' else val_data
    # Generate random starting indices
    ix = torch.randint(len(data) - block_size, (batch_size,))
    # Create batch inputs and targets
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    # Move to device
    x, y = x.to(device), y.to(device)
    return x, y

# Function to estimate loss
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# Function to calculate accuracy
def calculate_accuracy(model, data, num_samples=1000, context_len=10):
    if len(data) == 0:
        return 0.0
        
    model.eval()
    correct = 0
    total = 0
    
    # Sample random positions
    if len(data) <= context_len:
        positions = [0]
    else:
        max_pos = min(len(data) - context_len - 1, num_samples)
        if max_pos <= 0:
            return 0.0
        positions = random.sample(range(max_pos), min(num_samples, max_pos))
    
    with torch.no_grad():
        for pos in positions:
            # Get context and next character
            x = data[pos:pos+context_len].unsqueeze(0).to(device)
            y_true = data[pos+context_len].to(device)
            
            # Get model prediction
            logits, _ = model(x)
            y_pred = torch.argmax(logits[0, -1, :])
            
            # Compare
            if y_pred == y_true:
                correct += 1
            total += 1
    
    model.train()
    return (correct / total * 100) if total > 0 else 0.0

# Function to load a model
def load_model(model_path):
    """Load the trained model and vocabulary"""
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    checkpoint = torch.load(model_path, map_location=device)
    model_state_dict = checkpoint['model_state_dict']
    vocab_size = checkpoint['vocab_size']
    chars = checkpoint['chars']
    string_to_int = checkpoint['string_to_int']
    int_to_string = checkpoint['int_to_string']
    
    # Create and load model
    model = ImprovedLanguageModel(vocab_size)
    model.load_state_dict(model_state_dict)
    model = model.to(device)
    model.eval()  # Set to evaluation mode
    
    return model, chars, string_to_int, int_to_string, device

# Function to generate text
def predict_text(model, input_text, string_to_int, int_to_string, device, num_chars=100, temperature=0.8):
    """Generate text continuation from input"""
    # Convert input text to tensor
    encode = lambda s: [string_to_int.get(c, 0) for c in s]  # Default to 0 if char not found
    decode = lambda l: ''.join([int_to_string[i] for i in l])
    
    # Handle empty input
    if not input_text:
        input_tensor = torch.tensor([[0]], dtype=torch.long, device=device)
    else:
        input_tensor = torch.tensor(encode(input_text), dtype=torch.long, device=device).unsqueeze(0)
    
    # Generate text continuation
    with torch.no_grad():
        output_tensor = model.generate(input_tensor, max_new_tokens=num_chars, temperature=temperature)
    
    # Get the full generated text
    generated_text = decode(output_tensor[0].tolist())
    
    return generated_text

# Function for the interactive mode
def interactive_mode(model, string_to_int, int_to_string, device):
    """Run interactive prediction mode"""
    print("\n===== Language Model Predictor =====")
    print("Type some text and the model will continue it.")
    print("Commands:")
    print("  :temp X    - Set temperature (0.1-2.0, default 0.8)")
    print("  :length X  - Set generation length (default 100)")
    print("  :exit      - Quit the program")
    
    temperature = 0.8
    length = 100
    
    while True:
        try:
            input_text = input("\nYour text: ")
            
            # Handle commands
            if input_text.startswith(':'):
                cmd = input_text.lower()
                if cmd == ':exit':
                    break
                elif cmd.startswith(':temp '):
                    try:
                        temperature = float(cmd.split()[1])
                        print(f"Temperature set to {temperature}")
                    except:
                        print("Invalid temperature value")
                elif cmd.startswith(':length '):
                    try:
                        length = int(cmd.split()[1])
                        print(f"Length set to {length}")
                    except:
                        print("Invalid length value")
                continue
            
            print("\nGenerating text...")
            generated_text = predict_text(model, input_text, string_to_int, int_to_string, 
                                          device, num_chars=length, temperature=temperature)
            
            print("\nGenerated text:")
            print("-" * 40)
            print(generated_text)
            print("-" * 40)
            
        except KeyboardInterrupt:
            break
        except Exception as e:
            print(f"Error: {e}")

In [None]:
# Cell 2: Data Loading and Training Functions
import os
import torch

# Define global variables
model = None
optimizer = None
train_data = None
val_data = None
string_to_int = None
int_to_string = None

# Function to load text data
def load_text_data(file_path='sample2.txt'):
    """Load text data from a file and prepare for processing"""
    print(f"Loading data from: {file_path}")
    
    # Create data directory if it doesn't exist
    if not os.path.exists('data'):
        os.makedirs('data')
        print("Created 'data' directory for text files")
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
        
        print(f"Loaded {len(text)} characters from {file_path}")
        
        # Create character mapping
        chars = sorted(set(text))
        vocabulary_size = len(chars)
        print(f"Vocabulary size: {vocabulary_size} unique characters")
        
        # Create encoder/decoder
        string_to_int = {c: i for i, c in enumerate(chars)}
        int_to_string = {i: c for i, c in enumerate(chars)}
        
        # Convert to tensor
        data = torch.tensor(
            [string_to_int[c] for c in text], 
            dtype=torch.long
        )
        
        return text, chars, string_to_int, int_to_string, data
    
    except FileNotFoundError:
        print(f"Error: File {file_path} not found.")
        if file_path != 'sample2.txt':
            print("Trying to load default file instead...")
            return load_text_data('sample2.txt')
        else:
            raise
    except Exception as e:
        print(f"Error loading file: {e}")
        raise

# Function to train on a selected data file
def train_on_file(file_path='sample2.txt', epochs=20):
    """Train the language model on a specific text file"""
    global model, optimizer, train_data, val_data, string_to_int, int_to_string
    
    # Load and prepare data
    text, chars, string_to_int, int_to_string, data = load_text_data(file_path)
    vocabulary_size = len(chars)
    
    # Split data
    n = int(0.8 * len(data))
    train_data = data[:n]
    val_data = data[n:]
    
    # Initialize model and optimizer
    print("Initializing model...")
    model = ImprovedLanguageModel(vocabulary_size)
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    
    # Training parameters
    iters_per_epoch = max(1, min(500, len(train_data) // (block_size * batch_size)))
    total_iters = epochs * iters_per_epoch
    eval_interval = max(1, min(100, iters_per_epoch // 2))
    
    print(f"Training for {epochs} epochs with {iters_per_epoch} iterations per epoch")
    print(f"Total iterations: {total_iters}, evaluating every {eval_interval} steps")
    
    # Training loop with epochs
    print("Starting training...")
    global_iter = 0
    
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        
        # Track progress within epoch
        for local_iter in range(iters_per_epoch):
            # Evaluate periodically
            if global_iter % eval_interval == 0:
                losses = estimate_loss()
                # Calculate and report training and validation accuracy
                train_accuracy = calculate_accuracy(model, train_data[:min(10000, len(train_data))])
                val_accuracy = calculate_accuracy(model, val_data[:min(5000, len(val_data))])
                print(f"Epoch {epoch+1}/{epochs}, Step {local_iter+1}/{iters_per_epoch}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
                print(f"Accuracies: train {train_accuracy:.2f}%, val {val_accuracy:.2f}%")
            
            # Get batch and train
            try:
                xb, yb = get_batch('train')
                logits, loss = model(xb, yb)
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()
            except Exception as e:
                print(f"Error during training: {e}")
                continue
            
            global_iter += 1
        
        # Evaluate at the end of each epoch
        try:
            epoch_losses = estimate_loss()
            print(f"End of epoch {epoch+1}/{epochs}: train loss {epoch_losses['train']:.4f}, val loss {epoch_losses['val']:.4f}")
        except Exception as e:
            print(f"Error evaluating at epoch end: {e}")
    
    # Calculate final accuracy
    try:
        final_train_accuracy = calculate_accuracy(model, train_data)
        final_val_accuracy = calculate_accuracy(model, val_data)
        print(f"\nFinal accuracy: train {final_train_accuracy:.2f}%, val {final_val_accuracy:.2f}%")
    except Exception as e:
        print(f"Error calculating final accuracy: {e}")
    
    # Save the model
    try:
        save_path = f"models/language_model_{os.path.basename(file_path).split('.')[0]}.pt"
        os.makedirs('models', exist_ok=True)
        torch.save({
            'model_state_dict': model.state_dict(),
            'vocab_size': vocabulary_size,
            'chars': chars,
            'string_to_int': string_to_int,
            'int_to_string': int_to_string
        }, save_path)
        print(f"Model saved to {save_path}")
    except Exception as e:
        print(f"Error saving model: {e}")
    
    return model, chars, string_to_int, int_to_string

# Interactive file selection and training
def select_and_train():
    """Interactive menu to select a file and train the model"""
    # Ensure data directory exists
    if not os.path.exists('data'):
        os.makedirs('data')
        print("Created 'data' directory. Please add text files to this directory.")
        return
    
    # List available files
    try:
        files = [f for f in os.listdir('data') if f.endswith('.txt')]
    except Exception as e:
        print(f"Error listing files: {e}")
        files = []
    
    if not files:
        print("No text files found in the 'data' directory. Please add some .txt files.")
        return
    
    print("\nAvailable text files:")
    for i, file in enumerate(files):
        print(f"{i+1}. {file}")
    
    # Get user selection
    try:
        choice = input("\nSelect a file number (or press Enter for default sample2.txt): ")
        if choice.strip() == '':
            file_path = 'sample2.txt'
        else:
            try:
                idx = int(choice) - 1
                if 0 <= idx < len(files):
                    file_path = f"data/{files[idx]}"
                else:
                    print("Invalid selection. Using default file.")
                    file_path = 'sample2.txt'
            except ValueError:
                print("Invalid input. Using default file.")
                file_path = 'sample2.txt'
        
        # Get epochs
        epochs = input("Enter number of epochs (default: 20): ")
        epochs = int(epochs) if epochs.strip().isdigit() else 20
        
        # Train model
        print(f"\nTraining on {file_path} for {epochs} epochs...")
        model, chars, string_to_int, int_to_string = train_on_file(file_path, epochs)
        
        # Ask if user wants to enter interactive mode
        response = input("\nWould you like to enter interactive mode? (y/n): ")
        if response.lower() in ['y', 'yes']:
            interactive_mode(model, string_to_int, int_to_string, device)
        
    except Exception as e:
        print(f"Error in file selection or training: {e}")

In [None]:
# Cell 3: Project Structure and Main Menu
import os
from IPython.display import clear_output

def create_project_structure():
    """Create the project directory structure if it doesn't exist"""
    dirs = ['data', 'models', 'bitnet_cache']
    for dir_name in dirs:
        try:
            if not os.path.exists(dir_name):
                os.makedirs(dir_name)
                print(f"Created '{dir_name}' directory")
        except Exception as e:
            print(f"Error creating directory {dir_name}: {e}")
    
    # Check if sample2.txt needs to be moved
    try:
        if os.path.exists('sample2.txt') and not os.path.exists('sample2.txt'):
            import shutil
            shutil.copy('sample2.txt', 'sample2.txt')
            print("Moved sample2.txt to data directory")
    except Exception as e:
        print(f"Error moving sample file: {e}")

def main_menu():
    """Display the main menu for model selection and training"""
    create_project_structure()
    
    while True:
        try:
            clear_output(wait=True)
            print("\n===== Text Generation Model Training =====")
            print("1. Train basic LSTM model on a text file")
            print("2. Use Microsoft BitNet model")
            print("3. Load existing model and run interactive mode")
            print("4. Exit")
            
            choice = input("\nSelect an option (1-4): ")
            
            if choice == '1':
                select_and_train()
                input("\nPress Enter to return to main menu...")
            
            elif choice == '2':
                # This would call the BitNet implementation
                try:
                    # Check if transformers is installed
                    import importlib
                    if importlib.util.find_spec("transformers") is None:
                        print("Installing required packages for BitNet...")
                        !pip install -q transformers accelerate datasets
                    
                    # BitNet functions in cell 4
                    print("Loading BitNet functionality...")
                    use_bitnet()
                    input("\nPress Enter to return to main menu...")
                except Exception as e:
                    print(f"Error using BitNet: {e}")
                    input("\nPress Enter to return to main menu...")
            
            elif choice == '3':
                # List and load existing models
                try:
                    models_dir = 'models'
                    if not os.path.exists(models_dir):
                        os.makedirs(models_dir)
                        print("Models directory created. Please train a model first.")
                        input("\nPress Enter to return to main menu...")
                        continue
                    
                    models = [f for f in os.listdir(models_dir) if f.endswith('.pt')]
                    if not models:
                        print("No saved models found. Please train a model first.")
                        input("\nPress Enter to return to main menu...")
                        continue
                    
                    print("\nAvailable models:")
                    for i, model_file in enumerate(models):
                        print(f"{i+1}. {model_file}")
                    
                    idx_input = input("\nSelect a model number: ")
                    try:
                        idx = int(idx_input) - 1
                        if 0 <= idx < len(models):
                            model_path = f"{models_dir}/{models[idx]}"
                            model, chars, string_to_int, int_to_string, device = load_model(model_path)
                            interactive_mode(model, string_to_int, int_to_string, device)
                        else:
                            print("Invalid selection.")
                    except ValueError:
                        print("Invalid input.")
                    
                    input("\nPress Enter to return to main menu...")
                except Exception as e:
                    print(f"Error loading model: {e}")
                    input("\nPress Enter to return to main menu...")
            
            elif choice == '4':
                print("Exiting program. Goodbye!")
                break
            
            else:
                print("Invalid choice. Please enter a number between 1 and 4.")
                input("\nPress Enter to try again...")
                
        except Exception as e:
            print(f"Error in main menu: {e}")
            input("\nPress Enter to try again...")

In [None]:
# Cell 4: BitNet Integration
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import Trainer, TrainingArguments
from datasets import Dataset
from tqdm.auto import tqdm

class BitNetTextGenerator:
    """Wrapper class to integrate Microsoft BitNet with existing project"""
    
    def __init__(self, model_name="microsoft/bitnet-b1.58-2B-4T", use_cache=True, cache_dir="./bitnet_cache"):
        """Initialize BitNet model and tokenizer"""
        print(f"Loading BitNet model: {model_name}")
        
        # Create cache directory if it doesn't exist
        if use_cache and not os.path.exists(cache_dir):
            os.makedirs(cache_dir)
            
        # Set device
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")
        
        try:
            # Load tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name, 
                cache_dir=cache_dir if use_cache else None
            )
            
            # Load model with optimizations
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                cache_dir=cache_dir if use_cache else None,
                torch_dtype=torch.float16 if self.device == 'cuda' else torch.float32,
                device_map="auto" if self.device == 'cuda' else None,
                low_cpu_mem_usage=True
            )
            
            # Connect with your character vocabulary
            self.char_to_token = {}  # Will map your characters to token IDs
            
        except Exception as e:
            print(f"Error initializing BitNet: {e}")
            raise
        
    def prepare_data_for_finetuning(self, text_file_path, max_length=256, stride=128):
        """Prepare data from your text file for BitNet fine-tuning"""
        print("Preparing data for fine-tuning...")
        
        try:
            # Read your existing text file
            with open(text_file_path, 'r', encoding='utf-8') as f:
                text = f.read()
            
            # Create char to token mapping if needed
            if not self.char_to_token:
                unique_chars = sorted(set(text))
                for char in unique_chars:
                    token_id = self.tokenizer.encode(char, add_special_tokens=False)
                    if token_id:  # Some characters might not have a direct mapping
                        self.char_to_token[char] = token_id[0]
            
            # Tokenize text - using stride to create overlapping examples for better learning
            encodings = self.tokenizer(text, return_overflowing_tokens=True, 
                                      max_length=max_length, stride=stride)
            
            # Prepare dataset
            examples = []
            for i in range(len(encodings["input_ids"])):
                examples.append({
                    "input_ids": torch.tensor(encodings["input_ids"][i]),
                    "attention_mask": torch.tensor(encodings["attention_mask"][i]),
                    "labels": torch.tensor(encodings["input_ids"][i])  # For causal language modeling
                })
            
            # Create Dataset and split
            dataset = Dataset.from_list(examples)
            dataset = dataset.train_test_split(test_size=0.1)
            
            print(f"Created dataset with {len(dataset['train'])} training and {len(dataset['test'])} validation examples")
            return dataset
            
        except Exception as e:
            print(f"Error preparing data: {e}")
            raise
    
    def finetune(self, dataset, output_dir="./models/bitnet-finetuned", epochs=1, batch_size=2):
        """Fine-tune BitNet on your data"""
        print("Starting BitNet fine-tuning...")
        
        try:
            # Create output directory if it doesn't exist
            os.makedirs(output_dir, exist_ok=True)
            
            # Adjust batch size based on device
            actual_batch_size = batch_size
            if self.device == 'cpu':
                actual_batch_size = 1
                print("Running on CPU, reducing batch size to 1")
            
            # Setup training arguments
            training_args = TrainingArguments(
                output_dir=output_dir,
                overwrite_output_dir=True,
                num_train_epochs=epochs,
                per_device_train_batch_size=actual_batch_size,
                per_device_eval_batch_size=actual_batch_size,
                gradient_accumulation_steps=4,  # To simulate larger batch sizes
                evaluation_strategy="steps",
                eval_steps=100,
                save_steps=100,
                save_total_limit=2,  # Only keep the 2 best checkpoints
                learning_rate=5e-5,
                fp16=self.device == 'cuda',
                load_best_model_at_end=True,
                logging_steps=50,
                report_to="none"  # Disable wandb reporting
            )
            
            # Create trainer
            trainer = Trainer(
                model=self.model,
                args=training_args,
                train_dataset=dataset["train"],
                eval_dataset=dataset["test"],
            )
            
            # Train
            trainer.train()
            
            # Save the model and tokenizer
            self.model.save_pretrained(output_dir)
            self.tokenizer.save_pretrained(output_dir)
            
            print(f"Model fine-tuned and saved to {output_dir}")
            return trainer
            
        except Exception as e:
            print(f"Error during fine-tuning: {e}")
            raise
    
    def generate_text(self, prompt, max_new_tokens=100, temperature=0.8, top_p=0.9):
        """Generate text using BitNet"""
        try:
            self.model.eval()
            
            # Handle empty prompt
            if not prompt:
                inputs = self.tokenizer("Hello", return_tensors="pt").to(self.device)
            else:
                # Encode the prompt
                inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            
            # Generate text
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs.input_ids,
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    temperature=temperature,
                    top_p=top_p,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            # Decode the generated text
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            return generated_text
            
        except Exception as e:
            print(f"Error generating text: {e}")
            return f"Error: {str(e)}"
    
    def interactive_mode(self):
        """Interactive mode compatible with your existing interface"""
        print("\n===== BitNet Text Generation =====")
        print("Type some text and press Enter to have the model continue it.")
        print("Type ':temp X' to set temperature (e.g., ':temp 0.5')")
        print("Type ':length X' to set generation length (e.g., ':length 200')")
        print("Type ':exit' to quit.")
        
        temperature = 0.8
        length = 100
        
        while True:
            try:
                user_input = input("\n> ")
                
                # Handle commands
                if user_input.startswith(':'):
                    cmd = user_input.lower()
                    if cmd == ':exit':
                        break
                    elif cmd.startswith(':temp '):
                        try:
                            temperature = float(cmd.split()[1])
                            print(f"Temperature set to {temperature}")
                        except:
                            print("Invalid temperature. Format: :temp 0.8")
                    elif cmd.startswith(':length '):
                        try:
                            length = int(cmd.split()[1])
                            print(f"Generation length set to {length}")
                        except:
                            print("Invalid length. Format: :length 100")
                    continue
                
                # Generate text
                print("\nGenerating...")
                generated_text = self.generate_text(user_input, max_new_tokens=length, temperature=temperature)
                print("\nOutput:")
                print("-" * 40)
                print(generated_text)
                print("-" * 40)
                
            except KeyboardInterrupt:
                print("\nExiting interactive mode.")
                break
            except Exception as e:
                print(f"Error: {e}")

def use_bitnet():
    """Function to use BitNet with your project"""
    # Create directories if they don't exist
    if not os.path.exists('data'):
        os.makedirs('data')
    if not os.path.exists('models'):
        os.makedirs('models')
    
    try:
        # Initialize BitNet
        bitnet = BitNetTextGenerator()
        
        # List available text files
        files = [f for f in os.listdir('data') if f.endswith('.txt')]
        
        if not files:
            print("No text files found in the 'data' directory.")
            print("Please add text files to the 'data' directory and try again.")
            return
        
        print("\nAvailable text files for BitNet:")
        for i, file in enumerate(files):
            print(f"{i+1}. {file}")
        
        # Get user selection
        choice = input("\nSelect a file number (or press Enter for default sample2.txt): ")
        if choice.strip() == '':
            file_path = 'sample2.txt'
        else:
            try:
                idx = int(choice) - 1
                if 0 <= idx < len(files):
                    file_path = f"data/{files[idx]}"
                else:
                    print("Invalid selection. Using default file.")
                    file_path = 'sample2.txt'
            except ValueError:
                print("Invalid input. Using default file.")
                file_path = 'sample2.txt'
        
        # Verify file exists
        if not os.path.exists(file_path):
            print(f"Error: File {file_path} not found.")
            return
        
        # Ask about fine-tuning
        fine_tune = input("Do you want to fine-tune BitNet on your data? (y/n): ").lower().startswith('y')
        
        if fine_tune:
            # Get epochs
            epochs_input = input("Enter number of epochs (default: 1): ")
            epochs = int(epochs_input) if epochs_input.strip().isdigit() else 1
            
            # Prepare data
            try:
                dataset = bitnet.prepare_data_for_finetuning(file_path)
                
                # Fine-tune
                output_dir = f"./models/bitnet-{os.path.basename(file_path).split('.')[0]}"
                trainer = bitnet.finetune(dataset, output_dir=output_dir, epochs=epochs)
            except Exception as e:
                print(f"Error during fine-tuning process: {e}")
                print("Continuing with pre-trained model...")
        
        # Generate samples
        print("\nGenerating samples with BitNet:")
        for temp in [0.5, 0.8, 1.0]:
            print(f"\nSample with temperature {temp}:")
            sample = bitnet.generate_text("", max_new_tokens=100, temperature=temp)
            print(f"{sample}\n{'-'*40}")
        
        # Interactive mode
        use_interactive = input("\nWould you like to enter interactive mode? (y/n): ")
        if use_interactive.lower().startswith('y'):
            bitnet.interactive_mode()
            
    except Exception as e:
        print(f"Error using BitNet: {e}")

In [None]:
# Cell 5: Run the Main Menu
try:
    # This will start the application
    main_menu()
except Exception as e:
    print(f"Error running main menu: {e}")