In [1]:
import re
import os
import gc
import dgl
import torch
import duckdb
import pandas as pd
import dask.dataframe as dd
import pyarrow.parquet as pq

from math import sqrt
from time import time
from rdkit import Chem
from dgl import function as func
from dgl.nn.functional import edge_softmax
from sklearn.model_selection import train_test_split

from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
sample = pd.read_parquet("/home/pervinco/Datasets/leash-bio/preprocessed/molecule_smiles_uniques.parquet", engine="pyarrow")
print(sample.shape)

In [None]:
sample.head(10)

In [2]:
# Dask를 사용하여 Parquet 파일 로드
data = dd.read_parquet('/home/pervinco/Datasets/leash-bio/train.parquet')

# Unique smiles 추출
unique_smiles = data['molecule_smiles'].unique().compute()
unique_smiles = unique_smiles.tolist()

# Train, validation, test splits
train_smiles, temp_smiles = train_test_split(unique_smiles, train_size=0.98, random_state=42)
valid_smiles, test_smiles = train_test_split(temp_smiles, test_size=0.64, random_state=42)

# 데이터셋 분할
train_data = data[data['molecule_smiles'].isin(train_smiles)].compute()
valid_data = data[data['molecule_smiles'].isin(valid_smiles)].compute()
test_data = data[data['molecule_smiles'].isin(test_smiles)].compute()

print(f"Training set size: {len(train_data)}")
print(f"Validation set size: {len(valid_data)}")
print(f"Test set size: {len(test_data)}")


In [None]:
n_iter = 1
limit = 100000

parquet_path = "/home/pervinco/Datasets/leash-bio/train.parquet"
data_save_path = "/home/pervinco/Datasets/leash-bio/GNN"
unique_atoms_path = "./unique_atoms.txt"

In [None]:
os.makedirs(data_save_path, exist_ok=True)

# 1.데이터셋을 구성하는 고유한 원자 리스트 추출

In [None]:
sample = 'c1ccccc1C(=O)O'
molecule = Chem.MolFromSmiles(sample)
atom_list = molecule.GetAtoms()
print(atom_list)

for atom in atom_list:
    print(atom.GetSymbol())

In [None]:
def extract_unique_atoms(smiles_list):
    unique_atoms = set()
    
    for smiles in smiles_list:
        molecule = Chem.MolFromSmiles(smiles)
        if molecule:
            for atom in molecule.GetAtoms():
                unique_atoms.add(atom.GetSymbol())
    
    return unique_atoms

def generate_datasets(parquet_path, n_iter, limit, save_path):
    os.makedirs(save_path, exist_ok=True)

    offset_0 = 0
    offset_1 = 0
    all_unique_atoms = set()

    try:
        con = duckdb.connect()
        for i in range(n_iter):
            data_0 = None
            data_1 = None
            try:
                data_0 = con.query(f"""
                    SELECT id, molecule_smiles, protein_name, binds
                    FROM parquet_scan('{parquet_path}')
                    WHERE binds = 0
                    ORDER BY random()
                    LIMIT {limit} OFFSET {offset_0}
                """).df()

                data_1 = con.query(f"""
                    SELECT id, molecule_smiles, protein_name, binds
                    FROM parquet_scan('{parquet_path}')
                    WHERE binds = 1
                    ORDER BY random()
                    LIMIT {limit} OFFSET {offset_1}
                """).df()
            except Exception as e:
                print(f"Query failed: {e}")
                break

            if data_1.empty:
                try:
                    data_1 = con.query(f"""
                        SELECT id, molecule_smiles, protein_name, binds
                        FROM parquet_scan('{parquet_path}')
                        WHERE binds = 1
                        ORDER BY random()
                        LIMIT {limit}
                    """).df()
                except Exception as e:
                    print(f"Query failed: {e}")
                    break

            data = pd.concat([data_0, data_1])
            smiles_list = data['molecule_smiles'].tolist()
            unique_atoms = extract_unique_atoms(smiles_list)
            all_unique_atoms.update(unique_atoms)

            data = data.sample(frac=1).reset_index(drop=True)
            binds_0_count = data[data['binds'] == 0].shape[0]
            binds_1_count = data[data['binds'] == 1].shape[0]
            print(f"Iter {i+1} : Dataset shape : {data.shape}, binds=0 count : {binds_0_count}, binds=1 count : {binds_1_count}")

            offset_0 += limit
            offset_1 += limit
            
            table = pq.Table.from_pandas(data)
            pq.write_table(table, f"{save_path}/dataset-{i:>04}.parquet")
            del data_0, data_1, data, table
            gc.collect()
    finally:
        con.close()

    with open(f"{save_path}/unique_atoms.txt", 'w') as file:
        for atom in all_unique_atoms:
            file.write(f"{atom}\n")
        file.write("\n")

