In [1]:
import pickle as pkl
import os 
import sys
import numpy as np
from xopen import xopen
import json
from tqdm import tqdm
import torch
from transformers import BertTokenizer, BertModel
import pandas as pd
from torch_geometric.data import Data

from softprompt.utility.prompting import (
    Item,
    get_prompt_tuning_prompt
)



  from .autonotebook import tqdm as notebook_tqdm


In [28]:
PROMPT_SETTINGS = {
    'arxiv':{
        'desc': "Question: Which category from the list that the paper most likely belong to?",
        'categories': ['Artificial Intelligence', 'Computation and Language',
                        'Computational Complexity',
                        'Computational Engineering, Finance, and Science',
                        'Computational Geometry', 'Computer Science and Game Theory',
                        'Computer Vision and Pattern Recognition', 'Computers and Society',
                        'Cryptography and Security', 'Data Structures and Algorithms',
                        'Databases', 'Digital Libraries', 'Discrete Mathematics',
                        'Distributed, Parallel, and Cluster Computing',
                        'Emerging Technologies', 'Formal Languages and Automata Theory',
                        'General Literature', 'Graphics', 'Hardware Architecture',
                        'Human-Computer Interaction', 'Information Retrieval',
                        'Information Theory', 'Logic in Computer Science',
                        'Machine Learning', 'Mathematical Software', 'Multiagent Systems',
                        'Multimedia', 'Networking and Internet Architecture',
                        'Neural and Evolutionary Computing', 'Numerical Analysis',
                        'Operating Systems', 'Other Computer Science', 'Performance',
                        'Programming Languages', 'Robotics',
                        'Social and Information Networks', 'Software Engineering', 'Sound',
                        'Symbolic Computation', 'Systems and Control'],
        'question': "Given the title and abstract of a research paper, identify one category from a distinct list of research topics that you predict the paper will most likely belong to."
    },
    'cora':{
        'desc': "Question: Which category from the list that the paper most likely belong to?",
        'categories': ['Case-Based',
                        'Genetic_Algorithms',
                        'Neural_Networks',
                        'Probabilistic_Methods',
                        'Reinforcement_Learning',
                        'Rule_Learning',
                        'Theory'],
        'question': "Given the title and abstract of a research paper, identify one category from a distinct list of research topics that you predict the paper will most likely belong to."
    },
    'cora_full':{
        'desc': "Question: Which category from the list that the paper most likely belong to?",
        'categories': ['Artificial_Intelligence/Agents/',
                        'Artificial_Intelligence/Data_Mining/',
                        'Artificial_Intelligence/Expert_Systems/',
                        'Artificial_Intelligence/Games_and_Search/',
                        'Artificial_Intelligence/Knowledge_Representation/',
                        'Artificial_Intelligence/Machine_Learning/Case-Based/',
                        'Artificial_Intelligence/Machine_Learning/Genetic_Algorithms/',
                        'Artificial_Intelligence/Machine_Learning/Neural_Networks/',
                        'Artificial_Intelligence/Machine_Learning/Probabilistic_Methods/',
                        'Artificial_Intelligence/Machine_Learning/Reinforcement_Learning/',
                        'Artificial_Intelligence/Machine_Learning/Rule_Learning/',
                        'Artificial_Intelligence/Machine_Learning/Theory/',
                        'Artificial_Intelligence/NLP/',
                        'Artificial_Intelligence/Planning/',
                        'Artificial_Intelligence/Robotics/',
                        'Artificial_Intelligence/Speech/',
                        'Artificial_Intelligence/Theorem_Proving/',
                        'Artificial_Intelligence/Vision_and_Pattern_Recognition/',
                        'Data_Structures__Algorithms_and_Theory/Computational_Complexity/',
                        'Data_Structures__Algorithms_and_Theory/Computational_Geometry/',
                        'Data_Structures__Algorithms_and_Theory/Formal_Languages/',
                        'Data_Structures__Algorithms_and_Theory/Hashing/',
                        'Data_Structures__Algorithms_and_Theory/Logic/',
                        'Data_Structures__Algorithms_and_Theory/Parallel/',
                        'Data_Structures__Algorithms_and_Theory/Quantum_Computing/',
                        'Data_Structures__Algorithms_and_Theory/Randomized/',
                        'Data_Structures__Algorithms_and_Theory/Sorting/',
                        'Databases/Concurrency/',
                        'Databases/Deductive/',
                        'Databases/Object_Oriented/',
                        'Databases/Performance/',
                        'Databases/Query_Evaluation/',
                        'Databases/Relational/',
                        'Databases/Temporal/',
                        'Encryption_and_Compression/Compression/',
                        'Encryption_and_Compression/Encryption/',
                        'Encryption_and_Compression/Security/',
                        'Hardware_and_Architecture/Distributed_Architectures/',
                        'Hardware_and_Architecture/High_Performance_Computing/',
                        'Hardware_and_Architecture/Input_Output_and_Storage/',
                        'Hardware_and_Architecture/Logic_Design/',
                        'Hardware_and_Architecture/Memory_Structures/',
                        'Hardware_and_Architecture/Microprogramming/',
                        'Hardware_and_Architecture/VLSI/',
                        'Human_Computer_Interaction/Cooperative/',
                        'Human_Computer_Interaction/Graphics_and_Virtual_Reality/',
                        'Human_Computer_Interaction/Interface_Design/',
                        'Human_Computer_Interaction/Multimedia/',
                        'Human_Computer_Interaction/Wearable_Computers/',
                        'Information_Retrieval/Digital_Library/',
                        'Information_Retrieval/Extraction/',
                        'Information_Retrieval/Filtering/',
                        'Information_Retrieval/Retrieval/',
                        'Networking/Internet/',
                        'Networking/Protocols/',
                        'Networking/Routing/',
                        'Networking/Wireless/',
                        'Operating_Systems/Distributed/',
                        'Operating_Systems/Fault_Tolerance/',
                        'Operating_Systems/Memory_Management/',
                        'Operating_Systems/Realtime/',
                        'Programming/Compiler_Design/',
                        'Programming/Debugging/',
                        'Programming/Functional/',
                        'Programming/Garbage_Collection/',
                        'Programming/Java/',
                        'Programming/Logic/',
                        'Programming/Object_Oriented/',
                        'Programming/Semantics/',
                        'Programming/Software_Development/'],
        'question': "Given the title and abstract of a research paper, identify one category from a distinct list of research topics that you predict the paper will most likely belong to."
    },
    'pubmed':{
        'desc': "Question: Which category from the list that the paper most likely belong to?",
        'categories': ['Diabetes Mellitus Type 1', 'Diabetes Mellitus Type 2','Diabetes Mellitus, Experimental'],
        'question': "Given the keywords of a research paper, identify one category from a distinct list of research topics that you predict the paper will most likely belong to."
    },
    'aids':{
        'desc': "Question: Which category from the list that the input molecule most likely belong to?",
        'categories': ['HIV antiviral active compound', 'HIV antiviral inactive compound'],
        'question': "Given the atoms type and their connection structure of a compound, identify if the given compound is HIV antiviral active or not."
    },
}

