Skip to content

Commit

Permalink
CCSD restart
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed May 16, 2018
1 parent 929332e commit 0f9ddc2
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 8 deletions.
34 changes: 27 additions & 7 deletions cc/ccsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,23 @@ def kernel(mycc, eris=None, t1=None, t2=None, max_cycle=50, tol=1e-8,
if eris is None:
mycc.ao2mo(mycc.mo_coeff)
eris = mycc._eris

# Use the existed amplitudes as initial guess
if t1 is None: t1 = mycc.t1
if t2 is None: t2 = mycc.t2
if t1 is None and t2 is None:
t1, t2 = mycc.get_init_guess(eris)
elif t2 is None:
t2 = mycc.get_init_guess(eris)[1]

eold = 0
vec_old = 0
eccsd = 0
if mycc.diis:
adiis = diis.DistributedDIIS(mycc)
eccsd = mycc.energy(t1, t2, eris)
log.info('Init E(CCSD) = %.15g', eccsd)

if isinstance(mycc.diis, diis.DistributedDIIS):
adiis = mycc.diis
elif mycc.diis:
adiis = diis.DistributedDIIS(mycc, mycc.diis_file)
adiis.space = mycc.diis_space
else:
adiis = None
Expand Down Expand Up @@ -390,7 +397,7 @@ def _add_vvvv_tril(mycc, t1T, t2T, eris, out=None, with_ovvv=None):
ao_loc0 = ao_loc[task_sh_locs[task_id ]]
ao_loc1 = ao_loc[task_sh_locs[task_id+1]]
Ht2tril -= lib.einsum('pa,pbx->abx', t1_ao[ao_loc0:ao_loc1], buf)
time1 = log.timer_debug1('vvvv-tau ao2mo', *time0)
time1 = log.timer_debug1('contracting vvvv-tau', *time0)
else:
raise NotImplementedError
return Ht2tril
Expand Down Expand Up @@ -706,6 +713,17 @@ def _diff_norm(mycc, t1new, t2new, t1, t2):
norm1 = numpy.linalg.norm(t1new - t1)
return (norm1**2 + norm2**2)**.5

@lib.with_doc(ccsd.restore_from_diis_.__doc__)
@mpi.parallel_call
def restore_from_diis_(mycc, diis_file, inplace=True):
adiis = diis.DistributedDIIS(mycc, mycc.diis_file)
adiis.restore(diis_file, inplace=inplace)
ccvec = adiis.extrapolate()
mycc.t1, mycc.t2 = mycc.vector_to_amplitudes(ccvec)
if inplace:
mycc.diis = adiis
return mycc

# Temporarily placed here. Move it to mpi_scf module in the future
def _pack_scf(mf):
mfdic = {'verbose' : mf.verbose,
Expand All @@ -724,6 +742,7 @@ def _init_ccsd(ccsd_obj):
mpi.comm.bcast((ccsd_obj.mol.dumps(), ccsd_obj.pack()))
else:
ccsd_obj = ccsd.CCSD.__new__(ccsd.CCSD)
ccsd_obj.t1 = ccsd_obj.t2 = None
mol, cc_attr = mpi.comm.bcast(None)
ccsd_obj.mol = gto.mole.loads(mol)
ccsd_obj.unpack_(cc_attr)
Expand Down Expand Up @@ -755,6 +774,7 @@ def pack(self):
'mo_occ' : self.mo_occ,
'_nocc' : self._nocc,
'_nmo' : self._nmo,
'diis_file' : self.diis_file,
'direct' : self.direct}
def unpack_(self, ccdic):
self.__dict__.update(ccdic)
Expand All @@ -774,8 +794,6 @@ def sanity_check(self):
_add_vvvv = _add_vvvv
update_amps = update_amps

def kernel(self, t1=None, t2=None, eris=None):
return self.ccsd(t1, t2, eris)
def ccsd(self, t1=None, t2=None, eris=None):
assert(self.mo_coeff is not None)
assert(self.mo_occ is not None)
Expand Down Expand Up @@ -812,6 +830,8 @@ def vector_to_amplitudes(self, vec, nmo=None, nocc=None):
if nmo is None: nmo = self.nmo
return vector_to_amplitudes(vec, nmo, nocc)

restore_from_diis_ = restore_from_diis_

CC = RCCSD = CCSD

@mpi.parallel_call
Expand Down
9 changes: 8 additions & 1 deletion examples/04-parallel_ccsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
mol.basis = '6-31g'
mol.verbose = 4
mol.build()
mf = scf.RHF(mol).run()
mf = scf.RHF(mol)
mf.chkfile = 'h2o.chk'
mf.run()

mycc = cc.RCCSD(mf)
mycc.diis_file = 'mpi_ccdiis.h5'
mycc.kernel()

mycc.restore_from_diis_('mpi_ccdiis.h5')
mycc.kernel()

29 changes: 29 additions & 0 deletions lib/diis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,22 @@

class DistributedDIIS(lib.diis.DIIS):

def _store(self, key, value):
if self._diisfile is None:
if isinstance(self.filename, str):
filename = self.filename + '__rank' + str(mpi.rank)
self._diisfile = lib.H5TmpFile(filename, 'w')

elif not (self.incore or value.size < lib.diis.INCORE_SIZE):
self._diisfile = lib.H5TmpFile(self.filename, 'w')

return lib.diis.DIIS._store(self, key, value)

def extrapolate(self, nd=None):
if nd is None:
nd = self.get_num_vec()
if nd == 0:
raise RuntimeError('No vector found in DIIS object.')

h = self._H[:nd+1,:nd+1].copy()
h[1:,1:] = mpi.comm.allreduce(self._H[1:nd+1,1:nd+1])
Expand Down Expand Up @@ -41,3 +54,19 @@ def extrapolate(self, nd=None):
xnew[p0:p1] += xi[p0:p1] * ci
return xnew

def restore(self, filename, inplace=True):
'''Read diis contents from a diis file and replace the attributes of
current diis object if needed, then construct the vector.
'''
filename_base = filename.split('__rank')[0]
filename = filename_base + '__rank' + str(mpi.rank)
val = lib.diis.DIIS.restore(self, filename, inplace)
if inplace:
self.filename = filename_base
return val


def restore(filename):
'''Restore/construct diis object based on a diis file'''
return DIIS().restore(filename)

0 comments on commit 0f9ddc2

Please sign in to comment.