In [75]:
import torch
import torch.nn as nn
import sys

In [154]:
def nll(X,Z,D,Y,alpha,beta,eta,gamma,estimator='LATE'):
    # print(X.type(), beta.type())
    phi = torch.sigmoid(X@beta)
    c1, c3 = phi[:,0], phi[:,1]
    if estimator == 'LATE':
        c0 = torch.tanh(X @ alpha)
    elif estimator == 'MLATE':
        c0 = torch.exp(X @ alpha)
    c5 = torch.exp(X@eta)
    # c5 = c5 * (torch.abs(c5-1) > 1e-10) + (torch.abs(c5-1) <= 1e-10) * (c5-1 >= 0.0) * (1+1e-10) + (torch.abs(c5-1) <= 1e-10) * (c5-1< 0.0) * (1-1e-10)
    # c5 = 1 + c5 - c5.clamp(1-1e-5,1+1e-5)
    c0 = c0.clamp(1e-10)
    if estimator == 'LATE':
        f0 = (c5*(2-c0)+c0-torch.sqrt(c0**2 * (c5-1)**2 + 4*c5)) / (2 * (c5-1))
        f1 = f0 + c0
    elif estimator == 'MLATE':
        f0 = torch.where(torch.abs(c5-1)>1e-10, (-(c0+1)*c5+torch.sqrt(c5**2*(c0-1)**2 + 4*c0*c5)) / (2*c0*(1-c5)), -(-c0-1+(2*(c0-1)**2+4*c0)/(c0+1)/2)/(2*c0))
        f1 = f0 * c0
    p011 = (1-c1)*c3
    p001 = (1-c1)*(1-c3)
    p110 = 0
    p100 = 0
    p111 = f1*c1 + p110
    p010 = f0*c1 + p011
    p101 = 1-p001-p011-p111
    p000 = 1-p010-p100-p110

    d = torch.sigmoid(X@gamma)
    l = D*Y*Z*p111*d + (1-D)*Y*Z*p011*d + D*(1-Y)*Z*p101*d + (1-D)*(1-Y)*Z*p001*d + (1-D)*Y*(1-Z)*p010*(1-d) + (1-D)*(1-Y)*(1-Z)*p000*(1-d)
    # if torch.mean(torch.log(l.clamp(1e-10, 1))) :
    #     print(l)
    #     sys.exit(0)
    return torch.mean(torch.log(l.clamp(1e-10)))

def square_loss(X, Z, D, Y, alpha, beta, gamma, eta, estimator, strategy='identity'):
    d = torch.sigmoid(X@gamma)
    phi = torch.sigmoid(X@beta)
    phi1, phi3 = phi[:,0], phi[:,1]
    OP = torch.exp(X@eta)
    f = (d**Z) * ((1-d)**(1-Z))
    if estimator == 'LATE':
        theta = torch.tanh(X@alpha)
        H = Y - D * theta
        f0 = (OP*(2-theta)+theta-torch.sqrt(theta**2*(OP-1)**2+4*OP))/(2*(OP-1))
        f1 = f0 + theta
        E = f0*phi1 + (1-phi1)*phi3
    elif estimator == 'MLATE':
        theta = torch.exp(X@alpha)
        H = Y * theta**(-D)
        f0 = (-(theta+1)*OP+torch.sqrt(OP**2*(theta-1)**2+4*theta*OP)) / (2*theta*(1-OP))
        f1 = f0 * theta
        E = f0*phi1 +(1-phi1)*phi3
    
    if strategy == 'identity':
        return torch.sum((torch.sum(X*((2*Z-1)*(H-E)/f).unsqueeze(1), dim=0))**2)
    elif strategy == 'optimal':
        p011 = (1-phi1)*phi3
        p001 = (1-phi1)*(1-phi3)
        p110 = 0
        p100 = 0
        p111 = f1*phi1 + p110
        p010 = f0*phi1 + p011
        p101 = 1-p001-p011-p111
        p000 = 1-p010-p100-p110
        if estimator == 'LATE':
            EH2_1 = p011+p111+ theta**2 * (p111+p101) -2*theta*p111
            EH2_0 = p110+p010+theta**2*(p110+p100) -2*theta*p110
            EH_1 = p111+p011-theta*(p111+p101)
            EH_0 = p110+p010-theta*(p110+p100)
            EZX = (EH2_1-EH_1**2) / d + (EH2_0-EH_0**2)/(1-d)
            w = -X * ((1/torch.cosh(X@alpha)**2) * phi1 / EZX).unsqueeze(1)
            return torch.sum((torch.sum(w*((2*Z-1)*(H-E)/f).unsqueeze(1), dim=0))**2)
        elif estimator == 'MLATE':
            EH2_1 = p111/theta**2+p101
            EH2_0 = p110/theta**2+p100
            EH_1 = p111/theta+p101
            EH_0 = p110/theta+p100
            EZX = (EH2_1-EH_1**2) / d + (EH2_0-EH_0**2)/(1-d)
            w = -X * (1 / theta * f1 * phi1 / EZX).unsqueeze(1)
            return torch.mean((torch.sum(w*((2*Z-1)*(H-E)/f).unsqueeze(1), dim=0))**2)

