In [2]:
import os
import numpy as np
from math import sqrt
from scipy import stats
import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric import data as DATA
from torch_geometric.loader import DataLoader
import os
import pandas as pd
import numpy as np
import json,pickle
from collections import OrderedDict
from rdkit import Chem
from rdkit.Chem import MolFromSmiles
import networkx as nx

import os
import os.path as osp
import re

import torch
from torch_geometric.data import (InMemoryDataset, Data, download_url,
                                  extract_gz)

try:
    from rdkit import Chem
except ImportError:
    Chem = None

In [3]:
def smiles_features(mol):
    symbols = ['K', 'Y', 'V', 'Sm', 'Dy', 'In', 'Lu', 'Hg', 'Co', 'Mg',    #list of all elements in the dataset
        'Cu', 'Rh', 'Hf', 'O', 'As', 'Ge', 'Au', 'Mo', 'Br', 'Ce', 
        'Zr', 'Ag', 'Ba', 'N', 'Cr', 'Sr', 'Fe', 'Gd', 'I', 'Al', 
        'B', 'Se', 'Pr', 'Te', 'Cd', 'Pd', 'Si', 'Zn', 'Pb', 'Sn', 
        'Cl', 'Mn', 'Cs', 'Na', 'S', 'Ti', 'Ni', 'Ru', 'Ca', 'Nd', 
        'W', 'H', 'Li', 'Sb', 'Bi', 'La', 'Pt', 'Nb', 'P', 'F', 'C',
        'Re','Ta','Ir','Be','Tl']

    hybridizations = [
        Chem.rdchem.HybridizationType.S,
        Chem.rdchem.HybridizationType.SP,
        Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3,
        Chem.rdchem.HybridizationType.SP3D,
        Chem.rdchem.HybridizationType.SP3D2,
        'other',
    ]

    stereos = [
        Chem.rdchem.BondStereo.STEREONONE,
        Chem.rdchem.BondStereo.STEREOANY,
        Chem.rdchem.BondStereo.STEREOZ,
        Chem.rdchem.BondStereo.STEREOE,
    ]
    features = []
    xs = []
    for atom in mol.GetAtoms():
        symbol = [0.] * len(symbols)
        symbol[symbols.index(atom.GetSymbol())] = 1.
        #comment degree from 6 to 8
        degree = [0.] * 8
        degree[atom.GetDegree()] = 1.
        formal_charge = atom.GetFormalCharge()
        radical_electrons = atom.GetNumRadicalElectrons()
        hybridization = [0.] * len(hybridizations)
        hybridization[hybridizations.index(
            atom.GetHybridization())] = 1.
        aromaticity = 1. if atom.GetIsAromatic() else 0.
        hydrogens = [0.] * 5
        hydrogens[atom.GetTotalNumHs()] = 1.
        chirality = 1. if atom.HasProp('_ChiralityPossible') else 0.
        chirality_type = [0.] * 2
        if atom.HasProp('_CIPCode'):
            chirality_type[['R', 'S'].index(atom.GetProp('_CIPCode'))] = 1.
    
        x = torch.tensor(symbol + degree + [formal_charge] +
                         [radical_electrons] + hybridization +
                         [aromaticity] + hydrogens + [chirality] +
                         chirality_type)
        xs.append(x)
    
        x = torch.stack(xs, dim=0)

    edge_indices = []
    edge_attrs = []
    for bond in mol.GetBonds():
        edge_indices += [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]]
        edge_indices += [[bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]]
    
        bond_type = bond.GetBondType()
        single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0.
        double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0.
        triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0.
        aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0.
        conjugation = 1. if bond.GetIsConjugated() else 0.
        ring = 1. if bond.IsInRing() else 0.
        stereo = [0.] * 4
        stereo[stereos.index(bond.GetStereo())] = 1.
    
        edge_attr = torch.tensor(
            [single, double, triple, aromatic, conjugation, ring] + stereo)
    
        edge_attrs += [edge_attr, edge_attr]
    
    if len(edge_attrs) == 0:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
        edge_attr = torch.zeros((0, 10), dtype=torch.float)
    else:
        edge_index = torch.tensor(edge_indices).t().contiguous()
        edge_attr = torch.stack(edge_attrs, dim=0)
    return x, edge_index, edge_attr

