In [3]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

In [10]:
class Simulator:
    
    def __init__(self, w, b, sigma, N, design_range=(-10,10)):
        self.w = w
        self.b = b
        self.theta = np.expand_dims(np.concatenate([w, [b]], axis=0), axis=1)
        self.sigma = sigma
        self.N = N
        self.design_range = design_range
        self.X = None
        self.y = None
        self.y_mean = None
        
    def run(self):
        designs = np.random.uniform(self.design_range[0], self.design_range[1], size=(self.N, self.w.size))
        self.X = np.concatenate([designs, np.ones((self.N, 1))], axis=1)
        self.y_mean = (self.X @ self.theta).squeeze()
        self.y = np.random.multivariate_normal(mean=self.y_mean, cov=np.diag([self.sigma**2] * self.N))
    
    def plot(self):
        x = self.X[:, 0]
        plt.scatter(x, self.y, label="data")
        x_dense = np.linspace(self.design_range[0], self.design_range[1], 100)
        y_dense = x_dense * self.w[0] + self.b
        plt.plot(x_dense, y_dense, label="y mean")
        plt.xlabel("x")
        plt.ylabel("y")
        plt.legend()
        plt.title("Simulated data, N="+str(self.N))
        plt.show()

In [75]:
# True weight(s)
w = np.array([1.5, -1.0, 0.7])

# Input dimensionality
d = w.size

# True intercept
b = 0.5

# True standard deviation
sigma = 0.5

# Number of data points
N = 100

# Defines range of inputs x
design_range = (-1.0, 1.0)

# Simulate
simulator = Simulator(w, b, sigma, N, design_range)
simulator.run()

X = simulator.X
y = simulator.y

In [29]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.layer_1 = nn.Linear(d + 1, 16)
        self.layer_2 = nn.Linear(16, 32)
        self.output_layer = nn.Linear(16, 1)
    def forward(self, x):
        x = torch.sigmoid(self.layer_1(x))
        x = torch.sigmoid(self.layer_2(x))
        return self.output_layer(x)

In [67]:
# Helper function for collecting nn gradient into a vector
def collect_grads(model):
    return torch.cat([p.grad.data.view(1, -1) for p in model.parameters()], dim=-1)

# Helper function for computing sizes of all nn parameters
def get_param_sizes(model):
    return [p.reshape(-1).size()[0] for p in model.parameters()]

# Helper function for writing the updated weights
def update_params(new_params, model, param_sizes):
    start_index = 0
    for i, p in enumerate(model.parameters()):
        end_index = start_index + param_sizes[i]
        source_tensor = new_params[:, start_index:end_index].reshape(p.shape)
        p.data = source_tensor
        start_index = end_index

In [68]:
mlp = MLP()
param_sizes = get_param_sizes(mlp)
param_sizes

[64, 16, 512, 32, 32, 1]

In [69]:
num_params = sum(param_sizes)
num_params

657

In [70]:
X_tensor = torch.tensor(X, dtype=torch.float32, requires_grad=False)
y_tensor = torch.tensor(y, dtype=torch.float32, requires_grad=False).view(-1, 1)

l2_loss = ((y_tensor - mlp(X_tensor)) ** 2).sum()
l2_loss

tensor(203.5119, grad_fn=<SumBackward0>)

In [71]:
mlp.zero_grad()
l2_loss.backward()

In [72]:
gradient = collect_grads(mlp)
gradient.shape

torch.Size([1, 657])

In [73]:
# Do some random ops
with torch.no_grad():
    beta = gradient ** 2
    beta = beta - 0.4

In [74]:
update_params(new_params=beta, model=mlp, param_sizes=param_sizes)