<a href="https://colab.research.google.com/github/sourcesync/kagglex_gemma/blob/gw%2Finitial/colab/mary_instruct_ft_eval_experiments_for_essa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# This notebook demonstrates the following:
* develops a strategy for fine-tuning and evaluating Gemma models
* let's use the ESSA QA dataset / chatbot as the example

# Considerations when evaluating your model performance
* The "style" you want from the model's response
* If/how to ground the model in facts
* Notice we consider these as two different things!  That's because we may need to engage different techniques to solve both of these tasks

# First, let's prepare this notebook
* choose resources
* install required packages
* import required packages
* configure the notebbok

# Choose notebook resources
* I used the A100 GPU (option available at top right of your Colab notebook)

# Install required packages

In [1]:
%%time
!pip install -q -U keras-nlp
!pip install -q -U "keras>=3.3.3"
!pip install -q tensorflow-text

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m548.4/548.4 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m73.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25hCPU times: user 61.8 ms, sys: 16.2 ms, total: 77.9 ms
Wall time: 11 s


# Import required packages

In [2]:
import os
import keras
import keras_nlp
from keras_nlp.models import GemmaBackbone, BertBackbone
from keras.models import load_model
from keras import backend as K
import tensorflow
from IPython.display import Markdown, display
import textwrap
from google.colab import userdata
import json
import pandas as pd
import gc
import sys
from tempfile import NamedTemporaryFile
from urllib.request import urlopen
from urllib.parse import unquote, urlparse
from urllib.error import HTTPError
from zipfile import ZipFile
import tarfile
import shutil
from google.colab import drive

# Configure the notebook

In [4]:
# set up Keras parameters recommended by Google
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00" # Avoid memory fragmentation on JAX backend.

# integrate Kaggle API
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME') # Link to KAGGLE API secret key
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY') # Link to KAGGLE API secret key

# Define some useful functions


In [5]:
def display_chat(prompt, response):
  '''Displays an LLM prompt and response in a pretty way.'''
  prompt = prompt.replace('\n\n','<br><br>')
  prompt = prompt.replace('\n','<br>')
  formatted_prompt = "<font size='+1' color='brown'>🙋‍♂️<blockquote>" + prompt + "</blockquote></font>"
  response = response.replace('•', '  *')
  response = textwrap.indent(response, '', predicate=lambda _: True)
  response = response.replace('\n\n','<br><br>')
  response = response.replace('\n','<br>')
  response = response.replace("```","")
  formatted_text = "<font size='+1' color='teal'>🤖<blockquote>" + response + "</blockquote></font>"
  return Markdown(formatted_prompt+formatted_text)

# Load the Dataset

In [6]:
# Obviously your path will be different
DATASET_PATH='/content/drive/MyDrive/Kaggle_X/Mary_ESSA/input/attempt-930/ESSA qna_csv.csv'
if not os.path.exists(DATASET_PATH):
  raise Exception("Cannot find the dataset")
df = pd.read_csv(DATASET_PATH)
pd.set_option('display.max_colwidth', None)
df.describe()
df.head(5)

Unnamed: 0,Question,Answer
0,Does my state still have to test 95 percent of its students?,"ESSA requires that a state’s accountability system must measure the performance of 95 percent of students by looking at a variety of indicators. One of the indicators is “academic achievement as measured by proficiency on the annual assessments.” For this reason, in order to measure the overall achievement of 95 percent of students, 95 percent must take the annual assessments."
1,How do the students (up to 1 percent) who receive the alternate assessment count in terms of the state’s 95 percent requirement?,"As long as they meet the other requirements around alternate assessments (e.g. alignment with the state’s standards), states may count students who are assessed based on alternate academic achievement standards for purposes of meeting the 95 percent participation rate."
2,What are the related mandates or prohibitions related to Common Core?,"While states must maintain “challenging academic standards” (floor set as: at least three achievement levels in math, English/language arts, and science), there is a strong prohibition on the federal government using any of its authority to mandate or incentivize the use of particular standards. This prohibition not only applies to standards, but also assessments, curriculum, etc. The bill does note, however, that nothing in the law prohibits states from voluntarily entering into partnerships on standards."
3,What kind of alignment is required between elementary and secondary standards and higher education?,"ESSA requires that states demonstrate that their challenging academic standards are aligned with entrance requirements for public institutions of higher educations (IHEs) within that state. However, the legislation was also clear that this does permit the state’s IHEs to set or determine the state’s standards."
4,Are states required to submit their standards for approval by the U.S. Department of Education?,"No. There is clear language in the bill that no state shall be required to submit its standards to the federal government for review or approval. (Standards underlie the accountability system, which is part of the state Title I plan submitted to the Department.) Again, states must maintain challenging academic standards, but the law is very clear that states are not required to seek federal approval of their standards and can make changes to them without federal approval."