def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def smile_to_graph(smile):
    mol = Chem.MolFromSmiles(smile)
    if mol == None:
        return None
    
    c_size = mol.GetNumAtoms()
    features, edge_index, edge_attr = smiles_features(mol)
    # features = []
    # bonds = mol.GetBonds()
    # for atom in mol.GetAtoms():
    #     feature = atom_features(atom)
    #     features.append( feature / sum(feature) )

    # edges = []
    # for bond in mol.GetBonds():
    #     edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
    # g = nx.Graph(edges).to_directed()
    # edge_index = []
    # for e1, e2 in g.edges:
    #     edge_index.append([e1, e2])
        
    return features, edge_index, edge_attr

In [4]:
class Molecule_data(InMemoryDataset):
    def __init__(self, root='/tmp', dataset='_drug1',xd=None, xt=None, y=None, xt_featrue=None, transform=None,
                 pre_transform=None,smile_graph=None):

        #root is required for save preprocessed data, default is '/tmp'
        super(Molecule_data, self).__init__(root, transform, pre_transform)
        # benchmark dataset, default = 'davis'
        self.dataset = dataset
        if os.path.isfile(self.processed_paths[0]):
#             print('Pre-processed data found: {}, loading ...'.format(self.processed_paths[0]))
            self.data, self.slices = torch.load(self.processed_paths[0])
        else:
            print('Pre-processed data {} not found, doing pre-processing...'.format(self.processed_paths[0]))
            self.process(xd, xt, xt_featrue, y, smile_graph)
            self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        pass
        #return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return [self.dataset + '.pt']

    def download(self):
        # Download to `self.raw_dir`.
        pass

    def _download(self):
        pass

    def _process(self):
        if not os.path.exists(self.processed_dir):
            os.makedirs(self.processed_dir)
            
    def get_cell_feature(self, cellId, cell_features):
        for row in islice(cell_features, 0, None):
            if cellId in row[0]:
                return row[1:]
        return False

    def process(self, xd, xt, xt_featrue,y,smile_graph):
        assert (len(xd) == len(xt) and len(xt) == len(y)), "The three lists must be the same length!"
        data_list = []
        data_len = len(xd)
        for i in range(data_len):
            print('Converting SMILES to graph: {}/{}'.format(i+1, data_len))
            smiles = xd[i]
            target = xt[i]
            labels = y[i]
            # convert SMILES to molecular representation using rdkit
            x, edge_index, edge_attr = smile_graph[smiles]
            # make the graph ready for PyTorch Geometrics GCN algorithms:
            GCNData = Data(x=torch.Tensor(x),
                      edge_index=edge_index,
                      edge_attr=edge_attr,
                      y=torch.FloatTensor([labels]))
        
            cell = self.get_cell_feature(target, xt_featrue)

            if cell == False : 
                print('cell', cell)
                sys.exit()

            new_cell = []
            # print('cell_feature', cell_feature)
            for n in cell:
                new_cell.append(float(n))
            GCNData.cell = torch.FloatTensor([new_cell])
            #GCNData.__setitem__('c_size', torch.LongTensor([c_size]))
            # append graph, label and target sequence to data list
            data_list.append(GCNData)

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]
        print('Graph construction done. Saving to file.')
        data, slices = self.collate(data_list)
#         print(data.shape,slices.shape)
        # save preprocessed data:
        torch.save((data, slices), self.processed_paths[0])

In [5]:
import csv
from itertools import islice

import pandas as pd
import numpy as np
import os
import json, pickle
from collections import OrderedDict
from rdkit import Chem
import networkx as nx