# 2.Torch Dataset

In [None]:
## chunk만큼의 데이터 묶음을 하나의 Dataset이라 가정.
class LeasBioDataset(Dataset):
    def __init__(self, data_path):
        data = pq.read_pandas(data_path).to_pandas()
        self.train_smiles = list(data['molecule_smiles'])
        self.train_labels = list(data['binds'])

    def __len__(self):
        return len(self.train_labels)
    
    def __getitem__(self, idx):
        return self.train_smiles[idx], self.train_labels[idx]

In [None]:
train_dataset = LeasBioDataset("/home/pervinco/Datasets/leash-bio/gnn/dataset-0000.parquet")

In [None]:
smiles, labels = train_dataset[0]

# 3.DataLoader

In [None]:
ATOM_VOCAB = [] ## TODO : 고유한 원자 리스트로 만들기.

def one_of_k_encoding(x, vocab):
	if x not in vocab:
		x = vocab[-1]
            
	return list(map(lambda s: float(x==s), vocab))


def get_atom_feature(atom):
	atom_feature = one_of_k_encoding(atom.GetSymbol(), ATOM_VOCAB)
	atom_feature += one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
	atom_feature += one_of_k_encoding(atom.GetTotalNumHs(), [0, 1, 2, 3, 4])
	atom_feature += one_of_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5])
	atom_feature += [atom.GetIsAromatic()]
	return atom_feature


def get_bond_feature(bond):
	bt = bond.GetBondType()
	bond_feature = [
		bt == Chem.rdchem.BondType.SINGLE,
		bt == Chem.rdchem.BondType.DOUBLE,
		bt == Chem.rdchem.BondType.TRIPLE,
		bt == Chem.rdchem.BondType.AROMATIC,
		bond.GetIsConjugated(),
		bond.IsInRing()
	]
	return bond_feature


def get_molecular_graph(smiles):
    molecule = Chem.MolFromSmiles(smiles)
    graph = dgl.DGLGraph()

    ## 원자 수(노드 수)
    atom_list = molecule.GetAtoms()
    num_atoms = len(atom_list)
    graph.add_nodes(num_atoms)

    ## 원자의 특징(Atom Features)
    # 원자 특징들을 계산 -> 정수형 특징값들 -> 배열
    atom_feature_list = torch.tensor([get_atom_feature(atom) for atom in atom_list], dtype=torch.float64)
    graph.ndata['h'] = atom_feature_list

    ## 연결의 특징(Edge Features)
    bond_feature_list = []
    bond_list = molecule.GetBonds() ## 분자가 가진 bond들을 구함.
    for bond in bond_list:
          ## 각각의 bond가 가진 특징값들을 계산.
          # 연결 특징들을 계산 -> 정수형 특징값들 -> 배열
          bond_feature = get_bond_feature(bond)

          src = bond.GetBeginAtom().GetIdx() ## 결합의 시작 원자의 인덱스
          dst = bond.GetEndAtom().GetIdx() ## 결합의  원자의 인덱스

          ## 그래프에 시작 원자에서 끝 원자로의 방향성을 가진 엣지를 추가
          graph.add_edges(src, dst)
          bond_feature_list.append(bond_feature)

          ## DGL 그래프는 비방향성 그래프. 따라서 반대 방향의 엣지를 추가
          graph.add_edges(dst, src)
          bond_feature_list.append(bond_feature)

    bond_feature_list = torch.tensor(bond_feature_list, dtype=torch.float64)
    graph.edata['e_ij'] = bond_feature_list
    return graph


