In [142]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [143]:
N=100
# coords=torch.rand((N,3),dtype=torch.float64)
coords=torch.arange(N*3,dtype=torch.float64).reshape(N,3)
atom_types=torch.ones(N,dtype=torch.long)
# coords,atom_types

In [144]:
sel=50
atom_i_idxs=torch.tensor([[i]*sel for i in range(N)]).view(-1)
# iの周りにsel以上の原子があると仮定
atom_j_idxs=torch.randint(N,(N*sel,))
# atom_j_idxs.shape

In [145]:
atom_i_types=torch.tensor([[1]*sel for i in range(N)],dtype=torch.long).view(-1)
print(atom_i_types.shape)
atom_j_types=torch.ones(N*sel,dtype=torch.long)
atom_j_types.shape

torch.Size([5000])


torch.Size([5000])

In [146]:
shift=torch.randint(2,(N*sel,3)).float()*10.0

In [147]:
def convert_relative(
    coords,
    atom_i_idxs,
    atom_j_idxs,
    shift
):
    relative_coords=coords[atom_j_idxs]-coords[atom_i_idxs]+shift
    return relative_coords

In [148]:
relative_coords=convert_relative(coords,atom_i_idxs,atom_j_idxs,shift)
relative_coords.shape

torch.Size([5000, 3])

In [149]:
def switch_func(r_norm,r_cutoff=6.0):
    r_cutoff_smth=0.1
    s=torch.zeros_like(r_norm)
    x=torch.zeros_like(r_norm)
    mask=(r_norm<r_cutoff)
    x[mask]=(r_norm[mask]-r_cutoff_smth)/(r_cutoff-r_cutoff_smth)
    s[mask]=(x[mask]*x[mask]*x[mask]*(-6*x[mask]*x[mask]+15*x[mask]-10)+1)/r_norm[mask]
    return s

In [150]:
def convert_general(relative_coords):
    general_coords=torch.zeros((relative_coords.shape[0],4))
    relative_coords_norm=torch.linalg.norm(relative_coords,dim=1)
    s_vec=switch_func(relative_coords_norm)
    general_coords[:,0]=s_vec
    general_coords[:,1:]=relative_coords*s_vec.view(-1,1)/relative_coords_norm.view(-1,1)
    return general_coords

In [151]:
general_coords=convert_general(relative_coords)
general_coords.shape

torch.Size([5000, 4])

In [152]:
# x=torch.arange(6).reshape(2,3)
# print(x)
# x-=torch.concat([x.max(axis=1,keepdim=True).values]*x.size()[1],dim=1)
# print(x)
# x_exp=torch.exp(x)
# x_exp/torch.concat([x_exp.sum(dim=1,keepdim=True)]*x.size()[1],dim=1)

In [153]:
def relu(x):
    x=torch.where(x>0,x,torch.zeros_like(x))
    return x

def softmax(x):
    x-=torch.concat([x.max(axis=1,keepdim=True).values]*x.size()[1],dim=1)
    x_exp=torch.exp(x)
    return x_exp/torch.concat([x_exp.sum(dim=1,keepdim=True)]*x.size()[1],dim=1)

class ThreeLayerPerceptron(nn.Module):
    def __init__(self,in_dim,hid_dim,out_dim,function=relu):
        super(ThreeLayerPerceptron,self).__init__()
        self.linear1=nn.Linear(in_dim,hid_dim)
        self.linear2=nn.Linear(hid_dim,hid_dim)
        self.linear3=nn.Linear(hid_dim,out_dim)
        self.function=function
    def forward(self,x):
        x=self.linear1(x)
        x=self.function(x)
        x=self.linear2(x)
        x=self.function(x)
        x=self.linear3(x)
        # x=softmax(x)
        return x

In [154]:
# ### TEST ###
# x=torch.tensor([[0,0],[1,0],[0,1],[1,1]],dtype=torch.float)
# t=torch.tensor([0,1,1,0],dtype=torch.long)

# mlp=ThreeLayerPerceptron(2,3,2)

# optimizer=optim.SGD(mlp.parameters(),lr=0.1)

# mlp.train()

# for i in range(5000):
#     t_hot=torch.eye(2)[t]
#     y_pred=mlp(x)
#     loss=-(t_hot*torch.log(y_pred)).sum(axis=1).mean()

#     optimizer.zero_grad()
#     loss.backward()

#     optimizer.step()

#     if i%1000==0:
#         print(i,loss.item())

In [155]:
# mlp.eval()
# y=mlp.forward(x)
# y

In [156]:
max_atom_type=1
atom_type_embeded_nchanl=1
atom_type_one_hot=torch.eye(max_atom_type+1)
atom_type_net=ThreeLayerPerceptron(
    in_dim=max_atom_type+1,
    hid_dim=atom_type_embeded_nchanl,
    out_dim=atom_type_embeded_nchanl
)
atom_type_matrix=atom_type_net(atom_type_one_hot)

In [157]:
atom_i_matrix=atom_type_matrix[atom_i_types]
atom_j_matrix=atom_type_matrix[atom_j_types]

g_vec=torch.concat((
    general_coords[:,0].reshape(-1,1),
    atom_i_matrix,
    atom_j_matrix
),dim=1)

print(g_vec.shape)

torch.Size([5000, 3])


In [158]:
M1,M2=100,50
embeded_net=ThreeLayerPerceptron(in_dim=1+2*atom_type_embeded_nchanl,hid_dim=M1,out_dim=M1)
Gi1s=embeded_net(g_vec).reshape(coords.shape[0],sel,M1) # the element is (i,j,k)
Gi2s=Gi1s[:,:,:M2]

In [159]:
general_coords=general_coords.reshape(coords.shape[0],sel,4)

In [160]:
# general_coords.shape,Gi2s.shape

In [161]:
left=torch.bmm(
    torch.transpose(Gi1s,1,2),
    general_coords
)

right=torch.bmm(
    torch.transpose(general_coords,1,2),
    Gi2s
)

Dis=torch.bmm(left,right)
Dis_reshaped=Dis.reshape(N,M1*M2)

In [None]:
# atom_type_matrix[atom_types]

In [163]:
feature_vectors=torch.concat((
    Dis_reshaped,
    atom_type_matrix[atom_types]
),dim=1)

feature_vectors.shape

torch.Size([100, 5001])

In [165]:
fitting_net=ThreeLayerPerceptron(
    in_dim=M1*M2+atom_type_embeded_nchanl,
    hid_dim=100,
    out_dim=1
)

total_potential_energy=torch.sum(
    fitting_net(feature_vectors)
)

print(total_potential_energy)

tensor(-11.3160, grad_fn=<SumBackward0>)
