Skip to content

Commit

Permalink
Merge pull request #77 from pycroscopy/gp
Browse files Browse the repository at this point in the history
Sparse image reconstructor based on the structured kernel interpolation framework
  • Loading branch information
ziatdinovmax committed Apr 19, 2023
2 parents 263eb3b + 6b9b9ea commit 7517583
Show file tree
Hide file tree
Showing 12 changed files with 533 additions and 61 deletions.
4 changes: 2 additions & 2 deletions atomai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from .regressor import Regressor
from .classifier import Classifier
from .dgm import BaseVAE, VAE, rVAE, jVAE, jrVAE
from .dklgp import dklGPR
from .dklgp import dklGPR, Reconstructor
from .loaders import load_model, load_ensemble, load_pretrained_model

__all__ = ["Segmentor", "ImSpec", "BaseVAE", "VAE", "rVAE",
"jVAE", "jrVAE", "load_model", "load_ensemble",
"load_pretrained_model", "dklGPR", "Regressor",
"Classifier"]
"Classifier", "Reconstructor"]
1 change: 1 addition & 0 deletions atomai/models/dklgp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .dklgpr import dklGPR
from .gpr import Reconstructor

__all__ = ["dklGPR"]
114 changes: 114 additions & 0 deletions atomai/models/dklgp/gpr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import Tuple, Optional, List

import numpy as np
import gpytorch
import torch

from ...trainers import GPTrainer
from ...utils import prepare_gp_input, create_batches, get_lengthscale_constraints


class Reconstructor(GPTrainer):

"""
Sparse image reconstructor based on the structured kernel interpolation framework.
Keyword Args:
device:
Sets device to which model and data will be moved.
Defaults to 'cuda:0' if a GPU is available and to CPU otherwise.
precision:
Sets tensor types for 'single' (torch.float32)
or 'double' (torch.float64) precision
seed:
Seed for enforcing reproducibility
"""

def __init__(self, **kwargs):
super(Reconstructor, self).__init__(**kwargs)

def fit(self, X: torch.Tensor, y: torch.Tensor,
training_cycles: int, **kwargs):
"""
Performs model training
Args:
X: Input training data. Usually, these are indices of pixels where the sparse measurements were performed.
The dimensions of X should be (N, num_features). For 2D images, it will be (N, 2).
y: Output targets of (N,) dimensions (usually, these are pixel values)
training_cycles: Number of training epochs
Keyword Args:
grid_points_ratio: Determines a grid size for the KISS-GP kernel. Defaults to 1.0 (recommended)
lr: learning rate (Default: 0.01)
kernel_type: Type of kernel to use, either 'sparse' or 'kissgp'.
base_kernel: Name of the base kernel as a string, either 'rbf' or 'matern', or a custom base kernel object.
inducing_points: Inducing points for the sparse kernel.
lengthscale_contraints: Optional lengthscale constraints for the base kernel.
print_loss: print loss at every n-th training cycle (epoch)
"""
_ = self.run(X, y, training_cycles, **kwargs)

def predict(self, X_new: torch.Tensor, **kwargs):
"""
Prediction on new data
Args:
X_new: new inputs (usually, a full set of image indices)
Keyword Args:
batch_size: batch size for a batch-by-batch prediction (to avoid memory overflow)
device: Sets device to which model and data will be moved.
Defaults to 'cuda:0' if a GPU is available and to CPU otherwise.
Returns:
Predictive mean
"""
batch_size = kwargs.get("batch_size", len(X_new))
device = kwargs.get("device")
X_new_batches = create_batches(X_new, batch_size)
self.gp_model.eval()
self.likelihood.eval()
reconstruction = []
with torch.no_grad(), gpytorch.settings.fast_pred_var():
for x in X_new_batches:
x = self._set_data(x, device)
y_pred = self.likelihood(self.gp_model(x))
reconstruction.append(y_pred.mean)
return torch.cat(reconstruction)

def reconstruct(self, sparse_image: np.ndarray,
training_cycles: int = 100,
lengthscale_constraints: Optional[Tuple[List[float]]] = None,
grid_points_ratio: float = 1.0, **kwargs):
"""
Trains a reconstructor on sparse image pixels
and uses the trained model to reconstruct the entire image.
Args:
sparse_image: Input sparse image. The non-measured pixels must be zeros.
training_cycles: Number of training epochs
lengthscale_contraints: Optional lengthscale constraints for the base kernel.
grid_points_ratio: Determines a grid size for the KISS-GP kernel. Defaults to 1.0 (recommended)
Keyword Args:
lr: learning rate (Default: 0.01)
kernel_type: Type of kernel to use, either 'sparse' or 'kissgp'.
base_kernel: Name of the base kernel as a string, either 'rbf' or 'matern', or a custom base kernel object.
inducing_points: Inducing points for the sparse kernel.
print_loss: print loss at every n-th training cycle (epoch)
Returns:
Reconstructed image
"""
X_train, y_train, X_full = prepare_gp_input(sparse_image)
if not lengthscale_constraints:
lengthscale_constraints = get_lengthscale_constraints(X_full)
print("Model training ...\n")
self.fit(X_train, y_train, training_cycles,
lengthscale_constraints=lengthscale_constraints,
grid_points_ratio=grid_points_ratio, **kwargs)
print('\n\rPerforming reconstruction... ', end="")
reconstruction = self.predict(X_full, **kwargs)
print("Done")
return reconstruction.view(sparse_image.shape).cpu().numpy()
5 changes: 3 additions & 2 deletions atomai/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
convEncoderNet, coord_latent, fcDecoderNet, fcEncoderNet,
rDecoderNet, init_imspec_model, init_VAE_nets)
from .fcnn import Unet, dilnet, SegResNet, ResHedNet, init_fcnn_model
from .gp import fcFeatureExtractor, GPRegressionModel
from .gp import fcFeatureExtractor, GPRegressionModel, CustomGPModel
from .reg_cls import RegressorNet, ClassifierNet, MultiTaskClassifierNet, init_reg_model, init_cls_model