def MLE(X, Z, D ,Y, estimator='MLATE', dr=False):
    N, p = X.shape
    alpha = nn.Parameter(torch.rand(size=(p,))*0.2-0.1)
    beta = nn.Parameter(torch.rand(size=(p,2))*0.2-0.1) ## only phi1 and phi3
    eta = nn.Parameter(torch.rand(size=(p,))*2-1)
    gamma = nn.Parameter(torch.rand(size=(p,))*0.2-0.1)
    opt = torch.optim.Adam(params=(alpha, beta, eta, gamma), lr=1e-3, weight_decay=0)
    optloss = float('inf')
    for i in range(20000):
        opt.zero_grad()
        loss = -nll(X,Z,D,Y,alpha,beta,eta,gamma, estimator)
        if loss.item() < optloss:
            if abs(loss.item() - optloss) < 1e-7:
                break
            optloss = loss.item()
            mlealpha = alpha.detach().clone()
            mlebeta = beta.detach().clone()
            mleeta = eta.detach().clone()
            mlegamma = gamma.detach().clone()
        if i % 100 ==0: 
            # print(alpha, beta, gamma, eta)
            print('Iter {} | loss {:.04f}'.format(i+1, loss.item()))
        loss.backward()
        opt.step()
    if not dr:
        return mlealpha, mlebeta, mleeta, mlegamma
    alpha = nn.Parameter(mlealpha.clone(),requires_grad=True)
    # alpha = nn.Parameter(torch.rand(size=(2,))*2-1)
    opt = torch.optim.Adam(params=(alpha,), lr=5e-3, weight_decay=0)
    sqoptloss = float('inf')
    for i in range(100000):
        opt.zero_grad()
        sq_loss = square_loss(X, Z, D, Y, alpha, mlebeta, mlegamma, mleeta, estimator, strategy='optimal')
        if i % 100 ==0:
            print('Iter {} | sq_loss {:.04f}'.format(i+1, sq_loss.item()))
        if sq_loss.item() < sqoptloss:
            sqoptloss = sq_loss.item()
            drwalpha = alpha.detach().clone()
            if abs(sqoptloss) < 1e-6:
                break
        sq_loss.backward()
        opt.step()

In [155]:
data = torch.load('401k.pt')
data[:,1] /= 100
data[:,4] /= 10
data[:,5] 
data[:,9] /= 10000
Z = data[:,0]
X = torch.cat((torch.ones((data.shape[0],1)), data[:, [1,2,4,5,9]]), dim=-1)
D = data[:,7]
Y = data[:,8]

N, p = X.shape
NR = 1
torch.manual_seed(6971)
mlealphas = torch.zeros(size=(NR, p))
drualphas = torch.zeros(size=(NR, p))
minimums, optlosses = [], []
for i in range(NR):
    idxes = torch.multinomial(torch.ones(N), N, replacement=True)
    Xdata = X[idxes].clone()
    MLE(X,Z,D,Y)

Iter 1 | loss 1.5820
Iter 101 | loss 1.4942
Iter 201 | loss 1.4743
Iter 301 | loss 1.4635
Iter 401 | loss 1.4546
Iter 501 | loss 1.4482
Iter 601 | loss 1.4440
Iter 701 | loss 1.4400
Iter 801 | loss 1.4364
Iter 901 | loss 1.4335
Iter 1001 | loss 1.4309
Iter 1101 | loss 1.4287
Iter 1201 | loss 1.4259
Iter 1301 | loss 1.4242
Iter 1401 | loss 1.4228
Iter 1501 | loss 1.4216
Iter 1601 | loss 1.4205
Iter 1701 | loss 1.4197
Iter 1801 | loss 1.4188
Iter 1901 | loss 1.4179
Iter 2001 | loss 1.4170
Iter 2101 | loss 1.4162
Iter 2201 | loss 1.4153
Iter 2301 | loss 1.4145
Iter 2401 | loss 1.4136
Iter 2501 | loss 1.4127
Iter 2601 | loss 1.4118
Iter 2701 | loss 1.4109
Iter 2801 | loss 1.4100
Iter 2901 | loss 1.4091
Iter 3001 | loss 1.4083
Iter 3101 | loss 1.4074
Iter 3201 | loss 1.4066
Iter 3301 | loss 1.4058
Iter 3401 | loss 1.4050
Iter 3501 | loss 1.4042
Iter 3601 | loss 1.4033
Iter 3701 | loss 1.4025
Iter 3801 | loss 1.4017
Iter 3901 | loss 1.4010
Iter 4001 | loss 1.4002
Iter 4101 | loss 1.3994
Iter

KeyboardInterrupt: 

In [122]:
print(torch.min(X, dim=0))

torch.return_types.min(
values=tensor([1.0000, 0.1001, 0.0000, 2.5000, 1.0000, 0.0100]),
indices=tensor([   0, 4621,    0,   30,    0, 4621]))
