In [None]:
import torch
import numpy as np

from torch import optim, nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

import matplotlib.pyplot as plt

In [None]:
from tqdm import  notebook
from scipy import stats

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cpu = torch.device("cpu")
print(device)

In [None]:
from scipy.stats import norm

In [None]:
### load test data

In [None]:
xlr=np.load('xlr.npy')
test_x=torch.from_numpy(xlr[:,0].reshape(-1,1)).float()

#left interval limit
l=xlr[:,1]
#right interval limit
r=xlr[:,2]

ntest=len(test_x)

### feed the training data 

In [None]:
class my_dataset(Dataset):
    def __init__(self,data,label):
        self.data=x
        self.label=y          
    def __getitem__(self, index):
        return self.data[index],self.label[index]
    def __len__(self):
        return len(self.data)

### generate the training data

In [None]:
def gen3(x,q):
  z=torch.from_numpy(norm.ppf(q)).float()
  muintermediate=torch.cos(x*2*2*3.14159)
  sdintermediate=(.2+.3*(muintermediate+1))
  y=muintermediate+sdintermediate*z
  return y

experiment=3

## Collaborating Networks:   g only with default bn (gd)

In [None]:
class cn_gd(nn.Module):
    def __init__(self):
        super().__init__()
        self.k1=100
        self.k2=80
        self.fc1 = nn.Linear(2, self.k1)
        self.bn1 = nn.BatchNorm1d(self.k1)
        self.fc2 = nn.Linear(self.k1, self.k2)
        self.bn2 = nn.BatchNorm1d(self.k2)
        self.fc3 = nn.Linear(self.k2, 1)
        self.bn3 = nn.BatchNorm1d(1,momentum=.1)


    def forward(self, y, x):
        data=torch.cat([y,x],dim=1)
        h1 = self.fc1(data)
        h1 = self.bn1(h1)
        h1 = F.elu(h1)
        h2 = self.fc2(h1)
        h2 = self.bn2(h2)
        h2 = F.elu(h2)
        h3 = self.fc3(h2)
        g_logit=self.bn3(h3)
        return g_logit

In [None]:
ns=[50,100,200,400,600,800,1000,1200,1400,1600,2400,3400,4500,5000,5400,60000]

In [None]:
allll=[]
for n in ns:
    torch.manual_seed(42)
    x=torch.linspace(-.5,.5,n).reshape(-1,1)
    np.random.seed(42)
    q=np.random.rand(n,1)
    y=gen3(x,q)
    x=torch.cos(x*2*2*3.14159).reshape(-1,1)

  
    batch_size=500
    dataloader = DataLoader(dataset=my_dataset(x,y),
                            batch_size=batch_size,
                            shuffle=True,
                            pin_memory=True)
    
    pretrain_epochs =20000
    gd = cn_gd().to(device)
    gd_loss = nn.BCELoss()
    optimizer_gd = optim.Adam(gd.parameters(), lr=1e-4)


    for epoch in notebook.trange(pretrain_epochs):
        for xs, ys in dataloader:
            xs, ys = xs.to(device), ys.to(device)

            optimizer_gd.zero_grad()
            
            yhat=torch.randn(ys.shape).to(device)*4.3-1.6

            qhat_logit_c = gd(yhat,xs)
        

            with torch.no_grad():
                ylt=ys<yhat
                ylt=ylt.float()


            gld = gd_loss(torch.sigmoid(qhat_logit_c),ylt)
            gld.backward()
       
            # updates
            optimizer_gd.step()


    ll_est=np.empty(ntest)
    gd.eval()



    for i in notebook.trange(ntest):
        ltmp=torch.from_numpy(np.array(l[i])).float()
        rtmp=torch.from_numpy(np.array(r[i])).float()

        if(rtmp==np.inf):
            lp=torch.sigmoid(gd(ltmp.reshape(-1,1).to(device),test_x[i].reshape(-1,1).to(device)))
            lp=lp.cpu().detach()
            ll_est[i]=np.log(1.-lp+1.e-10)
        elif(ltmp==-np.inf):
            rp=torch.sigmoid(gd(rtmp.reshape(-1,1).to(device),test_x[i].reshape(-1,1).to(device)))
            rp=rp.cpu().detach()
            ll_est[i]=np.log(rp+1.e-10)
        else:
            lp=torch.sigmoid(gd(ltmp.reshape(-1,1).to(device),test_x[i].reshape(-1,1).to(device)))
            rp=torch.sigmoid(gd(rtmp.reshape(-1,1).to(device),test_x[i].reshape(-1,1).to(device)))             
            lp=lp.cpu().detach()
            rp=rp.cpu().detach()
            ll_est[i]=np.log(rp-lp+1.e-10)  
    print(np.nanmean(ll_est))
    allll.append(np.nanmean(ll_est))