# Environment and GPU Configuration 

This cell prepares the runtime for fine-tuning.
We set the Keras backend to TensorFlow, specify which GPUs should be visible, enable dynamic GPU memory growth to avoid full memory reservation, and suppress unnecessary TensorFlow logs.

After applying these settings, we import TensorFlow, check how many GPUs are available, and enable memory-growth on each one. This ensures stable GPU usage when loading and training the Gemma 3 model.

In [None]:
import kagglehub
kagglehub.login()


In [None]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Import TensorFlow FIRST to lock in GPU
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
print(f"Initial GPU check: {len(gpus)} GPUs")

if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    print("✓ GPUs configured")

# Upgrade KerasNLP to the Latest Version

This cell installs the latest version of KerasNLP, which includes full support for the Gemma 3 model family.
We run a simple pip upgrade command, and then print a confirmation.
Do not restart the runtime after installing, because TensorFlow and the GPU setup from the previous cell would reset.

In [None]:
# Upgrade to latest KerasNLP for Gemma3 support
!pip install -q --upgrade keras-nlp
!

print("✓ KerasNLP upgraded to latest - continue to next cell (do NOT restart)")

# Verify Installation and Environment

This cell performs several checks before we start fine-tuning:


1. Imports required libraries: keras, keras_nlp, TopKSampler, time, csv, and logging.
2. Suppresses verbose logs from sentencepiece.
3. Prints the current versions of Keras and KerasNLP.
4. Re-checks that GPUs are still available.
5. Verifies that the Gemma3CausalLM model is present in KerasNLP.


This ensures the environment is correctly set up and ready for model fine-tuning.

In [None]:
import keras
import keras_nlp
from keras_nlp.samplers import TopKSampler
from time import time
import csv
import logging

# Suppress messages
logging.getLogger("sentencepiece").setLevel(logging.ERROR)

print("="*60)
print("KerasNLP version:", keras_nlp.__version__)
print("Keras version:", keras.__version__)

# Re-verify GPU
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
print(f"Num GPUs: {len(gpus)}")

if gpus:
    print("✓✓✓ GPU STILL DETECTED! ✓✓✓")
else:
    print("⚠️ GPU lost")
    
# Check Gemma3
if hasattr(keras_nlp.models, 'Gemma3CausalLM'):
    print("✓ Gemma3CausalLM available!")
else:
    print("✗ Gemma3CausalLM NOT available")
    print(f"Available: {[x for x in dir(keras_nlp.models) if 'Gemma' in x]}")
print("="*60)

# Load Gemma3 270M Model

This cell loads the Gemma3 270M causal language model using KerasNLP’s from_preset method.
We use the Kaggle-hosted preset to get the pre-trained weights and configuration.
Once loaded, the model is ready for fine-tuning.

In [None]:
import kagglehub

path = kagglehub.model_download("keras/gemma3/keras/gemma3_270m")
print("Path to model files:", path)

In [None]:
# Load the model
# We load the model gemma_3_270M using keras_nlp.
print("Loading Gemma3 270M model...")
gemma_lm = keras_nlp.models.Gemma3CausalLM.from_preset(path)
print("✓ Model loaded successfully!")


# Inspect Model Architecture

This cell displays a summary of the Gemma3 270M model, including:

Layer types

Output shapes

Number of parameters

It helps us understand the model structure and verify that it loaded correctly.

In [None]:
gemma_lm.summary()

# Define Training Configuration

This cell sets up a simple configuration class CFG that contains key hyperparameters for fine-tuning:


* max_length: Maximum sequence length for input text.
* data_size: Number of training examples to use.
* lora_rank: Rank for LoRA (Low-Rank Adaptation) fine-tuning.
* epochs: Number of training epochs.
* batch_size: Number of samples per training batch.


An instance cfg is created so these parameters can be easily accessed throughout the notebook.

In [None]:
class CFG:
 
    max_length = 128
    data_size = 2560
    lora_rank = 16
    epochs = 40
    batch_size = 2

cfg = CFG()

# Load and Prepare Dataset

This cell reads a CSV file containing medical question-answer pairs and converts it into a format suitable for fine-tuning.

The CSV has two columns: question and answer.

Each row is transformed into a dictionary with keys prompts (from question) and responses (from answer).

All examples are collected in a list called data.

In [None]:
import csv
import kagglehub


path = kagglehub.dataset_download("gpreda/medquad")
print("Dataset downloaded to:", path)

csv_path = f"{path}/medquad.csv"      

data = []

# The CSV file contains two columns 'question' and 'answer'
with open(csv_path, mode='r', encoding='utf-8') as file:
    reader = csv.DictReader(file)
    for row in reader:
        # we replace with 'prompts' and 'responses'
        data.append({"prompts": row['question'], 'responses': row['answer']})

In [None]:
print(f"Data size: {len(data)}")

# Limit Dataset Size

This cell trims the dataset to the first cfg.data_size examples.
This allows faster training and easier experimentation while still using a representative subset of the data.

In [None]:
data = data[:cfg.data_size]

In [None]:
print(f"Data size: {len(data)}")

# Convert Data to TensorFlow Dataset

