In [1]:
import torch, torch.nn as nn
import numpy as np, matplotlib.pyplot as plt
import scipy.sparse as sparse
from tqdm.auto import tqdm

def seq_mlp(init, mlp, fin, act):
    modules = [nn.Linear(init, mlp[0]), act]
    for i in range(len(mlp) - 1):
        modules.append(nn.Linear(mlp[i], mlp[i+1]))
        modules.append(act)

    modules.append(nn.Linear(mlp[-1], fin)) #self.spl for spline

    return modules

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class EvalEig(nn.Module):
    def __init__(self, eval_para):
        super().__init__()
        self.bd = eval_para['batch_dim']

    def set_rdsc(self, xm, xn):
        self.xn = xn
        self.xm = xm
    
    def mesh_ptl(self, posx, posy): # input shape (bd, p_num)
        X, Y = np.meshgrid(np.linspace(-self.xm, self.xm, self.xn),
                            np.linspace(-self.xm, self.xm, self.xn), indexing = 'ij')
        X_broad, Y_broad = X[np.newaxis,np.newaxis,:,:], Y[np.newaxis,np.newaxis,:,:]
        posx_broad, posy_broad = posx[:,:,np.newaxis,np.newaxis], posy[:,:,np.newaxis,np.newaxis]

        dist = np.sqrt((X_broad-posx_broad)**2+(Y_broad-posy_broad)**2)
        dist[dist==0] = np.finfo(float).eps

        ptl = np.sum(-1/dist,axis=1)
        return ptl # shape (bd, xn, xn)
    
    def mesh_hml(self, term_ptl):
        dx = 2*self.xm/(self.xn-1)
        diag = [np.full(self.xn, -2/dx**2), np.full(self.xn-1, 1/dx**2), np.full(self.xn-1, 1/dx**2)]

        term_kin_partial = sparse.diags(diag, [0,-1,1], shape=(self.xn,self.xn))
        term_kin = sparse.kron(sparse.identity(self.xn), term_kin_partial) + \
            sparse.kron(term_kin_partial, sparse.identity(self.xn))
        term_hml = term_kin + sparse.diags(term_ptl.ravel(), 0)
        
        return term_hml
    
    def init_evl(self, p_num):
        posx = np.random.uniform(-self.xm, self.xm, size=(self.bd, p_num))/10
        posy = np.random.uniform(-self.xm, self.xm, size=(self.bd, p_num))/10

        evl = np.zeros((self.bd, 6)) # p_num as cutoff for number of smallest evls obtained? fix as 6?
        mesh_ptl = self.mesh_ptl(posx, posy)

        pbar = tqdm(range(self.bd), desc='Progress', total=self.bd, leave = True, position=0, colour='blue')

        for i in range(self.bd):
            mesh_hml = self.mesh_hml(mesh_ptl[i])
            evl_i, _ = sparse.linalg.eigsh(mesh_hml, which = 'SM')
            evl[i] = evl_i

            pbar.update()
        
        return posx, posy, evl

    def forward(self, rm, rn, p_num):
        self.set_rdsc(rm, rn)
        posx_tr, posy_tr, evl_tr = self.init_evl(p_num)

        return posx_tr, posy_tr, evl_tr

class InvEig(EvalEig):
    def __init__(self, eval_para, model_para):
        super().__init__(eval_para)
        self.mlp_shape = model_para['mlp']
    
    def set_rdsc(self, rm, rn, p_num):
        self.rn = rn
        self.rm = rm
        self.pn = p_num

        # initialise model
        #self.ptl = nn.Parameter(torch.rand(self.batch_dim, self.rn-1)) # random parameters
        modules = seq_mlp(init = 6, mlp = self.mlp_shape, fin = p_num*2, act = nn.ReLU())
        self.mlp = nn.Sequential(*modules)
    
    def dist_list(self, posx, posy):
        

    def forward(self, evl):
        pos = self.mlp(evl)
        pos_x_md, pos_y_md = pos[:,:self.pn], pos[:,self.pn:]
        print(pos_x_md.shape, pos_y_md.shape)



In [3]:
eval_para = {
        # model specifics
        'precision' : 64, # 32 or 64 bit
        'batch_dim' : 1000
        }

model_para = {
        # model
        'mlp' : [100, 100],

        # training
        'epoch' : 5000,
        'lr' : 1e-2,

        # loss regularisation
        'reg1' : 1e-1, # V(0) sign
        'reg2' : 1, # V -> 0 as r -> infty
        
        }

eval = EvalEig(eval_para)

In [4]:
#eval_grid = [[800], \
#    [10000], \
#        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] # rm, rn, para_1
#for midx in itertools.product(*eval_grid):
#for midx in zip(*eval_grid):
    #eval.set_rdsc(midx[0], midx[1])
    #ptl_tr = eval.fixed_tr(midx[2], "coulomb")
    #evl_scl_tr = eval.dsc_eigs(ptl_tr)
    #evl_tr = evl_scl_tr[:,:,:eval.evl_cutoff(evl_scl_tr)]
#    ptl_tr, evl_tr = eval(midx[0], midx[1], midx[2], "coulomb")
#    factor = torch.mean(1/evl_tr, dim = 0)
#    print(factor[0])
#    print(midx, nn.L1Loss()(factor[0],torch.arange(1,factor.shape[1]+1)**2), evl_tr[0,0,0])

In [10]:
posx_tr, posy_tr, evl_tr = eval(1e4, 100, 10)

print(evl_tr)
print(posx_tr, posy_tr)

Progress:  34%|[34m███▍      [0m| 343/1000 [03:19<05:32,  1.98it/s]

In [20]:
import pickle
with open("posx_tr.data", "wb") as fw:
    pickle.dump(posx_tr, fw)
with open("posy_tr.data", "wb") as fw:
    pickle.dump(posy_tr, fw)
with open("evl_tr.data", "wb") as fw:
    pickle.dump(evl_tr, fw)

  model.load_state_dict(torch.load('1.pth'))
Progress: 100%|[34m██████████[0m| 5000/5000 [21:05<00:00,  3.95it/s]
Progress: 100%|[34m██████████[0m| 5000/5000 [15:28<00:00,  5.53it/s]

In [5]:
import pickle
with open("posx_tr.data", "rb") as fr:
    posx_tr = pickle.load(fr)
with open("posy_tr.data", "rb") as fr:
    posy_tr = pickle.load(fr)
with open("evl_tr.data", "rb") as fr:
    evl_tr = pickle.load(fr)