In [3]:
DATA_HOME_PATH = "/home/ubuntu/proj/data/graph"

# Pubmed

In [4]:
DATA_PATH = f"{DATA_HOME_PATH}/node_pubmed"
DATA_NAME = "text_graph_pubmed" #"text_graph_aids" #"text_graph_pubmed" # # "text_graph_cora"
TRAIN_SPLIT_NAME = 'train_index'
TEST_SPLIT_NAME = 'test_index'

with open(os.path.join(DATA_PATH, f"{DATA_NAME}.pkl"), 'rb') as f:
    graph = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{TRAIN_SPLIT_NAME}.pkl"), 'rb') as f:
    train_split = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{TEST_SPLIT_NAME}.pkl"), 'rb') as f:
    test_split = pkl.load(f)

In [5]:
task_name = 'prompt_tuning'
pubmed_item = Item(
    desc = "Question: Which category from the list that the paper most likely belong to?",
    categories = ['Diabetes Mellitus Type 1', 'Diabetes Mellitus Type 2','Diabetes Mellitus, Experimental'],
    question = "Given the keywords of a research paper, identify one category from a distinct list of research topics that you predict the paper will most likely belong to."
    )
hard_prompt = get_prompt_tuning_prompt(
    task_name = task_name,
    task_item = pubmed_item
)

