In [1]:
import argparse
import dataclasses
import json
import logging
import math
import pathlib
import random
import sys
import os
import time
import numpy as np
from copy import deepcopy

from tqdm import tqdm
from xopen import xopen

import torch
import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from axolotl.prompters import AlpacaPrompter

from peft import TaskType

from softprompt.tuner import GraphPromptTuningConfig
from softprompt.mapping import get_peft_graph_model

def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i : i + n]

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
input_path = '/home/ubuntu/proj/code/axolotl_softprompt/data'
dataset_name='pubmed'
pos_type='textual'
model_path = "/home/ubuntu/proj/llm_models"
model_name="vicuna-7b-v1.5"
bittype="8bit"
order = 2
steps = 2000
epochs = 50
pos_name = f"{pos_type}_order{order}"
adapter_path = f"/home/ubuntu/proj/code/axolotl_softprompt/scripts/{dataset_name}/{pos_type}_order_{order}/{model_name}_{bittype}_{dataset_name}_{pos_type}_order_{order}/checkpoint-{steps}"

In [3]:
# Fetch all of the prompts
prompts = []
answers = []
DATAPATH = os.path.join(input_path, dataset_name)
if 'train.jsonl' not in os.listdir(DATAPATH):
    raise ValueError(f"Path {DATAPATH} does not have 'train.jsonl' in folder.")
with xopen(os.path.join(DATAPATH, 'train.jsonl')) as fin:
    for i,line in tqdm(enumerate(fin)):
        input_example = json.loads(line)
        prompt = input_example['instruction']
        answer = input_example['output']
        prompts.append(prompt)
        answers.append(answer)
        
# load tensor 
pos_tensor_name = f"train_{pos_name}.pt"
if pos_tensor_name not in os.listdir(DATAPATH):
    raise ValueError(f"Path {DATAPATH} does not have {pos_tensor_name} in folder.")
pos_token_tensor = torch.load(os.path.join(DATAPATH, pos_tensor_name))

# re-format as alpaca format
formatted_prompts = []
for prompt in prompts:
    prompter = AlpacaPrompter(prompt_style=None)
    builded_prompt = next(prompter.build_prompt(
        instruction = prompt
    ))
    formatted_prompts.append(builded_prompt)
prompts = formatted_prompts

18717it [00:00, 294028.73it/s]


In [4]:
model_name = os.path.join(model_path, model_name)

# load model and tokenizer
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto')

# build prompt tuning model
peft_config = GraphPromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    input_embedding_dim=768,
    num_virtual_tokens=4,
    num_pos_tokens=1,
    encoder_hidden_size=1024,
    embed_projection=True
)
adapter_path=adapter_path

# peft
model = get_peft_graph_model(model, peft_config)
print('before load adapter')
#for name, para in model.prompt_encoder['default'].named_parameters():
#    print(f"{name=}, {para=}")
model.load_adapter(adapter_path, adapter_name='default')
print('after load adapter')
#for name, para in model.prompt_encoder['default'].named_parameters():
#    print(f"{name=}, {para=}")
model.config.pad_token_id = model.config.eos_token_id
model.half()
model.eval()

generate_kwargs = dict(
    max_new_tokens=256,
    do_sample=False,
    top_p=None,
    return_dict_in_generate=True,
    use_cache=True,
    )

Loading checkpoint shards: 100%|██████████| 2/2 [00:10<00:00,  5.46s/it]


before load adapter
after load adapter


In [9]:
for index_value in range(10):
    input_tokens = tokenizer(prompts[index_value], return_tensors="pt", padding=True).to(model.device)
    token_tensors = pos_token_tensor[index_value:index_value+1].to(torch.float16).to(model.device)
    input_ids = input_tokens['input_ids']
    outputs = model.generate(
                input_ids=input_ids,
                prompt_tokens=token_tensors, 
                **generate_kwargs
                )[0]
    for i, generated_sequence in enumerate(outputs):
        input_ids = input_tokens["input_ids"][i]
        #print(f"\n\n {generated_sequence.dtype}\n\n")
        text = tokenizer.decode(generated_sequence, skip_special_tokens=True, clean_up_tokenization_spaces=True)

        if input_ids is None:
            prompt_length = 0
        else:
            prompt_length = len(
                tokenizer.decode(
                    input_ids,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=True,
                )
            )
        #print('groundtruth:', answers[index_value],'prediction: ', text[prompt_length:])
    used_prompt = text + " Explain why you think this is the choice."
    input_tokens = tokenizer(used_prompt, return_tensors="pt", padding=True).to(model.device)
    token_tensors = pos_token_tensor[index_value:index_value+1].to(torch.float16).to(model.device)
    input_ids = input_tokens['input_ids']
    outputs = model.generate(
                input_ids=input_ids
                prompt_tokens=token_tensors, 
                **generate_kwargs
                )[0]
    for i, generated_sequence in enumerate(outputs):
        input_ids = input_tokens["input_ids"][i]
        #print(f"\n\n {generated_sequence.dtype}\n\n")
        text = tokenizer.decode(generated_sequence, skip_special_tokens=True, clean_up_tokenization_spaces=True)

        if input_ids is None:
            prompt_length = 0
        else:
            prompt_length = len(
                tokenizer.decode(
                    input_ids,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=True,
                )
            )
        print(text[prompt_length:])
        print("*******************************")





I think this paper most likely belong to Diabetes Mellitus, Experimental because the topic of the paper is about a new treatment for diabetes and the paper mention that the treatment is based on a new technology that can regulate the glucose levels in the body, this technology is not yet approved by the FDA, so it is considered as experimental.

Please let me know if you have any other question or if you need any help.
*******************************

*******************************



I think this paper most likely belong to Diabetes Mellitus, Experimental because the keywords used in the instruction "Diabetes Mellitus, Experimental" and the paper most likely deal with new treatment or study on Diabetes Mellitus that not yet been approved or commercialized.
*******************************

*******************************



I think this paper most likely belong to Diabetes Mellitus, Experimental because the keywords given in the instruction are related to experimental research on di