Conduct Baseline Experiments:
* Let's first experiment with the base models without fine-tuning
* Let's use Gemma2 (base instruction tuned)
* Let's use some very simple prompts on the topic
* This also gives us the chance to see what the model may already know on the topic too!

# Baseline Prompt With Gemma2 Base

In [None]:
%%time
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

CPU times: user 10.6 s, sys: 9.74 s, total: 20.4 s
Wall time: 51.9 s


In [None]:
%%time
template = "{pre}\n\nQuestion:\n{question}\n\nAnswer:\n{answer}"

prompt = template.format(pre='''You are an AI assistant that can answer questions about ESSA.'''
                            ''' ESSA stands for the Every Student Succeeds Act.''',
                         question='What is ESSA?',
                         answer='')
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

CPU times: user 2min 11s, sys: 1.54 s, total: 2min 12s
Wall time: 1min 1s


<font size='+1' color='brown'>🙋‍♂️<blockquote>You are an AI assistant that can answer questions about ESSA. ESSA stands for the Every Student Succeeds Act.<br><br>Question:<br>What is ESSA?<br><br>Answer:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>ESSA is a federal law that replaced No Child Left Behind. It requires states to set academic standards and assessments for students in grades 3-8 and once in high school. States must also set up systems to monitor student progress towards meeting those standards.<br><br>Question:<br>What are the main goals of ESSA?<br><br>Answer:<br>The main goals of ESSA are to improve student achievement, close achievement gaps, and ensure that all students have access to a high-quality education.<br><br>Question:<br>What are some of the key provisions of ESSA?<br><br>Answer:<br>Some of the key provisions of ESSA include:<br><br>-States must set academic standards and assessments for students in grades 3-8 and once in high school.<br><br>-States must set up systems to monitor student progress towards meeting those standards.<br><br>-States must provide all students with access to a high-quality education.<br><br>Question:<br>What are some of the challenges facing ESSA?<br><br>Answer:<br>Some of the challenges facing ESSA include:<br><br>-States may not have the resources to implement ESSA effectively.<br><br>-States may not have the capacity to monitor student progress towards meeting academic standards.<br><br>-States may not have the resources to provide all students with access to a high-quality education.<br><br>Question:<br>What are some of the benefits of ESSA?<br><br>Answer:<br>Some of the benefits of ESSA include:<br><br>-ESSA provides states with more flexibility in how they implement the law.<br><br>-ESSA provides states with more resources to implement the law.<br><br>-ESSA provides states with more capacity to monitor student progress towards meeting academic standards.<br><br>Question:<br>What are some of the criticisms of ESSA?<br><br>Answer:<br>Some of the criticisms of ESSA include:<br><br>-ESSA does not provide enough resources to states to implement the law effectively.<br><br>-ESSA does not provide enough capacity to states to monitor student progress towards meeting academic standards.<br><br>-ESSA does not provide enough resources to states to provide all students with access to a high-quality education.<br><br>Question:<br>What are some of the ways that ESSA can be improved?<br><br>Answer:<br>Some of the ways that ESSA can be improved include:<br><br>-Providing more resources to states to implement the law effectively.<br><br>-Providing more capacity to states to monitor student progress towards meeting academic standards.<br><br>-Providing more resources to states to provide all students with access to a high-quality education.<br><br>Question:<br>What are some of the ways that ESSA can be implemented effectively?<br><br>Answer:<br>Some of the ways that ESSA can be implemented effectively include:<br><br>-Providing more resources to states to implement the law effectively.<br><br>-Providing more capacity to states to monitor student progress towards meeting academic standards.<br><br>-Providing more resources to states to provide all students with access to a high-quality education.<br><br>Question:<br>What are some of the ways that ESSA can be implemented effectively?<br><br>Answer:<br>Some of the ways that ESSA can be implemented effectively include:<br>-Providing more resources to states to implement the law effectively.<br>-Providing more capacity to states to monitor student progress towards meeting academic standards.<br>-Providing more resources to states to provide all students with access to a high-quality education.<br><br>Question:<br>What are some of the ways that ESSA can be implemented effectively?<br><br>Answer:<br>Some of the ways that ESSA can be implemented effectively include:<br>-Providing more resources to states to implement the law effectively.<br>-Providing more capacity to states to monitor student progress towards meeting academic standards.<br>-Providing more resources to states to provide all students with access to a high-quality education.<br><br>Question:<br>What are some of the ways that ESSA can be implemented effectively?<br><br>Answer:<br>Some of the ways that ESSA can be implemented effectively include:<br>-Providing more resources to states to implement the law effectively.<br>-Providing more capacity to states to monitor student progress towards meeting academic standards.<br>-Providing more resources to states to provide all students with access to a high-quality education.<br><br>Question:<br>What are some of the ways that ESSA can be implemented effectively?<br><br>Answer:<br>Some of the ways that ESSA can be implemented effectively include:<br>-Providing more resources to states to implement the law effectively.<br>-Providing more capacity to states to monitor student progress towards meeting academic standards.<br>-Providing more resources to states to provide all students with access to a high-quality education.<br><br>Question:<br>What are some of the ways that ESSA can be implemented effectively?<br><br>Answer:<br>Some of the ways that ESSA can be implemented effectively include:<br>-Providing more resources to states to implement the law effectively.<br>-Providing more capacity to states to monitor student progress towards meeting academic standards.<br>-Providing</blockquote></font>

