In [20]:
import torch
import torch.nn as nn

def generate(X, alpha, beta, eta, gamma, estimator="LATE"):
    if estimator == 'LATE':
        delta = torch.tanh(X @ alpha)
    elif estimator == 'MLATE':
        delta = torch.exp(X @ alpha)
    phi = torch.sigmoid(X @ beta)
    OP = torch.exp(X @ eta)
    VIden = torch.sigmoid(X @ gamma)

    Z = torch.bernoulli(VIden)
    D = torch.bernoulli((1-phi[:,0])*phi[:,1] + Z*phi[:,0])
    if estimator == 'LATE':
        f0 = (OP*(2-delta)+delta-torch.sqrt(delta**2 * (OP-1)**2 + 4*OP)) / (2 * (OP-1))
        f1 = f0 + delta
    elif estimator == 'MLATE':
        f0 = (-(delta+1)*OP
         + torch.sqrt(OP**2 * (delta-1)**2 + 4*delta*OP)) / (2*delta*(1-OP))
        f1 = f0 * delta
    Y = Z*(1-D)*phi[:,2] + (1-Z)*D*phi[:,3] + Z*D*(f1*phi[:,0]+(1-phi[:,0])*phi[:,1]*phi[:,3])/(1-(1-phi[:,0])*(1-phi[:,1])) + (1-Z)*(1-D)*(f0*phi[:,0]+(1-phi[:,0])*(1-phi[:,1])*phi[:,2])/(1-(1-phi[:,0])*phi[:,1])
    Y = torch.bernoulli(Y)
    return Z, D, Y

def nll(X,Z,D,Y,alpha,beta,eta,estimator='LATE'):
    phi = torch.sigmoid(X@beta)
    c1, c2, c3, c4 = phi[:,0], phi[:,1], phi[:,2], phi[:,3]
    if estimator == 'LATE':
        c0 = torch.tanh(X @ alpha)
    elif estimator == 'MLATE':
        c0 = torch.exp(X @ alpha)
    c5 = torch.exp(X@eta)
    if estimator == 'LATE':
        f0 = (c5*(2-c0)+c0-torch.sqrt(c0**2 * (c5-1)**2 + 4*c5)) / (2 * (c5-1))
        f1 = f0 + c0
    elif estimator == 'MLATE':
        f0 = (-(c0+1)*c5+torch.sqrt(c5**2*(c0-1)**2 + 4*c0*c5)) / (2*c0*(1-c5))
        f1 = f0 * c0
    p011 = (1-c1)*(1-c2)*c3
    p001 = (1-c1)*(1-c2)*(1-c3)
    p110 = (1-c1)*c2*c4
    p100 = (1-c1)*c2*(1-c4)
    p111 = f1*c1 + p110
    p010 = f0*c1 + p011
    p101 = 1-p001-p011-p111
    p000 = 1-p010-p100-p110
    l = D*Y*Z*p111 + (1-D)*Y*Z*p011 + D*(1-Y)*Z*p101 + (1-D)*(1-Y)*Z*p001 + D*Y*(1-Z)*p110 + (1-D)*Y*(1-Z)*p010 + D*(1-Y)*(1-Z)*p100 + (1-D)*(1-Y)*(1-Z)*p000
    return torch.mean(torch.log(l.clamp(1e-10)))

def MLE(X, estimator='LATE'):
    alpha0 = torch.tensor([0.0, -1.0])
    beta0 = (torch.ones(size=(4,2)) * torch.tensor([-0.4,0.8])).T
    eta0 = torch.tensor([-0.4, 1.0])
    gamma0 = torch.tensor([0.1, -1.0])
    Z, D, Y = generate(X, alpha0, beta0, eta0, gamma0, estimator)
    minimum = (-nll(X,Z,D,Y,alpha0,beta0,eta0, estimator)).item()
    # print('minimum', minimum)
    alpha = nn.Parameter(torch.tensor([0.5, 0.5]))
    beta = nn.Parameter((torch.ones(size=(4,2)) * torch.tensor([-0.5,0.5])).T)
    eta = nn.Parameter(torch.tensor([-0.5, 0.5]))
    opt = torch.optim.Adam(params=(alpha, beta, eta), lr=1e-3, weight_decay=0.1)
    optloss = float('inf')
    for i in range(5000):
        loss = -nll(X,Z,D,Y,alpha,beta,eta, estimator)
        if loss.item() < optloss:
            optloss = loss.item()
            optalpha = alpha.detach().clone()
        # print('Iter {}\t | loss {:.04f}'.format(i+1, loss.item()))
        loss.backward()
        opt.step()
    # print(optalpha)
    return optalpha, minimum, optloss



####### mle.bth 
# N = 1000
# NR = 1000
# torch.manual_seed(24)
# optalphas = torch.zeros(size=(NR, 2))
# minimums, optlosses = [], []
# X = torch.column_stack((torch.ones(N)*1.0, torch.rand(N)*2-1))
# for i in range(NR):
#     optalpha, minimum, optloss = MLE(X)
#     print('{} Experiement | Difference {:.04f} | Alpha: ({:.04f}, {:.04f})'.format(i+1, optloss-minimum, optalpha[0].item(), optalpha[1].item()))
#     optalphas[i] = optalpha
#     minimums.append(minimum)
#     optlosses.append(optloss)
## performs well

####### mle.bth 
N = 1000
NR = 1000
torch.manual_seed(24)
optalphas = torch.zeros(size=(NR, 2))
minimums, optlosses = [], []
X = torch.column_stack((torch.ones(N)*1.0, torch.rand(N)*2-1))
for i in range(NR):
    optalpha, minimum, optloss = MLE(X, estimator='MLATE')
    print('{} Experiement | Difference {:.04f} | Alpha: ({:.04f}, {:.04f})'.format(i+1, optloss-minimum, optalpha[0].item(), optalpha[1].item()))
    optalphas[i] = optalpha
    minimums.append(minimum)
    optlosses.append(optloss)
## performs well


###
##对初值敏感(没有先验知识时可能造成无法训练出来)，对学习率敏感
##添加贝叶斯先验分布时可能的改进


RuntimeError: Expected p_in >= 0 && p_in <= 1 to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)