In [None]:
# !pip install bitsandbytes accelerate
# !pip install sentence_transformers


from sentence_transformers import SentenceTransformer, util

import pickle
import tqdm.notebook as tq
import transformers
import torch
import pandas as pd
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
device

In [None]:
interpretation_dataset = pd.read_excel("drive/MyDrive/NLP_Final_Project/dataset/final_interpretations.xlsx")
interpretation_dataset['quote_len'] = interpretation_dataset['Quote'].apply(len)
interpretation_dataset = interpretation_dataset[interpretation_dataset['quote_len']<256]

print('Dataset shape:',interpretation_dataset.shape)

In [None]:
model_name = 'google/gemma-2b'

model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

sentence_model = SentenceTransformer('all-mpnet-base-v2')

In [None]:
# text = '''Explain the meaning of the quote - \n
# \n
# Quote : "a smile is what makes a face beautiful."\n
# Answer : "This quote implies that a smile enhances one's facial attractiveness more than any other physical feature, suggesting that genuine joy and warmth are key components of beauty."\n
# \n
# Quote : "sometimes there a hundred lies behind a smile and not a single truth behind a tear."\n
# Answer :  "This quote suggests that smiles can often be used to hide true feelings or dishonest intentions, portraying a facade of happiness or contentment. In contrast, tears are depicted as more genuine expressions, less likely to be used to conceal the truth, indicating raw, unfiltered emotion."\n
# \n
# Quote : "peace begins with a smile"\n
# Answer : '''

# inputs = tokenizer(text, return_tensors="pt").to(device)
# outputs = model.generate(**inputs, max_new_tokens = 512, do_sample=True, temperature=0.5, top_p=0.95, repetition_penalty=1.6)
# print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [None]:
quote_string = "Quote"
tag = "Tag"
meaning = "Interpretation"
new_inter = "Generated_Interpretation"

class QuoteInterpreter():

    def __init__(self, model, tokenizer,  sentence_model) -> None:
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = tokenizer
        self.model = model.to(self.device)
        self.sentence_model = sentence_model.to(self.device)
        self.model.eval()
        self.sentence_model.eval()

    def generate_embeddings(self, data):
        embeds = {}
        for row in data.itertuples(index=True):
            idx = row.Index
            with torch.no_grad():
                embedding = self.sentence_model.encode(getattr(row, quote_string), convert_to_tensor=True).to(self.device)
            embeds[idx] = embedding
        return embeds

    def prompt(self, data, method="zero_shot", n = 0):
        if method == "few_shot":
            self.embeddings = self.generate_embeddings(data)

        loop = tq.tqdm(data.itertuples(index=True), total=len(data),
                      leave=True, colour='steelblue')

        for row in loop:
            i = row.Index
            prompt = self.get_prompt(i, data, method, n)
            data.loc[i, method] = prompt
            with torch.no_grad():
                input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
                outputs = self.model.generate(
                    input_ids,
                    min_new_tokens=20,
                    num_beams=5,
                    max_new_tokens = 250,
                    num_return_sequences=3,
                    no_repeat_ngram_size=2)
            for idx, beam_output in enumerate(outputs):
                sent = self.tokenizer.decode(beam_output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                data.loc[i, new_inter +'_'+ str(idx+1)] = sent

            loop.set_description(f"")
        return data

    def get_prompt(self, quote_idx, data, method, n):
        if method == "cot":
            return self.cot_prompt(data.loc[quote_idx, quote_string], data.loc[quote_idx, tag])

        if method == "few_shot":
            example_idx = self.get_examples(quote_idx, data, n)
            return self.few_shot_prompt(data.loc[quote_idx, quote_string], data, examples=example_idx)

        return self.zero_shot_prompt(data.loc[quote_idx, quote_string])

    def zero_shot_prompt(self, quote):
        return f"Explain the meaning of the quote - \n  Quote : '{quote}'\n Answer : "

    def get_examples(self, quote_idx, df, n):
        quote_embed = self.embeddings[quote_idx]
        similar_quotes = df[df[tag] == df.loc[quote_idx, tag]]
        similar_quotes = similar_quotes.drop(quote_idx)
        similar_quotes = [(row.Index, self.embeddings[row.Index]) for row in similar_quotes.itertuples()]
        cos_sims = {}
        for i in similar_quotes:
            cos_sims[i[0]] = float(util.cos_sim(quote_embed, i[1])[0][0].cpu())
        top_n = sorted(cos_sims.items(), key=lambda x: x[1], reverse=True)[:n]
        return [i[0] for i in top_n]

    def few_shot_prompt(self, quote, df, examples=None):
        ext_prompt = ""
        if examples is not None:
            for i in examples:
                e_quote, e_meaning = df.loc[i, quote_string], df.loc[i, meaning]
                ext_prompt += f"Quote: '{e_quote}'\n Answer : {e_meaning}\n\n"

        return f"Explain the meaning of the quote - \n  {ext_prompt} Quote : {quote}\n Answer : "


In [None]:
interpreter = QuoteInterpreter(model, tokenizer, sentence_model)

In [None]:
zero_shot_df = interpreter.prompt(interpretation_dataset.copy(), method="zero_shot", n = 0)

import pickle

checkpoint_path = 'drive/MyDrive/NLP_Final_Project/checkpoints/'
# with open(checkpoint_path+f'Gemma-2b-Zero-Shot.pkl', 'wb') as f:
#     pickle.dump((zero_shot_df), f)

zero_shot_df.to_csv(checkpoint_path+'Gemma-2b-Zero-Shot.csv')

In [None]:
one_shot_df = interpreter.prompt(interpretation_dataset.copy(), method="few_shot", n = 1)

checkpoint_path = 'drive/MyDrive/NLP_Final_Project/checkpoints/'
# with open(checkpoint_path+f'Gemma-2b-Zero-Shot.pkl', 'wb') as f:
#     pickle.dump((zero_shot_df), f)

one_shot_df.to_csv(checkpoint_path+'Gemma-2b-One-Shot.csv')

In [None]:
few_shot_df = interpreter.prompt(interpretation_dataset.copy(), method="few_shot", n = 3)

checkpoint_path = 'drive/MyDrive/NLP_Final_Project/checkpoints/'
# with open(checkpoint_path+f'Gemma-2b-Zero-Shot.pkl', 'wb') as f:
#     pickle.dump((zero_shot_df), f)

few_shot_df.to_csv(checkpoint_path+'Gemma-2b-Few-Shot.csv')