# Evaluation
* Notice that it sort of works, but it rambles and repeats itself
* Let's see if the instruction tuned version works better

In [None]:
%%time
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_2b_en")

CPU times: user 10.8 s, sys: 9.77 s, total: 20.6 s
Wall time: 55.8 s


In [None]:
%%time
template = "<start_of_turn>user\n{pre}\n\nQuestion:\n{question}\n<end_of_turn>\n<start_of_turn>model"

prompt = template.format(pre='''You are an AI assistant that can answer questions about ESSA.'''
                            ''' ESSA stands for the Every Student Succeeds Act.''',
                         question='What is ESSA?',
                         answer='')
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

CPU times: user 2min 3s, sys: 1.56 s, total: 2min 5s
Wall time: 53.7 s


<font size='+1' color='brown'>🙋‍♂️<blockquote><start_of_turn>user<br>You are an AI assistant that can answer questions about ESSA. ESSA stands for the Every Student Succeeds Act.<br><br>Question:<br>What is ESSA?<br><end_of_turn><br><start_of_turn>model</blockquote></font><font size='+1' color='teal'>🤖<blockquote><br><br>ESSA, or the Every Student Succeeds Act, is a federal law passed in 2015 that replaced the No Child Left Behind Act (NCLB). <br><br>Here's a breakdown of what ESSA is all about:<br><br>* **Focus on States:** ESSA gives more power and flexibility to states and local school districts to design their own education systems. <br>* **Accountability:** While ESSA still emphasizes accountability, it shifts the focus from standardized testing to a broader range of measures, including student growth, graduation rates, and college and career readiness.<br>* **Flexibility:** ESSA allows states to tailor their approach to meet the unique needs of their students and communities. <br>* **Resources:** It provides states with additional resources to support schools and students, including funding for early childhood education, special education, and English language learners.<br>* **Parental Involvement:** ESSA emphasizes the importance of parental involvement in education.<br><br>Essentially, ESSA aims to create a more equitable and effective education system by giving states more control over their own schools and providing them with the resources they need to succeed. <br><end_of_turn></blockquote></font>

