In [1]:
import argparse
import dataclasses
import json
import logging
import math
import pathlib
import random
import sys
import os
import time
import numpy as np
from copy import deepcopy

from tqdm import tqdm
from xopen import xopen

import torch
import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from axolotl.prompters import AlpacaPrompter

from peft import TaskType

from softprompt.tuner import GraphPromptTuningConfig
from softprompt.mapping import get_peft_graph_model

def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i : i + n]

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
input_path = '/home/ubuntu/proj/code/axolotl_softprompt/data'
dataset_name='pubmed'
pos_type='textual'
model_path = "/home/ubuntu/proj/llm_models"
model_name="vicuna-7b-v1.5"
bittype="8bit"
order = 0
pos_name = f"{pos_type}_order{order}"
adapter_path = f"/home/ubuntu/proj/code/axolotl_softprompt/scripts/{dataset_name}/{pos_type}_order_{order}/{model_name}_{bittype}_{dataset_name}_{pos_type}_order_{order}/checkpoint-1000"

In [3]:
# Fetch all of the prompts
prompts = []
answers = []
DATAPATH = os.path.join(input_path, dataset_name)
if 'train.jsonl' not in os.listdir(DATAPATH):
    raise ValueError(f"Path {DATAPATH} does not have 'train.jsonl' in folder.")
with xopen(os.path.join(DATAPATH, 'train.jsonl')) as fin:
    for i,line in tqdm(enumerate(fin)):
        input_example = json.loads(line)
        prompt = input_example['instruction']
        answer = input_example['output']
        prompts.append(prompt)
        answers.append(answer)
# load tensor 
pos_tensor_name = f"train_{pos_name}.pt"
if pos_tensor_name not in os.listdir(DATAPATH):
    raise ValueError(f"Path {DATAPATH} does not have {pos_tensor_name} in folder.")
pos_token_tensor = torch.load(os.path.join(DATAPATH, pos_tensor_name))

# re-format as alpaca format
formatted_prompts = []
for prompt in prompts:
    prompter = AlpacaPrompter(prompt_style=None)
    builded_prompt = next(prompter.build_prompt(
        instruction = prompt
    ))
    formatted_prompts.append(builded_prompt)
prompts = formatted_prompts

18717it [00:00, 294841.48it/s]


In [4]:
model_name = os.path.join(model_path, model_name)

# load model and tokenizer
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(model_name)

# build prompt tuning model
peft_config = GraphPromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    input_embedding_dim=768,
    num_virtual_tokens=4,
    num_pos_tokens=1,
    encoder_hidden_size=1024,
    embed_projection=True
)
adapter_path=adapter_path

# peft
model = get_peft_graph_model(model, peft_config)
print('before load adapter')
for name, para in model.prompt_encoder['default'].named_parameters():
    print(f"{name=}, {para=}")
model.load_adapter(adapter_path, adapter_name='default')
print('after load adapter')
for name, para in model.prompt_encoder['default'].named_parameters():
    print(f"{name=}, {para=}")
model.config.pad_token_id = model.config.eos_token_id
#model.half()
model.eval()

generate_kwargs = dict(
    max_new_tokens=256, 
    do_sample=False,
    top_p=None,
    return_dict_in_generate=True,
    use_cache=True,
    )

Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.50s/it]


before load adapter
name='transform.0.weight', para=Parameter containing:
tensor([[-0.0207,  0.0266,  0.0116,  ..., -0.0145, -0.0330,  0.0097],
        [-0.0152,  0.0242, -0.0065,  ..., -0.0198, -0.0067, -0.0137],
        [-0.0357,  0.0160, -0.0045,  ...,  0.0048, -0.0304, -0.0284],
        ...,
        [ 0.0324,  0.0155,  0.0193,  ..., -0.0200,  0.0251, -0.0324],
        [-0.0216, -0.0345, -0.0196,  ..., -0.0062,  0.0300, -0.0304],
        [ 0.0065,  0.0155,  0.0037,  ...,  0.0157,  0.0154,  0.0040]],
       requires_grad=True)
name='transform.0.bias', para=Parameter containing:
tensor([ 0.0238,  0.0126,  0.0056,  ..., -0.0234, -0.0343, -0.0215],
       requires_grad=True)
name='transform.2.weight', para=Parameter containing:
tensor([[ 1.1070e-02,  2.2004e-02, -2.9223e-02,  ..., -2.7320e-02,
          1.5658e-02, -7.5257e-03],
        [ 1.2907e-02,  2.8829e-02, -2.1709e-02,  ...,  1.0715e-02,
          1.2256e-02, -1.0334e-02],
        [-1.5875e-02, -2.6800e-05, -5.5465e-03,  ...,  9.

In [7]:
input_tokens = tokenizer(prompts[0], return_tensors="pt", padding=True).to(model.device)
token_tensors = pos_token_tensor[:1]#.to(torch.float16).to(model.device)


In [8]:
outputs = model.generate(
            prompt_tokens=token_tensors, 
            **input_tokens,
            **generate_kwargs
            )[0]

In [9]:
for i, generated_sequence in enumerate(outputs):
    input_ids = input_tokens["input_ids"][i]
    #print(f"\n\n {generated_sequence.dtype}\n\n")
    text = tokenizer.decode(generated_sequence, skip_special_tokens=True, clean_up_tokenization_spaces=True)

    if input_ids is None:
        prompt_length = 0
    else:
        prompt_length = len(
            tokenizer.decode(
                input_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )
        )
    print(text)

### System:
Below is an instruction that describes a task. Write a response that appropriately completes the request.



### Instruction:
### USER: Question: Which category from the list that the paper most likely belong to? 

Belows are 3 potential categories to consider:
Category [1](Diabetes Mellitus Type 1) 
Category [2](Diabetes Mellitus Type 2) 
Category [3](Diabetes Mellitus, Experimental) 

Given the keywords of a research paper, identify one category from a distinct list of research topics that you predict the paper will most likely belong to.
### ASSISTANT:

### Response:
 This paper most likely belongs to the ### System: ### System: ### System ### System ### System ### System:
Below is a

 \ \


In [None]:
generated_sequence

tensor([    1,   835,  2184, 29901,    13, 21140,   340,   338,   385, 15278,
          393, 16612,   263,  3414, 29889, 14350,   263,  2933,   393,  7128,
         2486,  1614,  2167,   278,  2009, 29889,    13,    13,    13,    13,
         2277, 29937,  2799,  4080, 29901,    13,  2277, 29937,  3148,  1001,
        29901,   894, 29901,  8449,  7663,   515,   278,  1051,   393,   278,
         5650,  1556,  5517,  6852,   304, 29973, 29871,    13,    13, 21140,
         1242,   526, 29871, 29941,  7037, 13997,   304,  2050, 29901,    13,
        10900,   518, 29896,   850, 12130,   370, 10778,   341,   514,   277,
          375,  5167, 29871, 29896, 29897, 29871,    13, 10900,   518, 29906,
          850, 12130,   370, 10778,   341,   514,   277,   375,  5167, 29871,
        29906, 29897, 29871,    13, 10900,   518, 29941,   850, 12130,   370,
        10778,   341,   514,   277,   375, 29892,  1222, 27910, 29897, 29871,
           13,    13, 29954,  5428,   278, 29361,   310,   263, 