In [1]:
import torch
import math
from torch import nn
import torch.nn.functional as F

In [2]:
#Reward.py
#Define the reward function r(s,a) when a is not a special action END

#For each chromosome, we consider 50 SNP loci
chrom_width=50;
#44 normal chromosomes in human genome
num_chromosome=44


#normalisation constant to make normal_const*(\sum_i a_i) <1, so that the possibility of all CNV sum to a real value smaller than 1
normal_const=1e-5;
#probability of single locus gain/loss
single_loci_loss=normal_const*(1-2e-1);
#probability of WGD
WGD=normal_const*0.6;

#log probability of CNV
#used for calculating the distribution of p(a) when a is a focal CNV
const1=single_loci_loss
const2=1;

#Whole chromosome change probability
Whole_Chromosome_CNV=single_loci_loss/4;
Half_Chromosome_CNV=single_loci_loss/3;

max_copy=20

def Reward(Start,End):
    Start=Start.to(torch.float32)
    End=End.to(torch.float32)
    reward=torch.log(const1/(const2+End-Start))
    #chromosome changes
    for i in range(Start.shape[0]):
        #full chromosome
        if End[i]-Start[i]>chrom_width-0.5:
            reward[i]=math.log(Whole_Chromosome_CNV)
        #arm level changes
        if chrom_width-End[i]<0.5 and abs(chrom_width//2-Start[i])<1.5:
            reward[i]=math.log(Half_Chromosome_CNV)
        if Start[i]<0.5 and abs(chrom_width//2-End[i])<1.5:
            reward[i]=math.log(Half_Chromosome_CNV)
    return reward



#Q-function.py
#defining the Q-function 
#Q(s,a) in manuscript


#switch structure mentioned in section 3.3.4
#kernel sizes for convolution layers
nkernels_switch = [20,40,80]
activatiion_wgd=torch.tanh
class WGD_Net(nn.Module):
    def __init__(self):
        super(WGD_Net, self).__init__()
        #chromosome permutation invariant structure as described in section 3.3.3
        #slide for chromosome is 1 and the filter length in this dimension is also 1
        #thus, the same filter goes through all chromosomes in the same fashion
        self.conv1=nn.Conv2d(1, nkernels_switch[0], (1,3),(1,1),(0,1))
        self.conv2=nn.Conv2d(nkernels_switch[0],nkernels_switch[1] , (1,3),(1,1), (0,1))
        self.conv3=nn.Conv2d(nkernels_switch[1],nkernels_switch[2] , (1,5),(1,1), (0,0))
        self.linear=nn.Linear(nkernels_switch[2],1)
    
    def forward(self, x):
        #y=torch.clamp((F.relu(x.mean(3)-1)-0.5+0.5*F.relu(1-x.mean(3))),-1,1).sum((1,2))
        y=20*(x.mean((1,2,3))-1.5)
        y=y.reshape(x.shape[0],1).detach()
        x=x.reshape(x.shape[0],1,num_chromosome,chrom_width)
        x=F.max_pool2d(activatiion_wgd(self.conv1(x)),(1,5),(1,5),(0,0))
        x=F.max_pool2d(activatiion_wgd(self.conv2(x)),(1,2),(1,2),(0,0))
        x=(activatiion_wgd(self.conv3(x))).sum((2,3))
        x=self.linear(x)
        x=x/2
        #residule representation in x as described in section 3.3.4
        x=y/2+x
        x=torch.sigmoid(x)
        return x

#chromosome evaluation net 
#Used in Chrom_NN (which is Q_{phi_1}(s,c) in section 3.3.1)
#kernel sizes for convolution layers
nkernels_chr = [80,120,160]
activation_cnp=torch.tanh
class CNP_Val(nn.Module):
    def __init__(self):
        super(CNP_Val, self).__init__()
        self.conv1=nn.Conv2d(1, nkernels_chr[0], (1,5),(1,1),(0,2))
        self.conv2=nn.Conv2d(nkernels_chr[0],nkernels_chr[1] , (1,3),(1,1), (0,1))
        self.conv3=nn.Conv2d(nkernels_chr[1],nkernels_chr[2] , (1,3),(1,1), (0,1))
        self.conv4=nn.Conv2d(nkernels_chr[2],1, (1,5),(1,1), (0,0))
    
    def forward(self, x):
        x=F.max_pool2d(activation_cnp(self.conv1(x)),(1,3),(1,3),(0,1))
        x=F.max_pool2d(activation_cnp(self.conv2(x)),(1,2),(1,2),(0,1))
        x=F.max_pool2d(activation_cnp(self.conv3(x)),(1,2),(1,2),(0,1))
        #KL divergence is always nonpositive
        x=0.25+F.elu(self.conv4(x),0.25)
        #number of sample * 44 chromosomes
        x=x.reshape(x.shape[0],num_chromosome)
        return x

#Implemts Q_{phi_1}(s,c) in section 3.3.1
#It combines two chromosome evaluation nets mentioned above,
#with a switch structure in section 3.3.4 to form Q_{phi_1}(s,c)
class Chrom_NN(nn.Module):
    def __init__(self):
        super(Chrom_NN,self).__init__()
        #two parts of the Chrom_NN
        #NN for CNP without WGD 
        self.Val_noWGD=CNP_Val()
        #NN for CNP with WGD
        self.Val_WGD=CNP_Val()
    
    def forward(self,x,sigma):
        #probability for WGD, which is computed by switch structure
        sigma=sigma.expand(-1,num_chromosome)
        #we assume the copy number for each loci ranges from 0~9
        #for samples without WGD
        
        #y represents if a chromosome have abnormal copy numbers (positions with copy number other than 1)
        y=torch.ceil((torch.abs(x-1)).mean(3)/max_copy)
        y=y.reshape(x.shape[0],num_chromosome)
        y=y.detach()
        #Residule representation mentioned in section 3.3.4
        #the value for Q_{phi_1}(s,c) is computed as Val_no (the value estimated by the neural net, the residule part)+ y (the empirial estimation) 
        Val_no=self.Val_noWGD.forward(x)
        #chromosome with all 1 copies don't need any CNV and thus will be less likely mutated.
        Val_no=-(1-y)*math.log(single_loci_loss)*2+((y*Val_no).sum(1)).reshape(x.shape[0],1).expand(-1,num_chromosome)
        #for samples with WGD
        #it is similar to the previsou part, where z is an equivalent for y and Val_wgd is an equivalent for Val_no
        z=torch.ceil((torch.abs(x-2*(x//2))).mean(3)/max_copy)
        z=z.reshape(x.shape[0],num_chromosome)
        z=z.detach()
        Val_wgd=self.Val_WGD.forward(x)
        Val_wgd=-(1-z)*math.log(single_loci_loss)*2+(z*Val_wgd).sum(1).reshape(x.shape[0],1).expand(-1,num_chromosome)-math.log(WGD)
        
        #combine two NN with switch as defined in Section 3.3.4
        x=sigma*Val_wgd+(1-sigma)*Val_no
        x=-x
        return x



#starting point and gain or loss (defined as sp in manuscript) 
#Used in CNV_NN (which is Q_{phi_2}(s,c,sp) on section 3.3.1)
#kernel sizes for convolution layers
nkernels_CNV = [80,120,160,10]
activation_cnv=torch.tanh
class CNV_Val(nn.Module):
    def __init__(self):
        super(CNV_Val,self).__init__()
        self.conv1=nn.Conv2d(1, nkernels_CNV[0], (1,7),(1,1),(0,3))
        self.conv2=nn.Conv2d(nkernels_CNV[0],nkernels_CNV[1] , (1,7),(1,1), (0,3))
        self.conv3=nn.Conv2d(nkernels_CNV[1],nkernels_CNV[2] , (1,7),(1,1), (0,3))
        self.conv4=nn.Conv2d(nkernels_CNV[2], nkernels_CNV[3], (1,7),(1,1), (0,3))
        self.linear=nn.Linear(nkernels_CNV[3]*chrom_width,2*chrom_width-1)
    
    def forward(self,x):
        x=activation_cnv(self.conv1(x))
        x=activation_cnv(self.conv2(x))
        x=activation_cnv(self.conv3(x))
        x=activation_cnv(self.conv4(x))
        x=x.reshape(x.shape[0],nkernels_CNV[3]*chrom_width)
        x=self.linear(x)
        #number of samples* [(50 regions)*(2(gain or loss))-1] 
        #Only have 50*2-1=99 output dimensions because we fix the average these output
        #The average of them could be arbitrary because of the partitioning
        return x

#Implemts Q_{phi_2}(s,c,sp) in section 3.3.1
#It combines two CNV_Val nets mentioned above,
#with a switch structure in section 3.3.4 to form Q_{phi_2}(s,c,sp)
class CNV_NN(nn.Module):
    def __init__(self):
        super(CNV_NN,self).__init__()
        #two network setting
        self.CNV_noWGD=CNV_Val()
        self.CNV_WGD=CNV_Val()
    
    def forward(self,x,sigma):
        #as in section 3.3.4
        #y is the empirical estimation
        #Val_no is the redidule representation
        y=torch.Tensor(x.shape[0],chrom_width,2)
        y[:,:,0]=F.relu(x-1)
        y[:,:,1]=F.relu(1-x)
        y=y.reshape(x.shape[0],2*chrom_width)
        y=y[:,1:(2*chrom_width)]-y[:,0:1].expand(-1,2*chrom_width-1)
        y=-y.detach()*math.log(single_loci_loss)
        Val_no=self.CNV_noWGD.forward(x.reshape(x.shape[0],1,1,chrom_width))
        Val_no=y+Val_no
        
        z=((torch.abs(x-2*(x//2))).reshape(x.shape[0],chrom_width,1)).expand(-1,-1,2)
        z=z.reshape(x.shape[0],2*chrom_width)
        z=z[:,1:(2*chrom_width)]-z[:,0:1].expand(-1,2*chrom_width-1)
        z=-z.detach()*math.log(single_loci_loss)
        Val_wgd=self.CNV_WGD.forward(x.reshape(x.shape[0],1,1,chrom_width))
        Val_wgd=z+Val_wgd
        #switch
        x=sigma*Val_wgd+(1-sigma)*Val_no
        return(x)
    
     
    def find_one_cnv(self,chrom,sigma):
        #used for finding the cnv during deconvolution
        #it is not used in training process
        
        res_cnv=self.forward(chrom,sigma)
        #if there is originally a break point for start
        #rule system 
        break_start=torch.zeros(chrom.shape[0],50,2,requires_grad=False)
        chrom_shift=torch.zeros(chrom.shape[0],50,requires_grad=False)
        chrom_shift[:,1:]=chrom[:,:49]
        #allow adding one copy for every breakpoint
        break_start[:,:,1]=torch.ceil(torch.abs(chrom-chrom_shift)/max_copy)
        #always allow adding one chromosone
        break_start[:,0,1]=1
        #don't allow lose one copy when copy number equalls 1
        break_start[:,:,0]=break_start[:,:,1]
        break_start[:,:,0]=break_start[:,:,0]*torch.ceil((chrom/2-0.5)/max_copy)
        break_start=break_start.reshape(chrom.shape[0],100)
        res_cnv_full=torch.zeros(chrom.shape[0],100)
        res_cnv_full[:,1:]=res_cnv
        #Prior_rule=break_start
        res_cnv_full=res_cnv_full+torch.log(break_start)
        #best cnv according to the current Q
        cnv_max_val,cnv_max=torch.max(res_cnv_full,1)
        return int(cnv_max[0])
    


#end point
#Used in End_Point_NN (which is Q_{phi_3}(s,c,sp,ep) on section 3.3.1)
#kernel sizes for convolution layers
nkernels_End = [80,120,240]
activation_end=torch.tanh
class End_Point_Val(nn.Module):
    def __init__(self):
        super(End_Point_Val,self).__init__()
        self.conv1=nn.Conv2d(2, nkernels_End[0], (1,7),(1,1),(0,3))
        self.conv2=nn.Conv2d(nkernels_End[0],nkernels_End[1] , (1,7),(1,1), (0,3))
        self.conv3=nn.Conv2d(nkernels_End[1],nkernels_End[2] , (1,7),(1,1), (0,3))
        self.linear=nn.Linear(nkernels_End[2]*chrom_width,chrom_width-1)
    
    def forward(self,old,new):
        x=torch.Tensor(old.shape[0],2,1,chrom_width)
        x[:,0,0,:]=old
        x[:,1,0,:]=new
        x=x.detach()
        x=activation_end(self.conv1(x))
        x=activation_end(self.conv2(x))
        x=activation_end(self.conv3(x))
        x=x.reshape(x.shape[0],nkernels_End[2]*chrom_width)
        x=self.linear(x)
        #number of samples* [(chrom_width regions)-1] 
        #Only have chrom_width-1=49 output dimensions because we fix the average these output
        #The average of them could be arbitrary because of the partitioning
        return x
    
#Implemts Q_{phi_3}(s,c,sp,ep) in section 3.3.1
#It combines two End_Point_Val nets mentioned above,
#with a switch structure in section 3.3.4 to form Q_{phi_3}(s,c,sp,ep)
class End_Point_NN(nn.Module):
    def __init__(self):
        super(End_Point_NN,self).__init__()
        #two network setting
        self.Val_noWGD=End_Point_Val()
        self.Val_WGD=End_Point_Val()
    
    def forward(self,old,new,sigma):
        
        y=F.relu((old-1)*(old-new))
        y=y[:,1:chrom_width]-y[:,0:1].expand(-1,chrom_width-1)
        y=-y.detach()*math.log(single_loci_loss)
        Val_no=self.Val_noWGD.forward(old,new)
        Val_no=Val_no+y
        
        z=(old-2*(old//2))*(1-(new-2*(new//2)))
        z=z[:,1:chrom_width]-z[:,0:1].expand(-1,chrom_width-1)
        z=-z.detach()*math.log(single_loci_loss)
        Val_wgd=self.Val_WGD.forward(old,new)
        Val_wgd=Val_wgd+z
        #switch
        x=sigma*Val_wgd+(1-sigma)*Val_no
        return x
    
    
    def find_end(self,old,new,sigma,start_loci,cnv,valid):
        #used for finding the end during loading data
        res_end=self.forward(old,new,sigma)
        
        break_end=torch.zeros(old.shape[0],chrom_width,requires_grad=False)
        chrom_shift=torch.zeros(old.shape[0],chrom_width,requires_grad=False)
        chrom_shift[:,:49]=old[:,1:]
        #allow adding one copy for every breakpoint
        break_end[:,:]=torch.ceil(torch.abs(old-chrom_shift)/max_copy)
        #always allow adding one chromosone
        break_end[:,chrom_width-1]=1
        
        for i in range(old.shape[0]):
            #can't end before starting point
            break_end[i,:int(start_loci[i])]=0*break_end[i,:int(start_loci[i])]
            #don't allow lose one copy when copy number equalls 1
            if(cnv[i]<0.5):
                j=int(start_loci[i])+1
                while(j<chrom_width):
                    if(old[i][j]<1.5):
                        break
                    j=j+1
                break_end[i,j:chrom_width]=0*break_end[i,j:chrom_width]
        res_end_full=torch.zeros(old.shape[0],chrom_width)
        res_end_full[:,1:]=res_end
        #Prior_rule=break_end
        res_end_full=res_end_full+torch.log(break_end)
        end_max_val,end_max=torch.max(res_end_full,1)
        return end_max+1
    
    
    def find_one_end(self,old,new,sigma,start,cnv):
        #used for finding the end during deconvolution
        res_end=self.forward(old,new,sigma)
        
        break_end=torch.zeros(old.shape[0],chrom_width,requires_grad=False)
        chrom_shift=torch.zeros(old.shape[0],chrom_width,requires_grad=False)
        chrom_shift[:,:chrom_width-1]=old[:,1:]
        #allow adding one copy for every breakpoint
        break_end[:,:]=torch.ceil(torch.abs(old-chrom_shift)/max_copy)
        #always allow adding one chromosone
        break_end[:,chrom_width-1]=1
        #can't end before starting point
        break_end[0,:start]=0*break_end[0,:start]
        #don't allow lose one copy when copy number equalls 1
        if(cnv<0.5):
            j=start+1
            while(j<chrom_width):
                if(old[0][j]<1.5):
                    break
                j=j+1
            break_end[0,j:chrom_width]=0*break_end[0,j:chrom_width]
        res_end_full=torch.zeros(old.shape[0],chrom_width)
        res_end_full[:,1:]=res_end
        #Prior_rule=break_end
        res_end_full=res_end_full+torch.log(break_end)
        end_max_val,end_max=torch.max(res_end_full,1)
        end_max=int(end_max[0])
        return end_max+1
        

#combine all separate networks
#add Rule system

#calculating the softmax
#prevent inf when taking log(exp(x))
#log_exp is always gonna be between 1 and the total number of elements
def Soft_update(val1,soft1,val2,soft2):
    bias=val1.clone()
    log_exp=soft1.clone()
    set1=[torch.ge(val1,val2)]
    bias[set1]=val1[set1]
    log_exp[set1]=soft1[set1]+soft2[set1]*torch.exp(val2[set1]-val1[set1])
    set2=[torch.lt(val1,val2)]
    bias[set2]=val2[set2]
    log_exp[set2]=soft2[set2]+soft1[set2]*torch.exp(val1[set2]-val2[set2])
    return bias,log_exp



#Combine all the separate modules mentioned above
#Implementation of Q(s,a)
class Q_learning(nn.Module):
    def __init__(self):
        super(Q_learning,self).__init__()
        self.switch=WGD_Net()
        #the output refer to Q_{\phi_1}(s,c)
        self.Chrom_model=Chrom_NN()
        #the output refer to Q_{\phi_2}(s,c,sp)
        self.CNV=CNV_NN()
        #the output refer to Q_{\phi_3}(s,sp,c,ep)
        self.End=End_Point_NN()
    
    
    def forward(self,state,next_state,chrom,chrom_new,Chr,cnv,start_loci,end_loci,valid):
        '''
        computing the final advantage(loss) used for training
        loss in Thereom1
        state: s in Q(s,a)
        next_state: s' in softmaxQ(s',a')
        Chr,cnv,end_loci: a in Q(s,a)
        chrom,chrom_new,start_loci,end_loci: intermediate results from s,a, which is preprossed to make computation faster
            e.g. chrom is CNP of the Chr(part of a) from the state(s)
            They could be seen as a mapping without parameters to learn:f(s,a)
        valid: a boolean array, indicating if a training sample is valid (e.g. have non negative copy numbers for all loci)
        '''
        
        #computing softmaxQ(s',a')
        #It is a tradition in RL that gradient does not backpropogate through softmaxQ(s',a'), but only through Q(s,a) to make convergence faster
        #there is no theoritical guarantee behind, and it is only a practical trick
        sigma_next=self.switch(next_state)
        x,y=self.Softmax(next_state,sigma_next)
        x=x+torch.log(y)
        #computing r(s,a)
        x=x+Reward(start_loci,end_loci)
        x=x.detach()
        
        #computing Q(s,a)
        sigma=self.switch.forward(state)
        if counter_global<3e6:
            sigma=sigma.detach()
        #Q_{phi_1}(s,c)
        res_chrom=self.Chrom_model.forward(state,sigma)
        
        #Q_{phi_2}(s,c,sp)
        res_cnv=self.CNV.forward(chrom,sigma)
        #if there is originally a break point for start
        #real world constraint as described in section 3.3.2
        #only allow starting points (sp) to be the break points of CNP
        break_start=torch.zeros(state.shape[0],chrom_width,2,requires_grad=False)
        chrom_shift=torch.zeros(state.shape[0],chrom_width,requires_grad=False)
        chrom_shift[:,1:]=chrom[:,:(chrom_width-1)]
        #allow adding one copy for every breakpoint
        break_start[:,:,1]=torch.ceil(torch.abs(chrom-chrom_shift)/max_copy)
        #always allow adding one chromosone
        break_start[:,0,1]=1
        #don't allow lose one copy when copy number equals 0, otherwise there is going to be negative copy numbers
        break_start[:,:,0]=break_start[:,:,1]
        break_start[:,:,0]=break_start[:,:,0]*torch.ceil((chrom/2-0.5)/max_copy)
        break_start=break_start.reshape(state.shape[0],2*chrom_width)
        res_cnv_full=torch.zeros(state.shape[0],2*chrom_width)
        res_cnv_full[:,1:]=res_cnv
        res_cnv_full=res_cnv_full+torch.log(break_start)
        
        #Q_{phi_2}(s,c,sp)-softmax(Q_{phi_2}(s,c,sp)) as described in section 3.3.1
        cnv_max_val,cnv_max=torch.max(res_cnv_full,1)
        cnv_softmax=res_cnv_full-cnv_max_val.reshape(state.shape[0],1).expand(-1,2*chrom_width)
        cnv_softmax=torch.exp(cnv_softmax).sum(1)
        x=x+cnv_max_val+torch.log(cnv_softmax)
        
        #Q_{phi_3}(s,c,sp,ep)
        res_end=self.End.forward(chrom,chrom_new,sigma)
        #if there is originally a break point for end
        #and if this is after the starting point
        #real world constraint in section 3.3.2
        break_end=torch.zeros(state.shape[0],chrom_width,requires_grad=False)
        chrom_shift=torch.zeros(state.shape[0],chrom_width,requires_grad=False)
        chrom_shift[:,:(chrom_width-1)]=chrom[:,1:]
        #allow adding one copy for every breakpoint
        break_end[:,:]=torch.ceil(torch.abs(chrom-chrom_shift)/max_copy)
        #always allow adding one chromosone
        break_end[:,chrom_width-1]=1
        for i in range(state.shape[0]):
            #can't end before starting point
            break_end[i,:int(start_loci[i])]=0*break_end[i,:int(start_loci[i])]
            #don't allow lose one copy when copy number equalls 1
            if(cnv[i]<0.5):
                j=int(start_loci[i])+1
                while(j<chrom_width):
                    if(chrom[i][j]<1.5):
                        break
                    j=j+1
                break_end[i,j:chrom_width]=0*break_end[i,j:chrom_width]
            
        res_end_full=torch.zeros(state.shape[0],chrom_width)
        res_end_full[:,1:]=res_end
        
        #real world constraint described in section 3.3.2
        res_end_full=res_end_full+torch.log(break_end)
        end_max_val,end_max_temp=torch.max(res_end_full,1)
        end_softmax=res_end_full-end_max_val.reshape(state.shape[0],1).expand(-1,chrom_width)
        end_softmax=torch.exp(end_softmax).sum(1)
        #Q_{phi_3}(s,c,sp,ep)-softmax(Q_{phi_3}(s,c,sp,ep)) as described in section 3.3.1
        x=x+end_max_val+torch.log(end_softmax)
        
        for i in range(state.shape[0]):
            if valid[i]>0.5:#check validity to prevent inf-inf which ends in nan
                x[i]=x[i]-res_chrom[i][int(Chr[i])]
                cnv_rank=int(start_loci[i]*2+cnv[i])
                x[i]=x[i]-res_cnv_full[i][cnv_rank]
                end_rank=int(end_loci[i]-1)
                x[i]=x[i]-res_end_full[i][end_rank]
        
        #remove training data which include invalid actions
        x=x*valid
        #return avdantage as well as a best cnv and sigma used for generating training data
        #used for training in the next step
        return x,cnv_max,sigma,res_chrom,res_cnv_full,res_end_full
     
    def Softmax(self,next_state,sigma):
        #compute softmax_{a'} Q(s',a')
        x=self.Chrom_model.forward(next_state,sigma)
        max_chrom=torch.max(x,1)[0]
        softmax_chrom=x-max_chrom.reshape(x.shape[0],1).expand(-1,num_chromosome)
        softmax_chrom=torch.exp(softmax_chrom).sum(1)
        #special action END
        #all the remaining abnormal loci are treated to be caused by several independent single locus copy number changes
        end_val=torch.sum(torch.abs(next_state-1),(1,2,3))*math.log(single_loci_loss)
        max_chrom,softmax_chrom=Soft_update(max_chrom,softmax_chrom,end_val,torch.ones(x.shape[0]))
        #if there is a WGD followed immediately
        for i in range(x.shape[0]):
            #real world constraint as described in section 3.3.2
            #do not allow (reversing) WGD when the CNP contain odd numbers for some loci
            if (not torch.any(next_state[i]-2*torch.floor(next_state[i]/2)>0.5)) and torch.any(next_state[i]>0.5):
                sigma_wgd=self.switch(torch.floor(next_state[i:(i+1)]/2))
                sigma_wgd=sigma_wgd.detach()
                wgd_val,wgd_soft=self.Softmax(torch.floor(next_state[i:(i+1)]/2),sigma_wgd)
                max_chrom[i],softmax_chrom[i]=Soft_update(torch.ones(1)*max_chrom[i],torch.ones(1)*softmax_chrom[i],torch.ones(1)*wgd_val,torch.ones(1)*wgd_soft)
        
        return max_chrom,softmax_chrom
  

#Minimum example
if __name__ == "__main__":
    #test different parts separately
    '''
    switch=WGD_Net()
    Chrom_model=Chrom_NN()
    print(Chrom_model)
    #test the structure of permutation invariant structure
    x=torch.ones(3,1,num_chromosome,50)
    x[0][0][0][0:50]=2
    x[2][0][1][0:50]=2
    prob=switch.forward(x)
    print(prob)
    res=Chrom_model.forward(x,prob)
    print(res)
    res=-float('inf')
    res=torch.LongTensor(3)
    
    print(torch.log(res.type(torch.DoubleTensor)))
    #CNV
    CNV=CNV_NN()
    res=CNV.forward(x[:,0,0,0:50],prob)
    print(CNV)
    print(res.shape)
    #END
    End=End_Point_NN()
    res=End.forward(x[:,0,0,0:50],x[:,0,0,0:50]+1,prob)
    print(End)
    print(res.shape)
    '''
    #test Q-learning
    x=torch.ones(3,1,num_chromosome,50)
    y=torch.ones(3,1,num_chromosome,50)
    x[0][0][0][0:50]=2
    x[2][0][1][0:50]=2
    chrom=x[:,0,0,:]
    chrom_new=y[:,0,0,:]
    Chr=torch.zeros(3)
    cnv=torch.ones(3)
    start_loci=torch.zeros(3)
    end_loci=torch.ones(3)*50
    valid=torch.ones(3)
    Q_model=Q_learning()
    #res,cnv_max,sigma,t,t2,t3=Q_model.forward(x,y,chrom,chrom_new,Chr,cnv,start_loci,end_loci,valid)
    #print(res)
    #print(cnv_max)
    #loss=res.pow(2).mean()
    #print(loss)
    #loss.backward()
    #params = list(Q_model.parameters())
    #print(params[0].grad[0])
    #print(Q_model.switch.conv1.weight[0])

In [3]:
#Train_data.py
import torch
import math

batch_size=30
#
#during training
#data are sampled backwards
#when step==0, it means it is the last step for the trajectory
#and step++ to make CNP more complex
def Sample_train_data(first_step_flag=True,state=None,next_state=None,advantage=None,Chr=None,step=None,wgd=None,valid=None):
    #Sample data for training (similar to the case when a machine is playing a game against itself)
    #Thus, we don't need real world data during training, as long as the reward is similar to the real world probability
    #As in theorem 1, there is no specific destribution required to compute expectation over (s,a) pairs
    #Any distribution with broad support over all (s,a) will do the job
    if first_step_flag:
        #The first sample
        state=torch.ones(batch_size,1,num_chromosome,chrom_width,requires_grad=False)
        next_state=torch.ones(batch_size,1,num_chromosome,chrom_width,requires_grad=False)
        Chr=torch.ones(batch_size,requires_grad=False).type(torch.LongTensor)
        step=torch.zeros(batch_size,requires_grad=False)
        advantage=torch.zeros(batch_size)
        wgd=torch.zeros(batch_size,requires_grad=False)
        valid=torch.ones(batch_size,requires_grad=False)
    
    #sample starting point, end point, gain or loss  
    #because of the permutation invariant structure in section 3.3.3
    #it is not necessary to resample the chromosome everytime
    start_loci=torch.randint(high=chrom_width,size=(batch_size,),requires_grad=False)
    end_loci=torch.LongTensor(batch_size)
    cnv=torch.ones(batch_size,requires_grad=False)
    chrom=torch.Tensor(batch_size,chrom_width)
    chrom_new=torch.Tensor(batch_size,chrom_width)
    #probability of resetting the training trajectory back to step=0
    step_prob=0.18+0.8/(1+math.exp(-1e-2*counter_global+2))
    for i in range(batch_size):
        #if the model is poorly trained until the current step
        #go back to the state 0
        #to ensure small error for short trajectories
        if(torch.rand(1)[0]>step_prob or torch.abs(advantage[i])>=30 or wgd.sum()>24 or step[i]>90):
            state[i]=torch.ones(1,num_chromosome,chrom_width,requires_grad=False)
            next_state[i]=torch.ones(1,num_chromosome,chrom_width,requires_grad=False)
            step[i]=0
            wgd[i]=0
        #if model is fully trained for the current step
        #and there is no invalid operations been sampled
        #go to next step
        elif(valid[i]>0 and torch.abs(advantage[i])<10):
            next_state[i]=state[i].clone()
            step[i]=step[i]+1
        #stay to further train the current step
        #or resample another action
        else:
            state[i]=next_state[i].clone()
    
        #reset advantage and valid after they have been checked
        advantage[i]=0
        valid[i]=1
        end_loci[i]=1+torch.randint(low=start_loci[i],high=50,size=(1,))[0]
        #change the chromosone that CNV is on with some probability
        #otherwise, all CNV will be on the same chromosome
        Chr[i]=torch.randint(high=num_chromosome,size=(1,))[0]
        #adding probability to sample whole chromosomal changes during training
        if torch.rand(1)[0]>0.8:
            start_loci[i]=0
            end_loci[i]=chrom_width
        #adding probability to sample losses starting from the start of chromosome
        if torch.rand(1)[0]>0.3:
            cnv[i]=0
        #increasing the probability to sample WGD during training
        prob_wgd=0.1/(1+math.exp(-step[i]+15))
        #starting to modify state and next state
        #extract preprocessing data
        #wgd          
        if (torch.rand(1)[0]<prob_wgd and wgd[i]<1) or (sum(wgd)<5):
            wgd[i]=1
            state[i]=state[i]*2
            next_state[i]=next_state[i]*2
        #adding cnv effect
        #increasing copies when no wgd
        #decreasing copies when wgd
        if wgd[i]>0.5:
            cnv[i]=1-cnv[i]
        state[i][0][Chr[i]][(start_loci[i]):(end_loci[i])]=state[i][0][Chr[i]][(start_loci[i]):(end_loci[i])]-(cnv[i]-0.5)*2
        chrom[i]=state[i][0][Chr[i]][:]
        #reverse effect on chrom_new
        chrom_new[i]=state[i][0][Chr[i]][:]
        chrom_new[i][(start_loci[i]):]=chrom_new[i][(start_loci[i]):]+(cnv[i]-0.5)*2
        #not going to negative values
        if(torch.any(state[i][0][Chr[i]][(start_loci[i])]< -0.5)):
            valid[i]=0
        #not joining breakpoints
        if(start_loci[i]>0.5 and torch.abs(chrom[i][start_loci[i]]-chrom[i][start_loci[i]-1])<0.5):
            valid[i]=0
        if(end_loci[i]<chrom_width-0.5 and torch.abs(chrom[i][end_loci[i]-1]-chrom[i][end_loci[i]])<0.5):
            valid[i]=0
        if cnv[i]>0.5 and (torch.any(state[i][0][Chr[i]][(start_loci[i]):(end_loci[i])]< 0.5)):
            valid[i]=0
    return state,next_state,chrom,chrom_new,Chr,cnv,start_loci,end_loci,wgd,step,advantage,valid

def Modify_data(state,chrom,Chr,valid,cnv_max,model,sigma):
    #Modify the training data to train the Q values for the best action
    #place takers
    #make sure they are of correct tensor types
    #make sure they are meaningful values to avoid inf if they are not valid samples
    #otherwise nan may be generated
    start_loci=torch.randint(high=chrom_width,size=(batch_size,),requires_grad=False)
    end_loci=start_loci.clone()
    cnv=torch.ones(batch_size,requires_grad=False)
    next_state=state.clone()
    chrom_new=chrom.clone()
    advantage=torch.zeros(batch_size)
    for i in range(batch_size):
        #only deal with valid samples
        if valid[i]>0.5:
            start_loci[i]=cnv_max[i]//2
            cnv[i]=cnv_max[i]-start_loci[i]*2
            #update chrom_new
            chrom_new[i][(start_loci[i]):]=chrom_new[i][(start_loci[i]):]+(cnv[i]-0.5)*2
    
    end_loci=model.find_end(chrom,chrom_new,sigma,start_loci,cnv,valid)
    
    for i in range(batch_size):
        next_state[i][0][Chr[i]][(start_loci[i]):(end_loci[i])]=next_state[i][0][Chr[i]][(start_loci[i]):(end_loci[i])]+(cnv[i]-0.5)*2
      
      
    return state,next_state,chrom,chrom_new,cnv,start_loci,end_loci,advantage

#simulate data for testing
def Simulate_data(batch_size=15,Number_of_step=70):
    state=torch.ones(batch_size,1,num_chromosome,chrom_width,requires_grad=False)
    next_state=torch.ones(batch_size,1,num_chromosome,chrom_width,requires_grad=False)
    Chr=torch.ones(batch_size,requires_grad=False).type(torch.LongTensor)
    step=torch.zeros(batch_size,requires_grad=False)
    wgd=torch.zeros(batch_size,requires_grad=False)
    valid=torch.ones(batch_size,requires_grad=False)
    
    start_loci=torch.randint(high=chrom_width,size=(batch_size,),requires_grad=False)
    end_loci=torch.LongTensor(batch_size)
    cnv=torch.ones(batch_size,requires_grad=False)
    chrom=torch.Tensor(batch_size,chrom_width)
    chrom_new=torch.Tensor(batch_size,chrom_width)
    
    step_counter=0
    while(step_counter<Number_of_step):
        for i in range(batch_size):
            #reset valid after they have been checked
            valid[i]=1
            end_loci[i]=1+torch.randint(low=start_loci[i],high=50,size=(1,))[0]
            #change the chromosone that CNV is on with some probability
            if torch.rand(1)[0]>0.5:
                Chr[i]=torch.randint(high=num_chromosome,size=(1,))[0]
            #adding probability to sample chromosomal changes during training
            if torch.rand(1)[0]>0.8:
                start_loci[i]=0
                end_loci[i]=chrom_width
            #cnv
            if torch.rand(1)[0]>0.7:
                cnv[i]=0
            #modifying cnp
            prob_wgd=0.4/(1+math.exp(-step[i]+5))
            #wgd          
            if (torch.rand(1)[0]<prob_wgd and wgd[i]<1):
                wgd[i]=1
                state[i]=state[i]*2
                next_state[i]=next_state[i]*2
                #adding cnv effect
                #increasing copies when no wgd
                #decreasing copies when wgd
            if wgd[i]>0.5:
                cnv[i]=1-cnv[i]
            state[i][0][Chr[i]][(start_loci[i]):(end_loci[i])]=state[i][0][Chr[i]][(start_loci[i]):(end_loci[i])]-(cnv[i]-0.5)*2
            chrom[i]=state[i][0][Chr[i]][:]
            #reverse effect on chrom_new
            chrom_new[i]=state[i][0][Chr[i]][:]
            chrom_new[i][(start_loci[i]):]=chrom_new[i][(start_loci[i]):]+(cnv[i]-0.5)*2
            #not going to negative values
            if(torch.any(state[i][0][Chr[i]][(start_loci[i])]< -0.5)):
                valid[i]=0
            #not joining breakpoints
            if(start_loci[i]>0.5 and torch.abs(chrom[i][start_loci[i]]-chrom[i][start_loci[i]-1])<0.5):
                valid[i]=0
            if(end_loci[i]<chrom_width-0.5 and torch.abs(chrom[i][end_loci[i]-1]-chrom[i][end_loci[i]])<0.5):
                valid[i]=0
            if valid[i]>0 :
                next_state[i]=state[i].clone()
                step[i]=step[i]+1
            #stay to further train the current step
            #or resample another action
            else:
                state[i]=next_state[i].clone()
        step_counter=step_counter+1
    return state



In [None]:
#main.py
import torch
import math
import torch.optim as optim
#import Policy
#import Data_train
#setting up counter
counter_global=0
#Model
Q_model=Q_learning()
#Load initial data
state,next_state,chrom,chrom_new,Chr,cnv,start_loci,end_loci,wgd,step,advantage,valid=Sample_train_data()
#setting up optimizer
optimizer = optim.Adam(Q_model.parameters(), lr=1e-4,betas=(0.9, 0.99), eps=1e-06, weight_decay=1e-3)

#start training
while(counter_global< 3e8):
    counter_global=counter_global+1
    #load data
    state,next_state,chrom,chrom_new,Chr,cnv,start_loci,end_loci,wgd,step,advantage,valid=Sample_train_data(False,state,next_state,advantage,Chr,step,wgd,valid)
    #compute advantage
    optimizer.zero_grad()
    advantage,cnv_max,sigma,temp,t2,t3=Q_model.forward(state,next_state,chrom,chrom_new,Chr,cnv,start_loci,end_loci,valid)
    #compute loss
    loss=advantage.pow(2).mean()
    
    if counter_global<3e6:
        sigma=Q_model.switch.forward(state)
        loss+=0.1*((-wgd.view(-1,1)*torch.log(sigma)-(1-wgd.view(-1,1))*torch.log(1-sigma))*valid.view(-1,1)).mean()
    #train the model
    loss.backward()
    optimizer.step()
    #print(loss)
    
    #training with the best action
    #temp for the values both used in training and loading new data
    state_temp,next_state_temp,chrom,chrom_new,cnv,start_loci,end_loci,advantage_temp=Modify_data(state,chrom,Chr,valid,cnv_max,Q_model.End,sigma)
    #compute advantage
    optimizer.zero_grad()
    advantage,cnv_max,sigma,temp,temp2,temp3=Q_model.forward(state_temp,next_state_temp,chrom,chrom_new,Chr,cnv,start_loci,end_loci,valid)
    #compute loss
    loss=advantage.pow(2).mean()
    #print(loss)
    if(counter_global%10==9):
        print(loss,step.mean(),step.max(),wgd.sum())
        torch.save(Q_model.state_dict(),"/data/suzaku/ted/HOME/model")
        #print(temp[0])
        #train the model
    loss.backward()
    optimizer.step()
    
    if(counter_global%100==99):
        file_object = open('/data/suzaku/ted/HOME/log', 'a')
        file_object.write("loss:"+str(loss.item())+" step:"+str(step.mean().item())+", "+str(step.max().item())+", "+str(wgd.sum().item())+"\n")
        file_object.close()

tensor(91.8366, grad_fn=<MeanBackward0>) tensor(0.1000) tensor(1.) tensor(5.)
tensor(47.9540, grad_fn=<MeanBackward0>) tensor(0.3333) tensor(2.) tensor(5.)
tensor(24.8770, grad_fn=<MeanBackward0>) tensor(0.3333) tensor(2.) tensor(5.)
tensor(14.6047, grad_fn=<MeanBackward0>) tensor(0.2333) tensor(3.) tensor(5.)
tensor(15.9143, grad_fn=<MeanBackward0>) tensor(0.5333) tensor(3.) tensor(5.)
tensor(12.8373, grad_fn=<MeanBackward0>) tensor(0.3667) tensor(2.) tensor(5.)
tensor(9.9179, grad_fn=<MeanBackward0>) tensor(0.4000) tensor(2.) tensor(5.)
tensor(10.4999, grad_fn=<MeanBackward0>) tensor(0.4000) tensor(3.) tensor(5.)
tensor(10.8216, grad_fn=<MeanBackward0>) tensor(0.3667) tensor(2.) tensor(5.)
tensor(11.4125, grad_fn=<MeanBackward0>) tensor(0.5667) tensor(3.) tensor(5.)
tensor(11.6870, grad_fn=<MeanBackward0>) tensor(0.3667) tensor(4.) tensor(5.)
tensor(13.8307, grad_fn=<MeanBackward0>) tensor(0.5000) tensor(5.) tensor(5.)
tensor(15.2095, grad_fn=<MeanBackward0>) tensor(0.7667) tensor(5.

tensor(70.7433, grad_fn=<MeanBackward0>) tensor(16.7000) tensor(48.) tensor(12.)
tensor(104.0608, grad_fn=<MeanBackward0>) tensor(18.5667) tensor(56.) tensor(15.)
tensor(109.2484, grad_fn=<MeanBackward0>) tensor(19.5667) tensor(64.) tensor(15.)
tensor(105.2842, grad_fn=<MeanBackward0>) tensor(17.4333) tensor(60.) tensor(14.)
tensor(86.8160, grad_fn=<MeanBackward0>) tensor(15.9667) tensor(64.) tensor(12.)
tensor(75.8379, grad_fn=<MeanBackward0>) tensor(16.1333) tensor(71.) tensor(10.)
tensor(116.9907, grad_fn=<MeanBackward0>) tensor(16.3000) tensor(74.) tensor(16.)
tensor(117.8548, grad_fn=<MeanBackward0>) tensor(17.) tensor(80.) tensor(16.)
tensor(109.5750, grad_fn=<MeanBackward0>) tensor(19.0667) tensor(86.) tensor(19.)
tensor(110.2727, grad_fn=<MeanBackward0>) tensor(14.9667) tensor(25.) tensor(19.)
tensor(79.4746, grad_fn=<MeanBackward0>) tensor(17.0667) tensor(28.) tensor(19.)
tensor(69.5146, grad_fn=<MeanBackward0>) tensor(19.3667) tensor(32.) tensor(18.)
tensor(76.8879, grad_fn=<

tensor(98.2554, grad_fn=<MeanBackward0>) tensor(13.5333) tensor(27.) tensor(13.)
tensor(87.3943, grad_fn=<MeanBackward0>) tensor(13.0333) tensor(27.) tensor(12.)
tensor(101.2092, grad_fn=<MeanBackward0>) tensor(13.2333) tensor(27.) tensor(13.)
tensor(72.2121, grad_fn=<MeanBackward0>) tensor(13.4333) tensor(31.) tensor(14.)
tensor(47.6074, grad_fn=<MeanBackward0>) tensor(16.7000) tensor(35.) tensor(12.)
tensor(84.6623, grad_fn=<MeanBackward0>) tensor(18.4000) tensor(40.) tensor(16.)
tensor(106.6975, grad_fn=<MeanBackward0>) tensor(21.6000) tensor(46.) tensor(20.)
tensor(96.2251, grad_fn=<MeanBackward0>) tensor(21.0667) tensor(53.) tensor(17.)
tensor(106.7708, grad_fn=<MeanBackward0>) tensor(21.) tensor(59.) tensor(16.)
tensor(66.4063, grad_fn=<MeanBackward0>) tensor(18.6333) tensor(61.) tensor(12.)
tensor(84.5220, grad_fn=<MeanBackward0>) tensor(16.4333) tensor(62.) tensor(13.)
tensor(77.1094, grad_fn=<MeanBackward0>) tensor(15.5667) tensor(65.) tensor(9.)
tensor(84.5939, grad_fn=<MeanB

tensor(80.7185, grad_fn=<MeanBackward0>) tensor(19.6667) tensor(78.) tensor(15.)
tensor(62.0680, grad_fn=<MeanBackward0>) tensor(18.9333) tensor(80.) tensor(15.)
tensor(81.7259, grad_fn=<MeanBackward0>) tensor(20.5667) tensor(84.) tensor(17.)
tensor(74.0301, grad_fn=<MeanBackward0>) tensor(21.8333) tensor(86.) tensor(16.)
tensor(67.5408, grad_fn=<MeanBackward0>) tensor(20.1333) tensor(90.) tensor(16.)
tensor(96.2989, grad_fn=<MeanBackward0>) tensor(18.4333) tensor(45.) tensor(15.)
tensor(112.6129, grad_fn=<MeanBackward0>) tensor(18.6333) tensor(51.) tensor(19.)
tensor(134.3785, grad_fn=<MeanBackward0>) tensor(19.0333) tensor(60.) tensor(19.)
tensor(116.2442, grad_fn=<MeanBackward0>) tensor(18.7000) tensor(68.) tensor(16.)
tensor(95.7699, grad_fn=<MeanBackward0>) tensor(16.4000) tensor(70.) tensor(13.)
tensor(81.4867, grad_fn=<MeanBackward0>) tensor(17.4667) tensor(77.) tensor(12.)
tensor(85.0177, grad_fn=<MeanBackward0>) tensor(17.4000) tensor(79.) tensor(14.)
tensor(91.7914, grad_fn=<

tensor(106.6987, grad_fn=<MeanBackward0>) tensor(16.9667) tensor(89.) tensor(15.)
tensor(100.7183, grad_fn=<MeanBackward0>) tensor(14.5333) tensor(33.) tensor(14.)
tensor(100.7635, grad_fn=<MeanBackward0>) tensor(14.2667) tensor(35.) tensor(15.)
tensor(120.3371, grad_fn=<MeanBackward0>) tensor(14.9667) tensor(35.) tensor(16.)
tensor(84.7282, grad_fn=<MeanBackward0>) tensor(12.6667) tensor(35.) tensor(13.)
tensor(96.3692, grad_fn=<MeanBackward0>) tensor(13.5667) tensor(26.) tensor(14.)
tensor(62.4506, grad_fn=<MeanBackward0>) tensor(16.2000) tensor(31.) tensor(15.)
tensor(65.4244, grad_fn=<MeanBackward0>) tensor(15.4333) tensor(32.) tensor(14.)
tensor(43.9708, grad_fn=<MeanBackward0>) tensor(11.7000) tensor(35.) tensor(9.)
tensor(73.9106, grad_fn=<MeanBackward0>) tensor(13.6667) tensor(41.) tensor(9.)
tensor(108.7262, grad_fn=<MeanBackward0>) tensor(16.5000) tensor(47.) tensor(12.)
tensor(96.1750, grad_fn=<MeanBackward0>) tensor(15.9667) tensor(54.) tensor(13.)
tensor(70.5947, grad_fn=<

tensor(114.2871, grad_fn=<MeanBackward0>) tensor(17.6333) tensor(73.) tensor(17.)
tensor(98.5362, grad_fn=<MeanBackward0>) tensor(13.9000) tensor(37.) tensor(13.)
tensor(81.4990, grad_fn=<MeanBackward0>) tensor(12.2667) tensor(34.) tensor(12.)
tensor(90.3320, grad_fn=<MeanBackward0>) tensor(14.1000) tensor(34.) tensor(13.)
tensor(104.5066, grad_fn=<MeanBackward0>) tensor(14.5667) tensor(34.) tensor(15.)
tensor(124.8674, grad_fn=<MeanBackward0>) tensor(16.7667) tensor(34.) tensor(18.)
tensor(119.5256, grad_fn=<MeanBackward0>) tensor(15.) tensor(30.) tensor(18.)
tensor(88.5413, grad_fn=<MeanBackward0>) tensor(15.8000) tensor(30.) tensor(17.)
tensor(70.7461, grad_fn=<MeanBackward0>) tensor(17.9667) tensor(32.) tensor(17.)
tensor(69.8941, grad_fn=<MeanBackward0>) tensor(20.1667) tensor(38.) tensor(18.)
tensor(94.2449, grad_fn=<MeanBackward0>) tensor(21.2333) tensor(45.) tensor(21.)
tensor(96.5704, grad_fn=<MeanBackward0>) tensor(20.0333) tensor(54.) tensor(18.)
tensor(113.8525, grad_fn=<Me

tensor(67.6335, grad_fn=<MeanBackward0>) tensor(16.6000) tensor(42.) tensor(14.)
tensor(94.6922, grad_fn=<MeanBackward0>) tensor(18.0667) tensor(50.) tensor(15.)
tensor(117.1872, grad_fn=<MeanBackward0>) tensor(20.8667) tensor(58.) tensor(16.)
tensor(102.7715, grad_fn=<MeanBackward0>) tensor(20.4000) tensor(65.) tensor(16.)
tensor(96.8664, grad_fn=<MeanBackward0>) tensor(15.5667) tensor(52.) tensor(12.)
tensor(103.1394, grad_fn=<MeanBackward0>) tensor(16.5333) tensor(55.) tensor(14.)
tensor(100.0652, grad_fn=<MeanBackward0>) tensor(16.7000) tensor(57.) tensor(15.)
tensor(99.0097, grad_fn=<MeanBackward0>) tensor(13.6667) tensor(33.) tensor(15.)
tensor(102.9748, grad_fn=<MeanBackward0>) tensor(14.5000) tensor(33.) tensor(15.)
tensor(103.3876, grad_fn=<MeanBackward0>) tensor(13.9000) tensor(33.) tensor(15.)
tensor(106.9536, grad_fn=<MeanBackward0>) tensor(14.1667) tensor(29.) tensor(16.)
tensor(128.2899, grad_fn=<MeanBackward0>) tensor(13.8667) tensor(29.) tensor(14.)
tensor(94.3014, grad

tensor(115.7222, grad_fn=<MeanBackward0>) tensor(21.4333) tensor(62.) tensor(17.)
tensor(107.2598, grad_fn=<MeanBackward0>) tensor(16.0333) tensor(64.) tensor(13.)
tensor(118.3458, grad_fn=<MeanBackward0>) tensor(15.3333) tensor(58.) tensor(14.)
tensor(96.5721, grad_fn=<MeanBackward0>) tensor(17.2667) tensor(65.) tensor(14.)
tensor(106.2083, grad_fn=<MeanBackward0>) tensor(17.8000) tensor(70.) tensor(16.)
tensor(118.1947, grad_fn=<MeanBackward0>) tensor(17.9000) tensor(71.) tensor(18.)
tensor(113.0539, grad_fn=<MeanBackward0>) tensor(18.2000) tensor(75.) tensor(19.)
tensor(84.8883, grad_fn=<MeanBackward0>) tensor(14.5000) tensor(35.) tensor(15.)
tensor(78.2402, grad_fn=<MeanBackward0>) tensor(15.8333) tensor(44.) tensor(15.)
tensor(62.0781, grad_fn=<MeanBackward0>) tensor(16.3667) tensor(53.) tensor(14.)
tensor(71.0429, grad_fn=<MeanBackward0>) tensor(14.7667) tensor(31.) tensor(14.)
tensor(67.0351, grad_fn=<MeanBackward0>) tensor(13.6333) tensor(37.) tensor(12.)
tensor(73.8542, grad_f

tensor(108.0760, grad_fn=<MeanBackward0>) tensor(17.7000) tensor(57.) tensor(18.)
tensor(114.6763, grad_fn=<MeanBackward0>) tensor(19.3333) tensor(62.) tensor(20.)
tensor(116.2133, grad_fn=<MeanBackward0>) tensor(17.7333) tensor(65.) tensor(20.)
tensor(101.9022, grad_fn=<MeanBackward0>) tensor(17.7000) tensor(69.) tensor(18.)
tensor(71.0026, grad_fn=<MeanBackward0>) tensor(11.6333) tensor(27.) tensor(13.)
tensor(53.6083, grad_fn=<MeanBackward0>) tensor(11.9000) tensor(26.) tensor(10.)
tensor(68.7558, grad_fn=<MeanBackward0>) tensor(14.7000) tensor(32.) tensor(14.)
tensor(90.5522, grad_fn=<MeanBackward0>) tensor(16.6333) tensor(34.) tensor(17.)
tensor(111.8487, grad_fn=<MeanBackward0>) tensor(18.4000) tensor(39.) tensor(21.)
tensor(120.6988, grad_fn=<MeanBackward0>) tensor(18.0333) tensor(47.) tensor(18.)
tensor(91.8974, grad_fn=<MeanBackward0>) tensor(16.8333) tensor(55.) tensor(14.)
tensor(91.6650, grad_fn=<MeanBackward0>) tensor(19.0333) tensor(62.) tensor(14.)
tensor(112.2119, grad_

tensor(124.5750, grad_fn=<MeanBackward0>) tensor(20.2000) tensor(72.) tensor(20.)
tensor(113.6947, grad_fn=<MeanBackward0>) tensor(20.4000) tensor(75.) tensor(19.)
tensor(115.9801, grad_fn=<MeanBackward0>) tensor(19.9000) tensor(80.) tensor(20.)
tensor(87.2146, grad_fn=<MeanBackward0>) tensor(18.1000) tensor(84.) tensor(17.)
tensor(66.4580, grad_fn=<MeanBackward0>) tensor(15.5333) tensor(45.) tensor(14.)
tensor(71.5693, grad_fn=<MeanBackward0>) tensor(19.2667) tensor(49.) tensor(17.)
tensor(91.1358, grad_fn=<MeanBackward0>) tensor(19.7000) tensor(52.) tensor(20.)
tensor(89.3224, grad_fn=<MeanBackward0>) tensor(20.7333) tensor(58.) tensor(18.)
tensor(107.1436, grad_fn=<MeanBackward0>) tensor(20.5000) tensor(64.) tensor(17.)
tensor(135.1665, grad_fn=<MeanBackward0>) tensor(20.1667) tensor(72.) tensor(17.)
tensor(125.0833, grad_fn=<MeanBackward0>) tensor(21.5667) tensor(79.) tensor(17.)
tensor(110.3557, grad_fn=<MeanBackward0>) tensor(22.8333) tensor(86.) tensor(16.)
tensor(132.3097, grad

tensor(69.7166, grad_fn=<MeanBackward0>) tensor(14.7000) tensor(26.) tensor(16.)
tensor(69.2050, grad_fn=<MeanBackward0>) tensor(16.6667) tensor(30.) tensor(17.)
tensor(85.6175, grad_fn=<MeanBackward0>) tensor(17.9667) tensor(33.) tensor(17.)
tensor(99.9387, grad_fn=<MeanBackward0>) tensor(17.1000) tensor(40.) tensor(15.)
tensor(124.8911, grad_fn=<MeanBackward0>) tensor(18.1333) tensor(44.) tensor(19.)
tensor(106.1983, grad_fn=<MeanBackward0>) tensor(17.0667) tensor(47.) tensor(15.)
tensor(105.2703, grad_fn=<MeanBackward0>) tensor(16.2667) tensor(52.) tensor(15.)
tensor(100.2679, grad_fn=<MeanBackward0>) tensor(16.6333) tensor(47.) tensor(14.)
tensor(96.1537, grad_fn=<MeanBackward0>) tensor(16.2333) tensor(47.) tensor(17.)
tensor(86.2062, grad_fn=<MeanBackward0>) tensor(17.2667) tensor(47.) tensor(16.)
tensor(89.2387, grad_fn=<MeanBackward0>) tensor(18.3000) tensor(48.) tensor(18.)
tensor(77.2334, grad_fn=<MeanBackward0>) tensor(19.6333) tensor(51.) tensor(20.)
tensor(90.1819, grad_fn=

tensor(70.8729, grad_fn=<MeanBackward0>) tensor(17.4667) tensor(79.) tensor(17.)
tensor(67.8012, grad_fn=<MeanBackward0>) tensor(16.1000) tensor(27.) tensor(16.)
tensor(76.4482, grad_fn=<MeanBackward0>) tensor(18.4667) tensor(31.) tensor(17.)
tensor(96.1632, grad_fn=<MeanBackward0>) tensor(19.2000) tensor(35.) tensor(18.)
tensor(111.5399, grad_fn=<MeanBackward0>) tensor(20.9000) tensor(41.) tensor(18.)
tensor(146.3779, grad_fn=<MeanBackward0>) tensor(24.0667) tensor(49.) tensor(20.)
tensor(125.5864, grad_fn=<MeanBackward0>) tensor(22.2333) tensor(53.) tensor(17.)
tensor(95.2368, grad_fn=<MeanBackward0>) tensor(21.1000) tensor(57.) tensor(16.)
tensor(96.7455, grad_fn=<MeanBackward0>) tensor(17.6000) tensor(61.) tensor(15.)
tensor(81.0193, grad_fn=<MeanBackward0>) tensor(16.7000) tensor(65.) tensor(12.)
tensor(72.4535, grad_fn=<MeanBackward0>) tensor(15.6000) tensor(73.) tensor(11.)
tensor(62.8933, grad_fn=<MeanBackward0>) tensor(17.3333) tensor(77.) tensor(12.)
tensor(89.1570, grad_fn=<

tensor(105.5169, grad_fn=<MeanBackward0>) tensor(17.7667) tensor(65.) tensor(16.)
tensor(91.8557, grad_fn=<MeanBackward0>) tensor(18.2333) tensor(69.) tensor(15.)
tensor(111.4034, grad_fn=<MeanBackward0>) tensor(18.3667) tensor(65.) tensor(16.)
tensor(76.5720, grad_fn=<MeanBackward0>) tensor(14.7333) tensor(67.) tensor(12.)
tensor(102.8071, grad_fn=<MeanBackward0>) tensor(18.3667) tensor(72.) tensor(15.)
tensor(106.3714, grad_fn=<MeanBackward0>) tensor(16.8667) tensor(76.) tensor(17.)
tensor(85.9652, grad_fn=<MeanBackward0>) tensor(17.5000) tensor(80.) tensor(18.)
tensor(46.8112, grad_fn=<MeanBackward0>) tensor(18.1333) tensor(83.) tensor(13.)
tensor(53.3141, grad_fn=<MeanBackward0>) tensor(20.1667) tensor(86.) tensor(12.)
tensor(75.7405, grad_fn=<MeanBackward0>) tensor(18.6333) tensor(45.) tensor(13.)
tensor(80.3385, grad_fn=<MeanBackward0>) tensor(18.1667) tensor(52.) tensor(13.)
tensor(99.6888, grad_fn=<MeanBackward0>) tensor(21.5000) tensor(59.) tensor(16.)
tensor(109.7095, grad_fn

tensor(125.0352, grad_fn=<MeanBackward0>) tensor(17.6667) tensor(33.) tensor(20.)
tensor(79.4764, grad_fn=<MeanBackward0>) tensor(18.6667) tensor(41.) tensor(19.)
tensor(83.4039, grad_fn=<MeanBackward0>) tensor(23.6667) tensor(42.) tensor(22.)
tensor(81.7840, grad_fn=<MeanBackward0>) tensor(21.3667) tensor(43.) tensor(18.)
tensor(92.4358, grad_fn=<MeanBackward0>) tensor(20.9667) tensor(49.) tensor(15.)
tensor(107.0037, grad_fn=<MeanBackward0>) tensor(21.5000) tensor(52.) tensor(15.)
tensor(87.6659, grad_fn=<MeanBackward0>) tensor(18.5000) tensor(57.) tensor(13.)
tensor(102.7318, grad_fn=<MeanBackward0>) tensor(19.4333) tensor(60.) tensor(16.)
tensor(81.0822, grad_fn=<MeanBackward0>) tensor(20.7667) tensor(66.) tensor(14.)
tensor(98.2001, grad_fn=<MeanBackward0>) tensor(22.0333) tensor(73.) tensor(16.)
tensor(75.2510, grad_fn=<MeanBackward0>) tensor(16.3333) tensor(79.) tensor(13.)
tensor(84.4256, grad_fn=<MeanBackward0>) tensor(13.7333) tensor(39.) tensor(16.)
tensor(101.2049, grad_fn=

tensor(86.1960, grad_fn=<MeanBackward0>) tensor(23.8333) tensor(67.) tensor(14.)
tensor(95.9357, grad_fn=<MeanBackward0>) tensor(24.2667) tensor(71.) tensor(17.)
tensor(116.4273, grad_fn=<MeanBackward0>) tensor(21.7000) tensor(75.) tensor(18.)
tensor(130.6338, grad_fn=<MeanBackward0>) tensor(23.3333) tensor(81.) tensor(19.)
tensor(113.4358, grad_fn=<MeanBackward0>) tensor(22.8333) tensor(86.) tensor(19.)
tensor(106.3807, grad_fn=<MeanBackward0>) tensor(23.7333) tensor(90.) tensor(18.)
tensor(87.5961, grad_fn=<MeanBackward0>) tensor(18.1667) tensor(81.) tensor(15.)
tensor(77.7145, grad_fn=<MeanBackward0>) tensor(15.2333) tensor(84.) tensor(12.)
tensor(84.4709, grad_fn=<MeanBackward0>) tensor(15.6667) tensor(88.) tensor(13.)
tensor(86.7092, grad_fn=<MeanBackward0>) tensor(13.3333) tensor(34.) tensor(13.)
tensor(82.0430, grad_fn=<MeanBackward0>) tensor(12.8333) tensor(22.) tensor(12.)
tensor(86.1795, grad_fn=<MeanBackward0>) tensor(14.1000) tensor(27.) tensor(13.)
tensor(76.0531, grad_fn=

tensor(146.5210, grad_fn=<MeanBackward0>) tensor(24.1667) tensor(59.) tensor(21.)
tensor(131.8130, grad_fn=<MeanBackward0>) tensor(26.5667) tensor(65.) tensor(23.)
tensor(118.5720, grad_fn=<MeanBackward0>) tensor(26.9000) tensor(69.) tensor(21.)
tensor(109.2149, grad_fn=<MeanBackward0>) tensor(22.2000) tensor(73.) tensor(18.)
tensor(107.7699, grad_fn=<MeanBackward0>) tensor(19.5333) tensor(77.) tensor(18.)
tensor(110.0034, grad_fn=<MeanBackward0>) tensor(18.9000) tensor(82.) tensor(18.)
tensor(102.7507, grad_fn=<MeanBackward0>) tensor(19.8000) tensor(88.) tensor(16.)
tensor(89.0830, grad_fn=<MeanBackward0>) tensor(16.2000) tensor(77.) tensor(13.)
tensor(87.9150, grad_fn=<MeanBackward0>) tensor(16.3667) tensor(80.) tensor(14.)
tensor(54.3090, grad_fn=<MeanBackward0>) tensor(13.9000) tensor(83.) tensor(9.)
tensor(62.8088, grad_fn=<MeanBackward0>) tensor(16.7333) tensor(86.) tensor(9.)
tensor(82.5696, grad_fn=<MeanBackward0>) tensor(16.3667) tensor(89.) tensor(12.)
tensor(77.9462, grad_fn

tensor(97.5525, grad_fn=<MeanBackward0>) tensor(18.5000) tensor(59.) tensor(13.)
tensor(85.4678, grad_fn=<MeanBackward0>) tensor(19.6333) tensor(65.) tensor(13.)
tensor(117.4211, grad_fn=<MeanBackward0>) tensor(20.9667) tensor(68.) tensor(16.)
tensor(127.8938, grad_fn=<MeanBackward0>) tensor(20.2333) tensor(73.) tensor(19.)
tensor(133.6071, grad_fn=<MeanBackward0>) tensor(19.3333) tensor(80.) tensor(19.)
tensor(104.4967, grad_fn=<MeanBackward0>) tensor(16.0667) tensor(80.) tensor(17.)
tensor(90.1008, grad_fn=<MeanBackward0>) tensor(15.) tensor(34.) tensor(16.)
tensor(95.7104, grad_fn=<MeanBackward0>) tensor(14.6667) tensor(34.) tensor(16.)
tensor(115.9171, grad_fn=<MeanBackward0>) tensor(15.6333) tensor(34.) tensor(18.)
tensor(123.4468, grad_fn=<MeanBackward0>) tensor(17.2000) tensor(34.) tensor(20.)
tensor(55.1440, grad_fn=<MeanBackward0>) tensor(16.2000) tensor(37.) tensor(16.)
tensor(53.0739, grad_fn=<MeanBackward0>) tensor(18.8667) tensor(43.) tensor(15.)
tensor(87.6544, grad_fn=<M

tensor(72.4179, grad_fn=<MeanBackward0>) tensor(16.1000) tensor(33.) tensor(16.)
tensor(97.9157, grad_fn=<MeanBackward0>) tensor(16.8333) tensor(40.) tensor(16.)
tensor(94.1371, grad_fn=<MeanBackward0>) tensor(15.4667) tensor(48.) tensor(14.)
tensor(107.4935, grad_fn=<MeanBackward0>) tensor(18.3667) tensor(57.) tensor(16.)
tensor(105.3095, grad_fn=<MeanBackward0>) tensor(18.8667) tensor(62.) tensor(15.)
tensor(132.8719, grad_fn=<MeanBackward0>) tensor(18.2333) tensor(68.) tensor(18.)
tensor(110.8859, grad_fn=<MeanBackward0>) tensor(16.4333) tensor(71.) tensor(16.)
tensor(95.8859, grad_fn=<MeanBackward0>) tensor(18.3333) tensor(74.) tensor(17.)
tensor(102.7413, grad_fn=<MeanBackward0>) tensor(20.5000) tensor(77.) tensor(18.)
tensor(112.0948, grad_fn=<MeanBackward0>) tensor(20.1000) tensor(83.) tensor(21.)
tensor(80.8090, grad_fn=<MeanBackward0>) tensor(17.9000) tensor(88.) tensor(17.)
tensor(53.4912, grad_fn=<MeanBackward0>) tensor(15.3667) tensor(81.) tensor(15.)
tensor(51.5117, grad_f

tensor(105.5301, grad_fn=<MeanBackward0>) tensor(19.9667) tensor(65.) tensor(16.)
tensor(104.1307, grad_fn=<MeanBackward0>) tensor(19.7333) tensor(70.) tensor(16.)
tensor(113.6551, grad_fn=<MeanBackward0>) tensor(20.7000) tensor(72.) tensor(16.)
tensor(116.7825, grad_fn=<MeanBackward0>) tensor(19.3667) tensor(73.) tensor(17.)
tensor(126.5056, grad_fn=<MeanBackward0>) tensor(19.7000) tensor(77.) tensor(18.)
tensor(121.4515, grad_fn=<MeanBackward0>) tensor(20.9667) tensor(79.) tensor(17.)
tensor(138.4320, grad_fn=<MeanBackward0>) tensor(19.9000) tensor(84.) tensor(19.)
tensor(134.3071, grad_fn=<MeanBackward0>) tensor(18.6333) tensor(86.) tensor(19.)
tensor(114.8885, grad_fn=<MeanBackward0>) tensor(17.4333) tensor(90.) tensor(16.)
tensor(117.7091, grad_fn=<MeanBackward0>) tensor(14.7000) tensor(46.) tensor(17.)
tensor(109.4142, grad_fn=<MeanBackward0>) tensor(13.5000) tensor(32.) tensor(17.)
tensor(108.6402, grad_fn=<MeanBackward0>) tensor(14.1333) tensor(32.) tensor(15.)
tensor(107.5146,

tensor(75.6244, grad_fn=<MeanBackward0>) tensor(19.6667) tensor(45.) tensor(18.)
tensor(98.2761, grad_fn=<MeanBackward0>) tensor(21.9667) tensor(51.) tensor(16.)
tensor(102.0959, grad_fn=<MeanBackward0>) tensor(20.7000) tensor(51.) tensor(17.)
tensor(102.2282, grad_fn=<MeanBackward0>) tensor(20.4333) tensor(55.) tensor(18.)
tensor(109.8990, grad_fn=<MeanBackward0>) tensor(19.7000) tensor(55.) tensor(19.)
tensor(135.1276, grad_fn=<MeanBackward0>) tensor(19.7000) tensor(57.) tensor(20.)
tensor(98.1257, grad_fn=<MeanBackward0>) tensor(17.8000) tensor(59.) tensor(15.)
tensor(102.1533, grad_fn=<MeanBackward0>) tensor(19.3000) tensor(63.) tensor(16.)
tensor(107.7466, grad_fn=<MeanBackward0>) tensor(17.7333) tensor(69.) tensor(16.)
tensor(99.6915, grad_fn=<MeanBackward0>) tensor(18.7333) tensor(70.) tensor(16.)
tensor(105.4557, grad_fn=<MeanBackward0>) tensor(15.7667) tensor(73.) tensor(14.)
tensor(88.9583, grad_fn=<MeanBackward0>) tensor(13.4000) tensor(30.) tensor(12.)
tensor(73.1739, grad_

tensor(122.6075, grad_fn=<MeanBackward0>) tensor(20.6000) tensor(70.) tensor(17.)
tensor(118.2463, grad_fn=<MeanBackward0>) tensor(20.8667) tensor(67.) tensor(17.)
tensor(96.6884, grad_fn=<MeanBackward0>) tensor(18.1333) tensor(74.) tensor(15.)
tensor(118.0315, grad_fn=<MeanBackward0>) tensor(20.3000) tensor(79.) tensor(18.)
tensor(123.6091, grad_fn=<MeanBackward0>) tensor(19.5000) tensor(82.) tensor(19.)
tensor(122.8801, grad_fn=<MeanBackward0>) tensor(20.4000) tensor(87.) tensor(19.)
tensor(97.1526, grad_fn=<MeanBackward0>) tensor(18.4000) tensor(90.) tensor(16.)
tensor(149.7385, grad_fn=<MeanBackward0>) tensor(17.8000) tensor(38.) tensor(21.)
tensor(129.1655, grad_fn=<MeanBackward0>) tensor(16.3333) tensor(38.) tensor(19.)
tensor(102.8997, grad_fn=<MeanBackward0>) tensor(14.8333) tensor(38.) tensor(16.)
tensor(97.4292, grad_fn=<MeanBackward0>) tensor(15.1333) tensor(38.) tensor(15.)
tensor(93.3868, grad_fn=<MeanBackward0>) tensor(16.1333) tensor(38.) tensor(13.)
tensor(73.2299, grad

tensor(98.7603, grad_fn=<MeanBackward0>) tensor(23.7333) tensor(53.) tensor(19.)
tensor(108.7550, grad_fn=<MeanBackward0>) tensor(24.) tensor(54.) tensor(19.)
tensor(123.8651, grad_fn=<MeanBackward0>) tensor(24.8667) tensor(55.) tensor(20.)
tensor(129.8079, grad_fn=<MeanBackward0>) tensor(25.0667) tensor(58.) tensor(22.)
tensor(140.9136, grad_fn=<MeanBackward0>) tensor(24.2667) tensor(60.) tensor(23.)
tensor(138.0569, grad_fn=<MeanBackward0>) tensor(22.9333) tensor(63.) tensor(22.)
tensor(118.6779, grad_fn=<MeanBackward0>) tensor(20.6333) tensor(67.) tensor(19.)
tensor(92.8524, grad_fn=<MeanBackward0>) tensor(16.) tensor(70.) tensor(14.)
tensor(124.2642, grad_fn=<MeanBackward0>) tensor(18.1333) tensor(71.) tensor(17.)
tensor(127.2378, grad_fn=<MeanBackward0>) tensor(19.3667) tensor(75.) tensor(18.)
tensor(128.9404, grad_fn=<MeanBackward0>) tensor(20.1000) tensor(79.) tensor(19.)
tensor(128.2642, grad_fn=<MeanBackward0>) tensor(20.8333) tensor(82.) tensor(19.)
tensor(112.1250, grad_fn=<

tensor(68.0648, grad_fn=<MeanBackward0>) tensor(25.) tensor(46.) tensor(22.)
tensor(102.5449, grad_fn=<MeanBackward0>) tensor(25.4667) tensor(46.) tensor(20.)
tensor(100.4538, grad_fn=<MeanBackward0>) tensor(25.3667) tensor(52.) tensor(19.)
tensor(134.4492, grad_fn=<MeanBackward0>) tensor(30.0333) tensor(58.) tensor(23.)
tensor(114.0796, grad_fn=<MeanBackward0>) tensor(31.4333) tensor(64.) tensor(21.)
tensor(108.2382, grad_fn=<MeanBackward0>) tensor(30.7000) tensor(67.) tensor(21.)
tensor(105.4575, grad_fn=<MeanBackward0>) tensor(30.4000) tensor(71.) tensor(21.)
tensor(83.9681, grad_fn=<MeanBackward0>) tensor(31.5000) tensor(73.) tensor(20.)
tensor(77.2464, grad_fn=<MeanBackward0>) tensor(25.4333) tensor(77.) tensor(16.)
tensor(88.1384, grad_fn=<MeanBackward0>) tensor(24.7333) tensor(79.) tensor(15.)
tensor(101.0348, grad_fn=<MeanBackward0>) tensor(21.8667) tensor(84.) tensor(17.)
tensor(111.1923, grad_fn=<MeanBackward0>) tensor(20.9333) tensor(86.) tensor(17.)
tensor(102.1348, grad_fn

tensor(105.6599, grad_fn=<MeanBackward0>) tensor(15.3667) tensor(30.) tensor(15.)
tensor(111.9243, grad_fn=<MeanBackward0>) tensor(15.7667) tensor(30.) tensor(17.)
tensor(113.1276, grad_fn=<MeanBackward0>) tensor(14.6333) tensor(29.) tensor(16.)
tensor(107.0512, grad_fn=<MeanBackward0>) tensor(15.5333) tensor(30.) tensor(15.)
tensor(99.3884, grad_fn=<MeanBackward0>) tensor(16.2667) tensor(37.) tensor(14.)
tensor(110.7981, grad_fn=<MeanBackward0>) tensor(17.8667) tensor(46.) tensor(17.)
tensor(130.3285, grad_fn=<MeanBackward0>) tensor(18.3333) tensor(53.) tensor(20.)
tensor(96.3280, grad_fn=<MeanBackward0>) tensor(17.2000) tensor(56.) tensor(19.)
tensor(73.2792, grad_fn=<MeanBackward0>) tensor(16.3333) tensor(58.) tensor(16.)
tensor(38.0061, grad_fn=<MeanBackward0>) tensor(18.4667) tensor(63.) tensor(13.)
tensor(50.9658, grad_fn=<MeanBackward0>) tensor(19.9667) tensor(69.) tensor(11.)
tensor(81.7572, grad_fn=<MeanBackward0>) tensor(18.2667) tensor(77.) tensor(11.)
tensor(109.0796, grad_