In [1]:
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import numpy as np

In [2]:
train_dataset = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = dsets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

In [3]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

In [4]:
batch_size = 100
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        outputs = self.linear(x)
        return outputs

In [6]:
batch_size = 100
epochs = 30
input_dim = 784
output_dim = 10
lr_rate = 0.1
len_train_dataset = len(train_loader.dataset)

In [7]:
def get_alt_model_loss(x, y, model, alt_model, criterion, w_set):
    alt_losses = []
    
    for w_reshape in w_set:
        for param, alt_param, w in zip(model.parameters(), alt_model.parameters(), w_reshape):
            alt_param.data = param.data - lr_rate * w.data * param.grad.data
            #alt_param.data = param.data - lr_rate * (1) * param.grad.data

        alt_outputs = alt_model(x)
        alt_loss = criterion(alt_outputs, y)

        alt_losses.append(alt_loss)

    return torch.tensor(alt_losses)

In [8]:
model = LogisticRegression(input_dim, output_dim).cuda()
alt_model = LogisticRegression(input_dim, output_dim).cuda()
criterion = torch.nn.CrossEntropyLoss().cuda()

w_dim = sum([param.shape.numel() for param in model.parameters()])

In [9]:
z_dim = 4
A = torch.randn([w_dim, z_dim])
z_length = 50
z = torch.rand([z_dim, z_length])

def seperate_w(w, model):
    w_reshape_list = []
    num_params = 0
    for param in model.parameters():
        w_reshape = w[num_params:num_params + param.shape.numel()].reshape_as(param)
        w_reshape_list.append(w_reshape)
        num_params += param.shape.numel()
    
    return w_reshape_list

def get_w_set(Az, model):
    w_set = [seperate_w(Az[:, i].cuda(), model) for i in range(Az.shape[1])]
    return w_set

def get_w(A, z):
    sig = torch.nn.Sigmoid()
    return 2 * sig(A.mm(z))

w_set = get_w_set(get_w(A, z), model)

init_w = [torch.ones(w.shape).cuda() for w in w_set[0]]

best_w = init_w

In [10]:
def show_params(model):
    for param in model.parameters():
        print(param.data)

In [11]:
import torch
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_model
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.acquisition import qUpperConfidenceBound
from botorch.optim import optimize_acqf

def get_next_z_set_and_best_z(z, increments):
    train_X = z.T.cpu().clone().data
    train_Y = increments.reshape([-1, 1]).clone().data

    gp = SingleTaskGP(train_X, train_Y)
    mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
    fit_gpytorch_model(mll)

    qUCB = qUpperConfidenceBound(gp, beta=0.1)

    bounds = torch.stack([torch.zeros(4), torch.ones(4)])
    candidates, acq_value = optimize_acqf(
        qUCB, bounds=bounds, q=50, num_restarts=20, raw_samples=512,
    )
    print("increment : {0:.7f}".format(increments.max().item()))
    return candidates, z.T[increments.argmax()]

In [12]:
results = []
for epoch in range(epochs):
    avg_loss = 0.
    
    increments = torch.zeros(z_length)
    
    for i, (images, labels) in enumerate(train_loader):
        images = Variable(images.view(-1, 28 * 28)).cuda()
        labels = Variable(labels).cuda()
        
        outputs = model(images)
        loss = criterion(outputs, labels)

        model.zero_grad()
        loss.backward()
        
        avg_loss += loss / (len_train_dataset / batch_size)
        
        with torch.no_grad():
            #alt_losses = get_alt_model_loss(images, labels, model, alt_model, criterion, w_set)
            #show_params(alt_model)
            for w, param in zip(best_w, model.parameters()):
                param.data -= lr_rate * w.data * param.grad.data
        
#         with torch.no_grad():
#             outputs = model(images)
#             loss = criterion(outputs, labels)
#             #show_params(model)
        
#         #print(loss, alt_losses)
#         increments += (loss.cpu() - alt_losses) / (len_train_dataset / batch_size)
        
        #print(alt_losses)
    results.append(avg_loss)
    
#     z_set, best_z = get_next_z_set_and_best_z(z, increments)
#     w_set = get_w_set(get_w(A, z_set.T), model)
#     best_w = get_w_set(get_w(A, best_z.reshape([-1, 1])), model)[0]
    
    print(epoch, avg_loss.data)
    

