In [1]:
import pandas as pd
import numpy as np
import pickle
import torch

In [2]:
train = pd.read_csv('../Data/train.csv')
test = pd.read_csv('../Data/test.csv')
structures = pd.read_csv('../Data/structures.csv')
test['scalar_coupling_constant'] = np.nan

In [3]:
coupling = train.append(test, ignore_index=True)

def map_atom_info(df, atom_idx):
    df = pd.merge(df, structures, how = 'left',
                  left_on  = ['molecule_name', f'atom_index_{atom_idx}'],
                  right_on = ['molecule_name',  'atom_index'])
    
    df = df.drop('atom_index', axis=1)
    df = df.rename(columns={'atom': f'atom_{atom_idx}',
                            'x': f'x_{atom_idx}',
                            'y': f'y_{atom_idx}',
                            'z': f'z_{atom_idx}'})
    return df

coupling = map_atom_info(coupling, 0)
coupling = map_atom_info(coupling, 1)

train_p_0 = coupling[['x_0', 'y_0', 'z_0']].values
train_p_1 = coupling[['x_1', 'y_1', 'z_1']].values
coupling['dist'] = np.linalg.norm(train_p_0 - train_p_1, axis=1)
coupling[['1JHC', '1JHN', '2JHC', '2JHH', '2JHN', '3JHC', '3JHH', '3JHN']]=pd.get_dummies(coupling.type)
coupling = coupling.groupby(['molecule_name'])

coupling_edge_index = {}
coupling_edge_attr = {}
coupling_y = {}
for k,v in coupling:
    coupling_edge_index[k] = v[['atom_index_0','atom_index_1']].values
    coupling_edge_attr[k] = v[['1JHC','1JHN','2JHC','2JHH','2JHN','3JHC','3JHH','3JHN',\
                               'dist','x_0', 'y_0', 'z_0','x_1', 'y_1', 'z_1']].values.astype(np.float32)

    if not np.any(np.isnan(v.scalar_coupling_constant.values)):
        coupling_y[k] = v.scalar_coupling_constant.values.astype(np.float32)
    else:
        coupling_y[k] = v.id.values

In [4]:
structures[['C', 'F', 'H', 'N', 'O']] = pd.get_dummies(structures.atom)
structures_gb = structures.groupby(['molecule_name'])
structures_gb = {k:v.loc[:,['x','y','z','C', 'F', 'H', 'N', 'O']].values.astype(np.float32) for k,v in structures_gb}

In [5]:
# split train/val
train_mol = np.unique(train.molecule_name)
test_mol = np.unique(test.molecule_name)
train_mol = np.random.permutation(train_mol)
train_mol2 = train_mol[:70000]
val_mol = train_mol[70000:]
train_mol = train_mol2

In [6]:
train_atom_list,train_coupling_list,train_index_list,train_target_list = [],[],[],[]
for mol in train_mol:
    train_atom_list.append(torch.from_numpy(structures_gb[mol]))
    train_coupling_list.append(torch.from_numpy(coupling_edge_attr[mol]))
    train_index_list.append(torch.from_numpy(coupling_edge_index[mol]))
    train_target_list.append(torch.from_numpy(coupling_y[mol]))
    
val_atom_list,val_coupling_list,val_index_list,val_target_list = [],[],[],[]
for mol in val_mol:
    val_atom_list.append(torch.from_numpy(structures_gb[mol]))
    val_coupling_list.append(torch.from_numpy(coupling_edge_attr[mol]))
    val_index_list.append(torch.from_numpy(coupling_edge_index[mol]))
    val_target_list.append(torch.from_numpy(coupling_y[mol]))
    
test_atom_list,test_coupling_list,test_index_list,test_target_list = [],[],[],[]
for mol in test_mol:
    test_atom_list.append(torch.from_numpy(structures_gb[mol]))
    test_coupling_list.append(torch.from_numpy(coupling_edge_attr[mol]))
    test_index_list.append(torch.from_numpy(coupling_edge_index[mol]))
    test_target_list.append(torch.from_numpy(coupling_y[mol]))

In [7]:
with open('../Data/train_atom_list.pickle', 'wb') as handle:
    pickle.dump(train_atom_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('../Data/train_coupling_list.pickle', 'wb') as handle:
    pickle.dump(train_coupling_list, handle, protocol=pickle.HIGHEST_PROTOCOL)    
with open('../Data/train_index_list.pickle', 'wb') as handle:
    pickle.dump(train_index_list, handle, protocol=pickle.HIGHEST_PROTOCOL)  
with open('../Data/train_target_list.pickle', 'wb') as handle:
    pickle.dump(train_target_list, handle, protocol=pickle.HIGHEST_PROTOCOL)      
    
with open('../Data/val_atom_list.pickle', 'wb') as handle:
    pickle.dump(val_atom_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('../Data/val_coupling_list.pickle', 'wb') as handle:
    pickle.dump(val_coupling_list, handle, protocol=pickle.HIGHEST_PROTOCOL)    
with open('../Data/val_index_list.pickle', 'wb') as handle:
    pickle.dump(val_index_list, handle, protocol=pickle.HIGHEST_PROTOCOL)  
with open('../Data/val_target_list.pickle', 'wb') as handle:
    pickle.dump(val_target_list, handle, protocol=pickle.HIGHEST_PROTOCOL)      
    
with open('../Data/test_atom_list.pickle', 'wb') as handle:
    pickle.dump(test_atom_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('../Data/test_coupling_list.pickle', 'wb') as handle:
    pickle.dump(test_coupling_list, handle, protocol=pickle.HIGHEST_PROTOCOL)    
with open('../Data/test_index_list.pickle', 'wb') as handle:
    pickle.dump(test_index_list, handle, protocol=pickle.HIGHEST_PROTOCOL)  
with open('../Data/test_target_list.pickle', 'wb') as handle:
    pickle.dump(test_target_list, handle, protocol=pickle.HIGHEST_PROTOCOL)       