#### Friday, March 1, 2024

This notebook was run on Kaggle then downloaded here for reference purposes.

[Starter Notebook: Generating More Data With Gemma](https://www.kaggle.com/code/wlifferth/starter-notebook-generating-more-data-with-gemma)

# Starter Notebook: Generating More Data With Gemma
Our ultimate goal in this competition is to take an original sample of text and a new version of that text rewritten by Gemma, and to figure out what prompt was used to get the new version. A helpful first step is to be able to generate a bunch of examples of what that looks like, so we can then learn the relationships between the original text, rewrite prompt and rewritten text.

To generate examples, we'll need a few things:
1. A corpus of original texts
2. A set of rewrite prompts
3. Our model (Gemma!) to use the original text and rewrite prompt to generate a rewritten text

Let's tackle them one by one.

## Generating `original_text`
While we don't know too much about the original text used in the competition test set,
the meta-kaggle dataset provides a corpus of forum messages on kaggle that we can
use as a simple example.


In [1]:
import pandas as pd

forum_messsages_df = pd.read_csv('/kaggle/input/meta-kaggle/ForumMessages.csv')
forum_messsages_df.head()

Unnamed: 0,Id,ForumTopicId,PostUserId,PostDate,ReplyToForumMessageId,Message,Medal,MedalAwardDate
0,653655,113090,1142262,10/20/2019 19:07:58,,<p>Awesome EDA! \nAll the insights make sense....,3.0,10/20/2019
1,653654,112341,2541293,10/20/2019 19:05:02,,"<p>HI,</p>\n\n<p>Just a query. Not sure if it ...",3.0,11/05/2019
2,653653,101554,2566546,10/20/2019 19:03:14,,<p>Thank you for sharing!!!👍 </p>,,
3,653652,113114,3802444,10/20/2019 18:57:34,653317.0,"<p>That makes sense! However, how might we det...",,
4,653651,108226,1192157,10/20/2019 18:57:06,622972.0,"<p>Hi, even I am facing the same issue for the...",,


In [2]:
# Let's grab the first 5 messages to test our generation pipeline:

original_texts = forum_messsages_df['Message'][:5]

## Generating `rewrite_prompt`
While there are lots of ways to come up with rewrite prompts, for simplicity here are a few random prompts we can use.

In [3]:
rewrite_prompts = [
    'Explain this to me like I\'m five.',
    'Convert this into a sea shanty.',
    'Make this rhyme.',
]

## Generating `rewritten_text` with Gemma
Now for the fun part! We can use gemma to rewrite our original text samples
using the rewrite prompts we created.
The code in this cell is borrowed from [the model card](https://www.kaggle.com/models/google/gemma/frameworks/pyTorch/variations/7b-it-quant).
The important things to know:

We're using the 7B parameter instruction tuned quantized model, which means:

- 7B Parameter: this is the larger of the two Gemma models (the other has 2 billion parameters).
    In general we expect the larger model to perform better on complex tasks, but
    it's more resource intensive. You can see exactly how Gemma 7B compares to to Gemma 2B [here](https://ai.google.dev/gemma).
- Instruction Tuned: instruction tuning is an extra training step that results in a model that
    can follow user instructions better. Our rewrite prompt is a kind of instruction, so this is what we want!
- Quantized: quantization is a way of shrinking the size of a model by reducing the precision of each
    parameter; so while our model still has 7 billion parameters, it's easier to run on limited
    hardware.

At the end of this cell, we'll have a `model` we can call `generate` on with a specially formatted prompt.

In [4]:
!pip install -q -U immutabledict sentencepiece 
!git clone https://github.com/google/gemma_pytorch.git
!mkdir /kaggle/working/gemma/
!mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/gemma/

import sys 
sys.path.append("/kaggle/working/gemma_pytorch/") 
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

# Load the model
VARIANT = "7b-it-quant" 
MACHINE_TYPE = "cuda" 
weights_dir = '/kaggle/input/gemma/pytorch/7b-it-quant/2' 

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)

# Model Config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")
model_config.quant = "quant" in VARIANT

# Model.
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
  model = GemmaForCausalLM(model_config)
  ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
  model.load_weights(ckpt_path)
  model = model.to(device).eval()


Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 71, done.[K
remote: Counting objects: 100% (16/16), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 71 (delta 12), reused 8 (delta 8), pack-reused 55[K
Unpacking objects: 100% (71/71), 2.13 MiB | 5.60 MiB/s, done.


  return self.fget.__get__(instance, owner)()


In [5]:
# Now we can loop through our input texts, randomly select a rewrite prompt, and see Gemma in action:

import random
random.seed(0)
# This is the prompt format the model expects
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"

rewrite_data = []

for original_text in original_texts:
    rewrite_prompt = random.choice(rewrite_prompts)
    prompt = f'{rewrite_prompt}\n{original_text}'
    rewritten_text = model.generate(
        USER_CHAT_TEMPLATE.format(prompt=prompt),
        device=device,
        output_len=100,
    )
    rewrite_data.append({
        'original_text': original_text,
        'rewrite_prompt': rewrite_prompt,
        'rewritten_text': rewritten_text,
    })
    

In [6]:
# Let's turn our generated data into a dataframe, and spot check the first rewrite to see if it makes sense.
rewrite_data_df = pd.DataFrame(rewrite_data)
rewrite_data_df[:1].values

array([['<p>Awesome EDA! \nAll the insights make sense. Very well put! I recommend this as a pretty good notebook for starters! Thank you :)</p>',
        'Convert this into a sea shanty.',
        "Sure, here's the converted sea shanty:\n\n(Verse 1)\nAvast ye, me hearties, listen to me tale,\nOf a notebook that's a pleasure to hail.\nThe EDA's awesomeness, a sight to behold,\nWith insights that make sense, tales to be told.\n\n(Chorus)\nOh, the notebook's a treasure to behold,\nWith well-put ideas to make the mind fold.\nAvast ye"]],
      dtype=object)

# Next Steps

Huzzah! We have a dataset with original texts, rewrite prompts, and rewritten text. Here are a couple of suggestions of next steps you could take to generate a larger, more diverse dataset:
1. Add more original text data sources; besides just using all of the forum messages (instead of just the first 5), Kaggle has tons of datasets that would make reasonable input text. Here are few random datasets you could use:
    - The `Plot` column from the [Wikipedia Movie Plots dataset](https://www.kaggle.com/datasets/jrobischon/wikipedia-movie-plots).
    - The `text` column from the [Emotions dataset](https://www.kaggle.com/datasets/nelgiriyewithana/emotions).
    - The `body_text` and `abstract` columns of the [Wikibooks Dataset](https://www.kaggle.com/datasets/dhruvildave/wikibooks-dataset).
    
    Note that each of these may need different preprocessing; for example, Gemma has a context length of 8192 tokens, so if the text is long, you'll need to truncate it.
2. Use gemma to generate original text.
3. Expand the list of rewrite prompts. You can come up with them manually, or explore having Gemma write rewrite prompts.
4. Play around with the generation of `rewritten_text`:
   - How does changing `output_len` affect the length and quality of rewrites?
   - Do rewrites with the 2B parameter model differ substantially from the 7B model?
   - Can you use [few shot prompting](https://www.promptingguide.ai/techniques/fewshot) to get higher quality rewrites?