In [1]:
import torch
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cpu')

In [3]:
wiki_movie = pd.read_csv('wiki_movie_plots_deduped.csv')
wiki_plot = wiki_movie.dropna(subset=['Plot'])['Plot']

In [82]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-1.1-2b-it")
model = AutoModelForCausalLM.from_pretrained("google/gemma-1.1-2b-it")

Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.46s/it]


In [91]:
device

device(type='cpu')

In [84]:
rewrite_prompts = [
    'Explain this to me like I\'m five.',
    'Convert this into a sea shanty.',
    'Make this rhyme.',
]

In [85]:
import random
random.seed(0)
# This is the prompt format the model expects
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"

In [86]:
original_texts = wiki_plot[:5]
type(USER_CHAT_TEMPLATE.format(prompt=original_texts))

str

In [127]:
def preprocess_plot(rewrite_prompt, plot):
    prompt = f'{rewrite_prompt}\n{plot}'
    # USER_CHAT_TEMPLATE.format(prompt=prompt)
    return USER_CHAT_TEMPLATE.format(prompt=prompt)    

In [162]:
def generate_rewrite(input_text):
    input_ids = tokenizer(input_text, return_tensors="pt", padding=True).input_ids.to(device)
    outputs = model.generate(input_ids, max_length=500)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [163]:
def postprocess_rewrite(rewrite, input_text):
    if rewrite.startswith(input_text):
        return rewrite[len(input_text):]

In [164]:
rewrite_data = []
for original_text in original_texts:
    rewrite_prompt = random.choice(rewrite_prompts)
    input_text = preprocess_plot(rewrite_prompt, original_text)
    generated_rewrite = generate_rewrite(input_text)
    # rewrite = postprocess_rewrite(str(generated_rewrite), input_text)
    rewrite_data.append({'original_text': original_text, 'rewrite_prompt': rewrite_prompt, 'rewrite': generated_rewrite})

In [167]:
rewrite_data = pd.DataFrame(rewrite_data)

In [185]:
rewrite_data.head(5)

Unnamed: 0,original_text,rewrite_prompt,rewrite
0,"A bartender is working at a saloon, serving dr...",Convert this into a sea shanty.,user\nConvert this into a sea shanty.\nA barte...
1,"The moon, painted with a smiling face hangs ov...",Convert this into a sea shanty.,user\nConvert this into a sea shanty.\nThe moo...
2,"The film, just over a minute long, is composed...",Make this rhyme.,"user\nMake this rhyme.\nThe film, just over a ..."
3,Lasting just 61 seconds and consisting of two ...,Explain this to me like I'm five.,user\nExplain this to me like I'm five.\nLasti...
4,The earliest known adaptation of the classic f...,Convert this into a sea shanty.,user\nConvert this into a sea shanty.\nThe ear...


In [180]:
str(rewrite_data['rewrite'][0]).split('\n')[2] == str(rewrite_data['original_text'][0])

True

In [182]:
str(rewrite_data['rewrite'][0]).split('\n')

['user',
 'Convert this into a sea shanty.',
 "A bartender is working at a saloon, serving drinks to customers. After he fills a stereotypically Irish man's bucket with beer, Carrie Nation and her followers burst inside. They assault the Irish man, pulling his hat over his eyes and then dumping the beer over his head. The group then begin wrecking the bar, smashing the fixtures, mirrors, and breaking the cash register. The bartender then sprays seltzer water in Nation's face before a group of policemen appear and order everybody to leave.[1]",
 'model',
 '(Verse 1)',
 "At the saloon's dark and grim,",
 'A tale unfolds, a salty whim.',
 'A thirsty Irishman, bold and strong,',
 'His bucket filled, a tale to be sung.',
 '',
 '(Chorus)',
 'Heave ho, heave ho, the crowd they came,',
 'With fury in their eyes, a watery shame.',
 'Pulling caps, smashing frames,',
 'The bar in chaos, a watery game.',
 '',
 '(Verse 2)',
 'Carrie Nation, with fiery grace,',
 "Joined the fray, a whirlwind's embra

In [183]:
POSTPROCESS_TEMPLATE = "user\n{input_text}\nmodel\n"

In [None]:
#write df to csv
rewrite_data.to_csv('rewrite_data.csv', index=False)

In [None]:
#randomly select 5 rows from the dataframe
rewrite_data.sample(5)

In [2]:
import pandas as pd

In [3]:
pd.read_csv('rewrite_data_wikimovie.csv')

Unnamed: 0,original_text,rewrite_prompt,rewrite
0,The sudden death of her mother brings Myung-eu...,"Describe the setup, dissect the plot into risi...",**Setup:**\n\n- Myung-eun returns to Jeju afte...
1,"Living in Manhattan, Tom (Zach Braff) is a coo...","Describe the setup, dissect the plot into risi...","**Setup:**\n\n- Tom, a struggling cook, and So..."
2,"Set in Nottinghamshire, Dek (Rhys Ifans) propo...","Outline the main events, focus on character mo...",**Main Events:**\n\n- Dek proposes to Shirley ...
3,Sakthi (Vijay) is the adopted son of a Madurai...,"Summarize the plot, delve into the film’s them...",**Plot Summary:**\n\nThe film depicts the tran...
4,Kate Armstrong (Catherine Zeta-Jones) is the h...,"Summarize the plot, integrate various critical...",**Plot Summary:**\n\nThe film explores the com...
...,...,...,...
995,Recording some new music in an isolated farmho...,"Provide a plot overview, detail the main chara...","## Plot Overview:\n\nThe film follows Triton, ..."
996,After Digvijay Singh (Amjad Khan) implicates h...,"Outline the main events, focus on character mo...","**Main Events:**\n\n- Ram Kumar, a simple insu..."
997,"In German-occupied Paris, an announcer reports...","Outline the main events, focus on character mo...",**Main Events:**\n\n- Five downed pilots from ...
998,"Kenta Date, a senior high school student, is t...","Summarize the plot, integrate various critical...",**Plot Summary:**\n\nThe film follows Kenta Da...
