In [8]:
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import re
from tqdm import tqdm
import pandas as pd

In [2]:
# 基本パラメータ
model_name = "cyberagent/open-calm-7b"
peft_name = "lora-calm-7b-ja"

# モデルの準備
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,
    device_map="auto",
)

# トークンナイザーの準備
tokenizer = AutoTokenizer.from_pretrained(model_name)

# LoRAモデルの準備
model = PeftModel.from_pretrained(
    model, 
    peft_name, 
    device_map="auto"
)

# 評価モード
model.eval()

Loading checkpoint shards: 100%|██████████| 2/2 [00:13<00:00,  6.82s/it]


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GPTNeoXForCausalLM(
      (gpt_neox): GPTNeoXModel(
        (embed_in): Embedding(52224, 4096)
        (layers): ModuleList(
          (0-31): 32 x GPTNeoXLayer(
            (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
            (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
            (attention): GPTNeoXAttention(
              (rotary_emb): RotaryEmbedding()
              (query_key_value): Linear8bitLt(
                in_features=4096, out_features=12288, bias=True
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=12288, bias=False)
              

In [3]:
# プロンプトテンプレートの準備
def generate_prompt(data_point):
    if data_point["input"]:
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{data_point["instruction"]}

### Input:
{data_point["input"]}

### Response:"""
    else:
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{data_point["instruction"]}

### Response:"""

In [4]:
# テキスト生成関数の定義
def generate(instruction,input=None,maxTokens=256):
    # 推論
    prompt = generate_prompt({'instruction':instruction,'input':input})
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
    outputs = model.generate(
        input_ids=input_ids, 
        max_new_tokens=maxTokens, 
        do_sample=True,
        temperature=0.7, 
        top_p=0.75, 
        top_k=40,         
        no_repeat_ngram_size=10,
        pad_token_id=tokenizer.eos_token_id
    )
    outputs = outputs[0].tolist()

    # EOSトークンにヒットしたらデコード完了
    if tokenizer.eos_token_id in outputs:
        eos_index = outputs.index(tokenizer.eos_token_id)
        decoded = tokenizer.decode(outputs[:eos_index])

        # レスポンス内容のみ抽出
        sentinel = "### Response:"
        sentinelLoc = decoded.find(sentinel)
        if sentinelLoc >= 0:
            out = decoded[sentinelLoc+len(sentinel):].strip()
            return out
        else:
            print('Warning: Expected prompt template to be emitted.  Ignoring output.')
    else:
        print('Warning: no <eos> detected ignoring output')

In [6]:
instruction = '提示された化合物のSMILESについて、E3ユビキチンリガーゼ(CBL-B)のTKBドメインに対する結合親和性を予測してください。'
out = generate(instruction=instruction, input='c1ccccc1CCCCCC')
out

'非常に低い'

In [7]:
def extract_number_from_string(string):
    match = re.search(r'\d+', string)
    return match.group(0) if match else None

def generate_smidir(smi_dir):
    out_list = ['非常に高い', '高い', '普通', '低い', '非常に低い']
    smi_list, score_lang_list = [], []
    for dir in tqdm(os.listdir(smi_dir)):
        num = extract_number_from_string(dir)
        with open(f'{smi_dir}/ligand{num}.smi', 'r') as f:
            smi = f.readline()
        try:
            out = generate(instruction=instruction, input=smi)
            if out in out_list:
                smi_list.append(smi)
                score_lang_list.append(out)
            else:
                pass
        except:
            continue
    return smi_list, score_lang_list


In [8]:
smi_list, score_lang_list = generate_smidir('./enamine0')

# 辞書形式でリストを結合
data = {'SMILES': smi_list, 'affinity': score_lang_list}

# DataFrameを作成
df = pd.DataFrame(data)

df.to_csv('./csv_out0_ja.csv', index=False)

100%|██████████| 50000/50000 [10:53:59<00:00,  1.27it/s] 
