Skip to content

Commit

Permalink
IO performance of pbc.GDF initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Jul 3, 2018
1 parent ee90e2a commit bfca90c
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 43 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG
Expand Up @@ -11,6 +11,8 @@ PySCF 1.6 alpha (2018-09-??)

PySCF 1.5.2
-----------
* Improved
- IO performance of pbc.GDF initialization
* Bugfix
- selected_ci 2-particle density matrices for two electron systems

Expand Down
15 changes: 10 additions & 5 deletions pyscf/pbc/df/df.py
Expand Up @@ -144,8 +144,8 @@ def _make_j3c(mydf, cell, auxcell, kptij_lst, cderi_file):
log = logger.Logger(mydf.stdout, mydf.verbose)
max_memory = max(2000, mydf.max_memory-lib.current_memory()[0])
fused_cell, fuse = fuse_auxcell(mydf, auxcell)
outcore.aux_e2(cell, fused_cell, cderi_file, 'int3c2e', aosym='s2',
kptij_lst=kptij_lst, dataname='j3c', max_memory=max_memory)
outcore._aux_e2(cell, fused_cell, cderi_file, 'int3c2e', aosym='s2',
kptij_lst=kptij_lst, dataname='j3c', max_memory=max_memory)
t1 = log.timer_debug1('3c2e', *t1)

nao = cell.nao_nr()
Expand Down Expand Up @@ -221,6 +221,7 @@ def _make_j3c(mydf, cell, auxcell, kptij_lst, cderi_file):
feri['j2c/%d'%k] = fuse(fuse(j2c[k]).T).T
j2c = coulG = None

nsegs = len(feri['j3c/0'])
def make_kpt(uniq_kptji_id): # kpt = kptj - kpti
kpt = uniq_kpts[uniq_kptji_id]
log.debug1('kpt = %s', kpt)
Expand Down Expand Up @@ -273,6 +274,8 @@ def make_kpt(uniq_kptji_id): # kpt = kptj - kpti
aosym = 's1'
nao_pair = nao**2

fswap = lib.H5TmpFile()

mem_now = lib.current_memory()[0]
log.debug2('memory = %s', mem_now)
max_memory = max(2000, mydf.max_memory-mem_now)
Expand Down Expand Up @@ -300,7 +303,8 @@ def make_kpt(uniq_kptji_id): # kpt = kptj - kpti
j3cR = []
j3cI = []
for k, idx in enumerate(adapted_ji_idx):
v = numpy.asarray(feri['j3c/%d'%idx][:,col0:col1])
v = numpy.vstack([feri['j3c/%d/%d'%(idx,i)][0,col0:col1].T
for i in range(nsegs)])
if is_zero(kpt) and cell.dimension == 3:
for i in numpy.where(vbar != 0)[0]:
v[i] -= vbar[i] * ovlp[k][col0:col1]
Expand Down Expand Up @@ -359,11 +363,12 @@ def make_kpt(uniq_kptji_id): # kpt = kptj - kpti
v = scipy.linalg.solve_triangular(j2c, v, lower=True, overwrite_b=True)
else:
v = lib.dot(j2c, v)
feri['j3c/%d'%ji][:naux0,col0:col1] = v
fswap['%d/%d'%(k,istep)] = v

del(feri['j2c/%d'%uniq_kptji_id])
nsteps = len(shranges)
for k, ji in enumerate(adapted_ji_idx):
v = feri['j3c/%d'%ji][:naux0]
v = numpy.hstack([fswap['%d/%d'%(k,i)] for i in range(nsteps)])
del(feri['j3c/%d'%ji])
feri['j3c/%d'%ji] = v

Expand Down
3 changes: 2 additions & 1 deletion pyscf/pbc/df/fft_ao2mo.py
Expand Up @@ -166,7 +166,8 @@ def fill_orbital_pair(moT, i0, i1, buf):
return out

