In [1]:
import torch
import torch.nn as nn
import sys

In [4]:
def nll(X,Z,D,Y,alpha,beta,eta,gamma,estimator='LATE'):
    # print(X.type(), beta.type())
    phi = torch.sigmoid(X@beta)
    c1, c3 = phi[:,0], phi[:,1]
    if estimator == 'MLATE':
        c0 = torch.exp(X @ alpha)
    c5 = torch.exp(X@eta)
    # c5 = c5 * (torch.abs(c5-1) > 1e-10) + (torch.abs(c5-1) <= 1e-10) * (c5-1 >= 0.0) * (1+1e-10) + (torch.abs(c5-1) <= 1e-10) * (c5-1< 0.0) * (1-1e-10)
    # c5 = 1 + c5 - c5.clamp(1-1e-5,1+1e-5)
    c0 = c0.clamp(1e-10)
    if estimator == 'MLATE':
        f0 = torch.where(torch.abs(c5-1)>1e-10, (-(c0+1)*c5+torch.sqrt(c5**2*(c0-1)**2 + 4*c0*c5)) / (2*c0*(1-c5)), -(-c0-1+(2*(c0-1)**2+4*c0)/(c0+1)/2)/(2*c0))
        f1 = f0 * c0
    p011 = (1-c1)*c3
    p001 = (1-c1)*(1-c3)
    p110 = 0
    p100 = 0
    p111 = f1*c1 + p110
    p010 = f0*c1 + p011
    p101 = 1-p001-p011-p111
    p000 = 1-p010-p100-p110

    d = torch.sigmoid(X@gamma)
    l = D*Y*Z*p111*d + (1-D)*Y*Z*p011*d + D*(1-Y)*Z*p101*d + (1-D)*(1-Y)*Z*p001*d + (1-D)*Y*(1-Z)*p010*(1-d) + (1-D)*(1-Y)*(1-Z)*p000*(1-d)
    # if torch.mean(torch.log(l.clamp(1e-10, 1))) :
    #     print(l)
    #     sys.exit(0)
    return torch.mean(torch.log(l.clamp(1e-10)))

def square_loss(X, Z, D, Y, alpha, beta, gamma, eta, estimator, strategy='identity'):
    d = torch.sigmoid(X@gamma)
    phi = torch.sigmoid(X@beta)
    phi1, phi3 = phi[:,0], phi[:,1]
    OP = torch.exp(X@eta)
    f = (d**Z) * ((1-d)**(1-Z))
    if estimator == 'MLATE':
        theta = torch.exp(X@alpha)
        H = Y * theta**(-D)
        f0 = torch.where(torch.abs(OP-1)>1e-10, (-(theta+1)*OP+torch.sqrt(OP**2*(theta-1)**2+4*theta*OP)) / (2*theta*(1-OP)), -(-theta-1+(2*(theta-1)**2+4*theta)/(theta+1)/2)/(2*theta))
        f1 = f0 * theta
        E = f0*phi1 +(1-phi1)*phi3
    
    if strategy == 'identity':
        return torch.sum((torch.sum(X*((2*Z-1)*(H-E)/f).unsqueeze(1), dim=0))**2)
    elif strategy == 'optimal':
        p011 = (1-phi1)*phi3
        p001 = (1-phi1)*(1-phi3)
        p110 = 0
        p100 = 0
        p111 = f1*phi1 + p110
        p010 = f0*phi1 + p011
        p101 = 1-p001-p011-p111
        p000 = 1-p010-p100-p110
        if estimator == 'MLATE':
            EH2_1 = p111/theta**2+p101
            EH2_0 = p110/theta**2+p100
            EH_1 = p111/theta+p101
            EH_0 = p110/theta+p100
            EZX = (EH2_1-EH_1**2) / d + (EH2_0-EH_0**2)/(1-d)
            w = -X * (1 / theta * f1 * phi1 / EZX).unsqueeze(1)
            return torch.mean((torch.mean(w*((2*Z-1)*(H-E)/f).unsqueeze(1), dim=0))**2)