This cell converts the Python list data into a TensorFlow tf.data.Dataset, which is optimized for training.

We use a generator to yield each dictionary from data.

output_signature specifies the expected shape and type for each field: both prompts and responses are strings.

This allows TensorFlow to efficiently batch, shuffle, and prefetch the dataset for training.

In [None]:
import tensorflow as tf

dataset = tf.data.Dataset.from_generator(
    lambda: (item for item in data),
    output_signature={
        "prompts": tf.TensorSpec(shape=(), dtype=tf.string),
        "responses": tf.TensorSpec(shape=(), dtype=tf.string),
    }
)

In [None]:
from IPython.display import display, Markdown
def colorize_text(text):
    for word, color in zip(["Category", "Question", "Answer"], ["blue", "red", "green"]):
        text = text.replace(f"{word}:", f"\n\n**<font color='{color}'>{word}:</font>**")
    return text

# Display a Sample from the Dataset

This cell displays the 4th example (data[3]) from the dataset using the colorize_text_dict function.
It shows the prompts (question) in red and responses (answer) in green for easy visual inspection.

In [None]:
def colorize_text_dict(sample):
    """
    sample: dict with keys 'prompts' and 'responses'
    """
    colored_text = ""
    colored_text += f"**<font color='red'>Question:</font>** {sample['prompts']}\n\n"
    colored_text += f"**<font color='green'>Answer:</font>** {sample['responses']}\n\n"
    return colored_text

In [None]:
print(data[3])

In [None]:
display(Markdown(colorize_text_dict(data[3])))

# Generate a Sample Response

This cell demonstrates how the Gemma3 270M model generates text:

We create a prompt with a question (prompts) and an empty response (responses).

The model generates a response with gemma_lm.generate.

We format the generated answer and display it using colorize_text_dict for readability.

In [None]:
prompt = {
    "prompts":"What are the treatments for Glaucoma ?",
    "responses":""}
response = gemma_lm.generate(prompt, max_length=cfg.max_length)

answer = {"prompts": prompt["prompts"][0], "responses": response[len(prompt["prompts"][0]):]}
display(Markdown(colorize_text_dict(answer)))

# Enable LoRA Fine-Tuning

This cell enables LoRA (Low-Rank Adaptation) on the Gemma3 model:

LoRA allows parameter-efficient fine-tuning by only training low-rank matrices instead of the full model.

We set the LoRA rank to cfg.lora_rank as defined in our configuration.

After enabling LoRA, we display the model summary to verify the changes.

In [None]:
# Enable LoRA for the model and set the LoRA rank to cfg.lora_rank.
gemma_lm.backbone.enable_lora(rank=cfg.lora_rank)
gemma_lm.summary()

# Fine-Tune Gemma3 on Medical QA Dataset

This cell performs the actual fine-tuning of the model:


* Limits input sequences to cfg.max_length to control GPU memory usage.
* Uses AdamW optimizer, common for transformer models, with weight decay.Excludes biases and layer norm parameters from weight decay.
* Uses SparseCategoricalCrossentropy as the loss function and tracks accuracy.
* Batches the dataset according to cfg.batch_size and trains for cfg.epochs epochs.
* Training history is saved in the history variable for later analysis.


In [None]:
# Fine-tune on the Medical QA dataset.

# Limit the input sequence length to 128 (to control memory usage).
gemma_lm.preprocessor.sequence_length = cfg.max_length
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
batched_dataset = dataset.batch(cfg.batch_size)

history = gemma_lm.fit(batched_dataset, epochs=cfg.epochs)

# Visualize Training Performance

This cell plots the training loss and accuracy over epochs using Matplotlib:


* loss tracks the model’s cross-entropy loss.
* accuracy tracks the model’s Sparse Categorical Accuracy.
* Separate plots are generated to visualize how the model improved during fine-tuning.


In [None]:
import matplotlib.pyplot as plt

loss = history.history['loss']
accuracy = history.history['sparse_categorical_accuracy']

# Plot Loss
plt.plot(loss, label='Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.show()

# Plot Accuracy
plt.plot(accuracy, label='Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training Accuracy')
plt.legend()
plt.show()

# Generate Response After Fine-Tuning

This cell tests the fine-tuned Gemma3 model:

We create a prompt using a template with a question and empty answer.

The model generates a response using gemma_lm.generate.

The output is displayed using colorize_text for better readability in the notebook.

In [None]:
template = "Question:\n{question}\n\nAnswer:\n{answer}"
prompt = template.format(
    question="What are the complications of Paget's Disease of Bone ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=cfg.max_length)
display(Markdown(colorize_text(response)))

In [None]:
prompt = template.format(
    question="What are the treatments for Diabetes ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=cfg.max_length)
display(Markdown(colorize_text(response)))

In [None]:
prompt = template.format(
    question="What are the symptoms of Glaucoma ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=cfg.max_length)
display(Markdown(colorize_text(response)))

In [None]:
prompt = template.format(
    question="What are the treatments for Glaucoma ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=cfg.max_length)
display(Markdown(colorize_text(response)))

In [None]:
prompt = template.format(
    question="What causes High Blood Pressure ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=cfg.max_length)
display(Markdown(colorize_text(response)))