In [6]:
print(hard_prompt)

### USER: Question: Which category from the list that the paper most likely belong to? 

Belows are 3 potential categories to consider:
Category [1](Diabetes Mellitus Type 1) 
Category [2](Diabetes Mellitus Type 2) 
Category [3](Diabetes Mellitus, Experimental) 

Given the keywords of a research paper, identify one category from a distinct list of research topics that you predict the paper will most likely belong to.
### ASSISTANT:


In [24]:
train_samples, test_samples = [], []
train_pos_tokens, test_pos_tokens = graph.x[torch.tensor(train_split)], graph.x[torch.tensor(test_split)]
train_y_labels, test_y_labels = (np.array(graph.text_node_labels)[np.array(train_split)]).tolist(), (np.array(graph.text_node_labels)[np.array(test_split)]).tolist()
for label in train_y_labels:
    formated_ans = f"This paper most likely belong to {label}"
    sample = {
        'instruction': hard_prompt,
        'output': formated_ans,
    }
    train_samples.append(sample)
    
for label in test_y_labels:
    formated_ans = f"This paper most likely belong to {label}"
    sample = {
        'instruction': hard_prompt,
        'output': formated_ans,
    }
    test_samples.append(sample)


In [27]:
train_pos_tokens = train_pos_tokens.view(-1, 1, 768)
test_pos_tokens = test_pos_tokens.view(-1, 1, 768)
len(train_samples), len(test_samples)

(18717, 1000)

In [28]:
def data_generator(pos_token_tensor, all_input_examples):
    for tensor, example in zip(pos_token_tensor, all_input_examples):
        yield {
            'prompt_tokens': tensor.numpy(), 
            'instruction': example['instruction'], 
            'output': example['output']
            }

In [30]:
output_path = '../data/pubmed/train.jsonl'
with xopen(output_path, "w") as f:
    for output_example in train_samples:
        f.write(json.dumps(output_example) + "\n")
output_path = '../data/pubmed/test.jsonl'
with xopen(output_path, "w") as f:
    for output_example in test_samples:
        f.write(json.dumps(output_example) + "\n")
torch.save(train_pos_tokens, '../data/pubmed/train.pt')
torch.save(test_pos_tokens, '../data/pubmed/test.pt')


# Arxiv

In [7]:
DATA_PATH = f"{DATA_HOME_PATH}/node_ogb_arxiv"
DATA_NAME = "text_graph_arxiv" #"text_graph_aids" #"text_graph_pubmed" # # "text_graph_cora"
TRAIN_SPLIT_NAME = 'train_index'
VALID_SPLIT_NAME = 'valid_index'
TEST_SPLIT_NAME = 'test_index'

with open(os.path.join(DATA_PATH, f"{DATA_NAME}.pkl"), 'rb') as f:
    graph = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{TRAIN_SPLIT_NAME}.pkl"), 'rb') as f:
    train_split = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{VALID_SPLIT_NAME}.pkl"), 'rb') as f:
    valid_split = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{TEST_SPLIT_NAME}.pkl"), 'rb') as f:
    test_split = pkl.load(f)

In [8]:
task_name = 'prompt_tuning'
PROMPT_SETTINGS_DICT = PROMPT_SETTINGS['arxiv']
desc, categories, question = PROMPT_SETTINGS_DICT['desc'], PROMPT_SETTINGS_DICT['categories'], PROMPT_SETTINGS_DICT['question']
input_item = Item(
    desc = desc,
    categories = categories,
    question = question
    )
hard_prompt = get_prompt_tuning_prompt(
    task_name = task_name,
    task_item = input_item
)

In [9]:
print(hard_prompt)

### USER: Question: Which category from the list that the paper most likely belong to? 

