In [1]:
import keras_hub
import keras
import os
import tensorflow as tf
import pandas as pd

In [2]:
# Set backend and distribution (as defined previously)
os.environ["KERAS_BACKEND"] = "jax"
keras.config.set_floatx("bfloat16")

In [None]:
# --- 2. Load Base Model and LoRA Structure ---
# Load the base model, quantization, and LoRA structure using the same parameters
gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_instruct_4b", dtype="bfloat16")
# gemma_lm.quantize("int8")
gemma_lm.backbone.enable_lora(rank=32)

Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_instruct_4b/1/download/config.json...


100%|██████████| 1.84k/1.84k [00:00<00:00, 2.52MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_instruct_4b/1/download/task.json...


100%|██████████| 5.55k/5.55k [00:00<00:00, 9.77MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_instruct_4b/1/download/assets/tokenizer/vocabulary.spm...


100%|██████████| 4.47M/4.47M [00:00<00:00, 10.8MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_instruct_4b/1/download/model.weights.h5...


100%|██████████| 8.79G/8.79G [10:13<00:00, 15.4MB/s]


In [None]:
# --- 3. Load Saved Weights ---
load_path = "/content/drive/MyDrive/Colab Notebooks/Capstone/Model Weights/gemma_4b_lora_weights_v6.weights.h5"

# Load the saved LoRA weights onto the newly initialized model
gemma_lm.load_weights(load_path)

In [None]:
print("✅ Fine-tuned model fully loaded and ready for inference.")

# --- 4. Final Inference Setup ---
# You can now compile the model with a sampler for inference
sampler = keras_hub.samplers.TopKSampler(k=5, seed=42)
gemma_lm.compile(sampler=sampler)

✅ Fine-tuned model fully loaded and ready for inference.


# Testing new responses

In [None]:
df = pd.read_excel(
    "/content/drive/Shareddrives/Capstone1/Capstone II/Data for FT/final_ft_eval.xlsx",
    sheet_name='eval_set'
)

In [None]:
df.head()

Unnamed: 0,id,question,answer,source
0,1,"In Texas marriage law, what is the default pol...",Texas presumes a marriage is valid and will up...,fa.1.pdf
1,2,What legal capacity does a person have after m...,"Regardless of age, lawful marriage grants adul...",fa.1.pdf
2,3,"In Texas, who issues marriage licenses, and wh...",Marriage licenses are issued by a Texas county...,fa.2.pdf
3,4,"In Texas, who issues marriage licenses, and wh...",You obtain a license from the county clerk of ...,fa.2.pdf
4,5,"Under Texas marital‑property rules, what types...","Property owned or claimed before marriage, pro...",fa.3.pdf


In [None]:
# --- 2. MODIFIED Prompt template to include the persona instruction ---
# The persona instruction is placed at the very start of the template.
persona_instruction = (
    "You are a highly experienced and cautious **Texas Family Law Expert**."
    "Your primary goal is to provide a brief but legally rigorous, and factually accurate "
    "responses to the legal questions. The answer should be complete and cater to the main question asked."
)

template = (
    f"{persona_instruction}\n\n"  # Add the persona at the top
    "Instruction:\n{instruction}\n\n"
    "Response:\n{response}"
)

# --- 3. Generate responses and store them in a new column (No change needed here) ---
# The lambda function uses the new, context-rich template.
df["gemma_response"] = df["question"].apply(
    lambda q: gemma_lm.generate(
        template.format(instruction=q, response=""),
        max_length=512
    )
)

# --- 4. Extract only the text after "Response:\n" (No change needed here) ---
df["gemma_response"] = df["gemma_response"].apply(
    lambda x: x.split("Response:\n", 1)[-1].strip() if "Response:" in x else x
)

In [None]:
# --- 3. Generate responses in efficient batches ---
from tqdm.auto import tqdm

persona_instruction = (
    "You are a highly experienced and cautious **Texas Family Law Expert**."
    "Your primary goal is to provide a brief but legally rigorous, and factually accurate "
    "responses to the legal questions. The answer should be complete and cater to the main question asked."
)

template = (
    f"{persona_instruction}\n\n"  # Add the persona at the top
    "Instruction:\n{instruction}\n\n"
    "Response:\n{response}"
)

# Set a batch size. You may need to tune this (e.g., 8, 16, 32)
# based on your GPU's VRAM. Start small.
BATCH_SIZE = 16

# 1. Create an empty list to store all responses
all_responses = []

print(f"Starting generation with batch size {BATCH_SIZE}...")

# 2. Loop through the DataFrame in chunks (batches)
for i in tqdm(range(0, len(df), BATCH_SIZE)):

    # 3. Get the current batch of questions
    batch_df = df.iloc[i : i + BATCH_SIZE]
    batch_questions = batch_df["question"].tolist()

    # 4. Create a list of prompts for the batch
    batch_prompts = [
        template.format(instruction=q, response="") for q in batch_questions
    ]

    try:
        # 5. Generate responses for the entire batch in parallel on the GPU
        generated_texts = gemma_lm.generate(
            batch_prompts,
            max_length=512
        )

        # 6. Clean the responses from the batch
        cleaned_batch = [
            text.split("Response:\n", 1)[-1].strip() if "Response:" in text else text
            for text in generated_texts
        ]

        # 7. Add the cleaned batch to our main list
        all_responses.extend(cleaned_batch)

    except Exception as e:
        # Handle any errors that might occur on a specific batch
        print(f"Error processing batch {i} to {i+BATCH_SIZE}: {e}")
        # Add 'None' or 'Error' for each item in the failed batch
        all_responses.extend([None] * len(batch_prompts))

print("Generation complete.")

# 8. Assign the list of all responses to the DataFrame
df["gemma_response"] = all_responses

In [None]:
import re

def aggressive_clean(text: str) -> str:
    """
    Removes common LLM collapse/repetition artifacts (e.g., A., a), **, etc.).
    """
    cleaned_text = text.strip()

    # 1. Define common artifacts that appear in collapsed models
    # This list covers your observed output (A. a) **)
    artifacts = [
        r"^\*+",  # Leading asterisks (**)
        r"^\s*[A]\.\s*",  # Leading 'A.'
        r"^\s*[a]\)\s*",  # Leading 'a)'
        r"^\s*[A]\)\s*",  # Leading 'A)'
        r"^\s*(\d{1,2}\.)\s*", # Leading single numbers (1., 2.)
        r"\*+$",  # Trailing asterisks (**)
        r"\s*[A]\.\s*$",  # Trailing 'A.'
        r"\s*[a]\)\s*$",  # Trailing 'a)'
        r"\s*[A]\)\s*$",  # Trailing 'A)'
    ]

    # 2. Iteratively strip these patterns until no more changes occur
    # This handles the repeated nature of the garbage (A. a) A. a))
    while True:
        original_text = cleaned_text
        for pattern in artifacts:
            # Use re.sub to remove the pattern at the start/end
            cleaned_text = re.sub(pattern, "", cleaned_text).strip()

        # Stop if no patterns were removed in this iteration
        if cleaned_text == original_text:
            break

    # 3. Handle excessive newlines/whitespace
    cleaned_text = re.sub(r'\n+', ' ', cleaned_text).strip()
    cleaned_text = re.sub(r'\s{2,}', ' ', cleaned_text).strip()

    # 4. Remove short, likely meaningless fragments (e.g., "a", "A", "The law")
    if len(cleaned_text) < 15 and (re.match(r"^[aA]\.$", cleaned_text) or re.match(r"^\*+$", cleaned_text)):
         return "" # Return empty string for pure garbage

    return cleaned_text

df["gemma_response_cleaned"] = df["gemma_response"].apply(aggressive_clean)

In [None]:
import nltk

# # 1. Download the required resource
try:
    nltk.download('punkt_tab')
    print("punkt_tab resource downloaded successfully.")
except Exception as e:
    print(f"Error downloading punkt_tab: {e}")

# 2. You may also need the standard 'punkt' resource if you're using sent_tokenize
# for general text (which you are):
try:
    nltk.download('punkt')
    print("punkt resource downloaded successfully.")
except Exception as e:
    print(f"Error downloading punkt: {e}")

In [None]:
from nltk.tokenize import sent_tokenize
import nltk

# You might need to download this resource if you haven't:
# nltk.download('punkt')

def truncate_repetition(text: str) -> str:
    """Truncates the text when a sentence is repeated more than once."""
    sentences = sent_tokenize(text)
    seen_sentences = set()
    final_sentences = []

    for sentence in sentences:
        # Normalize the sentence (lower case, remove punctuation) for comparison
        normalized_sent = re.sub(r'[^\w\s]', '', sentence.lower()).strip()

        if normalized_sent in seen_sentences:
            # Stop adding sentences once a repetition is found
            break

        if normalized_sent:
            seen_sentences.add(normalized_sent)
            final_sentences.append(sentence)

    return " ".join(final_sentences).strip()

# Apply after aggressive_clean
df["gemma_response_cleaned"] = df["gemma_response_cleaned"].apply(truncate_repetition)

In [None]:
df.head()

In [None]:
df.iloc[1]['answer']

In [None]:
df.iloc[1]['gemma_response_cleaned']

In [None]:
df.to_csv('/content/drive/MyDrive/Colab Notebooks/Capstone/Model Responses/gemma3_4b_responses_finetuned_v6.csv', index='False')