eri = numpy.empty((nmoi*(nmoi+1)//2,nmok*(nmok+1)//2))
blksize = int(min(max(nmoi,nmok), (max_memory*1e6/8 - eri.size)/2/ngrids+1))
blksize = int(min(max(nmoi*(nmoi+1)//2, nmok*(nmok+1)//2),
(max_memory*1e6/8 - eri.size)/2/ngrids+1))
buf = numpy.empty((blksize,ngrids))
for p0, p1 in lib.prange_tril(0, nmoi, blksize):
mo_pairs_G = tools.fft(fill_orbital_pair(moiT, p0, p1, buf), mydf.mesh)
Expand Down
15 changes: 10 additions & 5 deletions pyscf/pbc/df/mdf.py
Expand Up @@ -51,8 +51,8 @@ def _make_j3c(mydf, cell, auxcell, kptij_lst, cderi_file):
log = logger.Logger(mydf.stdout, mydf.verbose)
max_memory = max(2000, mydf.max_memory-lib.current_memory()[0])
fused_cell, fuse = fuse_auxcell(mydf, auxcell)
outcore.aux_e2(cell, fused_cell, cderi_file, 'int3c2e', aosym='s2',
kptij_lst=kptij_lst, dataname='j3c', max_memory=max_memory)
outcore._aux_e2(cell, fused_cell, cderi_file, 'int3c2e', aosym='s2',
kptij_lst=kptij_lst, dataname='j3c', max_memory=max_memory)
t1 = log.timer_debug1('3c2e', *t1)

nao = cell.nao_nr()
Expand All @@ -72,6 +72,7 @@ def _make_j3c(mydf, cell, auxcell, kptij_lst, cderi_file):
# j2c ~ (-kpt_ji | kpt_ji)
j2c = fused_cell.pbc_intor('int2c2e', hermi=1, kpts=uniq_kpts)
feri = h5py.File(cderi_file)
nsegs = len(feri['j3c/0'])

for k, kpt in enumerate(uniq_kpts):
aoaux = ft_ao.ft_ao(fused_cell, Gv, None, b, gxyz, Gvbase, kpt).T
Expand Down Expand Up @@ -139,6 +140,8 @@ def make_kpt(uniq_kptji_id): # kpt = kptj - kpti
aosym = 's1'
nao_pair = nao**2

fswap = lib.H5TmpFile()

mem_now = lib.current_memory()[0]
log.debug2('memory = %s', mem_now)
max_memory = max(2000, mydf.max_memory-mem_now)
Expand Down Expand Up @@ -166,7 +169,8 @@ def make_kpt(uniq_kptji_id): # kpt = kptj - kpti
j3cR = []
j3cI = []
for k, idx in enumerate(adapted_ji_idx):
v = fuse(numpy.asarray(feri['j3c/%d'%idx][:,col0:col1]))
v = [feri['j3c/%d/%d'%(idx,i)][0,col0:col1].T for i in range(nsegs)]
v = fuse(numpy.vstack(v))
if is_zero(kpt) and cell.dimension == 3:
for i, c in enumerate(vbar):
if c != 0:
Expand Down Expand Up @@ -209,11 +213,12 @@ def make_kpt(uniq_kptji_id): # kpt = kptj - kpti
v = scipy.linalg.solve_triangular(j2c, v, lower=True, overwrite_b=True)
else:
v = lib.dot(j2c, v)
feri['j3c/%d'%ji][:naux0,col0:col1] = v
fswap['%d/%d'%(k,istep)] = v

del(feri['j2c/%d'%uniq_kptji_id])
nsteps = len(shranges)
for k, ji in enumerate(adapted_ji_idx):
v = feri['j3c/%d'%ji][:naux0]
v = numpy.hstack([fswap['%d/%d'%(k,i)] for i in range(nsteps)])
del(feri['j3c/%d'%ji])
feri['j3c/%d'%ji] = v

Expand Down
107 changes: 101 additions & 6 deletions pyscf/pbc/df/outcore.py
Expand Up @@ -25,17 +25,16 @@
from pyscf.pbc.df.incore import wrap_int3c
from pyscf import __config__

CHUNK_SIZE = getattr(__config__, 'pbc_df_outcore_chunk_size', 256)

libpbc = lib.load_library('libpbc')


def aux_e2(cell, auxcell, erifile, intor='int3c2e', aosym='s2ij', comp=None,
def aux_e1(cell, auxcell, erifile, intor='int3c2e', aosym='s2ij', comp=None,
kptij_lst=None, dataname='eri_mo', shls_slice=None, max_memory=2000,
verbose=0):
r'''3-center AO integrals (ij|L) with double lattice sum:
\sum_{lm} (i[l]j[m]|L[0]), where L is the auxiliary basis.
On diks, the integrals are stored as (kptij_idx, naux, nao_pair)
r'''3-center AO integrals (L|ij) with double lattice sum:
\sum_{lm} (L[0]|i[l]j[m]), where L is the auxiliary basis.
Three-index integral tensor (kptij_idx, naux, nao_pair) or four-index
integral tensor (kptij_idx, comp, naux, nao_pair) are stored on disk.
Args:
kptij_lst : (*,2,3) array
Expand Down Expand Up @@ -142,3 +141,99 @@ def aux_e2(cell, auxcell, erifile, intor='int3c2e', aosym='s2ij', comp=None,
return erifile


def _aux_e2(cell, auxcell, erifile, intor='int3c2e', aosym='s2ij', comp=None,
kptij_lst=None, dataname='eri_mo', shls_slice=None, max_memory=2000,
verbose=0):
r'''3-center AO integrals (ij|L) with double lattice sum:
\sum_{lm} (i[l]j[m]|L[0]), where L is the auxiliary basis.
Three-index integral tensor (kptij_idx, nao_pair, naux) or four-index
integral tensor (kptij_idx, comp, nao_pair, naux) are stored on disk.
**This function should be used by df and mdf initialization function
_make_j3c only**
Args:
kptij_lst : (*,2,3) array
A list of (kpti, kptj)
'''
intor, comp = gto.moleintor._get_intor_and_comp(cell._add_suffix(intor), comp)

if h5py.is_hdf5(erifile):
feri = h5py.File(erifile)
if dataname in feri:
del(feri[dataname])
if dataname+'-kptij' in feri:
del(feri[dataname+'-kptij'])
else:
feri = h5py.File(erifile, 'w')

if kptij_lst is None:
kptij_lst = numpy.zeros((1,2,3))
feri[dataname+'-kptij'] = kptij_lst

if shls_slice is None:
shls_slice = (0, cell.nbas, 0, cell.nbas, 0, auxcell.nbas)

ao_loc = cell.ao_loc_nr()
aux_loc = auxcell.ao_loc_nr(auxcell.cart or 'ssc' in intor)[:shls_slice[5]+1]
ni = ao_loc[shls_slice[1]] - ao_loc[shls_slice[0]]
nj = ao_loc[shls_slice[3]] - ao_loc[shls_slice[2]]
naux = aux_loc[shls_slice[5]] - aux_loc[shls_slice[4]]
nkptij = len(kptij_lst)

nii = (ao_loc[shls_slice[1]]*(ao_loc[shls_slice[1]]+1)//2 -
ao_loc[shls_slice[0]]*(ao_loc[shls_slice[0]]+1)//2)
nij = ni * nj

kpti = kptij_lst[:,0]
kptj = kptij_lst[:,1]
aosym_ks2 = abs(kpti-kptj).sum(axis=1) < KPT_DIFF_TOL
j_only = numpy.all(aosym_ks2)
#aosym_ks2 &= (aosym[:2] == 's2' and shls_slice[:2] == shls_slice[2:4])
aosym_ks2 &= aosym[:2] == 's2'

if j_only and aosym[:2] == 's2':
assert(shls_slice[2] == 0)
nao_pair = nii
else:
nao_pair = nij

if gamma_point(kptij_lst):
dtype = numpy.double
else:
dtype = numpy.complex128

buflen = max(8, int(max_memory*.47e6/16/(nkptij*ni*nj*comp)))
auxdims = aux_loc[shls_slice[4]+1:shls_slice[5]+1] - aux_loc[shls_slice[4]:shls_slice[5]]
auxranges = balance_segs(auxdims, buflen)
buflen = max([x[2] for x in auxranges])
buf = numpy.empty(nkptij*comp*ni*nj*buflen, dtype=dtype)
buf1 = numpy.empty_like(buf)

int3c = wrap_int3c(cell, auxcell, intor, aosym, comp, kptij_lst)

tril_idx = numpy.tril_indices(ni)
tril_idx = tril_idx[0] * ni + tril_idx[1]
def save(istep, mat):
for k, kptij in enumerate(kptij_lst):
v = mat[k]
if gamma_point(kptij):
v = v.real
if aosym_ks2[k] and nao_pair == ni**2:
v = v[:,tril_idx]
feri['%s/%d/%d' % (dataname,k,istep)] = v

with lib.call_in_background(save) as bsave:
for istep, auxrange in enumerate(auxranges):
sh0, sh1, nrow = auxrange
sub_slice = (shls_slice[0], shls_slice[1],
shls_slice[2], shls_slice[3],
shls_slice[4]+sh0, shls_slice[4]+sh1)
mat = numpy.ndarray((nkptij,comp,nao_pair,nrow), dtype=dtype, buffer=buf)
bsave(istep, int3c(sub_slice, mat))
buf, buf1 = buf1, buf

feri.close()
return erifile


48 changes: 24 additions & 24 deletions pyscf/pbc/df/test/test_df.py
Expand Up @@ -50,44 +50,44 @@
kmdf = df.DF(cell)
kmdf.auxbasis = 'weigend'
kmdf.kpts = kpts
kmdf.mesh = (21,)*3
kmdf.mesh = (6,)*3


def finger(a):
w = numpy.cos(numpy.arange(a.size))
return numpy.dot(a.ravel(), w)

class KnowValues(unittest.TestCase):
class KnownValues(unittest.TestCase):
def test_get_eri_gamma(self):
odf = df.DF(cell)
odf.auxbasis = 'weigend'
odf.mesh = (21,)*3
odf.mesh = (6,)*3
eri0000 = odf.get_eri()
self.assertTrue(eri0000.dtype == numpy.double)
self.assertAlmostEqual(eri0000.real.sum(), 41.612815388046052, 9)
self.assertAlmostEqual(finger(eri0000), 1.9981475967566333, 9)
self.assertAlmostEqual(eri0000.real.sum(), 41.612793785221186, 9)
self.assertAlmostEqual(finger(eri0000), 1.9981473214755234, 9)

eri1111 = kmdf.get_eri((kpts[0],kpts[0],kpts[0],kpts[0]))
self.assertTrue(eri1111.dtype == numpy.double)
self.assertAlmostEqual(eri1111.real.sum(), 41.612815388046101, 9)
self.assertAlmostEqual(eri1111.real.sum(), 41.612793785221186, 9)
self.assertAlmostEqual(eri1111.imag.sum(), 0, 9)
self.assertAlmostEqual(finger(eri1111), 1.9981475967566393, 9)
self.assertTrue(numpy.allclose(eri1111, eri0000))
self.assertAlmostEqual(finger(eri1111), 1.9981473214755234, 9)
self.assertAlmostEqual(abs(eri1111-eri0000).max(), 0, 9)

eri4444 = kmdf.get_eri((kpts[4],kpts[4],kpts[4],kpts[4]))
self.assertTrue(eri4444.dtype == numpy.complex128)
self.assertAlmostEqual(eri4444.real.sum(), 62.551238630045489, 9)
self.assertAlmostEqual(abs(eri4444.imag).sum(), 0, 7)
self.assertAlmostEqual(finger(eri4444), 0.62059866259713981-2.3427034826510518e-09j, 8)
self.assertAlmostEqual(eri4444.real.sum(), 62.551164364834513, 9)
self.assertAlmostEqual(abs(eri4444.imag).sum(), 0.0033227776224408773, 7)
self.assertAlmostEqual(finger(eri4444), 0.62060378746287159-0.00015488998139907677j, 8)
eri0000 = ao2mo.restore(1, eri0000, cell.nao_nr()).reshape(eri4444.shape)
self.assertTrue(numpy.allclose(eri0000, eri4444, atol=1e-7))
self.assertAlmostEqual(abs(eri0000-eri4444).max(), 0, 4)

def test_get_eri_1111(self):
eri1111 = kmdf.get_eri((kpts[1],kpts[1],kpts[1],kpts[1]))
self.assertTrue(eri1111.dtype == numpy.complex128)
self.assertAlmostEqual(eri1111.real.sum(), 62.549765060422182, 9)
self.assertAlmostEqual(abs(eri1111.imag).sum(), 0.0018154474705716237, 9)
self.assertAlmostEqual(finger(eri1111), 0.62039123349057901+8.7906060180183165e-05j, 9)
self.assertAlmostEqual(eri1111.real.sum(), 62.549690947485999, 9)
self.assertAlmostEqual(abs(eri1111.imag).sum(), 0.005129309666524944, 9)
self.assertAlmostEqual(finger(eri1111), 0.62039636506407403-6.6969834476153422e-05j, 9)
check2 = kmdf.get_eri((kpts[1]+5e-8,kpts[1]+5e-8,kpts[1],kpts[1]))
self.assertTrue(numpy.allclose(eri1111, check2, atol=1e-7))

Expand All @@ -99,25 +99,25 @@ def test_get_eri_1111(self):
def test_get_eri_0011(self):
eri0011 = kmdf.get_eri((kpts[0],kpts[0],kpts[1],kpts[1]))
self.assertTrue(eri0011.dtype == numpy.complex128)
self.assertAlmostEqual(eri0011.real.sum(), 62.550501761172804, 9)
self.assertAlmostEqual(abs(eri0011.imag).sum(), 0.00090808308701441343, 9)
self.assertAlmostEqual(finger(eri0011), 0.62054704928684989+7.5478295905019859e-05j, 9)
self.assertAlmostEqual(eri0011.real.sum(), 62.550427638608646, 9)
self.assertAlmostEqual(abs(eri0011.imag).sum(), 0.0036623806544977093, 9)
self.assertAlmostEqual(finger(eri0011), 0.62055221399933336+0.00017042195037428586j, 9)

def test_get_eri_0110(self):
eri0110 = kmdf.get_eri((kpts[0],kpts[1],kpts[1],kpts[0]))
self.assertTrue(eri0110.dtype == numpy.complex128)
self.assertAlmostEqual(eri0110.real.sum(), 83.113609623488301, 9)
self.assertAlmostEqual(abs(eri0110.imag).sum(), 5.0835167272062405, 9)
self.assertAlmostEqual(finger(eri0110), 0.97004623074621432-0.33188261713186479j, 9)
self.assertAlmostEqual(eri0110.real.sum(), 83.113865247801101, 9)
self.assertAlmostEqual(abs(eri0110.imag).sum(), 5.0834141934485935, 9)
self.assertAlmostEqual(finger(eri0110), 0.96998616670534554-0.33186033454783898j, 9)
check2 = kmdf.get_eri((kpts[0]+5e-8,kpts[1]+5e-8,kpts[1],kpts[0]))
self.assertTrue(numpy.allclose(eri0110, check2, atol=1e-7))

def test_get_eri_0123(self):
eri0123 = kmdf.get_eri(kpts[:4])
self.assertTrue(eri0123.dtype == numpy.complex128)
self.assertAlmostEqual(eri0123.real.sum(), 83.1094028635287, 9)
self.assertAlmostEqual(abs(eri0123.imag.sum()), 4.9901406037999863e-05, 9)
self.assertAlmostEqual(finger(eri0123), 0.96952612970275598-0.33222740866776712j, 9)
self.assertAlmostEqual(eri0123.real.sum(), 83.109500652810326, 9)
self.assertAlmostEqual(abs(eri0123.imag.sum()), 0.011093660109025016, 9)
self.assertAlmostEqual(finger(eri0123), 0.96956515144508404-0.33108517079284416j, 9)



Expand Down
4 changes: 2 additions & 2 deletions pyscf/pbc/df/test/test_outcore.py
Expand Up @@ -39,12 +39,12 @@ def finger(a):
return numpy.dot(w, a.ravel())

class KnowValues(unittest.TestCase):
def test_aux_e2(self):
def test_aux_e1(self):
tmpfile = tempfile.NamedTemporaryFile(dir=lib.param.TMPDIR)
numpy.random.seed(1)
kptij_lst = numpy.random.random((3,2,3))
kptij_lst[0] = 0
outcore.aux_e2(cell, cell, tmpfile.name, aosym='s2', comp=1,
outcore.aux_e1(cell, cell, tmpfile.name, aosym='s2', comp=1,
kptij_lst=kptij_lst, verbose=0)
refk = incore.aux_e2(cell, cell, aosym='s2', kptij_lst=kptij_lst)
with h5py.File(tmpfile.name, 'r') as f:
Expand Down

0 comments on commit bfca90c

Please sign in to comment.