In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Sampler, BatchSampler, Dataset, DataLoader, Subset, SubsetRandomSampler, random_split
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve
from matplotlib import pyplot as plt
import lmfit
from scipy import interpolate
from scipy import stats

In [179]:
def genData(iNS,iNB):
    sx=np.random.normal(0,0.5,iNS)
    sy=np.random.uniform(0.1,1,iNS)
    sz=np.random.triangular(0.,0.95, 1, iNS)
    s=np.vstack([sx,sy,sz])
    
    bx=np.random.uniform(-3,3,iNB)
    by=np.random.uniform(-1,-0.1,iNB)
    bz=np.random.triangular(0,0.05,1,iNB)
    b=np.vstack([bx,by,bz])
    return s,b

sig,bkg=genData(100,20000)

In [180]:
class DataSet(Dataset):
    def __init__(self, samples, labels):
        super(DataSet, self).__init__()
        self.labels  = labels
        self.samples = samples
        if len(samples) != len(labels):
            raise ValueError(
                f"should have the same number of samples({len(samples)}) as there are labels({len(labels)})")

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        y = self.labels[index]
        x = self.samples[index]
        return x, y

class simple_MLPFit_onelayer(torch.nn.Module):
    def __init__(self,in_data,input_size,out_channels=1,act_out=False,nhidden=16,batchnorm=False,batch_size=500,n_epochs=100):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, nhidden),
            nn.ReLU(),
            nn.Linear(nhidden, nhidden),
            nn.ReLU(),
            nn.Linear(nhidden, out_channels),
        )
        self.loss    = sigLoss()
        self.output  = torch.nn.Sigmoid()
        self.act_out = act_out
        self.batch_size = batch_size
        self.n_epochs     = n_epochs
        self.opt     = torch.optim.Adam(self.model.parameters(),lr=0.0002)
        self.dataloader = DataLoader(in_data, batch_size=self.batch_size, shuffle=True)#,pin_memory=True)
        
    def forward(self, x):
        x = self.model(x)        
        if self.act_out:
            x = self.output(x)
        return x
    
    def training_mse_epoch(self):
        running_loss = 0.0
        updates=0
        for batch_idx, (x, y) in enumerate(self.dataloader):
            self.opt.zero_grad()
            #x     = x.cuda(); y = y.cuda()
            #x = x.reshape((self.batch_size,1))
            x = x.reshape((len(x),1))
            x_out = self.forward(x)
            loss  = self.loss(x_out.flatten(), y.flatten())
            loss.backward()
            self.opt.step()
            running_loss += loss 
            updates = updates+1
        return running_loss/updates

    def training_mse(self):
        for epoch in range(self.n_epochs):
            self.model.train(True)
            loss_train = self.training_mse_epoch()

            #self.model2.train(False)
            #loss_valid = self.validate_mse_epoch(self.model2,self.val_dataloader_mse())
            if epoch % 10 == 0:
                print('Epoch: {} LOSS train: {} '.format(epoch,loss_train))


def makeDataSet(iD,iNS,iNB):
    #1 is going to be sideband
    #2 is going to be mass range
    pD =  torch.from_numpy(iD.T)
    pXD =  pD[torch.randperm(len(pD))]
    tot=pXD[:,2]
    tot=tot.float()
    label=pXD[:,0]
    label=label.float()
    datatrain=DataSet(samples=tot,labels=label)
    return datatrain,pD[0:iNS].float(),pD[iNS-1:-1].float()



In [211]:
#import torchsort

