In [1]:
import torch as ch

In [5]:
class TruncatedUnknownVarianceMSE(ch.autograd.Function):
    """
    Computes the gradient of negative population log likelihood for truncated linear regression
    with unknown noise variance.
    """
    @staticmethod
    def forward(ctx, pred, targ, lambda_):
        ctx.save_for_backward(pred, targ, lambda_)
        return 0.5 * (pred.float() - targ.float()).pow(2).mean(0)
    
    @staticmethod
    def backward(ctx, grad_output):
        pred, targ, lambda_ = ctx.saved_tensors
        # calculate std deviation of noise distribution estimate
        sigma, z = ch.sqrt(lambda_.inverse()), Tensor([]).to(config.args.device)
        stacked = pred[None,...].repeat(config.args.num_samples, 1, 1)
        # add noise to regression predictions
        noised = stacked + sigma*ch.randn(ch.Size([config.args.num_samples, 1])).to(config.args.device)
        # filter out copies that fall outside of truncation set
        filtered = ch.stack([config.args.phi(batch).unsqueeze(1) for batch in noised]).float()
        z = noised * filtered
        lambda_grad = targ.pow(2).sum(dim=0) / (filtered.sum(dim=0) + config.args.eps) - z.pow(2).sum(dim=0) / (filtered.sum(dim=0) + config.args.eps)
        """
        multiply the v gradient by lambda, because autograd computes 
        v_grad*x*variance, thus need v_grad*(1/variance) to cancel variance 
        factor
        """
        out = z.sum(dim=0) / (filtered.sum(dim=0) + config.args.eps)
        return lambda_ * (z - targ) / pred.size(0), targ / pred.size(0), .5 * lambda_grad / pred.size(0)

In [2]:
lin = ch.nn.Linear(in_features=1, out_features=1)
opt = ch.optim.SGD(lin.parameters(), lr=1e-1)

In [4]:
opt_ = ch.optim.SGD([lin.weight, lin.bias], lr=1e-1)

In [5]:
opt_.zero_grad()

In [6]:
params = [{'params': [lin.weight, lin.bias]}, {'params': ch.ones(1)}]

In [8]:
o= ch.optim.SGD(params, lr=1e-1)

In [9]:
o.zero_grad()

In [None]:
`