In [1]:
import kagglehub
import pandas as pd
import re
from tqdm import tqdm
import csv
import sys
import torch
import os

In [ ]:
PATH_TO_SUMMARY = r'C:\Users\sam\summary-eval\data\summaries_train.csv'

In [2]:
kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


## Install dependencies

In [3]:
!pip install -q -U torch immutabledict sentencepiece


[notice] A new release of pip is available: 23.0.1 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip


## Download model weights

In [4]:
# Choose variant and machine type
VARIANT = '2b-it' #@param ['2b', '2b-it', '7b', '7b-it', '7b-quant', '7b-it-quant']
MACHINE_TYPE = 'cuda' #@param ['cuda', 'cpu']

In [5]:
# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Downloading from https://www.kaggle.com/api/v1/models/google/gemma/pyTorch/2b-it/2/download...
100%|██████████| 3.75G/3.75G [04:13<00:00, 15.9MB/s]
Extracting model files...


## Download the model implementation

In [6]:
# NOTE: The "installation" is just cloning the repo.
!git clone https://github.com/google/gemma_pytorch.git

Cloning into 'gemma_pytorch'...


In [7]:
sys.path.append('gemma_pytorch')

In [8]:
from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM

## Setup the model

In [9]:
# Set up model config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

In [None]:
test_prompt = """
Reword the following passage, with a similar number of words and writing ability level.

"At the very top of the pyarmid are the Pharohs, they are responible for ruling the rest of the pyarmid. Next up is the vizier. The viziers make sure that the people pay taxes. The xcribes kept track of  government records. Then, we have nobles and priests. They take a cut of the tributes paid to the pharoh. Priests were in charge of pleasing the gods, while the nobles do almost nothing but give gifts to the god like all Egyptians but were wealthy from the donotions to the gods from everyone. Soldiers fought in wars and when there was no war to fight they would moniter the peasants, slaves, and farmers. The merchents sell things to make money and to fill the town's needs. At the very bottom of the hierarchy are the slaves and farmers. Slaves were mostly prisoners forced into labor who did everything anyone said. The farmers would make food and give 60% of their yearly harvest in taxes."
"""

In [0]:
# Generate sample
r = model.generate(
    test_prompt,
    device=device,
    output_len=200,
)
r

In [0]:
re.sub('\*\*[^>]+\*\*', '', r.replace('\n', ' ')).strip()

## SummaryEval

In [10]:
tqdm.pandas()

In [15]:
summaries_df = pd.read_csv(PATH_TO_SUMMARY)

In [16]:
summaries_df

Unnamed: 0,student_id,prompt_id,text,content,wording
0,000e8c3c7ddb,814d6b,The third wave was an experimentto see how peo...,0.205683,0.380538
1,0020ae56ffbf,ebad26,They would rub it up with soda to make the sme...,-0.548304,0.506755
2,004e978e639e,3b9047,"In Egypt, there were many occupations and soci...",3.128928,4.231226
3,005ab0199905,3b9047,The highest class was Pharaohs these people we...,-0.210614,-0.471415
4,0070c9e7af47,814d6b,The Third Wave developed rapidly because the ...,3.272894,3.219757
...,...,...,...,...,...
7160,ff7c7e70df07,ebad26,They used all sorts of chemical concoctions to...,0.205683,0.380538
7161,ffc34d056498,3b9047,The lowest classes are slaves and farmers slav...,-0.308448,0.048171
7162,ffd1576d2e1b,3b9047,they sorta made people start workin...,-1.408180,-0.493603
7163,ffe4a98093b2,39c16e,An ideal tragety has three elements that make ...,-0.393310,0.627128


In [17]:
def augment(text: str) -> str:
  llm_prompt = f"""
  Reword the following passage, with a similar number of words and writing ability level.

  "{text}"
  """
  n_words = len(text.split(' '))
  llm_response = model.generate(
    llm_prompt,
    device=device,
    output_len=n_words,
  )
  # Remove any headings (appear between **)
  return re.sub('\*\*[^>]+\*\*', '', llm_response.replace('\n', ' ')).strip()

In [18]:
outputs = []
for i, (_, row) in tqdm(enumerate(summaries_df.iterrows()), total=len(summaries_df)):
  outputs.append([
      row["student_id"], row["prompt_id"], row["text"], augment(row["text"]), row["content"], row["wording"]
  ])
  
  # Periodically save to avoid huge memory usage
  if (i+1) % 100 == 0:
    with open(f"run1_aug_{(i+1) // 100}.csv", "w") as f:
      writer = csv.writer(f)
      writer.writerows(outputs)
    outputs = []

  0%|          | 1/7165 [08:04<964:51:25, 484.85s/it]


KeyboardInterrupt: 

In [None]:
summaries_df["aug_text"] = summaries_df["text"].progress_apply(lambda x: augment(x))

In [None]:
summaries_df["aug_text"]