def differentiable_histogram(x, weights, bins=3, min=-3.0, max=3.0):

    #if len(x.shape) == 4:
    #    n_samples, n_chns, _, _ = x.shape
    #elif len(x.shape) == 2:
    #    n_samples, n_chns = 1, 1
    #else:
    #    raise AssertionError('The dimension of input tensor should be 2 or 4.')
    n_samples, n_chns = 1, 1
    hist_torch = torch.zeros(n_samples, n_chns, bins).to(x.device)
    delta = (max - min) / bins

    BIN_Table = torch.range(start=0, end=bins, step=1) * delta

    for dim in range(1, bins-1, 1):
        h_r = BIN_Table[dim].item()             # h_r
        h_r_sub_1 = BIN_Table[dim - 1].item()   # h_(r-1)
        h_r_plus_1 = BIN_Table[dim + 1].item()  # h_(r+1)

        mask_sub = ((h_r > x) & (x >= h_r_sub_1)).float()
        mask_plus = ((h_r_plus_1 > x) & (x >= h_r)).float()
        print("mask",mask_sub.shape,"hr",h_r,"hrsub",h_r_sub_1)
        print("val 0", (x - h_r_sub_1)* mask_sub)
        print("val 1", (h_r_plus_1 - x)*mask_plus)
        hist_torch[:, :, dim] += torch.sum(((x - h_r_sub_1) * mask_sub).view(n_samples, n_chns, -1), dim=-1)
        hist_torch[:, :, dim] += torch.sum(((h_r_plus_1 - x) * mask_plus).view(n_samples, n_chns, -1), dim=-1)
        print("dim",dim)
        print("hist torch",hist_torch.shape)
        print("diff:",x - h_r_sub_1)

    return hist_torch / delta

class sigLoss(torch.nn.Module):

    def __init__(self, sort_tolerance=1.0,sort_reg='l2'):
        super(sigLoss, self).__init__()
        self.tolerance = sort_tolerance
        self.reg       = sort_reg

    def forward(self, x, y):
        #loss = (1.0 - (torch.tanh (alpha * (x[1:] - x[:-1])) * torch.tanh (alpha * (y[1:] - y[:1])))).mean()
        #xsort=torchsort.soft_rank(x.reshape(1,-1))
        #xsort=torch.sort(x.reshape(1,-1))
        #xsort=torch.topk(x.reshape(1,-1), 25, dim=-1)
        #cut=xsort.values
        weight=0.5*(torch.tanh (5.0*(x-0.5))+1.0)
        #print("!",x,"W",weight)
        #sig = s/sqrt(s+b) 
        #yhist, bin_edges   = torch.histogram(y, bins=3,range=[-3,3],weight=weight,density=True)
        #ynorm   = torch.sum(weight*(y > -1)*(y < 1))
        #yvalues = (weight*(y > -1)*(y < 1)-0.5*weight*(y < -1)-0.5*weight*(y > 1))#
        #loss=10.0-torch.sum(yvalues)/torch.sqrt(ynorm)
        yhist = differentiable_histogram(y,weight)
        print(yhist)
        loss=yhist[1]-0.5*(yhist[0]+yhist[2])
        return loss

    def significance(self,iX,iY):
        print(iX,iY)
        NLL = deltaNLL(iX,iY,False)
        return NLL

val=torch.tensor(0.3)
print(0.5*(torch.tanh(10*(val-0.5))+1.0))

tensor(0.0180)


In [212]:
def cutval(iData, p=0.9):
    datasort=np.sort(iData)
    ndata=len(datasort)
    return datasort[int(ndata*p)]

