In [2]:
import torch
from torch import nn, optim
from torch.autograd import grad


def generate(n=10000, s=1, t=1, k=1):
    x_value = s * torch.randn(n, 1)
    y_value = k * x_value + s * torch.randn(n, 1)
    x_prime = y_value + t * torch.randn(n, 1)
    return torch.cat([x_value, x_prime], dim=1), y_value


def evaluate(loss, omega):
    g1 = grad(loss[0::2].mean(), omega, create_graph=True)[0]
    g2 = grad(loss[1::2].mean(), omega, create_graph=True)[0]
    return (g1 * g2).sum()


ks = [0.3, 0.5, 0.7, 0.8, 0.9, 1.0, 1.5, 2.0, 2.5, 3.0]


for k in ks:
    print("\n---------- k = {} ----------".format(k))
    sample = [generate(s=1.0, k=k), generate(s=0.1, k=k)]

    theta = nn.Parameter(torch.ones(2, 1))
    omega = nn.Parameter(torch.Tensor([1.0]))
    optimizer = optim.SGD([theta], lr=1e-3)
    function = nn.MSELoss(reduction='none')
    record = []

    for epoch in range(1, 100001):
        error = 0
        penalty = 0
        for x, y in sample:
            p = torch.randperm(len(x))
            x, y = x[p], y[p]
            loss = function(x @ theta * omega, y)
            error += loss.mean()
            penalty += evaluate(loss, omega)
        optimizer.zero_grad()
        (1e-5 * error + penalty).backward()
        optimizer.step()
        record.append((1e-5 * error + penalty).item())
        if epoch % 5000 == 0:
            print('----- Epoch {} -----'.format(epoch))
            print('Error: {}'.format(error.item()))
            print('Penalty: {}'.format(penalty.item()))
            print(theta)



---------- k = 0.3 ----------
----- Epoch 5000 -----
Error: 0.8703827261924744
Penalty: 0.00016503059305250645
Parameter containing:
tensor([[0.4294],
        [0.0872]], requires_grad=True)
----- Epoch 10000 -----
Error: 0.891944944858551
Penalty: 8.399621583521366e-05
Parameter containing:
tensor([[0.4126],
        [0.0688]], requires_grad=True)
----- Epoch 15000 -----
Error: 0.9027113318443298
Penalty: 4.743588579003699e-05
Parameter containing:
tensor([[0.4034],
        [0.0602]], requires_grad=True)
----- Epoch 20000 -----
Error: 0.9095416069030762
Penalty: -0.00010865063813980669
Parameter containing:
tensor([[0.3973],
        [0.0549]], requires_grad=True)
----- Epoch 25000 -----
Error: 0.9143661260604858
Penalty: -4.827185330213979e-05
Parameter containing:
tensor([[0.3928],
        [0.0513]], requires_grad=True)
----- Epoch 30000 -----
Error: 0.9180030822753906
Penalty: -4.301983426557854e-05
Parameter containing:
tensor([[0.3894],
        [0.0486]], requires_grad=True)
----- 