In [None]:
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import numpy as np

In [None]:
train_dataset = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = dsets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

In [None]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

In [None]:
batch_size = 100
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        outputs = self.linear(x)
        return outputs

In [None]:
class MultiLayerNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_units_dim, output_dim):
        super(MultiLayerNN, self).__init__()
        self.model = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_units_dim),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(hidden_units_dim),
            torch.nn.Linear(hidden_units_dim, hidden_units_dim),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(hidden_units_dim),
            torch.nn.Linear(hidden_units_dim, hidden_units_dim),
        )
        for m in self.model:
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.kaiming_normal_(m.weight)
                torch.nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        return self.model(x)

In [None]:
batch_size = 100
epochs = 30
input_dim = 784
output_dim = 10
lr_rate = 0.05
len_train_dataset = len(train_loader.dataset)

In [None]:
def get_alt_model_loss(x, y, model, alt_model, criterion, w_set):
    alt_losses = []
    
    for w_reshape in w_set:
        for param, alt_param, w in zip(model.parameters(), alt_model.parameters(), w_reshape):
            alt_param.data = param.data - lr_rate * w.data * param.grad.data
            #alt_param.data = param.data - lr_rate * (1) * param.grad.data

        alt_outputs = alt_model(x)
        alt_loss = criterion(alt_outputs, y)

        alt_losses.append(alt_loss)

    return torch.tensor(alt_losses)

In [None]:
model = MultiLayerNN(input_dim, 1000, output_dim).cuda()
alt_model = MultiLayerNN(input_dim, 1000, output_dim).cuda()
criterion = torch.nn.CrossEntropyLoss().cuda()

w_dim = sum([param.shape.numel() for param in model.parameters()])

In [None]:
import torch
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_model
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.acquisition import qUpperConfidenceBound
from botorch.optim import optimize_acqf

def _get_next_z_set_and_best_z(z, increments):
    train_X = z.transpose(0, 1).cpu().clone().data
    train_Y = increments.reshape([-1, 1]).clone().data
    train_Y = (train_Y - train_Y.min()) / (train_Y.max() - train_Y.min())

    gp = SingleTaskGP(train_X, train_Y)
    mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
    fit_gpytorch_model(mll)

    qUCB = qUpperConfidenceBound(gp, beta=0.1)
    
    z_dim = z.shape[0]
    bounds = torch.stack([torch.zeros(z_dim), torch.ones(z_dim)])
    candidates, acq_value = optimize_acqf(
        qUCB, bounds=bounds, q=50, num_restarts=20, raw_samples=512,
    )
    print("increment : {0:.7f}".format(increments.max().item()))
    return candidates, z.transpose(0, 1)[increments.argmax()]

In [None]:
test = torch.zeros([2, 2])
test = torch.cat([test, torch.ones([1, 2])], 0)
test

In [None]:
def _seperate_w(w, model):
    w_reshape_list = []
    num_params = 0
    for param in model.parameters():
        w_reshape = w[num_params:num_params + param.shape.numel()].reshape_as(param)
        w_reshape_list.append(w_reshape)
        num_params += param.shape.numel()
    
    return w_reshape_list

def _get_w_set(w, model):
    w_set = [_seperate_w(w[:, i].cuda(), model) for i in range(w.shape[1])]
    return w_set

def _get_w(A, z, c):
    sig = torch.nn.Sigmoid()
    h = 2 * sig(c * A.mm(z))
    return h.clamp(min=1.0)

def _get_alt_model_loss(x, y, model, alt_model, criterion, w_set):
    alt_losses = []
    
    for w_reshape in w_set:
        for param, alt_param, w in zip(model.parameters(), alt_model.parameters(), w_reshape):
            alt_param.data = param.data - lr_rate * w.data * param.grad.data
            #alt_param.data = param.data - lr_rate * (1) * param.grad.data

        alt_outputs = alt_model(x)
        alt_loss = criterion(alt_outputs, y)

        alt_losses.append(alt_loss)

    return torch.tensor(alt_losses)