__all__ = ['ConvBlock', 'ResBlock', 'ResModule', 'UpsampleBlock', 'DilatedBlock',
'init_fcnn_model', 'SegResNet', 'Unet', 'ResHedNet', 'dilnet', 'fcEncoderNet',
'fcDecoderNet', 'convEncoderNet', 'convDecoderNet', 'rDecoderNet',
'coord_latent', 'load_model', 'load_ensemble', 'init_imspec_model',
'init_VAE_nets', 'SignalEncoder', 'SignalDecoder', 'SignalED',
'fcFeatureExtractor', 'GPRegressionModel', 'CustomBackbone', 'RegressorNet',
'ClassifierNet', 'init_reg_model', 'init_cls_model']
'ClassifierNet', 'init_reg_model', 'init_cls_model', 'CustomGPModel', 'MultiTaskClassifierNet']
73 changes: 72 additions & 1 deletion atomai/nets/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Modules for Gaussian process regression with deep kernel learning
"""

from typing import Type
from typing import Type, Optional, Union, Tuple, List

import torch
import gpytorch
Expand Down Expand Up @@ -58,3 +58,74 @@ def forward(self, x: torch.Tensor) -> gpytorch.distributions.MultivariateNormal:
mean_x = self.mean_module(embedded_x)
covar_x = self.covar_module(embedded_x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class CustomGPModel(gpytorch.models.ExactGP):
def __init__(self,
train_x: torch.Tensor, train_y: torch.Tensor,
likelihood: gpytorch.likelihoods.GaussianLikelihood,
kernel_type: str = 'kissgp',
base_kernel: Union[str, gpytorch.kernels.Kernel] = 'rbf',
inducing_points: Optional[torch.Tensor] = None, grid_points_ratio: int = 1.0,
lengthscale_constraints: Optional[Tuple[List[float]]] = None, **kwargs):
"""
Custom GP Model that allows the user to choose different base kernels, kernel types, and lengthscales.
Args:
train_x: Input training data.
train_y: Output training data.
likelihood: Gaussian likelihood object.
kernel_type: Type of kernel to use, either 'sparse' or 'kissgp'. Defaults to 'sparse'.
base_kernel: Name of the base kernel as a string, either 'rbf' or 'matern', or a custom base kernel object. Defaults to 'rbf'.
inducing_points: Inducing points for the sparse kernel. Defaults to None.
grid_points_ratio: Determines a grid size for the KISS-GP kernel. Defaults to 1.0
lengthscale_contraints: Optional lengthscale constraints for the base kernel. Defaults to None.
"""
super(CustomGPModel, self).__init__(train_x, train_y, likelihood)

self.mean_module = gpytorch.means.ConstantMean()

if isinstance(base_kernel, str):

if lengthscale_constraints:
lengthscale_constraints = gpytorch.constraints.Interval(
torch.tensor(lengthscale_constraints[0]),
torch.tensor(lengthscale_constraints[1]))

if base_kernel == 'rbf':
base_kernel = gpytorch.kernels.RBFKernel(
ard_num_dims=train_x.shape[-1],
lengthscale_constraint=lengthscale_constraints)
elif base_kernel == 'matern':
base_kernel = gpytorch.kernels.MaternKernel(
ard_num_dims=train_x.shape[-1],
lengthscale_constraint=lengthscale_constraints)
else:
raise ValueError("base_kernel must be either 'rbf', 'matern', or a custom gpytorch.kernels.Kernel object")

self.base_covar_module = gpytorch.kernels.ScaleKernel(base_kernel)

if kernel_type == 'sparse':
self.covar_module = gpytorch.kernels.InducingPointKernel(
self.base_covar_module, inducing_points=inducing_points, likelihood=likelihood)
elif kernel_type == 'kissgp':
grid_size = gpytorch.utils.grid.choose_grid_size(train_x, grid_points_ratio)
self.covar_module = gpytorch.kernels.GridInterpolationKernel(
self.base_covar_module, grid_size=grid_size, num_dims=train_x.shape[-1])
else:
raise ValueError(
f"Invalid kernel_type: {kernel_type}. Supported values are 'sparse' and 'kissgp'.")

def forward(self, x: torch.Tensor) -> gpytorch.distributions.MultivariateNormal:
"""
Forward pass for the GP model.
Args:
x (torch.Tensor): Input data.
Returns:
Multivariate normal distribution representing the predicted output.
"""
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
5 changes: 3 additions & 2 deletions atomai/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .trainer import SegTrainer, ImSpecTrainer, RegTrainer, clsTrainer, BaseTrainer
from .etrainer import BaseEnsembleTrainer, EnsembleTrainer
from .vitrainer import viBaseTrainer
from .gptrainer import dklGPTrainer
from .gptrainer import dklGPTrainer, GPTrainer

__all__ = ["SegTrainer", "ImSpecTrainer", "BaseTrainer", "BaseEnsembleTrainer",
"EnsembleTrainer", "viBaseTrainer", "dklGPTrainer", "RegTrainer", "clsTrainer"]
"EnsembleTrainer", "viBaseTrainer", "dklGPTrainer", "RegTrainer", "clsTrainer",
"GPTrainer"]
Loading

0 comments on commit 7517583

Please sign in to comment.