# Defining an Example Model

In the next section, we define a simple 2-layer sparse DGP model for a regression task. We’ll be using this model to demonstrate the usage of the library.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import dgp_sparse as gp
from dgp_sparse.layers.linear import LinearReparameterization
from dgp_sparse.layers.activation import TMGP

## Defining a 2-layer DTMGP Model

First, we define a 2-layer DTMGP model with a single output dimension. The model consists of two layers, each with level-3 sparse grid design.

In [2]:
# Define a 2-layer DTMGP model for regression
class SparseDGP_grid(nn.Module):
    def __init__(self, input_dim, output_dim, design_class, kernel):
        super(SparseDGP_grid, self).__init__()
        
        # 1st layer of DGP: input:[n, input_dim] size tensor, output:[n, w1] size tensor
        self.tmk1 = TMGP(in_features=input_dim, n_level=3, design_class=design_class, kernel=kernel)
        self.fc1 = LinearReparameterization(
            in_features=self.tmk1.out_features, 
            out_features=8, 
            prior_mean=0.0, 
            prior_variance=1.0, 
            posterior_mu_init=0.0, 
            posterior_rho_init=-3.0, 
            bias=True,
        )

        # 2nd layer of DGP: input:[n, w1] size tensor, output:[n, output_dim] size tensor
        self.tmk2 = TMGP(in_features=8, n_level=3, design_class=design_class, kernel=kernel)
        self.fc2 = LinearReparameterization(
            in_features=self.tmk2.out_features, 
            out_features=output_dim, 
            prior_mean=0.0, 
            prior_variance=1.0, 
            posterior_mu_init=0.0, 
            posterior_rho_init=-3.0, 
            bias=True,
        )

    def forward(self, x):
        kl_sum = 0

        x = self.tmk1(x)
        x, kl = self.fc1(x)
        kl_sum += kl

        x = self.tmk2(x)
        x, kl = self.fc2(x)
        kl_sum += kl

        return torch.squeeze(x), kl_sum

## Preparing the Data

We set up the training data for this example. We'll be using 1000 regularly spaced points in the range [0, 10] as input data. The output data is generated by a function that takes the input data and adds Gaussian noise to get the training labels.

In [3]:
train_X = torch.linspace(0, 1, 1000)
train_y = 3 * train_X + 2 + torch.randn(train_X.size()) * 0.1

class RegressionDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
    
dataset = RegressionDataset(train_X, train_y)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

## Initializing the Model and the Optimizer

In [4]:
from dgp_sparse.utils.sparse_design.design_class import HyperbolicCrossDesign
from dgp_sparse.kernels.laplace_kernel import LaplaceProductKernel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using: ", device)

model = SparseDGP_grid(input_dim=1, 
                       output_dim=1, 
                       design_class=HyperbolicCrossDesign, 
                       kernel=LaplaceProductKernel(1.),
                       ).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

## Training the Model

In the next cell, we handle using variational inference (VI) to train the 2-layer sparse DGP model.

In [5]:
for epoch in range(50):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output_ = []
        kl_ = []
        for mc_run in range(1):
            output, kl = model(data)
            output_.append(output)
            kl_.append(kl)
        output = torch.mean(torch.stack(output_), dim=0)
        kl = torch.mean(torch.stack(kl_), dim=0)
        nll_loss = F.mse_loss(output, target)
        # ELBO loss
        loss = nll_loss + (kl / 32)
        loss.backward()
        optimizer.step()
        
    print(f"Epoch: {epoch}, Loss: {loss.item()}")

See our [documentation](https://sparse-dgp.readthedocs.io/en/latest/) for more information on how to use the library.