In [2]:
import pandas as pd
import numpy as np
import pickle
import copy

In [18]:
train = pd.read_csv('../Data/train.csv')
test = pd.read_csv('../Data/test.csv')

### node information

In [3]:
with open('../Data/structures_dict_ACSF_PCA.pickle', 'rb') as handle:
    structures_dict = pickle.load(handle)

### bond information

In [4]:
with open('../Data/bonds_edge_index.pickle', 'rb') as handle:
    bonds_edge_index = pickle.load(handle)
with open('../Data/bonds_edge_attr_expand.pickle', 'rb') as handle:
    bonds_edge_attr = pickle.load(handle)

### coupling information

In [9]:
with open('../Data/coupling_edge_index.pickle', 'rb') as handle:
    coupling_edge_index = pickle.load(handle)
with open('../Data/coupling_edge_attr.pickle', 'rb') as handle:
    coupling_edge_attr = pickle.load(handle)
with open('../Data/coupling_edge_dist_expand.pickle', 'rb') as handle:
    coupling_edge_dist = pickle.load(handle)
with open('../Data/coupling_y.pickle', 'rb') as handle:
    coupling_y = pickle.load(handle)
with open('../Data/coupling_id.pickle', 'rb') as handle:
    coupling_id = pickle.load(handle)    

In [19]:
train_mol = np.unique(train.molecule_name)
test_mol = np.unique(test.molecule_name)

In [20]:
train_mol = np.random.permutation(train_mol)

In [21]:
train_mol2 = train_mol[:70000]
val_mol = train_mol[70000:]
train_mol = train_mol2

In [22]:
def create_data(mols,IsTrain):
    type_list = [[] for _ in range(8)]
    tot_list = []
    if not IsTrain:
        test_id_type_list = [[] for _ in range(8)]
        test_id_list = []
        
    for m in mols:
        if IsTrain:
            dict_ = {'x':structures_dict[m],'edge_index':bonds_edge_index[m],\
                               'edge_attr':bonds_edge_attr[m],'y':coupling_y[m],\
                               'edge_index3':coupling_edge_index[m],'edge_attr3':coupling_edge_attr[m],\
                               'edge_attr4':coupling_edge_dist[m]}
            tot_list.append(copy.deepcopy(dict_))
                        
            temp = dict_['edge_attr3'].argmax(1)
            for i in np.nonzero(dict_['edge_attr3'].sum(0))[0]:
                dict_['type_attr'] = (temp==i).astype(np.uint8)
                type_list[i].append(copy.deepcopy(dict_))
        else:
            dict_ = {'x':structures_dict[m],'edge_index':bonds_edge_index[m],\
                       'edge_attr':bonds_edge_attr[m],\
                       'edge_index3':coupling_edge_index[m],'edge_attr3':coupling_edge_attr[m],\
                       'edge_attr4':coupling_edge_dist[m]}
            tot_list.append(copy.deepcopy(dict_))
            test_id_list.append(coupling_id[m])
            
            temp = dict_['edge_attr3'].argmax(1)
            for i in np.nonzero(dict_['edge_attr3'].sum(0))[0]:
                dict_['type_attr'] = (temp==i).astype(np.uint8)
                type_list[i].append(copy.deepcopy(dict_))
                test_id_type_list[i].append(coupling_id[m][temp==i])
    
    if IsTrain:
        return tot_list,type_list 
    else:
        return tot_list,type_list,np.concatenate(test_id_list),[np.concatenate(type_i) for type_i in test_id_type_list]

In [23]:
tot_list_train,type_list_train = create_data(train_mol,True)
tot_list_val,type_list_val = create_data(val_mol,True)
tot_list_test,type_list_test,test_id,test_id_type = create_data(test_mol,False)

In [24]:
# convert numpy array to torch array
import torch
tot_list_train = [{k:torch.tensor(i[k]) for k in i.keys()} for i in tot_list_train]
tot_list_val = [{k:torch.tensor(i[k]) for k in i.keys()} for i in tot_list_val]
tot_list_test = [{k:torch.tensor(i[k]) for k in i.keys()} for i in tot_list_test]

In [25]:
def numpy2torch(type_list):
    out = []
    for type_ in type_list:
        out.append([{k:torch.tensor(i[k]) for k in i.keys()} for i in type_])
    return out

In [26]:
type_list_train = numpy2torch(type_list_train)
type_list_val = numpy2torch(type_list_val)
type_list_test = numpy2torch(type_list_test)

In [27]:
with open('../Data/train_data_ACSF_expand_PCA.pickle', 'wb') as handle:
    pickle.dump(tot_list_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('../Data/val_data_ACSF_expand_PCA.pickle', 'wb') as handle:
    pickle.dump(tot_list_val, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('../Data/test_data_ACSF_expand_PCA.pickle', 'wb') as handle:
    pickle.dump(tot_list_test, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [28]:
with open('../Data/test_data_ACSF_expand_id_PCA.pickle', 'wb') as handle:
    pickle.dump(test_id, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [29]:
def save_type(prefix,type_list):
    for i,type_ in enumerate(type_list):
        with open(prefix+'_type_'+str(i)+'.pickle', 'wb') as handle:
            pickle.dump(type_, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [30]:
save_type('../Data/train_data_ACSF_expand_PCA',type_list_train)
save_type('../Data/val_data_ACSF_expand_PCA',type_list_val)
save_type('../Data/test_data_ACSF_expand_PCA',type_list_test)

In [31]:
save_type('../Data/test_data_ACSF_expand_PCA_id',test_id_type)