def creat_data(datafile, cellfile):
    file2 = cellfile
    cell_features = []
    with open(file2) as csvfile:
        csv_reader = csv.reader(csvfile)  # 使用csv.reader读取csvfile中的文件
        for row in csv_reader:
            cell_features.append(row)
    cell_features = np.array(cell_features)
    print('cell_features', cell_features)

    compound_iso_smiles = []
    df = pd.read_csv('data/smiles.csv')
    compound_iso_smiles += list(df['smile'])
    compound_iso_smiles = set(compound_iso_smiles)
    smile_graph = {}
    for smile in compound_iso_smiles:
        g = smile_to_graph(smile)
        smile_graph[smile] = g

    datasets = datafile
    # convert to PyTorch data format
    processed_data_file_train = 'data_perparation/processed/' + datasets + '_train.pt'

    if ((not os.path.isfile(processed_data_file_train))):
        df = pd.read_csv('data/' + datasets + '.csv')
        drug1, drug2, cell, label = list(df['drug1']), list(df['drug2']), list(df['cell']), list(df['label'])
        drug1, drug2, cell, label = np.asarray(drug1), np.asarray(drug2), np.asarray(cell), np.asarray(label)
        # make data PyTorch Geometric ready
        print(len(cell),len(label))
      
        Molecule_data(root='data_perparation', dataset=datafile + '_drug1', xd=drug1, xt=cell, xt_featrue=cell_features, y=label,smile_graph=smile_graph)
        Molecule_data(root='data_perparation', dataset=datafile + '_drug2', xd=drug2, xt=cell, xt_featrue=cell_features, y=label,smile_graph=smile_graph)
       
        print('preparing ', datasets + '_.pt in pytorch format!')
    #
    #     print(processed_data_file_train, ' have been created')
    #
    # else:
    #     print(processed_data_file_train, ' are already created')

if __name__ == "__main__":
    # datafile = 'prostate'
    cellfile = 'data/new_cell_features_954.csv'
    da = ['A375']
    for datafile in da:
        creat_data(datafile, cellfile)

cell_features [['gene_id' 'ENSG00000116237' 'ENSG00000162413' ... 'ENSG00000157617'
  'ENSG00000160208' 'ENSG00000141959']
 ['transcript_id' 'ENST00000343813' 'ENST00000377658' ...
  'ENST00000329623' 'ENST00000340648' 'ENST00000349048']
 ['A2058' '51.07' '20.76' ... '1.46' '23.67' '78.86']
 ...
 ['YKG1_CENTRAL_NERVOUS_SYSTEM' '41.25' '23.28' ... '1.72' '18.24'
  '71.48']
 ['ZR751_BREAST' '30.69' '4.72' ... '0.77' '6.8' '36.37']
 ['ZR7530_BREAST' '25.04' '7.31' ... '5.44' '10.06' '73.45']]
425 425
Pre-processed data data_perparation/processed/A375_drug1.pt not found, doing pre-processing...
Converting SMILES to graph: 1/425
Converting SMILES to graph: 2/425
Converting SMILES to graph: 3/425
Converting SMILES to graph: 4/425
Converting SMILES to graph: 5/425
Converting SMILES to graph: 6/425
Converting SMILES to graph: 7/425
Converting SMILES to graph: 8/425
Converting SMILES to graph: 9/425
Converting SMILES to graph: 10/425
Converting SMILES to graph: 11/425
Converting SMILES to graph



Pre-processed data data_perparation/processed/A375_drug2.pt not found, doing pre-processing...
Converting SMILES to graph: 1/425
Converting SMILES to graph: 2/425
Converting SMILES to graph: 3/425
Converting SMILES to graph: 4/425
Converting SMILES to graph: 5/425
Converting SMILES to graph: 6/425
Converting SMILES to graph: 7/425
Converting SMILES to graph: 8/425
Converting SMILES to graph: 9/425
Converting SMILES to graph: 10/425
Converting SMILES to graph: 11/425
Converting SMILES to graph: 12/425
Converting SMILES to graph: 13/425
Converting SMILES to graph: 14/425
Converting SMILES to graph: 15/425
Converting SMILES to graph: 16/425
Converting SMILES to graph: 17/425
Converting SMILES to graph: 18/425
Converting SMILES to graph: 19/425
Converting SMILES to graph: 20/425
Converting SMILES to graph: 21/425
Converting SMILES to graph: 22/425
Converting SMILES to graph: 23/425
Converting SMILES to graph: 24/425
Converting SMILES to graph: 25/425
Converting SMILES to graph: 26/425
Conv