def MLE(X, Z, D ,Y, estimator='MLATE', dr=True):
    N, p = X.shape
    alpha = nn.Parameter(torch.rand(size=(p,))*0.2-0.1)
    beta = nn.Parameter(torch.rand(size=(p,2))*0.2-0.1) ## only phi1 and phi3
    eta = nn.Parameter(torch.rand(size=(p,))*2-1)
    gamma = nn.Parameter(torch.rand(size=(p,))*0.2-0.1)
    opt = torch.optim.Adam(params=(alpha, beta, eta, gamma), lr=1e-3, weight_decay=0)
    optloss = float('inf')
    for i in range(50000):
        opt.zero_grad()
        loss = -nll(X,Z,D,Y,alpha,beta,eta,gamma, estimator)
        if loss.item() < optloss:
            if abs(loss.item() - optloss) < 1e-7:
                break
            optloss = loss.item()
            mlealpha = alpha.detach().clone()
            mlebeta = beta.detach().clone()
            mleeta = eta.detach().clone()
            mlegamma = gamma.detach().clone()
        if i % 100 ==0: 
            print('Iter {} | loss {:.04f}'.format(i+1, loss.item()))
        loss.backward()
        opt.step()
        if torch.sum(loss.isnan()) > 0:
            break
    if not dr:
        return mlealpha, mlebeta, mleeta, mlegamma
        
    alpha = nn.Parameter(mlealpha.clone(),requires_grad=True)
    # alpha = nn.Parameter(torch.rand(size=(2,))*2-1)
    opt = torch.optim.Adam(params=(alpha,), lr=1e-3, weight_decay=0)
    sqoptloss = float('inf')
    for i in range(50000):
        opt.zero_grad()
        sq_loss = square_loss(X, Z, D, Y, alpha, mlebeta, mlegamma, mleeta, estimator, strategy='optimal')
        if i % 100 ==0:
            print('Iter {} | sq_loss {:.08f}'.format(i+1, sq_loss.item()))
        if sq_loss.item() < sqoptloss:
            sqoptloss = sq_loss.item()
            drwalpha = alpha.detach().clone()
            if abs(sqoptloss) < 1e-6:
                break
        sq_loss.backward()
        opt.step()
        if torch.sum(sq_loss.isnan()) > 0:
            break
    return mlealpha, drwalpha, mlebeta, mleeta, mlegamma

In [5]:
data = torch.load('401k.pt')
data[:,1] /= 100
data[:,4] /= 10
data[:,9] /= 10000
Z = data[:,0]
X = torch.cat((torch.ones((data.shape[0],1)), data[:, [1,2,4,5,9]]), dim=-1)
D = data[:,7]
Y = data[:,8]

N, p = X.shape
NR = 500
torch.manual_seed(6971)
mlealphas = torch.zeros(size=(NR, p))
drwalphas = torch.zeros(size=(NR, p))
minimums, optlosses = [], []
for i in range(NR):
    print('Bootstrap {}'.format(i+1), '-'*50)
    idxes = torch.multinomial(torch.ones(N), N, replacement=True)
    Xdata = X[idxes].clone()
    mlealpha, drwalpha, mlebeta, mleeta, mlegamma = MLE(Xdata,Z,D,Y)
    mlealphas[i] = mlealpha
    drwalphas[i] = drwalpha

Bootstrap 1 --------------------------------------------------
Iter 1 | loss 1.5797
Iter 101 | loss 1.5060
Iter 201 | loss 1.4943
Iter 301 | loss 1.4887
Iter 401 | loss 1.4853
Iter 501 | loss 1.4837
Iter 601 | loss 1.4807
Iter 701 | loss 1.4795
Iter 801 | loss 1.4786
Iter 901 | loss 1.4780
Iter 1001 | loss 1.4775
Iter 1101 | loss 1.4770
Iter 1201 | loss 1.4767
Iter 1301 | loss 1.4763
Iter 1401 | loss 1.4760
Iter 1501 | loss 1.4758
Iter 1601 | loss 1.4755
Iter 1701 | loss 1.4753
Iter 1801 | loss 1.4750
Iter 1901 | loss 1.4748
Iter 2001 | loss 1.4746
Iter 2101 | loss 1.4743
Iter 2201 | loss 1.4741
Iter 2301 | loss 1.4739
Iter 2401 | loss 1.4736
Iter 2501 | loss 1.4734
Iter 2601 | loss 1.4732
Iter 2701 | loss 1.4729
Iter 2801 | loss 1.4727
Iter 2901 | loss 1.4725
Iter 3001 | loss 1.4722
Iter 3101 | loss 1.4720
Iter 3201 | loss 1.4718
Iter 3301 | loss 1.4715
Iter 3401 | loss 1.4713
Iter 3501 | loss 1.4711
Iter 3601 | loss 1.4709
Iter 3701 | loss 1.4706
Iter 3801 | loss 1.4704
Iter 3901 | l

KeyboardInterrupt: 

In [35]:
mlealphas[:19]

tensor([[ 0.0550,  1.6181, -0.0665,  0.0621, -0.0297, -1.3122],
        [-0.0221, -0.0122,  0.0806,  0.0280,  0.0153,  0.0976],
        [ 0.2240,  0.0399, -0.1309,  0.0756, -0.0317, -0.1542],
        [ 0.3013, -0.7200, -0.0963,  0.1174,  0.0084,  0.3663],
        [ 0.2606,  0.4524, -0.4480,  0.0324,  0.1154, -0.3442],
        [ 0.3778,  0.1074,  0.0386, -0.0704,  0.0302,  0.0537],
        [ 0.5544,  0.5994, -0.2681, -0.0059,  0.0099, -0.3899],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0

In [29]:
drwalphas[1]

tensor([0.0459, 0.0135, 0.0143, 0.1103, 0.0092, 0.0888])