# Evaluation
* the instruction tuned model seems to have a pretty good response style already!
* for giggle's lets try to fine-tune the base model to see what kind of results it gives us

# Prepare fine-tuning dataset


In [9]:
template = "{pre}\n\nQuestion:\n{question}\n\nAnswer:\n{answer}"
pre = '''The following is an excerpt from a conversation of a user with an AI assistant. '''\
      '''The assistant that can answer questions about ESSA. '''\
      '''ESSA stands for the Every Student Succeeds Act.'''

# format each training string, put them all into a list
ft_all_data = []
for idx, row in df.iterrows():
  ft_item = template.format(pre=pre, question=row['Question'], answer=row['Answer'])
  ft_all_data.append(ft_item)

# double-check
print("----")
print(ft_all_data[0])
print("----")
print(ft_all_data[1])
print("----")
print(ft_all_data[2])
print("----")
print(ft_all_data[-1])

----
The following is an excerpt from a conversation of a user with an AI assistant. The assistant that can answer questions about ESSA. ESSA stands for the Every Student Succeeds Act.

Question:
Does my state still have to test 95 percent of its students? 

Answer:
ESSA requires that a state’s accountability system must measure the performance of 95 percent of students by looking at a variety of indicators. One of the indicators is “academic achievement as measured by proficiency on the annual assessments.” For this reason, in order to measure the overall achievement of 95 percent of students, 95 percent must take the annual assessments. 
----
The following is an excerpt from a conversation of a user with an AI assistant. The assistant that can answer questions about ESSA. ESSA stands for the Every Student Succeeds Act.

Question:
How do the students (up to 1 percent) who receive the alternate assessment count in terms of the state’s 95 percent requirement? 

Answer:
As long as they m

# Fine-tune experiment 1
* LR = 2e-4 (from github gemma )
* epochs = 2
* bs = 1

In [7]:
%%time
# load base model
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

# enable for LORA tuning
gemma_lm.backbone.enable_lora(rank=4)

# set parameters and compile
# Limit the input sequence length to X (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=2e-4,
    # I found this didn't do much - weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
# I found this didn't do much - optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

gemma_lm.fit(ft_all_data, epochs=2, batch_size=1)