def collate_fn(batch):
    graph_list, label_list = [], []
    for item in batch:
        smiles, label = item[0], item[1]
        graph = get_molecular_graph(smiles)

        graph_list.append(graph)
        label_list.append(label)

    graph_list = dgl.batch(graph_list)
    label_list = torch.tensor(label_list, dtype=torch.float64)

    return graph_list, label_list


In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)

# 4.Graph Attention Layer

In [None]:
class MultiLayerPerceptron(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, bias=False, activation=F.relu):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.activation = activation

        self.linear1 = nn.Linear(input_dim, hidden_dim, bias=bias)
        self.linear2 = nn.Linear(hidden_dim, output_dim, bias=bias)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)

        return x

class GraphAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads=4, mlp_bias=False, drop_prob=0.2, activation=F.relu):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.dk = hidden_dim // num_heads
        self.prob = drop_prob
        self.activation = activation

        self.mlp = MultiLayerPerceptron(input_dim=hidden_dim, hidden_dim=2*hidden_dim, output_dim=hidden_dim, bias=mlp_bias, activation=activation)
        self.norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(drop_prob)

        self.w1 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.w4 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.w5 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.w6 = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def forward(self, graph:dgl.DGLGraph):
        h0 = graph.ndata['h'] ## graph nodes
        e_ij = graph.edata['e_ij'] ## graph edges

        graph.ndata['u'] = self.w1(h0).view(-1, self.num_heads, self.dk)
        graph.ndata['v'] = self.w2(h0).view(-1, self.num_heads, self.dk)
        graph.edata['x_ij'] = self.w3(e_ij).view(-1, self.num_heads, self.dk)

        graph.apply_edges(func.v_add_e('v', 'x_ij', 'm'))
        graph.apply_edges(func.u_mul_e('u', 'm', 'attn'))
        graph.edata['attn'] = edge_softmax(graph, graph.edata['attn'] / sqrt(self.dk))

        graph.ndata['k'] = self.w4(h0).view(-1, self.num_heads, self.dk)
        graph.edata['x_ij'] = self.w5(e_ij).view(-1, self.num_heads, self.dk)
        graph.apply_edges(func.v_add_e('k', 'x_ij', 'm'))

        graph.edata['m'] = graph.edata['attn'] * graph.edata['m']
        graph.update_all(func.copy_edge('m', 'm'), func.sum('m', 'h'))

        h = self.w6(h0) + graph.ndata['h'].view(-1, self.hidden_dim)
        h = self.norm(h)

        h = h + self.mlp(h)
        h = self.norm(h)
        h = self.dropout(h)
        
        graph.ndata['h'] = h 
        return graph

# 5.Model

In [None]:
class GAT(nn.Modile):
    def __init__(self, 
                 num_layers=5, 
                 hidden_dim=64, 
                 num_heads=4, 
                 drop_prob=0.2, 
                 mlp_bias=False, 
                 readout='sum', 
                 activation=F.relu, 
                 initial_node_dim=58, 
                 initial_edge_dim=6):
        super().__init__()
        self.num_layers = num_layers
        self.readout = readout

        self.node_embedding = nn.Linear(initial_node_dim, hidden_dim, bias=False)
        self.edge_embedding = nn.Linear(initial_edge_dim, hidden_dim, bias=False)

        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            layer = GraphAttention(hidden_dim, num_heads, mlp_bias, drop_prob, activation)
            self.layers.append(layer)

        self.output = nn.Linear(hidden_dim, 1, bias=False)
        self.sigmoid = F.sigmoid

    def forward(self, graph:dgl.DGLGraph):
        h = self.node_embedding(graph.ndata['h'].float())
        e_ij = self.edge_embedding(graph.edata['e_ij'].float())

        graph.ndata['h'] = h
        graph.edata['e_ij'] = e_ij

        for i in range(self.num_layers):
            graph = self.layers[i](graph)

        out = dgl.readout_nodes(graph, 'h', op=self.readout)
        out = self.output(out)
        out = self.sigmoid(out)

        return out


In [None]:
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv

print("PyTorch version:", torch.__version__)
print("PyTorch Geometric version:", torch_geometric.__version__)

# 테스트용 데이터 생성
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)

# 간단한 GCN 레이어 생성 및 적용
conv = GCNConv(1, 2)
x = conv(data.x, data.edge_index)

print("GCN layer output:", x)