def plotHistComp(iGSig,iGBkg,iSig,iBkg,iMin=-3,iMax=3):
    ns=len(iSig)
    nb=len(iBkg)
    ngs=len(iGSig)
    ngb=len(iGBkg)
    print(ns,nb,ngs,ngb)
    ys, bin_edges = np.histogram(iSig,density=True,bins=20,range=[iMin,iMax])#bins=bin_edges
    yb, bin_edges = np.histogram(iBkg, bins=bin_edges,density=True)
    ygs,bin_edges = np.histogram(iGSig, bins=bin_edges,density=True)
    ygb,bin_edges = np.histogram(iGBkg, bins=bin_edges,density=True)
    ygs*=len(iGSig)/len(iSig)
    ygb*=len(iGBkg)/len(iBkg)
    bin_centers = 0.5*(bin_edges[1:] + bin_edges[:-1])
    plt.errorbar(bin_centers,yb,yerr=(nb*yb)**0.5/nb,marker='.',linestyle = '-', color = 'red',label='bkg')
    plt.errorbar(bin_centers,ys,yerr=(ns*ys)**0.5/ns,marker='.',linestyle = '-', color = 'blue',label='signal')
    plt.errorbar(bin_centers,ygb,yerr=(ngb/nb)*((ngb*ygb)**0.5)/ngb,marker='.',linestyle = '--', color = 'red',label='bkg(no cuts)')
    plt.errorbar(bin_centers,ygs,yerr=(ngs/ns)*((ngs*ygs)**0.5)/ngs,marker='.',linestyle = '--', color = 'blue',label='signal(no cuts)')
    plt.xlabel("x")
    plt.ylabel("Normalized")
    plt.legend()
    plt.show()

pdata = np.hstack((sig,bkg))
data,test_sig,test_bkg=makeDataSet(pdata,len(sig[0]),len(bkg[0]))
rw_model = simple_MLPFit_onelayer(data,1,out_channels=1,act_out=True,batchnorm=False)
rw_model.training_mse()
output_sig=rw_model.forward(test_sig[:,2].reshape(len(test_sig),1))
output_bkg=rw_model.forward(test_bkg[:,2].reshape(len(test_bkg),1))
output=torch.cat((output_sig,output_bkg))
cut=cutval(output.flatten().detach().numpy(),p=0.9)
_,bins,_=plt.hist(output_sig.flatten().detach().numpy(),density=True,alpha=0.5,label='sig')
plt.hist(output_bkg.flatten().detach().numpy(),density=True,alpha=0.5,label='bkg',bins=bins)
plt.legend()
plt.show()
cutsig = output_sig.flatten().detach().numpy() > cut
cutbkg = output_bkg.flatten().detach().numpy() > cut
plotHistComp(test_sig[:,0],test_bkg[:,0],test_sig[cutsig][:,0],test_bkg[cutbkg][:,0])
plotHistComp(test_sig[:,2],test_bkg[:,2],test_sig[cutsig][:,2],test_bkg[cutbkg][:,2],0,1)



mask torch.Size([500]) hr 2.0 hrsub 0.0
val 0 tensor([0.8441, 0.0872, -0.0000, 1.4599, -0.0000, -0.0000, 0.0000, -0.0000, 0.9452,
        -0.0000, 0.0000, -0.0000, -0.0000, -0.0000, 1.8578, -0.0000, -0.0000, 0.2312,
        -0.0000, 1.5823, -0.0000, 0.0000, -0.0000, 1.7369, 0.0998, -0.0000, -0.0000,
        0.0000, 1.4957, 1.3157, -0.0000, -0.0000, 1.9034, -0.0000, -0.0000, 0.9324,
        -0.0000, 0.0000, -0.0000, -0.0000, 0.0877, 0.0000, 1.8359, 0.3853, 0.0000,
        -0.0000, 0.0000, 0.0000, 1.3902, 0.9857, 0.5066, -0.0000, -0.0000, -0.0000,
        0.0000, 0.0000, 0.0456, 0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        1.4303, -0.0000, -0.0000, 1.0355, -0.0000, 0.4813, 1.1978, 0.0000, 0.7552,
        1.6093, -0.0000, 0.0000, 0.0000, 0.1743, -0.0000, 0.0000, -0.0000, 0.0000,
        -0.0000, -0.0000, 0.0000, 0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        0.1557, 1.4426, 0.8321, 1.4240, 0.0000, -0.0000, 0.3013, 1.1230, 1.8218,
        -0.0000, 1.3551, -0.0

  BIN_Table = torch.range(start=0, end=bins, step=1) * delta


IndexError: index 1 is out of bounds for dimension 0 with size 1