-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathNNLDA.h
36 lines (34 loc) · 1.13 KB
/
NNLDA.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#ifndef NNLDA_H
#define NNLDA_H
#ifdef DFTFE_WITH_TORCH
# include <string>
# include <torch/torch.h>
# include <excDensityPositivityCheckTypes.h>
namespace dftfe
{
class NNLDA
{
public:
NNLDA(std::string modelFilename,
const bool isSpinPolarized = false,
const excDensityPositivityCheckTypes densityPositivityCheckType =
excDensityPositivityCheckTypes::MAKE_POSITIVE);
~NNLDA();
void
evaluateexc(const double *rho, const unsigned int numPoints, double *exc);
void
evaluatevxc(const double * rho,
const unsigned int numPoints,
double * exc,
double * vxc);
private:
std::string d_modelFilename;
std::string d_ptcFilename;
torch::jit::script::Module * d_model;
const bool d_isSpinPolarized;
double d_rhoTol;
const excDensityPositivityCheckTypes d_densityPositivityCheckType;
};
} // namespace dftfe
#endif
#endif // NNLDA_H