Belows are 40 potential categories to consider:
Category [1](Artificial Intelligence) 
Category [2](Computation and Language) 
Category [3](Computational Complexity) 
Category [4](Computational Engineering, Finance, and Science) 
Category [5](Computational Geometry) 
Category [6](Computer Science and Game Theory) 
Category [7](Computer Vision and Pattern Recognition) 
Category [8](Computers and Society) 
Category [9](Cryptography and Security) 
Category [10](Data Structures and Algorithms) 
Category [11](Databases) 
Category [12](Digital Libraries) 
Category [13](Discrete Mathematics) 
Category [14](Distributed, Parallel, and Cluster Computing) 
Category [15](Emerging Technologies) 
Category [16](Formal Languages and Automata Theory) 
Category [17](General Literature) 
Category [18](Graphics) 
Category [19](Hardware Architecture) 
Category [20](Human-Computer Interaction) 
Category [21](Information

In [10]:
train_samples, valid_samples, test_samples = [], [], []
train_pos_tokens, valid_pos_tokens, test_pos_tokens = graph.x[torch.tensor(train_split)], graph.x[torch.tensor(valid_split)], graph.x[torch.tensor(test_split)]
train_y_labels, valid_y_labels, test_y_labels = (np.array(graph.text_node_labels)[np.array(train_split)]).tolist(), (np.array(graph.text_node_labels)[np.array(valid_split)]).tolist(), (np.array(graph.text_node_labels)[np.array(test_split)]).tolist()
for label in train_y_labels:
    formated_ans = f"This paper most likely belong to {label}"
    sample = {
        'instruction': hard_prompt,
        'output': formated_ans,
    }
    train_samples.append(sample)
    
for label in valid_y_labels:
    formated_ans = f"This paper most likely belong to {label}"
    sample = {
        'instruction': hard_prompt,
        'output': formated_ans,
    }
    valid_samples.append(sample)
    
for label in test_y_labels:
    formated_ans = f"This paper most likely belong to {label}"
    sample = {
        'instruction': hard_prompt,
        'output': formated_ans,
    }
    test_samples.append(sample)


In [11]:
train_pos_tokens = train_pos_tokens.view(-1, 1, 768)
valid_pos_tokens = valid_pos_tokens.view(-1, 1, 768)
test_pos_tokens = test_pos_tokens.view(-1, 1, 768)
len(train_samples), len(valid_samples), len(test_samples)

(90941, 29799, 48603)

In [12]:
output_path = '../data/arxiv/train.jsonl'
with xopen(output_path, "w") as f:
    for output_example in train_samples:
        f.write(json.dumps(output_example) + "\n")
output_path = '../data/arxiv/valid.jsonl'
with xopen(output_path, "w") as f:
    for output_example in valid_samples:
        f.write(json.dumps(output_example) + "\n")
output_path = '../data/arxiv/test.jsonl'
with xopen(output_path, "w") as f:
    for output_example in test_samples:
        f.write(json.dumps(output_example) + "\n")
torch.save(train_pos_tokens, '../data/arxiv/train.pt')
torch.save(valid_pos_tokens, '../data/arxiv/valid.pt')
torch.save(test_pos_tokens, '../data/arxiv/test.pt')


# Cora Full

In [54]:
DATA_PATH = f"{DATA_HOME_PATH}/node_cora_full"
DATA_NAME = "text_graph_cora_full" #"text_graph_aids" #"text_graph_pubmed" # # "text_graph_cora"
TRAIN_SPLIT_NAME = 'train_index'
VALID_SPLIT_NAME = 'valid_index'
TEST_SPLIT_NAME = 'test_index'

with open(os.path.join(DATA_PATH, f"{DATA_NAME}.pkl"), 'rb') as f:
    graph = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{TRAIN_SPLIT_NAME}.pkl"), 'rb') as f:
    train_split = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{VALID_SPLIT_NAME}.pkl"), 'rb') as f:
    valid_split = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{TEST_SPLIT_NAME}.pkl"), 'rb') as f:
    test_split = pkl.load(f)

In [55]:
cora_categories = [item.strip() for item in (np.unique(graph.text_node_labels))]

In [56]:
task_name = 'prompt_tuning'
PROMPT_SETTINGS_DICT = PROMPT_SETTINGS['cora_full']
desc, categories, question = PROMPT_SETTINGS_DICT['desc'], PROMPT_SETTINGS_DICT['categories'], PROMPT_SETTINGS_DICT['question']
input_item = Item(
    desc = desc,
    categories = categories,
    question = question
    )
hard_prompt = get_prompt_tuning_prompt(
    task_name = task_name,
    task_item = input_item
)

In [57]:
print(hard_prompt)

### USER: Question: Which category from the list that the paper most likely belong to? 

Belows are 70 potential categories to consider:
Category [1](Artificial_Intelligence/Agents/) 
Category [2](Artificial_Intelligence/Data_Mining/) 
Category [3](Artificial_Intelligence/Expert_Systems/) 
Category [4](Artificial_Intelligence/Games_and_Search/) 
Category [5](Artificial_Intelligence/Knowledge_Representation/) 
Category [6](Artificial_Intelligence/Machine_Learning/Case-Based/) 
Category [7](Artificial_Intelligence/Machine_Learning/Genetic_Algorithms/) 
Category [8](Artificial_Intelligence/Machine_Learning/Neural_Networks/) 
Category [9](Artificial_Intelligence/Machine_Learning/Probabilistic_Methods/) 
Category [10](Artificial_Intelligence/Machine_Learning/Reinforcement_Learning/) 
Category [11](Artificial_Intelligence/Machine_Learning/Rule_Learning/) 
Category [12](Artificial_Intelligence/Machine_Learning/Theory/) 
Category [13](Artificial_Intelligence/NLP/) 
Category [14](Artificial_Int

In [58]:
len(hard_prompt)

4161

In [59]:
train_samples, valid_samples, test_samples = [], [], []
train_pos_tokens, valid_pos_tokens, test_pos_tokens = graph.x[torch.tensor(train_split)], graph.x[torch.tensor(valid_split)], graph.x[torch.tensor(test_split)]
train_y_labels, valid_y_labels, test_y_labels = (np.array(graph.text_node_labels)[np.array(train_split)]).tolist(), (np.array(graph.text_node_labels)[np.array(valid_split)]).tolist(), (np.array(graph.text_node_labels)[np.array(test_split)]).tolist()
for label in train_y_labels:
    formated_ans = f"This paper most likely belong to {label}"
    sample = {
        'instruction': hard_prompt,
        'output': formated_ans,
    }
    train_samples.append(sample)
    
for label in valid_y_labels:
    formated_ans = f"This paper most likely belong to {label}"
    sample = {
        'instruction': hard_prompt,
        'output': formated_ans,
    }
    valid_samples.append(sample)
    
for label in test_y_labels:
    formated_ans = f"This paper most likely belong to {label}"
    sample = {
        'instruction': hard_prompt,
        'output': formated_ans,
    }
    test_samples.append(sample)


In [60]:
train_pos_tokens = train_pos_tokens.view(-1, 1, 768)
valid_pos_tokens = valid_pos_tokens.view(-1, 1, 768)
test_pos_tokens = test_pos_tokens.view(-1, 1, 768)
len(train_samples), len(valid_samples), len(test_samples)

(700, 1400, 22704)

In [61]:
output_path = '../data/cora_full/train.jsonl'
with xopen(output_path, "w") as f:
    for output_example in train_samples:
        f.write(json.dumps(output_example) + "\n")

output_path = '../data/cora_full/valid.jsonl'
with xopen(output_path, "w") as f:
    for output_example in valid_samples:
        f.write(json.dumps(output_example) + "\n")

output_path = '../data/cora_full/test.jsonl'
with xopen(output_path, "w") as f:
    for output_example in test_samples:
        f.write(json.dumps(output_example) + "\n")

torch.save(train_pos_tokens, '../data/cora_full/train.pt')
torch.save(valid_pos_tokens, '../data/cora_full/valid.pt')
torch.save(test_pos_tokens, '../data/cora_full/test.pt')


In [62]:
train_pos_tokens.shape, valid_pos_tokens.shape, test_pos_tokens.shape

(torch.Size([700, 1, 768]),
 torch.Size([1400, 1, 768]),
 torch.Size([22704, 1, 768]))

In [63]:
len(np.unique(train_split)), len(np.unique(valid_split)), len(np.unique(test_split))

(700, 1400, 22704)

# Cora

In [46]:
DATA_PATH = f"{DATA_HOME_PATH}/node_cora"
DATA_NAME = "text_graph_cora" #"text_graph_aids" #"text_graph_pubmed" # # "text_graph_cora"
TRAIN_SPLIT_NAME = 'train_index'
VALID_SPLIT_NAME = 'valid_index'
TEST_SPLIT_NAME = 'test_index'

with open(os.path.join(DATA_PATH, f"{DATA_NAME}.pkl"), 'rb') as f:
    graph = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{TRAIN_SPLIT_NAME}.pkl"), 'rb') as f:
    train_split = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{VALID_SPLIT_NAME}.pkl"), 'rb') as f:
    valid_split = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{TEST_SPLIT_NAME}.pkl"), 'rb') as f:
    test_split = pkl.load(f)

In [47]:
cora_categories = [item.strip() for item in (np.unique(graph.text_node_labels))]

In [48]:
task_name = 'prompt_tuning'
PROMPT_SETTINGS_DICT = PROMPT_SETTINGS['cora']
desc, categories, question = PROMPT_SETTINGS_DICT['desc'], PROMPT_SETTINGS_DICT['categories'], PROMPT_SETTINGS_DICT['question']
input_item = Item(
    desc = desc,
    categories = categories,
    question = question
    )
hard_prompt = get_prompt_tuning_prompt(
    task_name = task_name,
    task_item = input_item
)

In [49]:
print(hard_prompt)

### USER: Question: Which category from the list that the paper most likely belong to? 

Belows are 7 potential categories to consider:
Category [1](Case-Based) 
Category [2](Genetic_Algorithms) 
Category [3](Neural_Networks) 
Category [4](Probabilistic_Methods) 
Category [5](Reinforcement_Learning) 
Category [6](Rule_Learning) 
Category [7](Theory) 

Given the title and abstract of a research paper, identify one category from a distinct list of research topics that you predict the paper will most likely belong to.
### ASSISTANT:


In [50]:
train_samples, valid_samples, test_samples = [], [], []
train_pos_tokens, valid_pos_tokens, test_pos_tokens = graph.x[torch.tensor(train_split)], graph.x[torch.tensor(valid_split)], graph.x[torch.tensor(test_split)]
train_y_labels, valid_y_labels, test_y_labels = (np.array(graph.text_node_labels)[np.array(train_split)]).tolist(), (np.array(graph.text_node_labels)[np.array(valid_split)]).tolist(), (np.array(graph.text_node_labels)[np.array(test_split)]).tolist()
for label in train_y_labels:
    formated_ans = f"This paper most likely belong to {label}"
    sample = {
        'instruction': hard_prompt,
        'output': formated_ans,
    }
    train_samples.append(sample)
    
for label in valid_y_labels:
    formated_ans = f"This paper most likely belong to {label}"
    sample = {
        'instruction': hard_prompt,
        'output': formated_ans,
    }
    valid_samples.append(sample)
    
for label in test_y_labels:
    formated_ans = f"This paper most likely belong to {label}"
    sample = {
        'instruction': hard_prompt,
        'output': formated_ans,
    }
    test_samples.append(sample)


In [51]:
len(train_samples), len(valid_samples), len(test_samples)

(1624, 542, 542)

In [53]:
output_path = '../data/cora/train.jsonl'
with xopen(output_path, "w") as f:
    for output_example in train_samples:
        f.write(json.dumps(output_example) + "\n")

output_path = '../data/cora/valid.jsonl'
with xopen(output_path, "w") as f:
    for output_example in valid_samples:
        f.write(json.dumps(output_example) + "\n")

output_path = '../data/cora/test.jsonl'
with xopen(output_path, "w") as f:
    for output_example in test_samples:
        f.write(json.dumps(output_example) + "\n")

# AIDS

In [67]:
DATA_PATH = "/home/ec2-user/proj/datasets/graph/text_graph"
DATA_NAME = "text_graph_aids" #"text_graph_aids" #"text_graph_pubmed" # # "text_graph_cora"
TRAIN_SPLIT_NAME = 'train_index'
TEST_SPLIT_NAME = 'test_index'

with open(os.path.join(DATA_PATH, f"{DATA_NAME}.pkl"), 'rb') as f:
    graph = pkl.load(f)
train_split = list(np.random.choice(np.arange(len(graph)), 1600, replace=False))
test_split = []
for i in range(len(graph)):
    if i not in train_split:
        test_split.append(i)

In [69]:
task_name = 'prompt_tuning'
PROMPT_SETTINGS_DICT = PROMPT_SETTINGS['aids']
desc, categories, question = PROMPT_SETTINGS_DICT['desc'], PROMPT_SETTINGS_DICT['categories'], PROMPT_SETTINGS_DICT['question']
input_item = Item(
    desc = desc,
    categories = categories,
    question = question
    )
hard_prompt = get_prompt_tuning_prompt(
    task_name = task_name,
    task_item = input_item
)

In [70]:
print(hard_prompt)

Question: Which category from the list that the input molecule most likely belong to? 

Belows are 2 potential categories to consider:
Category [1](HIV antiviral active compound) 
Category [2](HIV antiviral inactive compound) 

Given the atoms type and their connection structure of a compound, identify if the given compound is HIV antiviral active or not.


In [82]:
train_samples, test_samples = [], []
X = torch.stack([graph[i].x.mean(dim=0) for i in range(len(graph))])
train_pos_tokens, test_pos_tokens = X[torch.tensor(train_split)], X[torch.tensor(test_split)]
for i in tqdm(range(len(graph))):
    formated_ans = f"This compound most likely belong to {graph[i].text_graph_label}"
    sample = {
        'instruction': hard_prompt,
        'output': formated_ans,
    }
    if i in train_split:
        train_samples.append(sample)
    elif i in test_split:
        test_samples.append(sample)



100%|██████████| 2000/2000 [00:00<00:00, 32639.74it/s]


In [84]:
train_pos_tokens = train_pos_tokens.view(-1, 1, 768)
test_pos_tokens = test_pos_tokens.view(-1, 1, 768)
len(train_samples), len(test_samples)

(1600, 400)

In [85]:
output_path = '../data/aids/train.jsonl'
with xopen(output_path, "w") as f:
    for output_example in train_samples:
        f.write(json.dumps(output_example) + "\n")
output_path = '../data/aids/test.jsonl'
with xopen(output_path, "w") as f:
    for output_example in test_samples:
        f.write(json.dumps(output_example) + "\n")
torch.save(train_pos_tokens, '../data/aids/train.pt')
torch.save(test_pos_tokens, '../data/aids/test.pt')


# wiki

In [26]:
DATA_PATH = "/home/ubuntu/data/graph/edge_ogbl_wiki"
DATA_NAME = "text_graph_wiki" #"text_graph_aids" #"text_graph_pubmed" # # "text_graph_cora"
TRAIN_SPLIT_NAME = 'train_index'
TEST_SPLIT_NAME = 'test_index'

with open(os.path.join(DATA_PATH, f"{DATA_NAME}.pkl"), 'rb') as f:
    graph = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{TRAIN_SPLIT_NAME}.pkl"), 'rb') as f:
    train_split = pkl.load(f)
    train_split = list(np.unique(train_split))
with open(os.path.join(DATA_PATH, f"{TEST_SPLIT_NAME}.pkl"), 'rb') as f:
    test_split = pkl.load(f)
    test_split = list(np.unique(test_split))

In [27]:
graph

Data(text_nodes=[46210], text_node_labels=[46210], text_graph_label='None', text_edge_labels=[34273], x=[46210, 768], y=[46210], edge_index=[2, 34273])

In [1]:
66.64+60+1.64+44.66+56.72

229.66