<a href="https://www.kaggle.com/code/yujansaya/gemma-7b-with-lora-prompt-recovery?scriptVersionId=165719476" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
!pip install git+https://github.com/huggingface/transformers -U
!pip install accelerate
!pip install -i https://pypi.org/simple/ bitsandbytes

Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-2vvyzula
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-2vvyzula
  Resolved https://github.com/huggingface/transformers to commit 9322576e2f49d1014fb0c00a7a7c8c34b6a5fd35
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: transformers
  Building wheel for transformers (pyproject.toml) ... [?25ldone
[?25h  Created wheel for transformers: filename=transformers-4.39.0.dev0-py3-none-any.whl size=8662453 sha256=35c4b4f922277fe66f23934ef3846f86e1fc6a8c6603d277f419ff6f823188ee
  Stored in directory: /tmp/pip-ephem-wheel-cache-ges0cwfb/wheels/c0/14/d6/6c9a5582d2ac191ec0a483be151a4495fe1eb2a6706ca49f1b
Successfully built transformers
Insta

In [2]:
from accelerate import Accelerator
import transformers
import bitsandbytes
import torch
from transformers import BitsAndBytesConfig, AutoTokenizer, AutoModelForCausalLM

In [3]:
accelerator = Accelerator()

In [4]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

In [None]:
MODEL_PATH = "/kaggle/input/gemma/transformers/7b-it/2"

tokenizer = AutoTokenizer.from_pretrained("/kaggle/input/gemma/transformers/7b-it/2")
model = AutoModelForCausalLM.from_pretrained(
    "/kaggle/input/gemma/transformers/7b-it/2",
    device_map = "auto",
    trust_remote_code = True,
    quantization_config=quantization_config,
)

# model = model.to_bettertransformer()
model = accelerator.prepare(model)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
import pandas as pd
from tqdm import tqdm


TEST_DF_FILE = '/kaggle/input/llm-prompt-recovery/test.csv'
SUB_DF_FILE = '/kaggle/input/llm-prompt-recovery/sample_submission.csv'
NROWS = 1000

TRAIN_DF_FILE = '/kaggle/input/gemma-rewrite-nbroad/nbroad-v2.csv'

train_df = pd.read_csv(TRAIN_DF_FILE, nrows=NROWS)
    
tdf = pd.read_csv(TEST_DF_FILE, usecols=['id', 'original_text', 'rewritten_text'])
sub = pd.read_csv(SUB_DF_FILE, usecols=['id', 'rewrite_prompt'])

In [None]:
def truncate_txt(text, length):
    text_list = text.split()
    
    if len(text_list) <= length:
        return text
    
    return " ".join(text_list[:length])


def gen_prompt(og_text, rewritten_text):
    
    # Truncate the texts to first 200 words for now
    # As we are having memory issues on Mixtral8x7b
    og_text = truncate_txt(og_text, 150)
    rewritten_text = truncate_txt(rewritten_text, 150)
    
    return f"""    
    Original Essay:
    \"""{og_text}\"""
    
    Rewritten Essay:
    \"""{rewritten_text}\"""
    
    Instruction:
    Given are 2 essays, the Rewritten essay was created from the Original essay using the google Gemma model.
    You are trying to understand how the original essay was transformed into a new version.
    Analyzing the changes in style, theme, etc., please come up with a prompt that must have been used to guide the transformation from the original to the rewritten essay.
    Only give me the PROMPT. Start directly with the prompt, that's all I need. Output should be only line ONLY.
    
    Response: 
    \"""\"""
    """

In [None]:
import datetime
start_time = datetime.datetime.now()

In [None]:
import gc
import re

device = accelerator.device
#tdf['id'] = sub['id'].copy()

pbar = tqdm(total=tdf.shape[0])

it = iter(tdf.iterrows())
idx, row = next(it, (None, None))

# https://www.kaggle.com/competitions/llm-prompt-recovery/discussion/481116
DEFAULT_TEXT = "Please improve the following text using the writing style of, maintaining the original meaning but altering the tone, diction, and stylistic elements to match the new style.Enhance the clarity, elegance, and impact of the following text by adopting the writing style of , ensuring the core message remains intact while transforming the tone, word choice, and stylistic features to align with the specified style."

res = []

while idx is not None:
    
    if (datetime.datetime.now() - start_time) > datetime.timedelta(hours=8, minutes=30):
        res.append([row["id"], DEFAULT_TEXT])
        idx, row = next(it, (None, None))
        pbar.update(1)
        continue
        
    torch.cuda.empty_cache()
    gc.collect()
        
    try:        
        messages = [
            {
                "role": "user",
                "content": gen_prompt(row["original_text"], row["rewritten_text"])
            }
        ]
        encoded_input = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(device)
        
        with torch.no_grad():
            encoded_output = model.generate(encoded_input, max_new_tokens=50, do_sample=True, pad_token_id=tokenizer.eos_token_id)
        
        decoded_output = tokenizer.batch_decode(encoded_output, skip_special_tokens=True)[0]
        decoded_output = result = re.sub(r"[\s\S]*\[\/INST\]", '', decoded_output, 1)
                
        res.append([row["id"], decoded_output])
                            
    except Exception as e:
        print(f"ERROR: {e}")
        res.append([row["id"], DEFAULT_TEXT])
        
    finally:
        idx, row = next(it, (None, None))
        pbar.update(1)

        
pbar.close()

In [None]:
!pip install peft

In [None]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM 
)

In [None]:
model

In [None]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(model))

