In [None]:
%pip install --upgrade --quiet torch

In [None]:
%pip install --upgrade --quiet transformers

In [None]:
%pip install --upgrade --quiet accelerate

In [None]:
from tinydb import TinyDB, Query

db = TinyDB('db.json')
table = db.table('articles')

articles = table.all()
print(f'loaded {len(articles)} articles')

articles = [x for x in articles if x['abstract'] != 'No abstract available.']
print(f'retaining {len(articles)} articles')

In [None]:
from huggingface_hub import login

login()

In [None]:
# model_id = "mistralai/Mistral-Nemo-Instruct-2407"
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

model_kwargs = {
    "low_cpu_mem_usage": True,
    "device_map": "sequential", # load the model into GPUs sequentially, to avoid memory allocation issues with balancing
    "torch_dtype": "auto"
}

In [None]:
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    **model_kwargs
)

In [None]:
import utils
utils.print_model_info(model)

In [None]:
generate_kwargs = {
    "max_new_tokens": 1024,
    "do_sample": True,
    "temperature": 0.7,
    "top_k": 50,
    "top_p": 0.95,
    "bos_token_id": tokenizer.bos_token_id,
    "pad_token_id": tokenizer.eos_token_id,
    "eos_token_id": tokenizer.eos_token_id
}

In [None]:
prompt = """
Your goal is to identify important keywords in scientific paper abstracts.
For the abstract below, identify all diseases, treatments, interventions, and vectors mentioned.
List the keywords identified in a JSON array, with each item in the array including the keyword type and value.
The only valid keyword types are disease, treatment, intervention, and vector.
Only return the JSON array.

abstract:
"""

prompt += articles[0]['abstract']

In [None]:
%%time

messages = [
    {"role": "user", "content": prompt}
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

outputs = model.generate(
    input_ids,
    **generate_kwargs
)

response = outputs[0][input_ids.shape[-1]:]

print(articles[0]['abstract'])
print(tokenizer.decode(response, skip_special_tokens=True))