In [1]:
#Import important library

from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig
import torch
import gc



In [2]:


import os
from dotenv import load_dotenv

# Load environment variables from .env
load_dotenv()

# Fetch Hugging Face token
hf_token = os.getenv("HF_TOKEN")



In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, BitsAndBytesConfig
import torch
from datasets import concatenate_datasets
from torch import mps
mps.empty_cache()
torch.mps.empty_cache()
# Define the Gemma 3 1B model for fine-tuning.
# 'google/gemma-3-1b-it' is the instruction-tuned version, good for generation tasks.
GPT_MODEL = "google/gemma-3-1b-it"

# 1. Helper function to load datasets safely, handling potential errors
def try_load_dataset(*args, **kwargs):
    try:
        return load_dataset(*args, **kwargs)
    except Exception as e:
        print(f"Failed to load {args[0]}: {e}")
        return None

# 2. Load the two datasets from Hugging Face Hub
# These datasets will now be used for causal language modeling fine-tuning.
dataset_geeta = try_load_dataset("nikita200/geeta")
dataset_mahabharata = try_load_dataset("aaru2330/maha-epic")

# 3. Print column names to verify structure of the loaded datasets
if dataset_geeta:
    print("Geeta columns:", dataset_geeta["train"].column_names)
if dataset_mahabharata:
    print("Mahabharata columns:", dataset_mahabharata["train"].column_names)

# 4. Initialize the tokenizer using the selected Gemma model.
# It's best practice to use the tokenizer associated with the model you're fine-tuning.
tokenizer = AutoTokenizer.from_pretrained(GPT_MODEL)

# If the tokenizer doesn't have a pad token, add one.
# For causal LMs, it's crucial to have a pad token for batching.
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token # Gemma often uses EOS as pad

# 5. Tokenization function for the Geeta dataset for causal language modeling
# For causal LMs, we typically prepare inputs and the model handles shifting for labels.
def tokenize_geeta_function(examples):
    return tokenizer(
        examples["output"],
        padding="max_length",
        truncation=True,
        max_length=256 # Reduced max_length to save memory
    )

# 6. Tokenization function for the Mahabharata dataset for causal language modeling
def tokenize_mahabharata_function(examples):
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=256 # Reduced max_length to save memory
    )

# 7. Tokenize datasets
tokenized_geeta = None
tokenized_mahabharata = None

if dataset_geeta and "output" in dataset_geeta["train"].column_names:
    tokenized_geeta = dataset_geeta["train"].map(tokenize_geeta_function, batched=True)
    print("Geeta dataset tokenized.")
else:
    print("Geeta dataset not loaded or missing 'output' column. Skipping tokenization.")

if dataset_mahabharata and "text" in dataset_mahabharata["train"].column_names:
    tokenized_mahabharata = dataset_mahabharata["train"].map(tokenize_mahabharata_function, batched=True)
    print("Mahabharata dataset tokenized.")
else:
    print("Mahabharata dataset not loaded or missing 'text' column. Skipping tokenization.")

# 8. Combine tokenized datasets for training
# Ensure both datasets were successfully tokenized before concatenating
if tokenized_geeta and tokenized_mahabharata:
    combined_dataset = concatenate_datasets([tokenized_geeta, tokenized_mahabharata])
    print(f"Combined dataset created with {len(combined_dataset)} examples.")

    # For causal language modeling, labels are typically the input_ids themselves.
    # The Trainer will handle the shifting for next-token prediction.
    combined_dataset = combined_dataset.map(lambda examples: {"labels": examples["input_ids"]}, batched=True)

    print("Combined dataset columns after adding labels:", combined_dataset.column_names)
    # Remove original text columns if they are not needed for training
    combined_dataset = combined_dataset.remove_columns([col for col in combined_dataset.column_names if col not in ['input_ids', 'attention_mask', 'labels']])
else:
    print("Could not create combined dataset. Ensure both source datasets loaded and tokenized successfully.")
    combined_dataset = None # Set to None to prevent errors if not created

# 9. Load Gemma model for causal language modeling (text generation)
# AutoModelForCausalLM is used for generative tasks. No 'num_labels' needed here.
model = AutoModelForCausalLM.from_pretrained(GPT_MODEL)

# 10. Move model to MPS (Apple Silicon GPU) if available, otherwise to CPU
# This is for optimizing performance on Apple Silicon Macs
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Apple Silicon GPU (MPS) for training.")
else:
    device = torch.device("cpu")
    print("MPS not available. Using CPU for training.")
model.to(device)

# 11. Configure Training Arguments
# 'fp16=False' and 'bf16=False' are essential to avoid mixed precision errors on MPS.
training_args = TrainingArguments(
    output_dir="./gemma-3-1b-mythology-pretrained", # Directory to save checkpoints and final model
    per_device_train_batch_size=1,                    # Batch size per device during training
    num_train_epochs=1,                               # Number of training epochs
    save_steps=500,                                   # Save checkpoint every 500 steps
    logging_steps=100,                                # Log training metrics every 100 steps
    fp16=False,                                       # Disable FP16 mixed precision
    bf16=False,                                       # Disable BF16 mixed precision
    report_to="none",                                 # Do not report training metrics to external services
    remove_unused_columns=True, 
    offload_buffers=True  
)

# 12. Initialize the Hugging Face Trainer
# The Trainer orchestrates the training loop.
if combined_dataset: # Only proceed if the combined_dataset was successfully created
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=combined_dataset,
    )

    # 13. Start training the model
    print("Starting model training...")
    trainer.train()
    print("Model training complete.")

    # 14. Save the fine-tuned model
    trainer.save_model("./gemma-3-1b-mythology-pretrained")
    print("Fine-tuned model saved to ./gemma-3-1b-mythology-pretrained")
else:
    print("Training skipped because combined_dataset could not be created.")



In [None]:
import ollama

response = ollama.chat(
    model='llama3.3:latest',
    messages=[
        {'role': 'user', 'content': 'Write a short story about a wise king in ancient India.'}
    ]
)

print(response['message']['content'])

In [60]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Path to your fine-tuned model directory
model_name = "./gemma-3-1b-mythology-pretrained"

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Move model to appropriate device (MPS for Mac, CUDA for GPU, else CPU)
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
model.to(device)

# Provide a prompt for the story
prompt = "Once upon a time in a magical forest,"

# Encode the prompt and move to device
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

# Generate story continuation
with torch.no_grad():
    output = model.generate(
        input_ids,
        max_length=100,         # Total length (prompt + generated)
        num_return_sequences=1, # Number of stories to generate
        do_sample=True,         # Enable sampling for creativity
        temperature=0.8         # Sampling temperature
    )

# Decode and print the generated story
story = tokenizer.decode(output[0], skip_special_tokens=True)
print(story)

HFValidationError: Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are forbidden, '-' and '.' cannot start or end the name, max length is 96: './gemma-3-1b-mythology-pretrained'.

In [10]:
import torch

torch.mps.empty_cache()

