In [None]:
class GNN_multiHead_noEdge_Int(torch.nn.Module):
    # no edge update
    def __init__(self,reuse,block,head,head_mol,head_atom,head_edge,dim,layer1,layer2,factor,\
                 node_in,edge_in,edge_in4,edge_in3=8,mol_shape=4,atom_shape=10,edge_shape=4,aggr='mean',interleave=False):
        # block,head are nn.Module
        # node_in,edge_in are dim for bonding and edge_in4,edge_in3 for coupling
        super(GNN_multiHead_noEdge_Int, self).__init__()
        if interleave:
            assert layer1==layer2,'layer1 needs to be the same as layer2'
        self.interleave = interleave
        self.lin_node = Sequential(BatchNorm1d(node_in),Linear(node_in, dim*factor),LeakyReLU(), \
                                   BatchNorm1d(dim*factor),Linear(dim*factor, dim),LeakyReLU())

        self.edge1 = Sequential(BatchNorm1d(edge_in),Linear(edge_in, dim*factor),LeakyReLU(), \
                                   BatchNorm1d(dim*factor),Linear(dim*factor, dim),LeakyReLU())

        self.edge2 = Sequential(BatchNorm1d(edge_in4+edge_in3),Linear(edge_in4+edge_in3, dim*factor),LeakyReLU(), \
                                   BatchNorm1d(dim*factor),Linear(dim*factor, dim),LeakyReLU())        
        if reuse:
            self.conv1 = schnet_block(dim=dim,edge_dim=dim,aggr=aggr)
            self.conv2 = block(dim=dim,edge_dim=dim,aggr=aggr)
        else:
            self.conv1 = nn.ModuleList([schnet_block(dim=dim,edge_dim=dim,aggr=aggr) for _ in range(layer1)])
            self.conv2 = nn.ModuleList([block(dim=dim,edge_dim=dim,aggr=aggr) for _ in range(layer2)])            
        
        self.head = head(dim)
        self.head_mol = head_mol(dim,mol_shape)
        self.head_atom = head_atom(dim,atom_shape)
        self.head_edge = head_edge(dim,edge_shape)
        
    def forward(self, data,IsTrain=False,typeTrain=False,logLoss=True,weight=None):
        out = self.lin_node(data.x)
        # edge_*3 only does not repeat for undirected graph. Hence need to add (j,i) to (i,j) in edges
        edge_index3 = torch.cat([data.edge_index3,data.edge_index3[[1,0]]],1)
        n = data.edge_attr3.shape[0]
        temp_ = self.edge2(torch.cat([data.edge_attr3,data.edge_attr4],1))
        edge_attr3 = torch.cat([temp_,temp_],0)
        int_types = torch.cat([data.edge_attr3,data.edge_attr3],0)
        edge_attr = self.edge1(data.edge_attr)
        
        if self.interleave:
            for conv1,conv2 in zip(self.conv1,self.conv2):
                out = conv1(out,data.edge_index,edge_attr)
                out = conv2(out,edge_index3,edge_attr3,int_types)
        else:
            for conv in self.conv1:
                out = conv(out,data.edge_index,edge_attr)
            for conv in self.conv2:
                out = conv(out,edge_index3,edge_attr3,int_types)    
        
        edge_attr3 = edge_attr3[:n]
        if typeTrain:
            if IsTrain:
                y = data.y[data.type_attr]
            edge_attr3 = edge_attr3[data.type_attr]
            edge_index3 = data.edge_index3[:,data.type_attr]
            edge_attr3_old = data.edge_attr3[data.type_attr]
        else:
            if IsTrain:
                y = data.y
            edge_index3 = data.edge_index3
            edge_attr3_old = data.edge_attr3
            
        yhat = self.head(out,edge_index3,edge_attr3,edge_attr3_old)
        
        if IsTrain:
            if weight is None:
                loss_other = 0
            else:
                y_mol = self.head_mol(out,data.batch)
                y_atom = self.head_atom(out)
                y_edge = self.head_edge(out,edge_index3)
                loss_other = weight * (torch.mean(torch.abs(data.y_mol - y_mol)) + \
                                       torch.mean(torch.abs(data.y_atom - y_atom)) + \
                                       torch.mean(torch.abs(data.y_coupling - y_edge)))

            k = torch.sum(edge_attr3_old,0)
            nonzeroIndex = torch.nonzero(k).squeeze(1)
            abs_ = torch.abs(y-yhat).unsqueeze(1)
            loss_perType = torch.zeros(8,device='cuda:0')
            if logLoss:
                loss_perType[nonzeroIndex] = torch.log(torch.sum(abs_ * edge_attr3_old[:,nonzeroIndex],0)/k[nonzeroIndex])
                loss = torch.sum(loss_perType)/nonzeroIndex.shape[0]
                return loss+loss_other,loss_perType         
            else:
                loss_perType[nonzeroIndex] = torch.sum(abs_ * edge_attr3_old[:,nonzeroIndex],0)/k[nonzeroIndex]
                loss = torch.sum(loss_perType)/nonzeroIndex.shape[0]
                loss_perType[nonzeroIndex] = torch.log(loss_perType[nonzeroIndex])
                return loss+loss_other,loss_perType
        else:
            return yhat
        