In [None]:
peft_model = get_peft_model(model, 
                            lora_config)
print(print_number_of_trainable_model_parameters(peft_model))

In [None]:
from datasets import load_dataset,Dataset

data = Dataset.from_pandas(train_df)
data = data.map(lambda samples: tokenizer(samples["original_text"]), batched=True)
data = data.map(lambda samples: tokenizer(samples["rewritten_text"]), batched=True)
data = data.map(lambda samples: tokenizer(samples["rewrite_prompt"]), batched=True)

In [None]:
!pip install trl

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
from datasets.arrow_writer import SchemaInferenceError

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

In [None]:
from trl import SFTTrainer

def formatting_func(example):
    text = f"Original Essay:\n{truncate_txt(example['original_text'][0], 150)}\n\nRewritten Essay:\n{truncate_txt(example['rewritten_text'][0], 150)}\n\nInstruction:\n Given are 2 essays, the Rewritten essay was created from the Original essay using the google Gemma model.You are trying to understand how the original essay was transformed into a new version.Analyzing the changes in style, theme, etc., please come up with a prompt that must have been used to guide the transformation from the original to the rewritten essay.Only give me the PROMPT. Start directly with the prompt, that's all I need. Output should be only line ONLY.\n\nResponse: \n{truncate_txt(example['rewrite_prompt'][0], 150)}"
    return [text]

trainer = SFTTrainer(
    model=model,
    train_dataset=data,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=5,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
    formatting_func=formatting_func,
    #max_seq_length=8192
)
trainer.train()

In [None]:
device = accelerator.device

pbar = tqdm(total=tdf.shape[0])

it = iter(tdf.iterrows())
idx, row = next(it, (None, None))

# https://www.kaggle.com/competitions/llm-prompt-recovery/discussion/481116
DEFAULT_TEXT = "Please improve the following text using the writing style of, maintaining the original meaning but altering the tone, diction, and stylistic elements to match the new style.Enhance the clarity, elegance, and impact of the following text by adopting the writing style of , ensuring the core message remains intact while transforming the tone, word choice, and stylistic features to align with the specified style."

res = []

while idx is not None:
    
    if (datetime.datetime.now() - start_time) > datetime.timedelta(hours=8, minutes=30):
        res.append([row["id"], DEFAULT_TEXT])
        idx, row = next(it, (None, None))
        pbar.update(1)
        continue
        
    torch.cuda.empty_cache()
    gc.collect()
        
    try:        
        messages = [
            {
                "role": "user",
                "content": gen_prompt(row["original_text"], row["rewritten_text"] )
            }
        ]
        encoded_input = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(device)
        
        with torch.no_grad():
            encoded_output = model.generate(encoded_input, max_new_tokens=200, do_sample=True, pad_token_id=tokenizer.eos_token_id)
        
        decoded_output = tokenizer.batch_decode(encoded_output, skip_special_tokens=True)[0]
        decoded_output = result = re.sub(r"[\s\S]*\[\/INST\]", '', decoded_output, 1)
                
        res.append([row["id"], decoded_output])
                            
    except Exception as e:
        print(f"ERROR: {e}")
        res.append([row["id"], DEFAULT_TEXT])
        
    finally:
        idx, row = next(it, (None, None))
        pbar.update(1)

        
pbar.close()

In [None]:
sub = pd.DataFrame(res, columns=['id', 'rewrite_prompt'])

#sub.to_csv("sample_submission.csv", index=False)
sub.to_csv("submission.csv", index=False)