class AdaptiveGPGradient:
    def __init__(self, model, z_dim, z_length, update_weight_step, fit_gp_step):
        w_dim = sum([param.shape.numel() for param in model.parameters()])
        
        self.A = torch.randn([w_dim, z_dim])
        self.Z = torch.rand([z_dim, z_length])
        
        self.model = model
        
        self.step = 0
        self.update_weight_step = update_weight_step
        self.fit_gp_step = fit_gp_step
        
        self.c = 1.0
        
        W = _get_w(self.A, self.Z, self.c)
        self.w_set = _get_w_set(W, model)
        self.V = torch.zeros(z_length)
        
        
        self.best_z = torch.zeros([z_dim ,1])
        self.best_w = _get_w_set(_get_w(self.A, self.best_z, self.c), model)[0]
        
        self.z_dim = z_dim
        
        #self.fit_gp_step = fit_gp_step
    
    def gather_V(self, images, labels, model, criterion):
        with torch.no_grad():
            V = _get_alt_model_loss(images, labels, model, self.model, criterion, self.w_set)
            self.V += V / self.update_weight_step
        
    def update_Z(self):
        if self.step < self.update_weight_step:
            self.step += 1
            return
        else:
            self.step = 0
        
        train_X = self.Z.transpose(0, 1).cpu().data
        train_X = torch.cat([train_X, self.best_z.reshape([1, -1])], 0)
        
        train_Y = self.V.reshape([-1, 1]).data
        train_Y = torch.cat([train_Y, torch.zeros([1, 1])], 0)
        train_Y = (train_Y - train_Y.min()) / (train_Y.max() - train_Y.min())

        gp = SingleTaskGP(train_X, train_Y)
        mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        fit_gpytorch_model(mll)

        UCB = UpperConfidenceBound(gp, beta=0.1)

        z_dim = z.shape[0]
        bounds = torch.stack([torch.zeros(z_dim), torch.ones(z_dim)])
        self.best_z = sequential_optimize(
            UCB, bounds=bounds, q=1, num_restarts=20, raw_samples=512,
        )
        
test = AdaptiveGPGradient(alt_model, z_dim, z_length, 100, 1)

In [None]:
z_dim = 4
A = torch.randn([w_dim, z_dim])
z_length = 50
z = torch.rand([z_dim, z_length])

def seperate_w(w, model):
    w_reshape_list = []
    num_params = 0
    for param in model.parameters():
        w_reshape = w[num_params:num_params + param.shape.numel()].reshape_as(param)
        w_reshape_list.append(w_reshape)
        num_params += param.shape.numel()
    
    return w_reshape_list

def get_w_set(Az, model):
    w_set = [seperate_w(Az[:, i].cuda(), model) for i in range(Az.shape[1])]
    return w_set

def get_w(A, z):
    sig = torch.nn.Sigmoid()
    return 2 * sig(A.mm(z))

w_set = get_w_set(get_w(A, z), model)

init_w = [torch.ones(w.shape).cuda() for w in w_set[0]]

best_w = init_w

In [None]:
def show_params(model):
    for param in model.parameters():
        print(param.data)

In [None]:
import torch
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_model
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.acquisition import qUpperConfidenceBound
from botorch.optim import optimize_acqf

def get_next_z_set_and_best_z(z, increments):
    train_X = z.T.cpu().clone().data
    train_Y = increments.reshape([-1, 1]).clone().data
    train_Y = (train_Y - train_Y.min()) / (train_Y.max() - train_Y.min())

    gp = SingleTaskGP(train_X, train_Y)
    mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
    fit_gpytorch_model(mll)

    qUCB = qUpperConfidenceBound(gp, beta=0.1)

    bounds = torch.stack([torch.zeros(4), torch.ones(4)])
    candidates, acq_value = optimize_acqf(
        qUCB, bounds=bounds, q=50, num_restarts=20, raw_samples=512,
    )
    print("increment : {0:.7f}".format(increments.max().item()))
    return candidates, z.T[increments.argmax()]

In [None]:
results = []
for epoch in range(epochs):
    avg_loss = 0.
    
    increments = torch.zeros(z_length)
    
    for i, (images, labels) in enumerate(train_loader):
        images = Variable(images.view(-1, 28 * 28)).cuda()
        labels = Variable(labels).cuda()
        
        outputs = model(images)
        loss = criterion(outputs, labels)

        model.zero_grad()
        loss.backward()
        
        avg_loss += loss / (len_train_dataset / batch_size)
        
        with torch.no_grad():
            alt_losses = get_alt_model_loss(images, labels, model, alt_model, criterion, w_set)
            #show_params(alt_model)
            for w, param in zip(best_w, model.parameters()):
                param.data -= lr_rate * w.data * param.grad.data
        
        with torch.no_grad():
            outputs = model(images)
            loss = criterion(outputs, labels)
            #show_params(model)
        
        #print(loss, alt_losses)
        increments += (loss.cpu() - alt_losses) / (len_train_dataset / batch_size)
        
        #print(alt_losses)
    results.append(avg_loss)
    
    z_set, best_z = get_next_z_set_and_best_z(z, increments)
    w_set = get_w_set(get_w(A, z_set.T), model)
    best_w = get_w_set(get_w(A, best_z.reshape([-1, 1])), model)[0]
    
    print(epoch, avg_loss.data)
    

In [None]:
import matplotlib.pyplot as plt

plt.plot(results)

In [None]:
correct = 0.
total = 0.

for images, labels in test_loader:
    images = Variable(images.view(-1, 28*28)).cuda()
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct+= (predicted.cpu() == labels).sum()

accuracy = 100 * correct.float() / total

In [None]:
accuracy