class NNConv2_int(MessagePassing):
    r""" use element-wise multiplication as in schnet instead of matrix multiplication
    """
    def __init__(self,
                 dim,
                 nn,
                 aggr='mean'):
        super(NNConv2_int, self).__init__(aggr=aggr)
        cat_factor = 2
        multiple_factor = 3
        self.dim = dim
        self.type_factor = 8
        self.nn = nn
        self.aggr = aggr
        self.v_update = Sequential(BatchNorm1d(dim*cat_factor),
                                    Linear(dim*cat_factor,dim*cat_factor*multiple_factor),
                                    LeakyReLU(inplace=True),
                                    BatchNorm1d(dim*cat_factor*multiple_factor),
                                    Linear(dim*cat_factor*multiple_factor,dim),
                                    LeakyReLU(inplace=True))
        
    def forward(self, x, edge_index, edge_attr,int_types):
        x = x.unsqueeze(-1) if x.dim() == 1 else x
        pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr
        return self.propagate(edge_index, x=x, pseudo=pseudo,int_types=int_types)

    def message(self, x_j, pseudo,int_types):
        weight = self.nn(pseudo).reshape(-1,self.dim,self.type_factor) # (n,d,k)
        int_types=int_types.unsqueeze(1).to(torch.bool)
        _,int_types = torch.broadcast_tensors(weight,int_types)
        weight = weight[int_types].reshape(-1,self.dim)
        return x_j * weight

    def update(self, aggr_out, x):
        return self.v_update(torch.cat([aggr_out,x],1))

    def __repr__(self):
        return 'NNConv2_int'   

class schnet_block_int(torch.nn.Module):
    # use both types of edges
    def __init__(self,dim=64,edge_dim=12,aggr='mean'):
        super(schnet_block_int, self).__init__()
        multiple_factor = 3
        type_factor = 8
        nn = Sequential(BatchNorm1d(edge_dim),Linear(edge_dim, dim*multiple_factor),LeakyReLU(inplace=True), \
                        BatchNorm1d(dim*multiple_factor),Linear(dim*multiple_factor, dim*type_factor))
        self.conv = NNConv2_int(dim, nn, aggr=aggr)
        self.lin_covert = Sequential(BatchNorm1d(dim),Linear(dim, dim*multiple_factor),LeakyReLU(inplace=True), \
                                     BatchNorm1d(dim*multiple_factor),Linear(dim*multiple_factor, dim))
        
    def forward(self, x, edge_index, edge_attr,int_types):
        m = self.conv(x, edge_index, edge_attr,int_types)
        m = self.lin_covert(m)
        return x + m

In [1]:
import pickle
import torch
from torch_geometric.data import Data,DataLoader
from functions_refactor import *
from pytorch_util import *
#from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [2]:
# fixed parameters
head_mol,head_atom,head_edge = head_mol,head_atom,head_edge2
clip = 2
batch_size = 32
threshold = -1.3
reuse = False
lr = 1e-4

In [3]:
# changing parameters
block = schnet_block_int
head = cat3HeadInteraction_noEdge
data = '../Data/{}_data_ACSF_SOAP_atomInfo_otherInfo.pickle'
dim = 512
logLoss = True
weight = None
layer1 = 3
layer2 = 4
factor = 2
epochs = 150
aggr = 'mean'
interleave = False

In [4]:
prefix = '_'.join([str(i).split('}')[1] if '}' in str(i) else str(i) \
                                        for i in [block,head,data,dim,logLoss,weight,layer1,layer2,factor,epochs,aggr,interleave]])

In [None]:
train_df = pd.read_csv('../Data/train.csv')
test_df = pd.read_csv('../Data/test.csv')

In [None]:
folds = []
for f in range(5):
    with open(data.format('train').split('pickle')[0][:-1]+'_f'+str(f)+'.pickle', 'rb') as handle:
        folds.append(pickle.load(handle))
