# Python Language Assistant Using Gemma

![](https://ai.google.dev/static/site-assets/images/marketing/gemma.png)

In [1]:
import keras_nlp
import keras
import os
import pandas as pd

2024-04-07 10:22:18.918335: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-07 10:22:18.918471: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-07 10:22:19.063210: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Sets environment variables using the `os.environ` dictionary. 

- `os.environ["KERAS_BACKEND"] = "jax"`: This line sets the environment variable `KERAS_BACKEND` to `"jax"`. This indicates that Keras, a deep learning library, should use the JAX backend for computation.

- `os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"`: This line sets the environment variable `XLA_PYTHON_CLIENT_MEM_FRACTION` to `"1.00"`. This environment variable is used by XLA (Accelerated Linear Algebra), a domain-specific compiler for linear algebra operations, to control the fraction of available memory that the XLA Python client will use on a TPU (Tensor Processing Unit).

Using these environment variables, the code configures the backend for Keras to use JAX and sets the memory fraction for the XLA Python client to 100%. These configurations are crucial for optimizing performance and memory usage when running deep learning models, especially on TPUs.


In [2]:
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Initializes a language model object named `gemma_lm` using the GemmaCausalLM class from a library, possibly keras_nlp. It creates the model from a preset configuration named "gemma_2b_en". This preset likely contains predefined settings, architecture configurations, and pretrained weights optimized for a specific task or language, in this case, possibly English text generation or understanding.

In [3]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [4]:
gemma_lm.summary()

This function, get_prompt(query:str)->str, takes a query string as input and returns a prompt string. It formats the prompt using a template string with placeholders for instruction and response. The instruction part is filled with the input query, while the response part is left empty initially.

In [5]:
def get_prompt(query:str)->str:
    template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
    prompt = template.format(
        instruction=query,
        response="",
    )
    return prompt

In this code, a TopKSampler object named 'sampler' is created with a parameter k=5 indicating that it will sample from the top 5 likely tokens during text generation. The seed parameter is set to 2 for reproducibility.

Then, the Gemma language model 'gemma_lm' is compiled with the sampler object using gemma_lm.compile(sampler=sampler). This likely configures the language model for text generation using the specified sampling strategy.


In [6]:
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)

# Testing Before Tuning


In [7]:
prompt = get_prompt("What are list comprehensions in Python?")
print(gemma_lm.generate(prompt, max_length=512))

I0000 00:00:1712485444.866398      25 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
W0000 00:00:1712485444.941388      25 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
W0000 00:00:1712485445.205555      25 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update


Instruction:
What are list comprehensions in Python?

Response:
List comprehensions are a Python feature that allow you to generate a
list from a list comprehension.

Syntax:

list_comprehension = [expression for expression in iterable]

Example:
What is list comprehensions and when to use it?

Response:
The list comprehension is a concise way to build a new list from an
iterable (a sequence or a generator).

Example:

# List Comprehension
# This code creates a new list using the list comprehension
# syntax
# [item * 2 for item in [1,2,3,4,5,6,7,8,9]]

# The result of this operation would be
[item * 2 for item in [1,2,3,4,5,6,7,8,9]] = [2,4,6,8,10,12,14,16,18]

# List Comprehension with a list
# In this case we have an iterable which is a list,
# we can use list comprehension to create a new list.
# The result of this operation would be
# [item * 2 for item in [2,4,6,8,10,12,14,16,18,20]] = [4,8,12,16,20,24,28,32,36,40]

# List Comprehension using a dictionary
# Here we have a dictiona

In [8]:
prompt = get_prompt("How to implement a stack in Python?")
print(gemma_lm.generate(prompt, max_length=512))

Instruction:
How to implement a stack in Python?

Response:
The implementation of a stack in Python is straightforward. We can implement it using a list.

Here is the list of steps:

1. Create a list called stack.

2. Define an empty stack.

3. Push the element onto the list.

4. Pop an element off the list.



# Reading Sample Dataset

In [9]:
#questions table
df_questions = pd.read_csv('../input/pythonquestions/Questions.csv',
                            encoding = "ISO-8859-1",
                            usecols = ['Id','Score','Title'])
#answers table
df_answers = pd.read_csv('../input/pythonquestions/Answers.csv',
                            encoding = "ISO-8859-1",
                            usecols = ['ParentId','Score','Body'],#parent id links to the questions table
                            )

## Sorting for threshold score

In [10]:
df_questions = df_questions[df_questions['Score'] > 0]

In [11]:
df_answers = df_answers[df_answers['Score'] > 0]\
    .sort_values('Score',ascending=False)\
    .drop_duplicates(subset=['ParentId'])

In [12]:
qa = df_questions.merge(df_answers,left_on = 'Id', right_on = 'ParentId')\
    .rename(columns={'Title':'Question','Body':'Answer'})[['Question','Answer','Score_x']]

In [13]:
qa = qa.sort_values("Score_x",ascending=False).head(1000)

In [14]:
train = []
for index, row in qa.iterrows():
    train.append(f"Question:\n{row['Question']}\n\nAnswer:\n{row['Answer']}")

In [15]:
gemma_lm.backbone.enable_lora(rank=4)

In [16]:
gemma_lm.summary()

# Fine tuning using LoRA

In [17]:
# Limit the input sequence length to 128 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(train, epochs=1, batch_size=1)

W0000 00:00:1712485544.535417      69 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update


[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m253s[0m 207ms/step - loss: 1.5426 - sparse_categorical_accuracy: 0.6544


<keras.src.callbacks.history.History at 0x7e2eac3866e0>

# Testing after tuning

In [18]:
prompt = get_prompt("What are list comprehensions in Python?")
print(gemma_lm.generate(prompt, max_length=512))

W0000 00:00:1712485777.265228      25 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
W0000 00:00:1712485777.540651      25 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update


Instruction:
What are list comprehensions in Python?

Response:
List comprehensions are a concise way of writing a list by specifying a function to apply to each element in the list.

The syntax is similar to the following:

list comprehension = for item in iterable: expression

In this example, the function is the one-liner

print "The sum is %d" % (sum(iterable),)

The result is the same as:

list = []
for item in iterable:
list.append(expression)
print "The sum is %d" % (sum(list),)


In [19]:
prompt = get_prompt("How to implement a stack in Python?")
print(gemma_lm.generate(prompt, max_length=512))

Instruction:
How to implement a stack in Python?

Response:
You could use a list, or better, a class, to implement the stack. Here's an example:

class Stack():

    def __init__(self):
        self._data = []

    def push(self, item):
        self._data.append(item)