0 tensor(0.5359, device='cuda:0')
1 tensor(0.3597, device='cuda:0')
2 tensor(0.3316, device='cuda:0')
3 tensor(0.3169, device='cuda:0')
4 tensor(0.3075, device='cuda:0')
5 tensor(0.3006, device='cuda:0')
6 tensor(0.2954, device='cuda:0')
7 tensor(0.2912, device='cuda:0')
8 tensor(0.2878, device='cuda:0')
9 tensor(0.2848, device='cuda:0')
10 tensor(0.2823, device='cuda:0')
11 tensor(0.2801, device='cuda:0')
12 tensor(0.2782, device='cuda:0')
13 tensor(0.2764, device='cuda:0')
14 tensor(0.2748, device='cuda:0')
15 tensor(0.2734, device='cuda:0')
16 tensor(0.2721, device='cuda:0')
17 tensor(0.2709, device='cuda:0')
18 tensor(0.2697, device='cuda:0')
19 tensor(0.2687, device='cuda:0')
20 tensor(0.2677, device='cuda:0')
21 tensor(0.2668, device='cuda:0')
22 tensor(0.2659, device='cuda:0')
23 tensor(0.2651, device='cuda:0')
24 tensor(0.2644, device='cuda:0')
25 tensor(0.2636, device='cuda:0')
26 tensor(0.2630, device='cuda:0')
27 tensor(0.2623, device='cuda:0')
28 tensor(0.2617, device='cuda

In [12]:
#lr_rate = 0.05
results = []
for epoch in range(epochs):
    avg_loss = 0.
    
    increments = torch.zeros(z_length)
    
    for i, (images, labels) in enumerate(train_loader):
        images = Variable(images.view(-1, 28 * 28)).cuda()
        labels = Variable(labels).cuda()
        
        outputs = model(images)
        loss = criterion(outputs, labels)

        model.zero_grad()
        loss.backward()
        
        avg_loss += loss / (len_train_dataset / batch_size)
        
        with torch.no_grad():
            #alt_losses = get_alt_model_loss(images, labels, model, alt_model, criterion, w_set)
            #show_params(alt_model)
            for w, param in zip(best_w, model.parameters()):
                param.data -= lr_rate * w.data * param.grad.data
        
#         with torch.no_grad():
#             outputs = model(images)
#             loss = criterion(outputs, labels)
#             #show_params(model)
        
#         #print(loss, alt_losses)
#         increments += (loss.cpu() - alt_losses) / (len_train_dataset / batch_size)
        
        #print(alt_losses)
    results.append(avg_loss)
    
#     z_set, best_z = get_next_z_set_and_best_z(z, increments)
#     w_set = get_w_set(get_w(A, z_set.T), model)
#     best_w = get_w_set(get_w(A, best_z.reshape([-1, 1])), model)[0]
    
    print(epoch, avg_loss.data)
    

0 tensor(0.6582, device='cuda:0')
1 tensor(0.4077, device='cuda:0')
2 tensor(0.3679, device='cuda:0')
3 tensor(0.3475, device='cuda:0')
4 tensor(0.3344, device='cuda:0')
5 tensor(0.3250, device='cuda:0')
6 tensor(0.3179, device='cuda:0')
7 tensor(0.3122, device='cuda:0')
8 tensor(0.3075, device='cuda:0')
9 tensor(0.3036, device='cuda:0')
10 tensor(0.3002, device='cuda:0')
11 tensor(0.2973, device='cuda:0')
12 tensor(0.2947, device='cuda:0')
13 tensor(0.2924, device='cuda:0')
14 tensor(0.2903, device='cuda:0')
15 tensor(0.2884, device='cuda:0')
16 tensor(0.2867, device='cuda:0')
17 tensor(0.2852, device='cuda:0')
18 tensor(0.2837, device='cuda:0')
19 tensor(0.2823, device='cuda:0')
20 tensor(0.2811, device='cuda:0')
21 tensor(0.2799, device='cuda:0')
22 tensor(0.2788, device='cuda:0')
23 tensor(0.2778, device='cuda:0')
24 tensor(0.2768, device='cuda:0')
25 tensor(0.2759, device='cuda:0')
26 tensor(0.2750, device='cuda:0')
27 tensor(0.2742, device='cuda:0')
28 tensor(0.2734, device='cuda

In [None]:
import matplotlib.pyplot as plt

plt.plot(results)

In [None]:
correct = 0.
total = 0.

for images, labels in test_loader:
    images = Variable(images.view(-1, 28*28)).cuda()
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct+= (predicted.cpu() == labels).sum()

accuracy = 100 * correct.float() / total

In [None]:
accuracy