folds = [[Data(**d) for d in fold] for fold in folds]

In [None]:
for i in range(5):
    print('\nstart fold '+str(i))
    # parpare data
    train_list = []
    val_list = []
    for j in range(5):
        if i == j:
            val_list.extend(folds[j])
        else:
            train_list.extend(folds[j])
    
    train_dl = DataLoader(train_list,batch_size,shuffle=True)
    val_dl = DataLoader(val_list,batch_size,shuffle=False)
    
    # train model
    model = GNN_multiHead_noEdge_Int(reuse,block,head,head_mol,head_atom,head_edge,\
                          dim,layer1,layer2,factor,**data_dict[data],aggr=aggr,interleave=interleave).to('cuda:0')    
    paras = trainable_parameter(model)
    opt = RAdam(paras,lr=lr,weight_decay=1e-2)
    scheduler = ReduceLROnPlateau(opt, 'min',factor=0.5,patience=5,min_lr=1e-05)
    model,train_loss_perType,val_loss_perType,bestWeight = train_type_earlyStop(opt,model,epochs,train_dl,val_dl,paras,clip,\
                                                                    scheduler=scheduler,logLoss=logLoss,weight=weight,threshold=threshold)
    torch.save({'model_state_dict_type_'+str(j_):w for j_,w in enumerate(bestWeight)},\
                '../Model/'+prefix+'_fold'+str(i)+'.tar')
    # predict oof for each type
    for type_i in range(8):
        # load val data and type_id
        with open(data.format('train').split('pickle')[0][:-1]+'_f'+str(i)+'_type_'+str(type_i)+'.pickle', 'rb') as handle:
            test_data = pickle.load(handle)
        test_list = [Data(**d) for d in test_data]
        test_dl = DataLoader(test_list,batch_size,shuffle=False)
        
        with open(data.format('train').split('pickle')[0][:-1]+'_f'+str(i)+'_type_'+str(type_i)+'_id.pickle', 'rb') as handle:
            test_id = pickle.load(handle)
    
        # load model
        model.load_state_dict(bestWeight[type_i])
    
        # predict
        model.eval()
        yhat_list = []
        with torch.no_grad():
            for data_torch in test_dl:
                data_torch = data_torch.to('cuda:0')
                yhat_list.append(model(data_torch,False,True))
        yhat = torch.cat(yhat_list).cpu().detach().numpy()        
    
        # join
        assert yhat.shape[0]==test_id.shape[0],'yhat and test_id should have same shape'
        submit_ = dict(zip(test_id,yhat))
        train_df['fold'+str(i)+'_type'+str(type_i)] = train_df.id.map(submit_)
    
    # predict test
    for type_i in range(8):
        # load val data and type_id
        with open(data.format('test').split('pickle')[0][:-1]+'_type_'+str(type_i)+'.pickle', 'rb') as handle:
            test_data = pickle.load(handle)
        test_list = [Data(**d) for d in test_data]
        test_dl = DataLoader(test_list,batch_size,shuffle=False)
        
        with open(data.format('test').split('pickle')[0][:-1]+'_id_type_'+str(type_i)+'.pickle', 'rb') as handle:
            test_id = pickle.load(handle)
    
        # load model
        model.load_state_dict(bestWeight[type_i])
    
        # predict
        model.eval()
        yhat_list = []
        with torch.no_grad():
            for data_torch in test_dl:
                data_torch = data_torch.to('cuda:0')
                yhat_list.append(model(data_torch,False,True))
        yhat = torch.cat(yhat_list).cpu().detach().numpy()        
    
        # join
        assert yhat.shape[0]==test_id.shape[0],'yhat and test_id should have same shape'
        submit_ = dict(zip(test_id,yhat))
        test_df['fold'+str(i)+'_type'+str(type_i)] = test_df.id.map(submit_)

In [None]:
#assert set(test.iloc[:,5:].isnull().sum(1)) == set([7*5])
test_df['yhat'] = np.nanmean(test_df.iloc[:,5:],1)
#test = test[['id','yhat']]
test_df.to_csv('../Data/test_oof_0820_'+prefix,index=False)

#assert set(train.iloc[:,6:].isnull().sum(1)) == set([train.iloc[:,6:].shape[1]-1])
train_df['yhat'] = np.nanmean(train_df.iloc[:,6:],1)
#train = train[['id','yhat']]
train_df.to_csv('../Data/train_oof_0820_'+prefix,index=False)