In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.double

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1,20,5,1)
        self.conv2 = nn.Conv2d(20,50,5,1)
        self.fc1 = nn.Linear(4*4*50,500)
        self.fc2 = nn.Linear(500,10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x,2,2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x,dim=1)

In [11]:
pretrained_loc = "pretrained_models"

cnn_model = Net().to(dtype=dtype, device=device)
cnn_state_dict = torch.load(f'./{pretrained_loc}/mnist_cnn.pt', map_location=device)
cnn_model.load_state_dict(cnn_state_dict)

<All keys matched successfully>

In [12]:
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)
        
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1,784))
        z = self.reparameterize(mu,logvar)
        return self.decode(z), mu, logvar
    
vae_model = VAE().to(dtype=dtype, device=device)
vae_state_dict = torch.load(f'./{pretrained_loc}/mnist_vae.pt', map_location=device)
vae_model.load_state_dict(vae_state_dict)

<All keys matched successfully>

In [13]:
def score(y, target=3):
    return torch.exp(-2*(y-target)**2)

In [14]:
def score_image_recognition(x,target=3):
    with torch.no_grad():
        probs = torch.exp(cnn_model(x))
        scores = score(torch.arange(10,device=device, dtype=dtype),target).expand(probs.shape)
    return (probs*scores).sum(dim=1)

In [15]:
def decode(train_x):
    with torch.no_grad():
        decoded = vae_model.decode(train_x)
    return decoded.view(train_x.shape[0], 1, 28, 28)

In [None]:
from botorch.models import SingleTaskGP
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from botorch.utils.transforms import standardize, normalize, unnormalize

d = 20
bounds = torch.tensor([[-6.0]*d, [6.0]*d], device=device, dtype=dtype)

def gen_initial_data(n=5):
    train_x = unnormalize(torch.rand(n, d, device=device, dtype=dtype), bounds=bounds)
    train_obj = score_image_recognition()
    