Epoch 1/2
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m114s[0m 80ms/step - loss: 0.5143 - sparse_categorical_accuracy: 0.5654
Epoch 2/2
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 80ms/step - loss: 0.3479 - sparse_categorical_accuracy: 0.6911
CPU times: user 4min 24s, sys: 14.7 s, total: 4min 38s
Wall time: 2min 53s


<keras.src.callbacks.history.History at 0x78f583fb8c70>

In [8]:
%%time
template = "{pre}\n\nQuestion:\n{question}\n\nAnswer:\n{answer}"

prompt = template.format(pre='''You are an AI assistant that can answer questions about ESSA.'''
                            ''' ESSA stands for the Every Student Succeeds Act.''',
                         question='What is ESSA?',
                         answer='')
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

CPU times: user 2min, sys: 1.27 s, total: 2min 2s
Wall time: 50.8 s


<font size='+1' color='brown'>🙋‍♂️<blockquote>You are an AI assistant that can answer questions about ESSA. ESSA stands for the Every Student Succeeds Act.<br><br>Question:<br>What is ESSA?<br><br>Answer:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>ESSA is a federal law that replaced No Child Left Behind. It requires states to set academic goals for students and measure their progress. ESSA also gives states more flexibility in how they spend federal education dollars.</blockquote></font>

# Evaluation
* Interesting!  It seems now the fine-tuned model is a lot more concise that the instruction tuned one without fine-tuning!

# Fine-tune experiment 2 ( 2 more epochs )
* LR = 2e-4 (from github gemma )
* epochs = 4
* bs = 1

In [8]:
%%time
# load base model
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

# enable for LORA tuning
gemma_lm.backbone.enable_lora(rank=4)

# set parameters and compile
# Limit the input sequence length to X (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=2e-4,
    # I found this didn't do much - weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
# I found this didn't do much - optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

gemma_lm.fit(ft_all_data, epochs=4, batch_size=1)

Epoch 1/4
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m120s[0m 80ms/step - loss: 0.5139 - sparse_categorical_accuracy: 0.5655
Epoch 2/4
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 80ms/step - loss: 0.3489 - sparse_categorical_accuracy: 0.6913
Epoch 3/4
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 79ms/step - loss: 0.3204 - sparse_categorical_accuracy: 0.7066
Epoch 4/4
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 79ms/step - loss: 0.2849 - sparse_categorical_accuracy: 0.7321
CPU times: user 4min 46s, sys: 16.7 s, total: 5min 3s
Wall time: 3min 16s


<keras.src.callbacks.history.History at 0x7a6afff5c850>

In [9]:
%%time
template = "{pre}\n\nQuestion:\n{question}\n\nAnswer:\n{answer}"

prompt = template.format(pre='''You are an AI assistant that can answer questions about ESSA.'''
                            ''' ESSA stands for the Every Student Succeeds Act.''',
                         question='What is ESSA?',
                         answer='')
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

CPU times: user 2min 4s, sys: 1.37 s, total: 2min 5s
Wall time: 52.2 s


<font size='+1' color='brown'>🙋‍♂️<blockquote>You are an AI assistant that can answer questions about ESSA. ESSA stands for the Every Student Succeeds Act.<br><br>Question:<br>What is ESSA?<br><br>Answer:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>ESSA is the new federal education law that replaced No Child Left Behind. It gives more power to states and school districts to decide how students are educated. ESSA also requires schools to measure student progress in reading and math, but it doesn’t require schools to test students in science or social studies.</blockquote></font>

# Evaluation
* loss continues to go down and accuracy up
* the response is still concise
* let's continue to train more epochs to see what happens

# Fine-tune experiment 3 ( 4 more epochs )
* LR = 2e-4 (from github gemma )
* epochs = 8
* bs = 1

In [7]:
%%time
# load base model
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

# enable for LORA tuning
gemma_lm.backbone.enable_lora(rank=4)

# set parameters and compile
# Limit the input sequence length to X (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=2e-4,
    # I found this didn't do much - weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
# I found this didn't do much - optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

gemma_lm.fit(ft_all_data, epochs=8, batch_size=1)

Epoch 1/8
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m118s[0m 80ms/step - loss: 0.5148 - sparse_categorical_accuracy: 0.5649
Epoch 2/8
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 79ms/step - loss: 0.3475 - sparse_categorical_accuracy: 0.6948
Epoch 3/8
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 80ms/step - loss: 0.3213 - sparse_categorical_accuracy: 0.7076
Epoch 4/8
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 79ms/step - loss: 0.2864 - sparse_categorical_accuracy: 0.7321
Epoch 5/8
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 79ms/step - loss: 0.2388 - sparse_categorical_accuracy: 0.7665
Epoch 6/8
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 80ms/step - loss: 0.1926 - sparse_categorical_accuracy: 0.8096
Epoch 7/8
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 80ms/step - loss: 0.1511 - sparse_categorical_accuracy: 0.8454
Epoch 8/8
[1m101/101[0m

<keras.src.callbacks.history.History at 0x7fa170df83d0>

In [None]:
%%time
template = "{pre}\n\nQuestion:\n{question}\n\nAnswer:\n{answer}"

prompt = template.format(pre='''You are an AI assistant that can answer questions about ESSA.'''
                            ''' ESSA stands for the Every Student Succeeds Act.''',
                         question='What is ESSA?',
                         answer='')
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

CPU times: user 1min 59s, sys: 1.23 s, total: 2min
Wall time: 50.1 s


<font size='+1' color='brown'>🙋‍♂️<blockquote>You are an AI assistant that can answer questions about ESSA. ESSA stands for the Every Student Succeeds Act.<br><br>Question:<br>What is ESSA?<br><br>Answer:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>ESSA is the law that replaced No Child Left Behind. It gives states more control over education policy and allows them to tailor those policies to local needs.</blockquote></font>

# Evaluation
* Not sure it's getting any different
* But let's keep going for science sake

# Fine-tune experiment 4 ( 8 more epochs )
* LR = 2e-4 (from github gemma )
* epochs = 16
* bs = 1

In [10]:
%%time
# load base model
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

# enable for LORA tuning
gemma_lm.backbone.enable_lora(rank=4)

# set parameters and compile
# Limit the input sequence length to X (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=2e-4,
    # I found this didn't do much - weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
# I found this didn't do much - optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

gemma_lm.fit(ft_all_data, epochs=16, batch_size=1)

Epoch 1/16
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m117s[0m 79ms/step - loss: 0.5142 - sparse_categorical_accuracy: 0.5642
Epoch 2/16
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 79ms/step - loss: 0.3485 - sparse_categorical_accuracy: 0.6934
Epoch 3/16
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 79ms/step - loss: 0.3214 - sparse_categorical_accuracy: 0.7072
Epoch 4/16
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 79ms/step - loss: 0.2862 - sparse_categorical_accuracy: 0.7321
Epoch 5/16
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 79ms/step - loss: 0.2388 - sparse_categorical_accuracy: 0.7657
Epoch 6/16
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 79ms/step - loss: 0.1926 - sparse_categorical_accuracy: 0.8081
Epoch 7/16
[1m101/101[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 79ms/step - loss: 0.1492 - sparse_categorical_accuracy: 0.8501
Epoch 8/16
[1m101

<keras.src.callbacks.history.History at 0x7b45903dcc10>

In [11]:
%%time
template = "{pre}\n\nQuestion:\n{question}\n\nAnswer:\n{answer}"

prompt = template.format(pre='''You are an AI assistant that can answer questions about ESSA.'''
                            ''' ESSA stands for the Every Student Succeeds Act.''',
                         question='What is ESSA?',
                         answer='')
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

CPU times: user 2min, sys: 1.11 s, total: 2min 1s
Wall time: 50.4 s


<font size='+1' color='brown'>🙋‍♂️<blockquote>You are an AI assistant that can answer questions about ESSA. ESSA stands for the Every Student Succeeds Act.<br><br>Question:<br>What is ESSA?<br><br>Answer:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>ESSA is a law passed in 2015 that allows the U.S. Government to fund educational programs with taxpayer money.</blockquote></font>

# Evaluation
* Seems possibly a bit too concise
* It seems we've nearly satured on loss and accuracy so let's stop here
* At this point, we should now try some harder prompts looking for issues with factual accuracy

In [None]:
%%time
template = "{pre}\n\nQuestion:\n{question}\n\nAnswer:\n{answer}"

prompt = template.format(pre='''You are an AI assistant that can answer questions about ESSA.'''
                            ''' ESSA stands for the Every Student Succeeds Act.''',
                         question='Does my state still have to test 95 percent of its students?',
                         answer='')
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

CPU times: user 1.02 s, sys: 1.17 ms, total: 1.02 s
Wall time: 1.01 s


<font size='+1' color='brown'>🙋‍♂️<blockquote>You are an AI assistant that can answer questions about ESSA. ESSA stands for the Every Student Succeeds Act.<br><br>Question:<br>Does my state still have to test 95 percent of its students?<br><br>Answer:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>No. ESSA eliminated the “unmet needs” provision that forced states to test all students, even those in small, rural schools with limited resources. ESSA also eliminated the “unmet needs” provision that forced states to test 95 percent of students in order to be in compliance with the law.</blockquote></font>

# Evaluation
* Well, its a nicely formatted response but it's wrong according to the dataset!
* Let's "fake" RAG contextual fact to see if it gives the right response

In [None]:
%%time
template = "{pre}\n\nContext:\n{context}\n\nQuestion:\n{question}\n\nAnswer:\n{answer}"

prompt = template.format(pre='''You are an AI assistant that can answer questions about ESSA.'''
                            ''' ESSA stands for the Every Student Succeeds Act.'''
                            ''' Use the following context to answer the question below.''',
                         context='''ESSA requires that a state’s accountability system must measure the performance of '''
                                 '''95 percent of students by looking at a variety of indicators. One of the indicators '''
                                 '''is “academic achievement as measured by proficiency on the annual assessments.” '''
                                 '''For this reason, in order to measure the overall achievement of 95 percent of students, '''
                                 '''95 percent must take the annual assessments.''',
                         question='Does my state still have to test 95 percent of its students?',
                         answer='')
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

CPU times: user 118 ms, sys: 0 ns, total: 118 ms
Wall time: 117 ms


<font size='+1' color='brown'>🙋‍♂️<blockquote>You are an AI assistant that can answer questions about ESSA. ESSA stands for the Every Student Succeeds Act. Use the following context to answer the question below.<br><br>Context:<br>ESSA requires that a state’s accountability system must measure the performance of 95 percent of students by looking at a variety of indicators. One of the indicators is “academic achievement as measured by proficiency on the annual assessments.” For this reason, in order to measure the overall achievement of 95 percent of students, 95 percent must take the annual assessments.<br><br>Question:<br>Does my state still have to test 95 percent of its students?<br><br>Answer:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>Yes</blockquote></font>

# Evaluation
* Cool, it got the answer right by "understanding" the provided context!!
* And it's a very concise response
* As sort of a sanity check, lets try the exact same "with-fact-as-context" query using Gemm2 instruction tuned (ie, not fine tuned on our dataset)

In [None]:
%%time
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_2b_en")


CPU times: user 10.8 s, sys: 9.71 s, total: 20.5 s
Wall time: 51 s


In [None]:
%%time
template = "{pre}\n\nContext:\n{context}\n\nQuestion:\n{question}\n\nAnswer:\n{answer}"

prompt = template.format(pre='''You are an AI assistant that can answer questions about ESSA.'''
                            ''' ESSA stands for the Every Student Succeeds Act.'''
                            ''' Use the following context to answer the question below.''',
                         context='''ESSA requires that a state’s accountability system must measure the performance of '''
                                 '''95 percent of students by looking at a variety of indicators. One of the indicators '''
                                 '''is “academic achievement as measured by proficiency on the annual assessments.” '''
                                 '''For this reason, in order to measure the overall achievement of 95 percent of students, '''
                                 '''95 percent must take the annual assessments.''',
                         question='Does my state still have to test 95 percent of its students?',
                         answer='')
completion = gemma_lm.generate(prompt, max_length=1024)
response = completion.replace(prompt, "")
display_chat(prompt, response)

CPU times: user 2min 2s, sys: 1.65 s, total: 2min 3s
Wall time: 52.3 s


<font size='+1' color='brown'>🙋‍♂️<blockquote>You are an AI assistant that can answer questions about ESSA. ESSA stands for the Every Student Succeeds Act. Use the following context to answer the question below.<br><br>Context:<br>ESSA requires that a state’s accountability system must measure the performance of 95 percent of students by looking at a variety of indicators. One of the indicators is “academic achievement as measured by proficiency on the annual assessments.” For this reason, in order to measure the overall achievement of 95 percent of students, 95 percent must take the annual assessments.<br><br>Question:<br>Does my state still have to test 95 percent of its students?<br><br>Answer:<br></blockquote></font><font size='+1' color='teal'>🤖<blockquote>Yes, ESSA requires that a state’s accountability system must measure the performance of 95 percent of students by looking at a variety of indicators. One of the indicators is “academic achievement as measured by proficiency on the annual assessments.” For this reason, in order to measure the overall achievement of 95 percent of students, 95 percent must take the annual assessments. <br><br><br>**Explanation:**<br><br>The answer is yes because the context explicitly states that ESSA requires states to test 95% of students. <br><end_of_turn></blockquote></font>

# Evaluation
* interesting that the Gemma2 instruct model (not fine-tuned on the dataset) also gets the answer right
* that said, it also adds a bit more extra details and even explanation (which may not be desired for the chatbot)