In [1]:
import numpy as np
import pandas as pd
import torch
from pathlib import Path, PurePath

from code.data_utils import DatasetLoader
from code.prompt_template import molhiv_prompt_template, molbace_prompt_template
from code.config import cfg, update_cfg
from code.utils import set_seed
from code.generate_caption import load_caption
from code.generate_prompt import save_prompt, generate_fs_prompt, generate_fsc_prompt

import warnings
warnings.filterwarnings('ignore')

In [2]:
# load cfg
# cfg = update_cfg(cfg)
set_seed(cfg.seed)

# manual cfg settings
cfg.dataset = "ogbg-molbace" # ogbg-molhiv

In [3]:
# Preprocess data
dataloader = DatasetLoader(name=cfg.dataset, text='raw')
dataset, smiles = dataloader.dataset, dataloader.text

caption = load_caption(dataset_name=cfg.dataset)

split_idx = dataset.get_idx_split()
index_pos = np.intersect1d(split_idx['train'], torch.where(dataset.y == 1)[0])
index_neg = np.intersect1d(split_idx['train'], torch.where(dataset.y == 0)[0])

if cfg.dataset == "ogbg-molhiv":
    template_set = molhiv_prompt_template
elif cfg.dataset == "ogbg-molbace":
    template_set = molbace_prompt_template
else:
    raise ValueError("Invalid Dataset Name to find Prompt Set.")

In [4]:
prompt_type = "IF"
template = template_set[prompt_type]
list_prompt = [template.format(s) for s in smiles]

save_prompt(dataset_name=cfg.dataset, list_prompt=list_prompt, prompt_type=prompt_type)

In [9]:
prompt_type = "IFC"
template = template_set[prompt_type]
for s, c in zip(smiles, caption):
    list_prompt.append(template.format(s, c))

save_prompt(dataset_name=cfg.dataset, list_prompt=list_prompt, prompt_type=prompt_type)

In [8]:
prompt_type = "IP"
template = template_set[prompt_type]
list_prompt = [template.format(s) for s in smiles]

save_prompt(dataset_name=cfg.dataset, list_prompt=list_prompt, prompt_type=prompt_type)

In [9]:
prompt_type = "IPC"
template = template_set[prompt_type]
for s, c in zip(smiles, caption):
    list_prompt.append(template.format(s, c))

save_prompt(dataset_name=cfg.dataset, list_prompt=list_prompt, prompt_type=prompt_type)

In [10]:
prompt_type = "IE"
template = template_set[prompt_type]
list_prompt = [template.format(s) for s in smiles]

save_prompt(dataset_name=cfg.dataset, list_prompt=list_prompt, prompt_type=prompt_type)

In [11]:
prompt_type = "IEC"
template = template_set[prompt_type]
for s, c in zip(smiles, caption):
    list_prompt.append(template.format(s, c))

save_prompt(dataset_name=cfg.dataset, list_prompt=list_prompt, prompt_type=prompt_type)

In [30]:
prompt_type = "FS-1"
fs = int(prompt_type.split('-')[1])
template = template_set[prompt_type.split('-')[0]]
knowledge_pos_template = template_set["FS_knowledge_pos"]
knowledge_neg_template = template_set["FS_knowledge_neg"]

list_prompt = []
for idx, s in enumerate(smiles):
    knowledge = ""
    _pos = np.random.choice(index_pos, fs)
    _neg = np.random.choice(index_neg, fs)
    while idx in _pos:
        _pos = np.random.choice(index_pos, fs)
    while idx in _neg:
        _neg = np.random.choice(index_neg, fs)
        
    for knowledge_id in range(fs):
        knowledge += knowledge_pos_template.format(smiles[_pos[knowledge_id]]) + "\n"
        if knowledge_id < fs-1:
            knowledge += knowledge_neg_template.format(smiles[_neg[knowledge_id]]) + "\n"
        else:
            knowledge += knowledge_neg_template.format(smiles[_neg[knowledge_id]])
        
    list_prompt.append(template.format(knowledge, s))

save_prompt(dataset_name=cfg.dataset, list_prompt=list_prompt, prompt_type=prompt_type)

In [26]:
prompt_type = "FSC-1"
fs = int(prompt_type.split('-')[1])
template = template_set[prompt_type.split('-')[0]]
knowledge_pos_template = template_set["FSC_knowledge_pos"]
knowledge_neg_template = template_set["FSC_knowledge_neg"]

list_prompt = []
for idx, (s, c) in enumerate(zip(smiles, caption)):
    knowledge = ""
    _pos = np.random.choice(index_pos, fs)
    _neg = np.random.choice(index_neg, fs)
    while idx in _pos:
        _pos = np.random.choice(index_pos, fs)
    while idx in _neg:
        _neg = np.random.choice(index_neg, fs)
        
    for knowledge_id in range(fs):
        knowledge += knowledge_pos_template.format(
            smiles[_pos[knowledge_id]], caption[_pos[knowledge_id]]
        ) + "\n"
        if knowledge_id < fs-1:
            knowledge += knowledge_neg_template.format(
                smiles[_neg[knowledge_id]], caption[_neg[knowledge_id]]
            ) + "\n"
        else:
            knowledge += knowledge_neg_template.format(
                smiles[_neg[knowledge_id]], caption[_neg[knowledge_id]]
            )
        
    list_prompt.append(template.format(knowledge, s, c))

save_prompt(dataset_name=cfg.dataset, list_prompt=list_prompt, prompt_type=prompt_type)