In [1]:
import numpy as np
import copy
import torch

from code.data_utils.dataset import DatasetLoader
from code.data_utils.utils import load_caption, load_description, save_message
from code.message_template import molhiv_message_template, molbace_message_template, molbbbp_message_template, molesol_message_template, molfreesolv_message_template, mollipo_message_template
from code.config import cfg, update_cfg
from code.utils import set_seed
from code.generate_message import generate_fs_message_cls, generate_fsd_message_cls, generate_fs_message_reg, generate_fsd_message_reg

import warnings
warnings.filterwarnings('ignore')

In [2]:
# load cfg
# cfg = update_cfg(cfg)
set_seed(cfg.seed)

# manual cfg settings
cfg.dataset = "ogbg-molesol" # ogbg-molhiv
cfg.demo_test = True

In [3]:
# Preprocess data
dataloader = DatasetLoader(name=cfg.dataset, text='raw')
dataset, smiles = dataloader.dataset, dataloader.text

caption = load_caption(dataset_name=cfg.dataset)
description = load_description(dataset_name=cfg.dataset)

split_idx = dataset.get_idx_split()
index_pos = np.intersect1d(split_idx['train'], torch.where(dataset.y == 1)[0])
index_neg = np.intersect1d(split_idx['train'], torch.where(dataset.y == 0)[0])

if cfg.dataset == "ogbg-molhiv":
    template_set = molhiv_message_template
elif cfg.dataset == "ogbg-molbace":
    template_set = molbace_message_template
elif cfg.dataset == "ogbg-molbbbp":
    template_set = molbbbp_message_template
elif cfg.dataset == "ogbg-molesol":
    template_set = molesol_message_template
elif cfg.dataset == "ogbg-molfreesolv":
    template_set = molfreesolv_message_template
elif cfg.dataset == "ogbg-mollipo":
    template_set = mollipo_message_template
else:
    raise ValueError("Invalid Dataset Name to find Message Set.")

In [4]:
message_type = "IF"
print('Generating {}'.format(message_type))
template = template_set[message_type]
list_message = []
for s in smiles:
    message = copy.deepcopy(template)
    message[1]["content"] = message[1]["content"].format(s)
    list_message.append(message)
save_message(
    dataset_name=cfg.dataset, list_message=list_message,
    message_type=message_type, demo_test=cfg.demo_test
)

Generating IF


In [None]:
message_type = "IFD"
print('Generating {}'.format(message_type))
template = template_set[message_type]
list_message = []
for s in smiles:
    message = copy.deepcopy(template)
    message[1]["content"] = message[1]["content"].format(s)
    list_message.append(message)
save_message(
    dataset_name=cfg.dataset, list_message=list_message,
    message_type=message_type, demo_test=cfg.demo_test
)

In [5]:
message_type = "IFC"
print('Generating {}'.format(message_type))
template = template_set[message_type]
list_message = []
for s, c in zip(smiles, caption):
    message = copy.deepcopy(template)
    message[1]["content"] = message[1]["content"].format(s, c)
    list_message.append(message)
save_message(
    dataset_name=cfg.dataset, list_message=list_message,
    message_type=message_type, demo_test=cfg.demo_test
)

Generating IFC


In [6]:
message_type = "IP"
print('Generating {}'.format(message_type))
template = template_set[message_type]
list_message = []
for s in smiles:
    message = copy.deepcopy(template)
    message[1]["content"] = message[1]["content"].format(s)
    list_message.append(message)
save_message(
    dataset_name=cfg.dataset, list_message=list_message,
    message_type=message_type, demo_test=cfg.demo_test
)

Generating IP


In [7]:
message_type = "IPC"
print('Generating {}'.format(message_type))
template = template_set[message_type]
list_message = []
for s, c in zip(smiles, caption):
    message = copy.deepcopy(template)
    message[1]["content"] = message[1]["content"].format(s, c)
    list_message.append(message)
save_message(
    dataset_name=cfg.dataset, list_message=list_message,
    message_type=message_type, demo_test=cfg.demo_test
)

Generating IPC


In [8]:
message_type = "IE"
print('Generating {}'.format(message_type))
template = template_set[message_type]
list_message = []
for s in smiles:
    message = copy.deepcopy(template)
    message[1]["content"] = message[1]["content"].format(s)
    list_message.append(message)
save_message(
    dataset_name=cfg.dataset, list_message=list_message,
    message_type=message_type, demo_test=cfg.demo_test
)

Generating IE


In [9]:
message_type = "IEC"
print('Generating {}'.format(message_type))
template = template_set[message_type]
list_message = []
for s, c in zip(smiles, caption):
    message = copy.deepcopy(template)
    message[1]["content"] = message[1]["content"].format(s, c)
    list_message.append(message)
save_message(
    dataset_name=cfg.dataset, list_message=list_message,
    message_type=message_type, demo_test=cfg.demo_test
)

Generating IEC


In [10]:
message_type = "FS-1"
print('Generating {}'.format(message_type))
if 'classification' in dataset.task_type:
    list_message = generate_fs_message_cls(
        message_type=message_type, template_set=template_set,
        smiles=smiles, index_pos=index_pos, index_neg=index_neg
    )
else:
    list_message = generate_fs_message_reg(
        message_type=message_type, template_set=template_set,
        smiles=smiles, label=dataset.y
    )
save_message(
    dataset_name=cfg.dataset, list_message=list_message,
    message_type=message_type, demo_test=cfg.demo_test
)

Generating FS-1


In [11]:
message_type = "FSC-1"
print('Generating {}'.format(message_type))
if 'classification' in dataset.task_type:
    list_message = generate_fsc_message_cls(
        message_type=message_type, template_set=template_set,
        smiles=smiles, caption=caption, index_pos=index_pos, index_neg=index_neg
    )
else:
    list_message = generate_fsc_message_reg(
        message_type=message_type, template_set=template_set,
        smiles=smiles, caption=caption, label=dataset.y
    )
save_message(
    dataset_name=cfg.dataset, list_message=list_message,
    message_type=message_type, demo_test=cfg.demo_test
)

Generating FSC-1
