Skip to content

Commit

Permalink
Subcalss dklGPTrainer from GPTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Apr 19, 2023
1 parent 3d7719c commit 5d4ad0f
Showing 1 changed file with 3 additions and 41 deletions.
44 changes: 3 additions & 41 deletions atomai/trainers/gptrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def print_statistics(self, e):
'Training loss: {}'.format(np.around(self.train_loss[-1], 4)))


class dklGPTrainer:
class dklGPTrainer(GPTrainer):
"""
Deep kernel learning (DKL)-based Gaussian process regression (GPR)
Expand All @@ -167,41 +167,16 @@ def __init__(self,
"""
Initializes DKL-GPR.
"""
super(dklGPTrainer, self).__init__(**kwargs)

set_seed_and_precision(**kwargs)
self.dimdict = {"input_dim": indim, "embedim": embedim}
self.device = kwargs.get(
"device", 'cuda:0' if torch.cuda.is_available() else 'cpu')
precision = kwargs.get("precision", "double")
self.dtype = torch.float32 if precision == "single" else torch.float64
self.correlated_output = shared_embedding_space
self.gp_model = None
self.likelihood = None
self.ensemble = False
self.compiled = False
self.train_loss = []

def _set_data(self, x: Union[torch.Tensor, np.ndarray],
device: str = None) -> torch.tensor:
"""Data preprocessing."""
device_ = device if device else self.device
if isinstance(x, np.ndarray):
x = torch.from_numpy(x).to(self.dtype).to(device_)
elif isinstance(x, torch.Tensor):
x = x.to(self.dtype).to(device_)
else:
raise TypeError("Pass data as ndarray or torch tensor object")
return x

def set_data(self, x: Union[torch.Tensor, np.ndarray],
y: Optional[Union[torch.Tensor, np.ndarray]] = None,
device: str = None) -> Tuple[torch.tensor]:
"""Data preprocessing. Casts data array to a selected tensor type
and moves it to a selected devive."""
x = self._set_data(x, device)
if y is not None:
y = y[None] if y.ndim == 1 else y
y = self._set_data(y, device)
return x, y

def compile_multi_model_trainer(self,
X: Union[torch.Tensor, np.ndarray],
Expand Down Expand Up @@ -329,19 +304,6 @@ def compile_trainer(self, X: Union[torch.Tensor, np.ndarray],
self.training_cycles = training_cycles
self.compiled = True

def train_step(self) -> None:
"""
Single training step with backpropagation
to computegradients and optimizes weights.
"""
self.optimizer.zero_grad()
X, y = self.gp_model.train_inputs, self.gp_model.train_targets
output = self.gp_model(*X)
loss = -self.mll(output, y).sum()
loss.backward()
self.optimizer.step()
self.train_loss.append(loss.item())

def run(self, X: Union[torch.Tensor, np.ndarray] = None,
y: Union[torch.Tensor, np.ndarray] = None,
training_cycles: int = 1,
Expand Down

0 comments on commit 5d4ad0f

Please sign in to comment.