<a href="https://colab.research.google.com/github/zhaizeyu/test/blob/master/Starter_Notebook_Generating_More_Data_With_Gemma.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Starter Notebook: Generating More Data With Gemma
Our ultimate goal in this competition is to take an original sample of text and a new version of that text rewritten by Gemma, and to figure out what prompt was used to get the new version. A helpful first step is to be able to generate a bunch of examples of what that looks like, so we can then learn the relationships between the original text, rewrite prompt and rewritten text.

To generate examples, we'll need a few things:
1. A corpus of original texts
2. A set of rewrite prompts
3. Our model (Gemma!) to use the original text and rewrite prompt to generate a rewritten text

Let's tackle them one by one.

## Generating `original_text`
While we don't know too much about the original text used in the competition test set,
the meta-kaggle dataset provides a corpus of forum messages on kaggle that we can
use as a simple example.


In [1]:
!mkdir -p ~/.kaggle
!touch ~/.kaggle/kaggle.json
from google.colab import userdata

api_token = {"username":userdata.get('KAGGLE_USERNAME'),"key":userdata.get('KAGGLE_KEY')}

import json

with open('/root/.kaggle/kaggle.json', 'w') as file:
    json.dump(api_token, file)

!chmod 600 ~/.kaggle/kaggle.json


In [8]:
!kaggle competitions download -c llm-prompt-recovery

Downloading llm-prompt-recovery.zip to /content
  0% 0.00/1.45k [00:00<?, ?B/s]
100% 1.45k/1.45k [00:00<00:00, 2.71MB/s]


In [9]:
!unzip llm-prompt-recovery.zip

Archive:  llm-prompt-recovery.zip
  inflating: sample_submission.csv   
  inflating: test.csv                
  inflating: train.csv               


In [10]:
!kaggle datasets download -d kaggle/meta-kaggle

Downloading meta-kaggle.zip to /content
100% 7.04G/7.04G [04:32<00:00, 31.9MB/s]
100% 7.04G/7.04G [04:32<00:00, 27.7MB/s]


In [11]:
!unzip meta-kaggle.zip ForumMessages.csv

Archive:  meta-kaggle.zip
  inflating: ForumMessages.csv       


In [12]:
import pandas as pd

forum_messsages_df = pd.read_csv('ForumMessages.csv')
forum_messsages_df.head()

Unnamed: 0,Id,ForumTopicId,PostUserId,PostDate,ReplyToForumMessageId,Message,Medal,MedalAwardDate
0,667014,113221,2358604,11/06/2019 18:05:48,,<p>Looks really helpful ... </p>,3.0,11/13/2019
1,667013,116036,1788308,11/06/2019 18:05:43,,<p>Might someone downloaded train images 180+ ...,2.0,11/12/2019
2,667012,116035,2532029,11/06/2019 18:05:28,,"<p>Nice Article, Arjit!\nJust a small point th...",,
3,667011,116032,413189,11/06/2019 18:02:30,666992.0,<p>Nope it was actually taking lot of space. S...,,
4,667009,116025,1939378,11/06/2019 18:01:09,666971.0,<p>But it's fun xd. I saw the .000 before find...,,


In [13]:
# Let's grab the first 5 messages to test our generation pipeline:

original_texts = forum_messsages_df['Message'][:5]

## Generating `rewrite_prompt`
While there are lots of ways to come up with rewrite prompts, for simplicity here are a few random prompts we can use.

In [14]:
rewrite_prompts = [
    'Explain this to me like I\'m five.',
    'Convert this into a sea shanty.',
    'Make this rhyme.',
]

