In [None]:
import os
import os.path as osp
from torch_geometric.data import Data
import torch
import pandas as pd
import numpy as np
from torch_geometric.data import InMemoryDataset, download_url
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, LlamaForCausalLM

tokenizer = AutoTokenizer.from_pretrained("/home/models/zephyr-7b-alpha")
model = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceH4/zephyr-7b-alpha", device_map="auto", load_in_4bit=True
)
model = LlamaForCausalLM.from_pretrained("/home/models/Llama-2-7b-hf", device_map="auto", load_in_4bit=True)
set_seed(0)
prompt = """How many helicopters can a human eat in one sitting? Reply as a thug."""
model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
input_length = model_inputs.input_ids.shape[1]
generated_ids = model.generate(**model_inputs, max_new_tokens=20)
print(tokenizer.batch_decode(generated_ids[:, input_length:], skip_special_tokens=True)[0])

set_seed(42)
def getKeywords(text):
    messages = [
        {
            "role": "system",
            "content": """
              You are a smart reviewer of a research conference. Given the title and abstract of the paper by a user, provide 5 keywords that you think are most suitable for describing the information.
              Strictly follow the output format.
            """,
        },
        {
            "role": "user",
            "content": text
        },
    ]
    model_inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to("cuda")
    input_length = model_inputs.shape[1]
    generated_ids = model.generate(model_inputs, do_sample=True, max_new_tokens=100)
    output = tokenizer.batch_decode(generated_ids[:, input_length:], skip_special_tokens=True)[0]
    return output

data_path = osp.join("/content/cora.pt")
raw_cora_data = torch.load(data_path)

raw_cora_data = Data.from_dict(raw_cora_data.to_dict())
texts = raw_cora_data.raw_text
label_names = raw_cora_data.label_names

keywords = []
for text in texts[:10]:
    keywords.append(getKeywords(text))