In [None]:
import os
import torch
import random
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.optim import AdamW
from tqdm import tqdm

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "facebook/opt-350m"  # Change to the larger model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

# Set pad token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load and chunk the data
def load_and_chunk_book(file_path, chunk_size=8):
    with open(file_path, "r", encoding="utf-8") as f:
        lines = [line.strip() for line in f.readlines() if line.strip()]
    random.seed(42)
    random.shuffle(lines)
    chunks = [" ".join(lines[i:i+chunk_size]) for i in range(0, len(lines), chunk_size)]
    return chunks

file_path = "/content/sample_data/Krishna_Leelas.txt"  # Update if needed
text_chunks = load_and_chunk_book(file_path)

# Tokenize
def tokenize_function(text, max_length=256):
    return tokenizer(text, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt")

# Dataset
class TextDataset(Dataset):
    def __init__(self, texts):
        self.examples = []
        for text in texts:
            tokenized = tokenize_function(text)
            input_ids = tokenized["input_ids"].squeeze(0)
            attention_mask = tokenized["attention_mask"].squeeze(0)
            labels = input_ids.clone()
            labels[labels == tokenizer.pad_token_id] = -100
            self.examples.append({
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels
            })

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return {key: val.to(device) for key, val in self.examples[idx].items()}

dataset = TextDataset(text_chunks)

# Train-test split
train_size = int(0.9 * len(dataset))
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)  # Reduce batch size for large model
eval_dataloader = DataLoader(eval_dataset, batch_size=2)

# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training
epochs = 5
gradient_accumulation_steps = 4

model.train()
for epoch in range(epochs):
    loop = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}")
    total_loss = 0
    for step, batch in enumerate(loop):
        outputs = model(**batch)
        loss = outputs.loss
        loss = loss / gradient_accumulation_steps
        loss.backward()

        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch+1} - Average Training Loss: {avg_loss:.4f}")
    print("🧠 Sample Output After Epoch:")
    print("➤", tokenizer.decode(model.generate(tokenizer("Lord Krishna was known for", return_tensors="pt").to(device)["input_ids"], max_length=60, do_sample=True, top_k=50, top_p=0.95, temperature=0.8)[0], skip_special_tokens=True))

# Save model
save_dir = "/content/opt350m_b_finetuned"
os.makedirs(save_dir, exist_ok=True)
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
print(f"✅ Model saved to {save_dir}")

# Function to generate text
def generate_text(prompt, max_length=60):
    model.eval()
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=256).to(device)  # Set max_length for tokenization
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_length=max_length + len(inputs["input_ids"][0]),  # Total length including the input tokens
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            temperature=1.0,
            num_return_sequences=1
        )
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Test the generation again
prompt = "How did Krishna defeat the demon?"
print("📝 Generated Text:")
print(generate_text(prompt))