## Generating `rewritten_text` with Gemma
Now for the fun part! We can use gemma to rewrite our original text samples
using the rewrite prompts we created.
The code in this cell is borrowed from [the model card](https://www.kaggle.com/models/google/gemma/frameworks/pyTorch/variations/7b-it-quant).
The important things to know:

We're using the 7B parameter instruction tuned quantized model, which means:

- 7B Parameter: this is the larger of the two Gemma models (the other has 2 billion parameters).
    In general we expect the larger model to perform better on complex tasks, but
    it's more resource intensive. You can see exactly how Gemma 7B compares to to Gemma 2B [here](https://ai.google.dev/gemma).
- Instruction Tuned: instruction tuning is an extra training step that results in a model that
    can follow user instructions better. Our rewrite prompt is a kind of instruction, so this is what we want!
- Quantized: quantization is a way of shrinking the size of a model by reducing the precision of each
    parameter; so while our model still has 7 billion parameters, it's easier to run on limited
    hardware.

At the end of this cell, we'll have a `model` we can call `generate` on with a specially formatted prompt.

In [2]:
!pip install -U keras-nlp
!pip install -U keras


Collecting keras-nlp
  Downloading keras_nlp-0.8.2-py3-none-any.whl (465 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m465.3/465.3 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting keras-core (from keras-nlp)
  Downloading keras_core-0.1.7-py3-none-any.whl (950 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m950.8/950.8 kB[0m [31m71.0 MB/s[0m eta [36m0:00:00[0m
Collecting tensorflow-text (from keras-nlp)
  Downloading tensorflow_text-2.16.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m89.2 MB/s[0m eta [36m0:00:00[0m
Collecting namex (from keras-core->keras-nlp)
  Downloading namex-0.0.7-py3-none-any.whl (5.8 kB)
Collecting tensorflow<2.17,>=2.16.1 (from tensorflow-text->keras-nlp)
  Downloading tensorflow-2.16.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (589.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
import keras
import keras_nlp

In [6]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")

Attaching 'config.json' from model 'keras/gemma/keras/gemma_instruct_2b_en/2' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_instruct_2b_en/2' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_instruct_2b_en/2' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_instruct_2b_en/2' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_instruct_2b_en/2' to your Colab notebook...


In [84]:
# Now we can loop through our input texts, randomly select a rewrite prompt, and see Gemma in action:

import random
random.seed(0)
# This is the prompt format the model expects
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"

rewrite_data = []

for original_text in original_texts:
    rewrite_prompt = random.choice(rewrite_prompts)
    prompt = f'{rewrite_prompt}\n{original_text}'
    rewritten_text = gemma_lm.generate(
        USER_CHAT_TEMPLATE.format(prompt=prompt),
        max_length=100,
    )
    rewrite_data.append({
        'original_text': original_text,
        'rewrite_prompt': rewrite_prompt,
        'rewritten_text': rewritten_text,
    })


NotFoundError: Exception encountered when calling GemmaTokenizer.call().

[1m{{function_node __wrapped__SentencepieceTokenizeOp_device_/job:localhost/replica:0/task:0/device:CPU:0}} Resource localhost/_0_SentencepieceOp/N10tensorflow4text12_GLOBAL__N_121SentencepieceResourceE does not exist. [Op:SentencepieceTokenizeOp][0m

Arguments received by GemmaTokenizer.call():
  • inputs=tf.Tensor(shape=(1,), dtype=string)
  • args=<class 'inspect._empty'>
  • training=None
  • kwargs=<class 'inspect._empty'>

In [18]:
rewrite_data

[{'original_text': '<p>Looks really helpful ... </p>',
  'rewrite_prompt': 'Convert this into a sea shanty.',
  'rewritten_text': "<start_of_turn>user\nConvert this into a sea shanty.\n<p>Looks really helpful ... </p><end_of_turn>\n<start_of_turn>model\n(Verse 1)\nHeave ho, me lads, and raise a glass,\nTo the help that's come our way.\nThe sails are set, the wind is blowin',\nWe'll sail the seas and see where we go.\n\n(Chorus)\nOh, the heavens be with us,\nThe storm's"},
 {'original_text': '<p>Might someone downloaded train images 180+ Gb. \nPlease, share if yes.</p>\n\n<p>We can have test images from <a href="/hengck23">@hengck23</a> by this link\n<a href="https://drive.google.com/open?id=16P7OizMu8i7PzRZhNkzk_I4HgwtGOrYM">https://drive.google.com/open?id=16P7OizMu8i7PzRZhNkzk_I4HgwtGOrYM</a>\nas posted in his topic\n<a href="https://www.kaggle.com/c/rsna-intracranial-hemorrhage-detection/discussion/115855">https://www.kaggle.com/c/rsna-intracranial-hemorrhage-detection/discussion/11

In [4]:
!pip install datasets

Collecting datasets
  Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: xxhash, dill, multiprocess, datasets
Successfully installed dataset

In [21]:
df = pd.DataFrame(rewrite_data)

In [22]:
df

Unnamed: 0,original_text,rewrite_prompt,rewritten_text
0,<p>Looks really helpful ... </p>,Convert this into a sea shanty.,<start_of_turn>user\nConvert this into a sea s...
1,<p>Might someone downloaded train images 180+ ...,Convert this into a sea shanty.,<start_of_turn>user\nConvert this into a sea s...
2,"<p>Nice Article, Arjit!\nJust a small point th...",Explain this to me like I'm five.,<start_of_turn>user\nExplain this to me like I...
3,<p>Nope it was actually taking lot of space. S...,Convert this into a sea shanty.,<start_of_turn>user\nConvert this into a sea s...
4,<p>But it's fun xd. I saw the .000 before find...,Make this rhyme.,<start_of_turn>user\nMake this rhyme.\n<p>But ...


In [23]:
from datasets import Dataset


# 使用 Dataset.from_dict() 方法将字典转换为数据集
dataset = Dataset.from_pandas(df)

dataset


Dataset({
    features: ['original_text', 'rewrite_prompt', 'rewritten_text'],
    num_rows: 5
})

In [24]:
dataset.push_to_hub('zhaizy/test')

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/357 [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/datasets/zhaizy/test/commit/ba183e44971c70bed7886dabc9e057f192da84b3', commit_message='Upload dataset', commit_description='', oid='ba183e44971c70bed7886dabc9e057f192da84b3', pr_url=None, pr_revision=None, pr_num=None)

## 输入dataset开始训练

In [7]:
from datasets import load_dataset


# 使用 Dataset.from_dict() 方法将字典转换为数据集
dataset = load_dataset('zhaizy/test')

dataset

Downloading readme:   0%|          | 0.00/357 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.36k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['original_text', 'rewrite_prompt', 'rewritten_text'],
        num_rows: 5
    })
})

In [5]:
class CFG:
    seed = 42
    dataset_path = "/kaggle/input/llm-prompt-recovery"
    preset = "gemma_instruct_2b_en" # name of pretrained Gemma
    sequence_length = 512 # max size of input sequence for training
    batch_size = 1 # size of the input batch in training
    epochs = 1 # number of epochs to train
keras.utils.set_random_seed(CFG.seed)

In [8]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [9]:
template = """Instruction:\nBelow, the `Original Text` passage has been rewritten/transformed/improved into `Rewritten Text` by the `Gemma 7b-it` LLM with a certain prompt/instruction. Your task is to carefully analyze the differences between the `Original Text` and `Rewritten Text`, and try to infer the specific prompt or instruction that was likely given to the LLM to rewrite/transform/improve the text in this way.\n\nOriginal Text:\n{original_text}\n\nRewriten Text:\n{rewritten_text}\n\nResponse:\n{rewrite_prompt}"""

In [11]:
df = dataset['train'].to_pandas()

In [12]:
df["prompt"] = df.apply(lambda row: template.format(original_text=row.original_text,
                                                             rewritten_text=row.rewritten_text,
                                                             rewrite_prompt=row.rewrite_prompt), axis=1)
data = df.prompt.tolist()

In [13]:
def colorize_text(text):
    for word, color in zip(["Instruction", "Original Text", "Rewriten Text", "Response"],
                           ["red", "yellow", "blue", "green"]):
        text = text.replace(f"{word}:", f"\n\n**<font color='{color}'>{word}:</font>**")
    return text

In [14]:
from IPython.display import display, Markdown
# Take a random sample
sample = data[2]

# Give colors to Instruction, Response and Category
sample = colorize_text(sample)

# Show sample in markdown
display(Markdown(sample))



**<font color='red'>Instruction:</font>**
Below, the `Original Text` passage has been rewritten/transformed/improved into `Rewritten Text` by the `Gemma 7b-it` LLM with a certain prompt/instruction. Your task is to carefully analyze the differences between the `Original Text` and `Rewritten Text`, and try to infer the specific prompt or instruction that was likely given to the LLM to rewrite/transform/improve the text in this way.



**<font color='yellow'>Original Text:</font>**
<p>Nice Article, Arjit!
Just a small point though, lowering the exploration decay rate like .001 would have been provided better results. Not sure why, but was observing inconsistent results when I've chosen higher decay rates.</p>



**<font color='blue'>Rewriten Text:</font>**
<start_of_turn>user
Explain this to me like I'm five.
<p>Nice Article, Arjit!
Just a small point though, lowering the exploration decay rate like .001 would have been provided better results. Not sure why, but was observing inconsistent results when I've chosen higher decay rates.</p><end_of_turn>
<start_of_turn>model
Sure, here's a simplified explanation:

Imagine you're playing a game and you have a special power that helps you explore



**<font color='green'>Response:</font>**
Explain this to me like I'm five.

In [15]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = CFG.sequence_length

# Compile the model with loss, optimizer, and metric
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=3e-5),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train model
gemma_lm.fit(data, epochs=CFG.epochs, batch_size=CFG.batch_size)

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m47s[0m 82ms/step - loss: 1.9342 - sparse_categorical_accuracy: 0.5715


<keras.src.callbacks.history.History at 0x7f0815513490>

In [16]:
# Take one sample
row = df.iloc[2]

# Generate Prompt using template
prompt = template.format(
    original_text=row.original_text,
    rewritten_text=row.rewritten_text,
    rewrite_prompt="",
)
print(prompt)
# Infer
output = gemma_lm.generate(prompt, max_length=512)

# Colorize
output = colorize_text(output)

# Display in markdown
display(Markdown(output))


Instruction:
Below, the `Original Text` passage has been rewritten/transformed/improved into `Rewritten Text` by the `Gemma 7b-it` LLM with a certain prompt/instruction. Your task is to carefully analyze the differences between the `Original Text` and `Rewritten Text`, and try to infer the specific prompt or instruction that was likely given to the LLM to rewrite/transform/improve the text in this way.

Original Text:
<p>Nice Article, Arjit!
Just a small point though, lowering the exploration decay rate like .001 would have been provided better results. Not sure why, but was observing inconsistent results when I've chosen higher decay rates.</p>

Rewriten Text:
<start_of_turn>user
Explain this to me like I'm five.
<p>Nice Article, Arjit!
Just a small point though, lowering the exploration decay rate like .001 would have been provided better results. Not sure why, but was observing inconsistent results when I've chosen higher decay rates.</p><end_of_turn>
<start_of_turn>model
Sure, here



**<font color='red'>Instruction:</font>**
Below, the `Original Text` passage has been rewritten/transformed/improved into `Rewritten Text` by the `Gemma 7b-it` LLM with a certain prompt/instruction. Your task is to carefully analyze the differences between the `Original Text` and `Rewritten Text`, and try to infer the specific prompt or instruction that was likely given to the LLM to rewrite/transform/improve the text in this way.



**<font color='yellow'>Original Text:</font>**
<p>Nice Article, Arjit!
Just a small point though, lowering the exploration decay rate like .001 would have been provided better results. Not sure why, but was observing inconsistent results when I've chosen higher decay rates.</p>



**<font color='blue'>Rewriten Text:</font>**
<start_of_turn>user
Explain this to me like I'm five.
<p>Nice Article, Arjit!
Just a small point though, lowering the exploration decay rate like .001 would have been provided better results. Not sure why, but was observing inconsistent results when I've chosen higher decay rates.</p><end_of_turn>
<start_of_turn>model
Sure, here's a simplified explanation:

Imagine you're playing a game and you have a special power that helps you explore



**<font color='green'>Response:</font>**
Nice Article, Arjit!
Just a small point though, lowering the exploration decay rate like .001 would have been provided better results. Not sure why, but was observing inconsistent results when I've chosen higher decay rates.

The LLM is trying to help Arjit by explaining the concept of exploration decay rate and how it affects the game. It suggests that lowering the rate might be a better option than the default setting.

# Next Steps

Huzzah! We have a dataset with original texts, rewrite prompts, and rewritten text. Here are a couple of suggestions of next steps you could take to generate a larger, more diverse dataset:
1. Add more original text data sources; besides just using all of the forum messages (instead of just the first 5), Kaggle has tons of datasets that would make reasonable input text. Here are few random datasets you could use:
    - The `Plot` column from the [Wikipedia Movie Plots dataset](https://www.kaggle.com/datasets/jrobischon/wikipedia-movie-plots).
    - The `text` column from the [Emotions dataset](https://www.kaggle.com/datasets/nelgiriyewithana/emotions).
    - The `body_text` and `abstract` columns of the [Wikibooks Dataset](https://www.kaggle.com/datasets/dhruvildave/wikibooks-dataset).
    
    Note that each of these may need different preprocessing; for example, Gemma has a context length of 8192 tokens, so if the text is long, you'll need to truncate it.
2. Use gemma to generate original text.
3. Expand the list of rewrite prompts. You can come up with them manually, or explore having Gemma write rewrite prompts.
4. Play around with the generation of `rewritten_text`:
   - How does changing `output_len` affect the length and quality of rewrites?
   - Do rewrites with the 2B parameter model differ substantially from the 7B model?
   - Can you use [few shot prompting](https://www.promptingguide.ai/techniques/fewshot) to get higher quality rewrites?