In [6]:
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import re
from tqdm import tqdm
import pandas as pd


In [7]:
# 基本パラメータ
model_name = "cyberagent/open-calm-7b"
peft_name = "lora-calm-7b"

# モデルの準備
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,
    device_map="auto",
)

# トークンナイザーの準備
tokenizer = AutoTokenizer.from_pretrained(model_name)

# LoRAモデルの準備
model = PeftModel.from_pretrained(
    model, 
    peft_name, 
    device_map="auto"
)

# 評価モード
model.eval()

Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.99s/it]


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GPTNeoXForCausalLM(
      (gpt_neox): GPTNeoXModel(
        (embed_in): Embedding(52224, 4096)
        (layers): ModuleList(
          (0-31): 32 x GPTNeoXLayer(
            (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
            (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
            (attention): GPTNeoXAttention(
              (rotary_emb): RotaryEmbedding()
              (query_key_value): Linear8bitLt(
                in_features=4096, out_features=12288, bias=True
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=12288, bias=False)
              

In [8]:
# プロンプトテンプレートの準備
def generate_prompt(data_point):
    if data_point["input"]:
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{data_point["instruction"]}

### Input:
{data_point["input"]}

### Response:"""
    else:
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{data_point["instruction"]}

### Response:"""

In [9]:
# テキスト生成関数の定義
def generate(instruction,input=None,maxTokens=256):
    # 推論
    prompt = generate_prompt({'instruction':instruction,'input':input})
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
    outputs = model.generate(
        input_ids=input_ids, 
        max_new_tokens=maxTokens, 
        do_sample=True,
        temperature=0.7, 
        top_p=0.75, 
        top_k=40,         
        no_repeat_ngram_size=10,
        pad_token_id=tokenizer.eos_token_id
    )
    outputs = outputs[0].tolist()

    # EOSトークンにヒットしたらデコード完了
    if tokenizer.eos_token_id in outputs:
        eos_index = outputs.index(tokenizer.eos_token_id)
        decoded = tokenizer.decode(outputs[:eos_index])

        # レスポンス内容のみ抽出
        sentinel = "### Response:"
        sentinelLoc = decoded.find(sentinel)
        if sentinelLoc >= 0:
            out = decoded[sentinelLoc+len(sentinel):].strip()
            return out
        else:
            print('Warning: Expected prompt template to be emitted.  Ignoring output.')
    else:
        print('Warning: no <eos> detected ignoring output')

In [10]:
instruction = 'For the compound represented by the SMILES presented, please predict the binding affinity of the E3 ubiquitin-protein ligase CBL-B to the TKB domain.'
# instruction = '提示された化合物のSMILESについて、E3ユビキチンリガーゼ(CBL-B)のTKBドメインに対する結合親和性を予測してください。'
out = generate(instruction=instruction, input='c1ccccc1CCCCCC')
out

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
  attn_scores = torch.where(causal_mask, attn_scores, mask_value)


'very low'

In [11]:
instruction = 'For the compound represented by the SMILES presented, please predict the binding affinity of the E3 ubiquitin-protein ligase CBL-B to the TKB domain.'
# instruction = '提示された化合物のSMILESについて、E3ユビキチンリガーゼ(CBL-B)のTKBドメインに対する結合親和性を予測してください。'
def extract_number_from_string(string):
    match = re.search(r'\d+', string)
    return match.group(0) if match else None

def generate_smidir(smi_dir):
    out_list = ['very high', 'high', 'normal', 'low', 'very low']
    # out_list = ['非常に高い', '高い', '普通', '低い', '非常に低い']
    smi_list, score_lang_list = [], []
    for dir in tqdm(os.listdir(smi_dir)):
        num = extract_number_from_string(dir)
        with open(f'{smi_dir}/ligand{num}.smi', 'r') as f:
            smi = f.readline()
        try:
            out = generate(instruction=instruction, input=smi)
            if out in out_list:
                smi_list.append(smi)
                score_lang_list.append(out)
            else:
                pass
        except:
            continue
    return smi_list, score_lang_list

def generate_from_csv(csv_dir, instruction):
    df = pd.read_csv(csv_dir)
    out_list = ['very high', 'high', 'normal', 'low', 'very low']
    # out_list = ['非常に高い', '高い', '普通', '低い', '非常に低い']
    success_num = 0
    correct = 0
    for smi, score_lang in tqdm(zip(df['input'], df['output'])):
        try:
            out = generate(instruction=instruction, input=smi)
            if out in out_list:
                success_num += 1
                if score_lang == out:
                    correct += 1
        except:
            continue
    accuracy = correct / success_num
    return accuracy

In [12]:
instruction = 'For the compound represented by the SMILES presented, please predict the binding affinity of the E3 ubiquitin-protein ligase CBL-B to the TKB domain.'
val_accuracy = generate_from_csv('./data_valid.csv', instruction=instruction)
val_accuracy

1683it [18:10,  1.67it/s]

In [10]:
instruction = '提示された化合物のSMILESについて、E3ユビキチンリガーゼ(CBL-B)のTKBドメインに対する結合親和性を予測してください。'
val_accuracy = generate_from_csv('./data_valid_ja.csv', instruction=instruction)
val_accuracy
# 日本語でのAccuracyは0.5310020174848689

14870it [1:57:22,  2.11it/s]


0.5310020174848689