In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import pandas as pd
import os
from pathlib import Path


model_name_or_path = "/mnt/nas1/models/llama/merged_models/llama2-7b-ner-chem_gene3"
device_map = "auto"
# if we are in a distributed setting, we need to set the device map and max memory per device
if os.environ.get('LOCAL_RANK') is not None:
    local_rank = int(os.environ.get('LOCAL_RANK', '0'))
    device_map = {'': local_rank}

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        device_map=device_map,
        load_in_4bit=True,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
        ),
    )


file = Path('/mnt/nas1/corpus-bio-nlp/NER/PGx_CTD_chem_x_gene.csv')
df_pgx_ctd = pd.read_csv(file, dtype=str)


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [01:32<00:00, 30.67s/it]


In [2]:
question = (
    "{sentence}"
    "\n---------------\n"
    "please extract all Chemical and Gene in the above text, "
    "Gene includes gene or protein, excluding Limited variation, Genomic variation, Genomic factor, Haplotype. "
    "Chemical includes chemical, drug and amino acid, excluding disease."
    "The output format should be '<starting index in sentence, ending index in sentence, entity name, entity type>' .")
df_pgx_ctd = df_pgx_ctd.drop_duplicates(subset=["sentence"])
df_pgx_ctd["prompt"] = df_pgx_ctd["sentence"].apply(lambda x: question.format(sentence=x))

In [3]:
def chat_ner(x):
    input_pattern = '<s>{}</s>'
    text = x.strip()
    text = input_pattern.format(text)
    input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.cuda()
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids, max_new_tokens=500, do_sample=False,
            top_p=1, temperature=1, repetition_penalty=1,
            eos_token_id=tokenizer.eos_token_id
        )
    outputs = outputs.tolist()[0][len(input_ids[0]):]
    response = tokenizer.decode(outputs)
    print(response)
    response = response.replace('</s>', "").strip()
    return response

input1 = df_pgx_ctd["prompt"][0]
print(input1)
r = chat_ner(input1)
print(r)

Among controls , we found women with the A2/A2 genotype to have elevated levels of estrone ( +14.3 % , P = 0.01 ) , estradiol ( +13.8 % , P = 0.08 ) , testosterone ( +8.6 % , P = 0.34 ) , androstenedione ( +17.1 % , P = 0.06 ) , dehydroepiandrosterone ( +14.4 % , P = 0.02 ) , and dehydroepiandrosterone sulfate ( +7.2 % , P = 0.26 ) compared with women with the A1/A1 genotype .
---------------
please extract all Chemical and Gene in the above text, Gene includes gene or protein, excluding Limited variation, Genomic variation, Genomic factor, Haplotype. Chemical includes chemical, drug and amino acid, excluding disease.The output format should be '<starting index in sentence, ending index in sentence, entity name, entity type>' .
<129, 140, testosterone, Chemical>, <105, 126, androstenedione, Chemical>, <150, 196, dehydroepiandrosterone, Chemical>, <199, 230, dehydroepiandrosterone sulfate, Chemical>, <155, 173, estrone, Chemical>, <176, 187, estradiol, Chemical>, <147, 153, A1, Gene>, <

In [5]:
model_name_or_path_e3 = "/mnt/nas1/models/llama/merged_models/llama2-7b-ner-chem_gene-e3s6"

tokenizer_e3 = AutoTokenizer.from_pretrained(model_name_or_path_e3)
model_e3 = AutoModelForCausalLM.from_pretrained(
        model_name_or_path_e3,
        device_map=device_map,
        load_in_4bit=True,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
        ),
    )

def chat_ner2(x):
    input_pattern = '<s>{}</s>'
    text = x.strip()
    text = input_pattern.format(text)
    input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.cuda()
    with torch.no_grad():
        outputs = model_e3.generate(
            input_ids=input_ids, max_new_tokens=500, do_sample=False,
            top_p=1, temperature=1, repetition_penalty=1,
            eos_token_id=tokenizer.eos_token_id
        )
    outputs = outputs.tolist()[0][len(input_ids[0]):]
    response = tokenizer.decode(outputs)
    print(response)
    response = response.replace('</s>', "").strip()
    return response

r = chat_ner2(input1)
print(r)

Loading checkpoint shards: 100%|██████████| 3/3 [00:11<00:00,  3.71s/it]


<105, 111, estrone, Chemical>, <105, 111, estradiol, Chemical>, <105, 111, testosterone, Chemical>, <105, 111, androstenedione, Chemical>, <105, 111, dehydroepiandrosterone, Chemical>, <105, 111, dehydroepiandrosterone sulfate, Chemical></s>
<105, 111, estrone, Chemical>, <105, 111, estradiol, Chemical>, <105, 111, testosterone, Chemical>, <105, 111, androstenedione, Chemical>, <105, 111, dehydroepiandrosterone, Chemical>, <105, 111, dehydroepiandrosterone sulfate, Chemical>
