In [7]:
import torch
import math
from torch import nn
import torch.nn.functional as F

In [8]:
#Reward.py
#Define the reward function r(s,a) when a is not a special action END

#For each chromosome, we consider 50 SNP loci
chrom_width=50;
#44 normal chromosomes in human genome
num_chromosome=44


#normalisation constant to make normal_const*(\sum_i a_i) <1, so that the possibility of all CNV sum to a real value smaller than 1
normal_const=1e-5;
#probability of single locus gain/loss
single_loci_loss=normal_const*(1-2e-1);
#probability of WGD
WGD=normal_const*0.6;

#log probability of CNV
#used for calculating the distribution of p(a) when a is a focal CNV
const1=single_loci_loss
const2=1;

#Whole chromosome change probability
Whole_Chromosome_CNV=single_loci_loss/4;
Half_Chromosome_CNV=single_loci_loss/3;

max_copy=20

def Reward(Start,End):
    Start=Start.to(torch.float32)
    End=End.to(torch.float32)
    reward=torch.log(const1/(const2+End-Start))
    #chromosome changes
    for i in range(Start.shape[0]):
        #full chromosome
        if End[i]-Start[i]>chrom_width-0.5:
            reward[i]=math.log(Whole_Chromosome_CNV)
        #arm level changes
        if chrom_width-End[i]<0.5 and abs(chrom_width//2-Start[i])<1.5:
            reward[i]=math.log(Half_Chromosome_CNV)
        if Start[i]<0.5 and abs(chrom_width//2-End[i])<1.5:
            reward[i]=math.log(Half_Chromosome_CNV)
    return reward



#Q-function.py
#defining the Q-function 
#Q(s,a) in manuscript


#switch structure mentioned in section 3.3.4
#kernel sizes for convolution layers
nkernels_switch = [20,40,80]
activatiion_wgd=torch.tanh
class WGD_Net(nn.Module):
    def __init__(self):
        super(WGD_Net, self).__init__()
        #chromosome permutation invariant structure as described in section 3.3.3
        #slide for chromosome is 1 and the filter length in this dimension is also 1
        #thus, the same filter goes through all chromosomes in the same fashion
        self.conv1=nn.Conv2d(1, nkernels_switch[0], (1,3),(1,1),(0,1))
        self.conv2=nn.Conv2d(nkernels_switch[0],nkernels_switch[1] , (1,3),(1,1), (0,1))
        self.conv3=nn.Conv2d(nkernels_switch[1],nkernels_switch[2] , (1,5),(1,1), (0,0))
        self.linear=nn.Linear(nkernels_switch[2],1)
    
    def forward(self, x):
        #y=torch.clamp((F.relu(x.mean(3)-1)-0.5+0.5*F.relu(1-x.mean(3))),-1,1).sum((1,2))
        y=20*(x.mean((1,2,3))-1.5)
        y=y.reshape(x.shape[0],1).detach()
        x=x.reshape(x.shape[0],1,num_chromosome,chrom_width)
        x=F.max_pool2d(activatiion_wgd(self.conv1(x)),(1,5),(1,5),(0,0))
        x=F.max_pool2d(activatiion_wgd(self.conv2(x)),(1,2),(1,2),(0,0))
        x=(activatiion_wgd(self.conv3(x))).sum((2,3))
        x=self.linear(x)
        x=x/2
        #residule representation in x as described in section 3.3.4
        x=y/2+x
        x=torch.sigmoid(x)
        return x

#chromosome evaluation net 
#Used in Chrom_NN (which is Q_{phi_1}(s,c) in section 3.3.1)
#kernel sizes for convolution layers
nkernels_chr = [80,120,160]
activation_cnp=torch.tanh
class CNP_Val(nn.Module):
    def __init__(self):
        super(CNP_Val, self).__init__()
        self.conv1=nn.Conv2d(1, nkernels_chr[0], (1,5),(1,1),(0,2))
        self.conv2=nn.Conv2d(nkernels_chr[0],nkernels_chr[1] , (1,3),(1,1), (0,1))
        self.conv3=nn.Conv2d(nkernels_chr[1],nkernels_chr[2] , (1,3),(1,1), (0,1))
        self.conv4=nn.Conv2d(nkernels_chr[2],1, (1,5),(1,1), (0,0))
    
    def forward(self, x):
        x=F.max_pool2d(activation_cnp(self.conv1(x)),(1,3),(1,3),(0,1))
        x=F.max_pool2d(activation_cnp(self.conv2(x)),(1,2),(1,2),(0,1))
        x=F.max_pool2d(activation_cnp(self.conv3(x)),(1,2),(1,2),(0,1))
        #KL divergence is always nonpositive
        x=0.25+F.elu(self.conv4(x),0.25)
        #number of sample * 44 chromosomes
        x=x.reshape(x.shape[0],num_chromosome)
        return x

#Implemts Q_{phi_1}(s,c) in section 3.3.1
#It combines two chromosome evaluation nets mentioned above,
#with a switch structure in section 3.3.4 to form Q_{phi_1}(s,c)
class Chrom_NN(nn.Module):
    def __init__(self):
        super(Chrom_NN,self).__init__()
        #two parts of the Chrom_NN
        #NN for CNP without WGD 
        self.Val_noWGD=CNP_Val()
        #NN for CNP with WGD
        self.Val_WGD=CNP_Val()
    
    def forward(self,x,sigma):
        #probability for WGD, which is computed by switch structure
        sigma=sigma.expand(-1,num_chromosome)
        #we assume the copy number for each loci ranges from 0~9
        #for samples without WGD
        
        #y represents if a chromosome have abnormal copy numbers (positions with copy number other than 1)
        y=torch.ceil((torch.abs(x-1)).mean(3)/max_copy)
        y=y.reshape(x.shape[0],num_chromosome)
        y=y.detach()
        #Residule representation mentioned in section 3.3.4
        #the value for Q_{phi_1}(s,c) is computed as Val_no (the value estimated by the neural net, the residule part)+ y (the empirial estimation) 
        Val_no=self.Val_noWGD.forward(x)
        #chromosome with all 1 copies don't need any CNV and thus will be less likely mutated.
        Val_no=-(1-y)*math.log(single_loci_loss)*2+((y*Val_no).sum(1)).reshape(x.shape[0],1).expand(-1,num_chromosome)
        #for samples with WGD
        #it is similar to the previsou part, where z is an equivalent for y and Val_wgd is an equivalent for Val_no
        z=torch.ceil((torch.abs(x-2*(x//2))).mean(3)/max_copy)
        z=z.reshape(x.shape[0],num_chromosome)
        z=z.detach()
        Val_wgd=self.Val_WGD.forward(x)
        Val_wgd=-(1-z)*math.log(single_loci_loss)*2+(z*Val_wgd).sum(1).reshape(x.shape[0],1).expand(-1,num_chromosome)-math.log(WGD)
        
        #combine two NN with switch as defined in Section 3.3.4
        x=sigma*Val_wgd+(1-sigma)*Val_no
        x=-x
        return x



#starting point and gain or loss (defined as sp in manuscript) 
#Used in CNV_NN (which is Q_{phi_2}(s,c,sp) on section 3.3.1)
#kernel sizes for convolution layers
nkernels_CNV = [80,120,160,10]
activation_cnv=torch.tanh
class CNV_Val(nn.Module):
    def __init__(self):
        super(CNV_Val,self).__init__()
        self.conv1=nn.Conv2d(1, nkernels_CNV[0], (1,7),(1,1),(0,3))
        self.conv2=nn.Conv2d(nkernels_CNV[0],nkernels_CNV[1] , (1,7),(1,1), (0,3))
        self.conv3=nn.Conv2d(nkernels_CNV[1],nkernels_CNV[2] , (1,7),(1,1), (0,3))
        self.conv4=nn.Conv2d(nkernels_CNV[2], nkernels_CNV[3], (1,7),(1,1), (0,3))
        self.linear=nn.Linear(nkernels_CNV[3]*chrom_width,2*chrom_width-1)
    
    def forward(self,x):
        x=activation_cnv(self.conv1(x))
        x=activation_cnv(self.conv2(x))
        x=activation_cnv(self.conv3(x))
        x=activation_cnv(self.conv4(x))
        x=x.reshape(x.shape[0],nkernels_CNV[3]*chrom_width)
        x=self.linear(x)
        #number of samples* [(50 regions)*(2(gain or loss))-1] 
        #Only have 50*2-1=99 output dimensions because we fix the average these output
        #The average of them could be arbitrary because of the partitioning
        return x

#Implemts Q_{phi_2}(s,c,sp) in section 3.3.1
#It combines two CNV_Val nets mentioned above,
#with a switch structure in section 3.3.4 to form Q_{phi_2}(s,c,sp)
class CNV_NN(nn.Module):
    def __init__(self):
        super(CNV_NN,self).__init__()
        #two network setting
        self.CNV_noWGD=CNV_Val()
        self.CNV_WGD=CNV_Val()
    
    def forward(self,x,sigma):
        #as in section 3.3.4
        #y is the empirical estimation
        #Val_no is the redidule representation
        y=torch.Tensor(x.shape[0],chrom_width,2)
        y[:,:,0]=F.relu(x-1)
        y[:,:,1]=F.relu(1-x)
        y=y.reshape(x.shape[0],2*chrom_width)
        y=y[:,1:(2*chrom_width)]-y[:,0:1].expand(-1,2*chrom_width-1)
        y=-y.detach()*math.log(single_loci_loss)
        Val_no=self.CNV_noWGD.forward(x.reshape(x.shape[0],1,1,chrom_width))
        Val_no=y+Val_no
        
        z=((torch.abs(x-2*(x//2))).reshape(x.shape[0],chrom_width,1)).expand(-1,-1,2)
        z=z.reshape(x.shape[0],2*chrom_width)
        z=z[:,1:(2*chrom_width)]-z[:,0:1].expand(-1,2*chrom_width-1)
        z=-z.detach()*math.log(single_loci_loss)
        Val_wgd=self.CNV_WGD.forward(x.reshape(x.shape[0],1,1,chrom_width))
        Val_wgd=z+Val_wgd
        #switch
        x=sigma*Val_wgd+(1-sigma)*Val_no
        return(x)
    
     
    def find_one_cnv(self,chrom,sigma):
        #used for finding the cnv during deconvolution
        #it is not used in training process
        
        res_cnv=self.forward(chrom,sigma)
        #if there is originally a break point for start
        #rule system 
        break_start=torch.zeros(chrom.shape[0],50,2,requires_grad=False)
        chrom_shift=torch.zeros(chrom.shape[0],50,requires_grad=False)
        chrom_shift[:,1:]=chrom[:,:49]
        #allow adding one copy for every breakpoint
        break_start[:,:,1]=torch.ceil(torch.abs(chrom-chrom_shift)/max_copy)
        #always allow adding one chromosone
        break_start[:,0,1]=1
        #don't allow lose one copy when copy number equalls 1
        break_start[:,:,0]=break_start[:,:,1]
        break_start[:,:,0]=break_start[:,:,0]*torch.ceil((chrom/2-0.5)/max_copy)
        break_start=break_start.reshape(chrom.shape[0],100)
        res_cnv_full=torch.zeros(chrom.shape[0],100)
        res_cnv_full[:,1:]=res_cnv
        #Prior_rule=break_start
        res_cnv_full=res_cnv_full+torch.log(break_start)
        #best cnv according to the current Q
        cnv_max_val,cnv_max=torch.max(res_cnv_full,1)
        return int(cnv_max[0])
    


#end point
#Used in End_Point_NN (which is Q_{phi_3}(s,c,sp,ep) on section 3.3.1)
#kernel sizes for convolution layers
nkernels_End = [80,120,240]
activation_end=torch.tanh
class End_Point_Val(nn.Module):
    def __init__(self):
        super(End_Point_Val,self).__init__()
        self.conv1=nn.Conv2d(2, nkernels_End[0], (1,7),(1,1),(0,3))
        self.conv2=nn.Conv2d(nkernels_End[0],nkernels_End[1] , (1,7),(1,1), (0,3))
        self.conv3=nn.Conv2d(nkernels_End[1],nkernels_End[2] , (1,7),(1,1), (0,3))
        self.linear=nn.Linear(nkernels_End[2]*chrom_width,chrom_width-1)
    
    def forward(self,old,new):
        x=torch.Tensor(old.shape[0],2,1,chrom_width)
        x[:,0,0,:]=old
        x[:,1,0,:]=new
        x=x.detach()
        x=activation_end(self.conv1(x))
        x=activation_end(self.conv2(x))
        x=activation_end(self.conv3(x))
        x=x.reshape(x.shape[0],nkernels_End[2]*chrom_width)
        x=self.linear(x)
        #number of samples* [(chrom_width regions)-1] 
        #Only have chrom_width-1=49 output dimensions because we fix the average these output
        #The average of them could be arbitrary because of the partitioning
        return x
    
#Implemts Q_{phi_3}(s,c,sp,ep) in section 3.3.1
#It combines two End_Point_Val nets mentioned above,
#with a switch structure in section 3.3.4 to form Q_{phi_3}(s,c,sp,ep)
class End_Point_NN(nn.Module):
    def __init__(self):
        super(End_Point_NN,self).__init__()
        #two network setting
        self.Val_noWGD=End_Point_Val()
        self.Val_WGD=End_Point_Val()
    
    def forward(self,old,new,sigma):
        
        y=F.relu((old-1)*(old-new))
        y=y[:,1:chrom_width]-y[:,0:1].expand(-1,chrom_width-1)
        y=-y.detach()*math.log(single_loci_loss)
        Val_no=self.Val_noWGD.forward(old,new)
        Val_no=Val_no+y
        
        z=(old-2*(old//2))*(1-(new-2*(new//2)))
        z=z[:,1:chrom_width]-z[:,0:1].expand(-1,chrom_width-1)
        z=-z.detach()*math.log(single_loci_loss)
        Val_wgd=self.Val_WGD.forward(old,new)
        Val_wgd=Val_wgd+z
        #switch
        x=sigma*Val_wgd+(1-sigma)*Val_no
        return x
    
    
    def find_end(self,old,new,sigma,start_loci,cnv,valid):
        #used for finding the end during loading data
        res_end=self.forward(old,new,sigma)
        
        break_end=torch.zeros(old.shape[0],chrom_width,requires_grad=False)
        chrom_shift=torch.zeros(old.shape[0],chrom_width,requires_grad=False)
        chrom_shift[:,:49]=old[:,1:]
        #allow adding one copy for every breakpoint
        break_end[:,:]=torch.ceil(torch.abs(old-chrom_shift)/max_copy)
        #always allow adding one chromosone
        break_end[:,chrom_width-1]=1
        
        for i in range(old.shape[0]):
            #can't end before starting point
            break_end[i,:int(start_loci[i])]=0*break_end[i,:int(start_loci[i])]
            #don't allow lose one copy when copy number equalls 1
            if(cnv[i]<0.5):
                j=int(start_loci[i])+1
                while(j<chrom_width):
                    if(old[i][j]<1.5):
                        break
                    j=j+1
                break_end[i,j:chrom_width]=0*break_end[i,j:chrom_width]
        res_end_full=torch.zeros(old.shape[0],chrom_width)
        res_end_full[:,1:]=res_end
        #Prior_rule=break_end
        res_end_full=res_end_full+torch.log(break_end)
        end_max_val,end_max=torch.max(res_end_full,1)
        return end_max+1
    
    
    def find_one_end(self,old,new,sigma,start,cnv):
        #used for finding the end during deconvolution
        res_end=self.forward(old,new,sigma)
        
        break_end=torch.zeros(old.shape[0],chrom_width,requires_grad=False)
        chrom_shift=torch.zeros(old.shape[0],chrom_width,requires_grad=False)
        chrom_shift[:,:chrom_width-1]=old[:,1:]
        #allow adding one copy for every breakpoint
        break_end[:,:]=torch.ceil(torch.abs(old-chrom_shift)/max_copy)
        #always allow adding one chromosone
        break_end[:,chrom_width-1]=1
        #can't end before starting point
        break_end[0,:start]=0*break_end[0,:start]
        #don't allow lose one copy when copy number equalls 1
        if(cnv<0.5):
            j=start+1
            while(j<chrom_width):
                if(old[0][j]<1.5):
                    break
                j=j+1
            break_end[0,j:chrom_width]=0*break_end[0,j:chrom_width]
        res_end_full=torch.zeros(old.shape[0],chrom_width)
        res_end_full[:,1:]=res_end
        #Prior_rule=break_end
        res_end_full=res_end_full+torch.log(break_end)
        end_max_val,end_max=torch.max(res_end_full,1)
        end_max=int(end_max[0])
        return end_max+1
        

#combine all separate networks
#add Rule system

#calculating the softmax
#prevent inf when taking log(exp(x))
#log_exp is always gonna be between 1 and the total number of elements
def Soft_update(val1,soft1,val2,soft2):
    bias=val1.clone()
    log_exp=soft1.clone()
    set1=[torch.ge(val1,val2)]
    bias[set1]=val1[set1]
    log_exp[set1]=soft1[set1]+soft2[set1]*torch.exp(val2[set1]-val1[set1])
    set2=[torch.lt(val1,val2)]
    bias[set2]=val2[set2]
    log_exp[set2]=soft2[set2]+soft1[set2]*torch.exp(val1[set2]-val2[set2])
    return bias,log_exp



#Combine all the separate modules mentioned above
#Implementation of Q(s,a)
class Q_learning(nn.Module):
    def __init__(self):
        super(Q_learning,self).__init__()
        self.switch=WGD_Net()
        #the output refer to Q_{\phi_1}(s,c)
        self.Chrom_model=Chrom_NN()
        #the output refer to Q_{\phi_2}(s,c,sp)
        self.CNV=CNV_NN()
        #the output refer to Q_{\phi_3}(s,sp,c,ep)
        self.End=End_Point_NN()
    
    
    def forward(self,state,next_state,chrom,chrom_new,Chr,cnv,start_loci,end_loci,valid):
        '''
        computing the final advantage(loss) used for training
        loss in Thereom1
        state: s in Q(s,a)
        next_state: s' in softmaxQ(s',a')
        Chr,cnv,end_loci: a in Q(s,a)
        chrom,chrom_new,start_loci,end_loci: intermediate results from s,a, which is preprossed to make computation faster
            e.g. chrom is CNP of the Chr(part of a) from the state(s)
            They could be seen as a mapping without parameters to learn:f(s,a)
        valid: a boolean array, indicating if a training sample is valid (e.g. have non negative copy numbers for all loci)
        '''
        
        #computing softmaxQ(s',a')
        #It is a tradition in RL that gradient does not backpropogate through softmaxQ(s',a'), but only through Q(s,a) to make convergence faster
        #there is no theoritical guarantee behind, and it is only a practical trick
        sigma_next=self.switch(next_state)
        x,y=self.Softmax(next_state,sigma_next)
        x=x+torch.log(y)
        #computing r(s,a)
        x=x+Reward(start_loci,end_loci)
        x=x.detach()
        
        #computing Q(s,a)
        sigma=self.switch.forward(state)
        if counter_global<3e6:
            sigma=sigma.detach()
        #Q_{phi_1}(s,c)
        res_chrom=self.Chrom_model.forward(state,sigma)
        
        #Q_{phi_2}(s,c,sp)
        res_cnv=self.CNV.forward(chrom,sigma)
        #if there is originally a break point for start
        #real world constraint as described in section 3.3.2
        #only allow starting points (sp) to be the break points of CNP
        break_start=torch.zeros(state.shape[0],chrom_width,2,requires_grad=False)
        chrom_shift=torch.zeros(state.shape[0],chrom_width,requires_grad=False)
        chrom_shift[:,1:]=chrom[:,:(chrom_width-1)]
        #allow adding one copy for every breakpoint
        break_start[:,:,1]=torch.ceil(torch.abs(chrom-chrom_shift)/max_copy)
        #always allow adding one chromosone
        break_start[:,0,1]=1
        #don't allow lose one copy when copy number equals 0, otherwise there is going to be negative copy numbers
        break_start[:,:,0]=break_start[:,:,1]
        break_start[:,:,0]=break_start[:,:,0]*torch.ceil((chrom/2-0.5)/max_copy)
        break_start=break_start.reshape(state.shape[0],2*chrom_width)
        res_cnv_full=torch.zeros(state.shape[0],2*chrom_width)
        res_cnv_full[:,1:]=res_cnv
        res_cnv_full=res_cnv_full+torch.log(break_start)
        
        #Q_{phi_2}(s,c,sp)-softmax(Q_{phi_2}(s,c,sp)) as described in section 3.3.1
        cnv_max_val,cnv_max=torch.max(res_cnv_full,1)
        cnv_softmax=res_cnv_full-cnv_max_val.reshape(state.shape[0],1).expand(-1,2*chrom_width)
        cnv_softmax=torch.exp(cnv_softmax).sum(1)
        x=x+cnv_max_val+torch.log(cnv_softmax)
        
        #Q_{phi_3}(s,c,sp,ep)
        res_end=self.End.forward(chrom,chrom_new,sigma)
        #if there is originally a break point for end
        #and if this is after the starting point
        #real world constraint in section 3.3.2
        break_end=torch.zeros(state.shape[0],chrom_width,requires_grad=False)
        chrom_shift=torch.zeros(state.shape[0],chrom_width,requires_grad=False)
        chrom_shift[:,:(chrom_width-1)]=chrom[:,1:]
        #allow adding one copy for every breakpoint
        break_end[:,:]=torch.ceil(torch.abs(chrom-chrom_shift)/max_copy)
        #always allow adding one chromosone
        break_end[:,chrom_width-1]=1
        for i in range(state.shape[0]):
            #can't end before starting point
            break_end[i,:int(start_loci[i])]=0*break_end[i,:int(start_loci[i])]
            #don't allow lose one copy when copy number equalls 1
            if(cnv[i]<0.5):
                j=int(start_loci[i])+1
                while(j<chrom_width):
                    if(chrom[i][j]<1.5):
                        break
                    j=j+1
                break_end[i,j:chrom_width]=0*break_end[i,j:chrom_width]
            
        res_end_full=torch.zeros(state.shape[0],chrom_width)
        res_end_full[:,1:]=res_end
        
        #real world constraint described in section 3.3.2
        res_end_full=res_end_full+torch.log(break_end)
        end_max_val,end_max_temp=torch.max(res_end_full,1)
        end_softmax=res_end_full-end_max_val.reshape(state.shape[0],1).expand(-1,chrom_width)
        end_softmax=torch.exp(end_softmax).sum(1)
        #Q_{phi_3}(s,c,sp,ep)-softmax(Q_{phi_3}(s,c,sp,ep)) as described in section 3.3.1
        x=x+end_max_val+torch.log(end_softmax)
        
        for i in range(state.shape[0]):
            if valid[i]>0.5:#check validity to prevent inf-inf which ends in nan
                x[i]=x[i]-res_chrom[i][int(Chr[i])]
                cnv_rank=int(start_loci[i]*2+cnv[i])
                x[i]=x[i]-res_cnv_full[i][cnv_rank]
                end_rank=int(end_loci[i]-1)
                x[i]=x[i]-res_end_full[i][end_rank]
        
        #remove training data which include invalid actions
        x=x*valid
        #return avdantage as well as a best cnv and sigma used for generating training data
        #used for training in the next step
        return x,cnv_max,sigma,res_chrom,res_cnv_full,res_end_full
     
    def Softmax(self,next_state,sigma):
        #compute softmax_{a'} Q(s',a')
        x=self.Chrom_model.forward(next_state,sigma)
        max_chrom=torch.max(x,1)[0]
        softmax_chrom=x-max_chrom.reshape(x.shape[0],1).expand(-1,num_chromosome)
        softmax_chrom=torch.exp(softmax_chrom).sum(1)
        #special action END
        #all the remaining abnormal loci are treated to be caused by several independent single locus copy number changes
        end_val=torch.sum(torch.abs(next_state-1),(1,2,3))*math.log(single_loci_loss)
        max_chrom,softmax_chrom=Soft_update(max_chrom,softmax_chrom,end_val,torch.ones(x.shape[0]))
        #if there is a WGD followed immediately
        for i in range(x.shape[0]):
            #real world constraint as described in section 3.3.2
            #do not allow (reversing) WGD when the CNP contain odd numbers for some loci
            if (not torch.any(next_state[i]-2*torch.floor(next_state[i]/2)>0.5)) and torch.any(next_state[i]>0.5):
                sigma_wgd=self.switch(torch.floor(next_state[i:(i+1)]/2))
                sigma_wgd=sigma_wgd.detach()
                wgd_val,wgd_soft=self.Softmax(torch.floor(next_state[i:(i+1)]/2),sigma_wgd)
                max_chrom[i],softmax_chrom[i]=Soft_update(torch.ones(1)*max_chrom[i],torch.ones(1)*softmax_chrom[i],torch.ones(1)*wgd_val,torch.ones(1)*wgd_soft)
        
        return max_chrom,softmax_chrom
  

#Minimum example
if __name__ == "__main__":
    #test different parts separately
    '''
    switch=WGD_Net()
    Chrom_model=Chrom_NN()
    print(Chrom_model)
    #test the structure of permutation invariant structure
    x=torch.ones(3,1,num_chromosome,50)
    x[0][0][0][0:50]=2
    x[2][0][1][0:50]=2
    prob=switch.forward(x)
    print(prob)
    res=Chrom_model.forward(x,prob)
    print(res)
    res=-float('inf')
    res=torch.LongTensor(3)
    
    print(torch.log(res.type(torch.DoubleTensor)))
    #CNV
    CNV=CNV_NN()
    res=CNV.forward(x[:,0,0,0:50],prob)
    print(CNV)
    print(res.shape)
    #END
    End=End_Point_NN()
    res=End.forward(x[:,0,0,0:50],x[:,0,0,0:50]+1,prob)
    print(End)
    print(res.shape)
    '''
    #test Q-learning
    x=torch.ones(3,1,num_chromosome,50)
    y=torch.ones(3,1,num_chromosome,50)
    x[0][0][0][0:50]=2
    x[2][0][1][0:50]=2
    chrom=x[:,0,0,:]
    chrom_new=y[:,0,0,:]
    Chr=torch.zeros(3)
    cnv=torch.ones(3)
    start_loci=torch.zeros(3)
    end_loci=torch.ones(3)*50
    valid=torch.ones(3)
    Q_model=Q_learning()
    #res,cnv_max,sigma,t,t2,t3=Q_model.forward(x,y,chrom,chrom_new,Chr,cnv,start_loci,end_loci,valid)
    #print(res)
    #print(cnv_max)
    #loss=res.pow(2).mean()
    #print(loss)
    #loss.backward()
    #params = list(Q_model.parameters())
    #print(params[0].grad[0])
    #print(Q_model.switch.conv1.weight[0])

In [9]:
#Train_data.py
import torch
import math

batch_size=30
#
#during training
#data are simulated backwards
#when step==0, it means it is the last step for the trajectory
#and step++ to make CNP more complex
def Simulate_train_data(first_step_flag=True,state=None,next_state=None,advantage=None,Chr=None,step=None,wgd=None,valid=None):
    #Simulate data for training (similar to the case when a machine is playing a game against itself)
    #Thus, we don't need real world data during training, as long as the reward is similar to the real world probability
    #As in theorem 1, there is no specific destribution required to compute expectation over (s,a) pairs
    #Any distribution with broad support over all (s,a) will do the job
    if first_step_flag:
        #The first simulated sample
        state=torch.ones(batch_size,1,num_chromosome,chrom_width,requires_grad=False)
        next_state=torch.ones(batch_size,1,num_chromosome,chrom_width,requires_grad=False)
        Chr=torch.ones(batch_size,requires_grad=False).type(torch.LongTensor)
        step=torch.zeros(batch_size,requires_grad=False)
        advantage=torch.zeros(batch_size)
        wgd=torch.zeros(batch_size,requires_grad=False)
        valid=torch.ones(batch_size,requires_grad=False)
    
    #sample starting point, end point, gain or loss  
    #because of the permutation invariant structure in section 3.3.3
    #it is not necessary to resample the chromosome everytime
    start_loci=torch.randint(high=chrom_width,size=(batch_size,),requires_grad=False)
    end_loci=torch.LongTensor(batch_size)
    cnv=torch.ones(batch_size,requires_grad=False)
    chrom=torch.Tensor(batch_size,chrom_width)
    chrom_new=torch.Tensor(batch_size,chrom_width)
    #probability of resetting the training trajectory back to step=0
    step_prob=0.18+0.8/(1+math.exp(-1e-2*counter_global+2))
    for i in range(batch_size):
        #if the model is poorly trained until the current step
        #go back to the state 0
        #to ensure small error for short trajectories
        if(torch.rand(1)[0]>step_prob or torch.abs(advantage[i])>=30 or wgd.sum()>24 or step[i]>90):
            state[i]=torch.ones(1,num_chromosome,chrom_width,requires_grad=False)
            next_state[i]=torch.ones(1,num_chromosome,chrom_width,requires_grad=False)
            step[i]=0
            wgd[i]=0
        #if model is fully trained for the current step
        #and there is no invalid operations been sampled
        #go to next step
        elif(valid[i]>0 and torch.abs(advantage[i])<10):
            next_state[i]=state[i].clone()
            step[i]=step[i]+1
        #stay to further train the current step
        #or resample another action
        else:
            state[i]=next_state[i].clone()
    
        #reset advantage and valid after they have been checked
        advantage[i]=0
        valid[i]=1
        end_loci[i]=1+torch.randint(low=start_loci[i],high=50,size=(1,))[0]
        #change the chromosone that CNV is on with some probability
        #otherwise, all CNV will be on the same chromosome
        Chr[i]=torch.randint(high=num_chromosome,size=(1,))[0]
        #adding probability to sample whole chromosomal changes during training
        if torch.rand(1)[0]>0.8:
            start_loci[i]=0
            end_loci[i]=chrom_width
        #adding probability to sample losses starting from the start of chromosome
        if torch.rand(1)[0]>0.3:
            cnv[i]=0
        #increasing the probability to sample WGD during training
        prob_wgd=0.1/(1+math.exp(-step[i]+15))
        #starting to modify state and next state
        #extract preprocessing data
        #wgd          
        if (torch.rand(1)[0]<prob_wgd and wgd[i]<1) or (sum(wgd)<5):
            wgd[i]=1
            state[i]=state[i]*2
            next_state[i]=next_state[i]*2
        #adding cnv effect
        #increasing copies when no wgd
        #decreasing copies when wgd
        if wgd[i]>0.5:
            cnv[i]=1-cnv[i]
        state[i][0][Chr[i]][(start_loci[i]):(end_loci[i])]=state[i][0][Chr[i]][(start_loci[i]):(end_loci[i])]-(cnv[i]-0.5)*2
        chrom[i]=state[i][0][Chr[i]][:]
        #reverse effect on chrom_new
        chrom_new[i]=state[i][0][Chr[i]][:]
        chrom_new[i][(start_loci[i]):]=chrom_new[i][(start_loci[i]):]+(cnv[i]-0.5)*2
        #not going to negative values
        if(torch.any(state[i][0][Chr[i]][(start_loci[i])]< -0.5)):
            valid[i]=0
        #not joining breakpoints
        if(start_loci[i]>0.5 and torch.abs(chrom[i][start_loci[i]]-chrom[i][start_loci[i]-1])<0.5):
            valid[i]=0
        if(end_loci[i]<chrom_width-0.5 and torch.abs(chrom[i][end_loci[i]-1]-chrom[i][end_loci[i]])<0.5):
            valid[i]=0
        if cnv[i]>0.5 and (torch.any(state[i][0][Chr[i]][(start_loci[i]):(end_loci[i])]< 0.5)):
            valid[i]=0
    return state,next_state,chrom,chrom_new,Chr,cnv,start_loci,end_loci,wgd,step,advantage,valid

def Modify_data(state,chrom,Chr,valid,cnv_max,model,sigma):
    #Modify the training data to train the Q values for the best action
    #place takers
    #make sure they are of correct tensor types
    #make sure they are meaningful values to avoid inf if they are not valid samples
    #otherwise nan may be generated
    start_loci=torch.randint(high=chrom_width,size=(batch_size,),requires_grad=False)
    end_loci=start_loci.clone()
    cnv=torch.ones(batch_size,requires_grad=False)
    next_state=state.clone()
    chrom_new=chrom.clone()
    advantage=torch.zeros(batch_size)
    for i in range(batch_size):
        #only deal with valid samples
        if valid[i]>0.5:
            start_loci[i]=cnv_max[i]//2
            cnv[i]=cnv_max[i]-start_loci[i]*2
            #update chrom_new
            chrom_new[i][(start_loci[i]):]=chrom_new[i][(start_loci[i]):]+(cnv[i]-0.5)*2
    
    end_loci=model.find_end(chrom,chrom_new,sigma,start_loci,cnv,valid)
    
    for i in range(batch_size):
        next_state[i][0][Chr[i]][(start_loci[i]):(end_loci[i])]=next_state[i][0][Chr[i]][(start_loci[i]):(end_loci[i])]+(cnv[i]-0.5)*2
      
      
    return state,next_state,chrom,chrom_new,cnv,start_loci,end_loci,advantage

#simulate data for testing
def Simulate_data(batch_size=15,Number_of_step=70):
    state=torch.ones(batch_size,1,num_chromosome,chrom_width,requires_grad=False)
    next_state=torch.ones(batch_size,1,num_chromosome,chrom_width,requires_grad=False)
    Chr=torch.ones(batch_size,requires_grad=False).type(torch.LongTensor)
    step=torch.zeros(batch_size,requires_grad=False)
    wgd=torch.zeros(batch_size,requires_grad=False)
    valid=torch.ones(batch_size,requires_grad=False)
    
    start_loci=torch.randint(high=chrom_width,size=(batch_size,),requires_grad=False)
    end_loci=torch.LongTensor(batch_size)
    cnv=torch.ones(batch_size,requires_grad=False)
    chrom=torch.Tensor(batch_size,chrom_width)
    chrom_new=torch.Tensor(batch_size,chrom_width)
    
    step_counter=0
    while(step_counter<Number_of_step):
        for i in range(batch_size):
            #reset valid after they have been checked
            valid[i]=1
            end_loci[i]=1+torch.randint(low=start_loci[i],high=50,size=(1,))[0]
            #change the chromosone that CNV is on with some probability
            if torch.rand(1)[0]>0.5:
                Chr[i]=torch.randint(high=num_chromosome,size=(1,))[0]
            #adding probability to sample chromosomal changes during training
            if torch.rand(1)[0]>0.8:
                start_loci[i]=0
                end_loci[i]=chrom_width
            #cnv
            if torch.rand(1)[0]>0.7:
                cnv[i]=0
            #modifying cnp
            prob_wgd=0.4/(1+math.exp(-step[i]+5))
            #wgd          
            if (torch.rand(1)[0]<prob_wgd and wgd[i]<1):
                wgd[i]=1
                state[i]=state[i]*2
                next_state[i]=next_state[i]*2
                #adding cnv effect
                #increasing copies when no wgd
                #decreasing copies when wgd
            if wgd[i]>0.5:
                cnv[i]=1-cnv[i]
            state[i][0][Chr[i]][(start_loci[i]):(end_loci[i])]=state[i][0][Chr[i]][(start_loci[i]):(end_loci[i])]-(cnv[i]-0.5)*2
            chrom[i]=state[i][0][Chr[i]][:]
            #reverse effect on chrom_new
            chrom_new[i]=state[i][0][Chr[i]][:]
            chrom_new[i][(start_loci[i]):]=chrom_new[i][(start_loci[i]):]+(cnv[i]-0.5)*2
            #not going to negative values
            if(torch.any(state[i][0][Chr[i]][(start_loci[i])]< -0.5)):
                valid[i]=0
            #not joining breakpoints
            if(start_loci[i]>0.5 and torch.abs(chrom[i][start_loci[i]]-chrom[i][start_loci[i]-1])<0.5):
                valid[i]=0
            if(end_loci[i]<chrom_width-0.5 and torch.abs(chrom[i][end_loci[i]-1]-chrom[i][end_loci[i]])<0.5):
                valid[i]=0
            if valid[i]>0 :
                next_state[i]=state[i].clone()
                step[i]=step[i]+1
            #stay to further train the current step
            #or resample another action
            else:
                state[i]=next_state[i].clone()
        step_counter=step_counter+1
    return state



In [None]:
#main.py
import torch
import math
import torch.optim as optim
#import Policy
#import Data_train
#setting up counter
counter_global=0
#Model
Q_model=Q_learning()
#Load initial data
state,next_state,chrom,chrom_new,Chr,cnv,start_loci,end_loci,wgd,step,advantage,valid=Simulate_train_data()
#setting up optimizer
optimizer = optim.Adam(Q_model.parameters(), lr=1e-4,betas=(0.9, 0.99), eps=1e-06, weight_decay=1e-3)

#start training
while(counter_global< 3e8):
    counter_global=counter_global+1
    #load data
    state,next_state,chrom,chrom_new,Chr,cnv,start_loci,end_loci,wgd,step,advantage,valid=Simulate_train_data(False,state,next_state,advantage,Chr,step,wgd,valid)
    #compute advantage
    optimizer.zero_grad()
    advantage,cnv_max,sigma,temp,t2,t3=Q_model.forward(state,next_state,chrom,chrom_new,Chr,cnv,start_loci,end_loci,valid)
    #compute loss
    loss=advantage.pow(2).mean()
    
    if counter_global<3e6:
        sigma=Q_model.switch.forward(state)
        loss+=0.1*((-wgd.view(-1,1)*torch.log(sigma)-(1-wgd.view(-1,1))*torch.log(1-sigma))*valid.view(-1,1)).mean()
    #train the model
    loss.backward()
    optimizer.step()
    #print(loss)
    
    #training with the best action
    #temp for the values both used in training and loading new data
    state_temp,next_state_temp,chrom,chrom_new,cnv,start_loci,end_loci,advantage_temp=Modify_data(state,chrom,Chr,valid,cnv_max,Q_model.End,sigma)
    #compute advantage
    optimizer.zero_grad()
    advantage,cnv_max,sigma,temp,temp2,temp3=Q_model.forward(state_temp,next_state_temp,chrom,chrom_new,Chr,cnv,start_loci,end_loci,valid)
    #compute loss
    loss=advantage.pow(2).mean()
    #print(loss)
    if(counter_global%10==9):
        print(loss,step.mean(),step.max(),wgd.sum())
        torch.save(Q_model.state_dict(),"/data/suzaku/ted/HOME/model")
        #print(temp[0])
        #train the model
    loss.backward()
    optimizer.step()
    
    if(counter_global%100==99):
        file_object = open('/data/suzaku/ted/HOME/log', 'a')
        file_object.write("loss:"+str(loss.item())+" step:"+str(step.mean().item())+", "+str(step.max().item())+", "+str(wgd.sum().item())+"\n")
        file_object.close()

tensor(85.0701, grad_fn=<MeanBackward0>) tensor(0.0333) tensor(1.) tensor(5.)
tensor(42.9819, grad_fn=<MeanBackward0>) tensor(0.2000) tensor(2.) tensor(5.)
tensor(28.7249, grad_fn=<MeanBackward0>) tensor(0.2667) tensor(2.) tensor(5.)
tensor(16.6450, grad_fn=<MeanBackward0>) tensor(0.2667) tensor(2.) tensor(5.)
tensor(13.9515, grad_fn=<MeanBackward0>) tensor(0.4000) tensor(2.) tensor(5.)
tensor(12.8602, grad_fn=<MeanBackward0>) tensor(0.3667) tensor(4.) tensor(5.)
tensor(13.0781, grad_fn=<MeanBackward0>) tensor(0.1667) tensor(1.) tensor(5.)
tensor(10.5207, grad_fn=<MeanBackward0>) tensor(0.2667) tensor(2.) tensor(5.)
tensor(11.5275, grad_fn=<MeanBackward0>) tensor(0.6667) tensor(3.) tensor(5.)
tensor(9.4326, grad_fn=<MeanBackward0>) tensor(0.6667) tensor(3.) tensor(5.)
tensor(9.9512, grad_fn=<MeanBackward0>) tensor(0.3333) tensor(3.) tensor(5.)
tensor(11.8939, grad_fn=<MeanBackward0>) tensor(0.3333) tensor(3.) tensor(5.)
tensor(18.8991, grad_fn=<MeanBackward0>) tensor(0.7000) tensor(4.)

tensor(93.7312, grad_fn=<MeanBackward0>) tensor(22.4000) tensor(89.) tensor(18.)
tensor(70.3995, grad_fn=<MeanBackward0>) tensor(18.1667) tensor(88.) tensor(15.)
tensor(60.5984, grad_fn=<MeanBackward0>) tensor(15.7333) tensor(34.) tensor(14.)
tensor(71.8731, grad_fn=<MeanBackward0>) tensor(15.2000) tensor(38.) tensor(12.)
tensor(83.3199, grad_fn=<MeanBackward0>) tensor(15.3667) tensor(43.) tensor(11.)
tensor(64.4313, grad_fn=<MeanBackward0>) tensor(15.5000) tensor(47.) tensor(11.)
tensor(72.7047, grad_fn=<MeanBackward0>) tensor(14.) tensor(54.) tensor(11.)
tensor(79.6723, grad_fn=<MeanBackward0>) tensor(11.8333) tensor(24.) tensor(10.)
tensor(85.3776, grad_fn=<MeanBackward0>) tensor(12.4000) tensor(24.) tensor(14.)
tensor(87.1294, grad_fn=<MeanBackward0>) tensor(12.) tensor(24.) tensor(15.)
tensor(68.0527, grad_fn=<MeanBackward0>) tensor(12.5000) tensor(23.) tensor(14.)
tensor(73.8594, grad_fn=<MeanBackward0>) tensor(13.3333) tensor(28.) tensor(14.)
tensor(78.9545, grad_fn=<MeanBackwar

tensor(115.2190, grad_fn=<MeanBackward0>) tensor(16.9333) tensor(81.) tensor(15.)
tensor(101.2908, grad_fn=<MeanBackward0>) tensor(17.5667) tensor(84.) tensor(14.)
tensor(73.7151, grad_fn=<MeanBackward0>) tensor(13.9667) tensor(85.) tensor(10.)
tensor(76.3885, grad_fn=<MeanBackward0>) tensor(13.9667) tensor(89.) tensor(11.)
tensor(74.6543, grad_fn=<MeanBackward0>) tensor(12.1333) tensor(28.) tensor(9.)
tensor(90.4655, grad_fn=<MeanBackward0>) tensor(13.8667) tensor(28.) tensor(12.)
tensor(116.3869, grad_fn=<MeanBackward0>) tensor(14.9333) tensor(29.) tensor(16.)
tensor(110.3038, grad_fn=<MeanBackward0>) tensor(14.6000) tensor(29.) tensor(15.)
tensor(108.9276, grad_fn=<MeanBackward0>) tensor(15.0333) tensor(29.) tensor(16.)
tensor(110.7410, grad_fn=<MeanBackward0>) tensor(14.6667) tensor(29.) tensor(16.)
tensor(108.0664, grad_fn=<MeanBackward0>) tensor(15.8000) tensor(29.) tensor(16.)
tensor(101.6150, grad_fn=<MeanBackward0>) tensor(14.9333) tensor(31.) tensor(16.)
tensor(60.3988, grad_

tensor(53.3547, grad_fn=<MeanBackward0>) tensor(17.5667) tensor(45.) tensor(10.)
tensor(53.4967, grad_fn=<MeanBackward0>) tensor(19.7333) tensor(53.) tensor(11.)
tensor(83.8025, grad_fn=<MeanBackward0>) tensor(21.6667) tensor(62.) tensor(12.)
tensor(113.5762, grad_fn=<MeanBackward0>) tensor(23.1667) tensor(72.) tensor(16.)
tensor(128.1151, grad_fn=<MeanBackward0>) tensor(25.2667) tensor(80.) tensor(20.)
tensor(109.2079, grad_fn=<MeanBackward0>) tensor(22.1000) tensor(87.) tensor(17.)
tensor(97.8220, grad_fn=<MeanBackward0>) tensor(19.4333) tensor(89.) tensor(15.)
tensor(93.3913, grad_fn=<MeanBackward0>) tensor(18.2000) tensor(85.) tensor(16.)
tensor(83.8008, grad_fn=<MeanBackward0>) tensor(18.2333) tensor(88.) tensor(16.)
tensor(62.6299, grad_fn=<MeanBackward0>) tensor(13.2667) tensor(29.) tensor(13.)
tensor(80.1606, grad_fn=<MeanBackward0>) tensor(14.8667) tensor(31.) tensor(16.)
tensor(68.7112, grad_fn=<MeanBackward0>) tensor(14.2333) tensor(34.) tensor(14.)
tensor(68.5264, grad_fn=<

tensor(91.5906, grad_fn=<MeanBackward0>) tensor(16.2333) tensor(61.) tensor(13.)
tensor(90.9730, grad_fn=<MeanBackward0>) tensor(16.5667) tensor(63.) tensor(12.)
tensor(91.9747, grad_fn=<MeanBackward0>) tensor(16.3333) tensor(64.) tensor(11.)
tensor(99.8771, grad_fn=<MeanBackward0>) tensor(16.8000) tensor(69.) tensor(13.)
tensor(79.3444, grad_fn=<MeanBackward0>) tensor(13.8667) tensor(72.) tensor(12.)
tensor(85.9434, grad_fn=<MeanBackward0>) tensor(16.1333) tensor(75.) tensor(13.)
tensor(108.3257, grad_fn=<MeanBackward0>) tensor(18.6333) tensor(78.) tensor(15.)
tensor(92.9110, grad_fn=<MeanBackward0>) tensor(16.1667) tensor(81.) tensor(14.)
tensor(87.1156, grad_fn=<MeanBackward0>) tensor(14.5667) tensor(87.) tensor(14.)
tensor(69.2078, grad_fn=<MeanBackward0>) tensor(10.3000) tensor(27.) tensor(10.)
tensor(43.3176, grad_fn=<MeanBackward0>) tensor(11.1333) tensor(30.) tensor(8.)
tensor(35.3015, grad_fn=<MeanBackward0>) tensor(12.2667) tensor(33.) tensor(8.)
tensor(47.0907, grad_fn=<Mean

tensor(85.9187, grad_fn=<MeanBackward0>) tensor(14.9667) tensor(79.) tensor(14.)
tensor(71.4830, grad_fn=<MeanBackward0>) tensor(14.5333) tensor(33.) tensor(15.)
tensor(65.8327, grad_fn=<MeanBackward0>) tensor(15.1667) tensor(34.) tensor(14.)
tensor(65.7570, grad_fn=<MeanBackward0>) tensor(15.6333) tensor(40.) tensor(12.)
tensor(81.1152, grad_fn=<MeanBackward0>) tensor(17.4667) tensor(46.) tensor(15.)
tensor(53.9432, grad_fn=<MeanBackward0>) tensor(14.4667) tensor(48.) tensor(10.)
tensor(66.7341, grad_fn=<MeanBackward0>) tensor(16.7000) tensor(52.) tensor(11.)
tensor(80.4245, grad_fn=<MeanBackward0>) tensor(18.8000) tensor(58.) tensor(13.)
tensor(91.2864, grad_fn=<MeanBackward0>) tensor(17.7333) tensor(65.) tensor(14.)
tensor(89.0144, grad_fn=<MeanBackward0>) tensor(19.0333) tensor(70.) tensor(15.)
tensor(97.6007, grad_fn=<MeanBackward0>) tensor(20.8000) tensor(74.) tensor(19.)
tensor(90.5366, grad_fn=<MeanBackward0>) tensor(20.2667) tensor(76.) tensor(17.)
tensor(94.7617, grad_fn=<Mea

tensor(105.6539, grad_fn=<MeanBackward0>) tensor(19.8000) tensor(45.) tensor(16.)
tensor(83.3116, grad_fn=<MeanBackward0>) tensor(17.4667) tensor(53.) tensor(13.)
tensor(96.4566, grad_fn=<MeanBackward0>) tensor(18.4667) tensor(59.) tensor(16.)
tensor(106.5476, grad_fn=<MeanBackward0>) tensor(18.8667) tensor(68.) tensor(15.)
tensor(100.0389, grad_fn=<MeanBackward0>) tensor(16.4000) tensor(72.) tensor(16.)
tensor(101.8368, grad_fn=<MeanBackward0>) tensor(17.2667) tensor(78.) tensor(18.)
tensor(98.7952, grad_fn=<MeanBackward0>) tensor(16.9667) tensor(84.) tensor(16.)
tensor(60.3927, grad_fn=<MeanBackward0>) tensor(13.0667) tensor(26.) tensor(12.)
tensor(62.3844, grad_fn=<MeanBackward0>) tensor(14.6333) tensor(29.) tensor(11.)
tensor(66.8303, grad_fn=<MeanBackward0>) tensor(17.3333) tensor(36.) tensor(12.)
tensor(89.0440, grad_fn=<MeanBackward0>) tensor(17.1667) tensor(44.) tensor(16.)
tensor(112.4866, grad_fn=<MeanBackward0>) tensor(18.5333) tensor(53.) tensor(17.)
tensor(97.3959, grad_fn

tensor(102.2316, grad_fn=<MeanBackward0>) tensor(15.8333) tensor(43.) tensor(19.)
tensor(112.3100, grad_fn=<MeanBackward0>) tensor(17.7667) tensor(48.) tensor(22.)
tensor(98.6045, grad_fn=<MeanBackward0>) tensor(17.4000) tensor(52.) tensor(20.)
tensor(88.4370, grad_fn=<MeanBackward0>) tensor(16.2000) tensor(55.) tensor(17.)
tensor(69.2667, grad_fn=<MeanBackward0>) tensor(16.6667) tensor(49.) tensor(15.)
tensor(68.6949, grad_fn=<MeanBackward0>) tensor(15.2667) tensor(36.) tensor(13.)
tensor(91.2455, grad_fn=<MeanBackward0>) tensor(17.7000) tensor(37.) tensor(16.)
tensor(75.6769, grad_fn=<MeanBackward0>) tensor(15.6000) tensor(43.) tensor(13.)
tensor(88.7596, grad_fn=<MeanBackward0>) tensor(17.8000) tensor(48.) tensor(16.)
tensor(80.9257, grad_fn=<MeanBackward0>) tensor(16.1667) tensor(54.) tensor(14.)
tensor(95.9866, grad_fn=<MeanBackward0>) tensor(18.7000) tensor(60.) tensor(17.)
tensor(77.9507, grad_fn=<MeanBackward0>) tensor(16.6000) tensor(67.) tensor(16.)
tensor(88.6579, grad_fn=<M

tensor(81.7517, grad_fn=<MeanBackward0>) tensor(14.5667) tensor(30.) tensor(14.)
tensor(78.4557, grad_fn=<MeanBackward0>) tensor(14.1667) tensor(32.) tensor(15.)
tensor(80.7405, grad_fn=<MeanBackward0>) tensor(15.6000) tensor(32.) tensor(18.)
tensor(65.4743, grad_fn=<MeanBackward0>) tensor(15.8000) tensor(38.) tensor(16.)
tensor(77.4883, grad_fn=<MeanBackward0>) tensor(17.5333) tensor(38.) tensor(16.)
tensor(66.7361, grad_fn=<MeanBackward0>) tensor(19.4667) tensor(46.) tensor(14.)
tensor(113.0746, grad_fn=<MeanBackward0>) tensor(23.4333) tensor(54.) tensor(18.)
tensor(116.7119, grad_fn=<MeanBackward0>) tensor(24.6000) tensor(62.) tensor(20.)
tensor(107.3117, grad_fn=<MeanBackward0>) tensor(23.4333) tensor(72.) tensor(18.)
tensor(67.8359, grad_fn=<MeanBackward0>) tensor(19.5333) tensor(76.) tensor(13.)
tensor(70.6346, grad_fn=<MeanBackward0>) tensor(19.5667) tensor(82.) tensor(14.)
tensor(90.9488, grad_fn=<MeanBackward0>) tensor(18.7000) tensor(87.) tensor(15.)
tensor(88.9350, grad_fn=<

tensor(64.4051, grad_fn=<MeanBackward0>) tensor(17.) tensor(89.) tensor(15.)
tensor(54.9413, grad_fn=<MeanBackward0>) tensor(13.3000) tensor(37.) tensor(13.)
tensor(65.8230, grad_fn=<MeanBackward0>) tensor(14.9333) tensor(43.) tensor(13.)
tensor(70.7504, grad_fn=<MeanBackward0>) tensor(17.5667) tensor(47.) tensor(13.)
tensor(96.1380, grad_fn=<MeanBackward0>) tensor(21.) tensor(56.) tensor(17.)
tensor(93.2712, grad_fn=<MeanBackward0>) tensor(21.5333) tensor(63.) tensor(17.)
tensor(112.5062, grad_fn=<MeanBackward0>) tensor(22.1667) tensor(70.) tensor(17.)
tensor(103.5437, grad_fn=<MeanBackward0>) tensor(24.1667) tensor(79.) tensor(16.)
tensor(122.6058, grad_fn=<MeanBackward0>) tensor(20.9667) tensor(67.) tensor(19.)
tensor(94.6251, grad_fn=<MeanBackward0>) tensor(19.5333) tensor(74.) tensor(16.)
tensor(75.4237, grad_fn=<MeanBackward0>) tensor(18.8333) tensor(79.) tensor(14.)
tensor(92.9615, grad_fn=<MeanBackward0>) tensor(21.1000) tensor(85.) tensor(16.)
tensor(116.2708, grad_fn=<MeanBac

tensor(70.0021, grad_fn=<MeanBackward0>) tensor(19.5667) tensor(58.) tensor(12.)
tensor(71.9946, grad_fn=<MeanBackward0>) tensor(19.9333) tensor(65.) tensor(12.)
tensor(58.9305, grad_fn=<MeanBackward0>) tensor(21.1000) tensor(73.) tensor(13.)
tensor(60.2217, grad_fn=<MeanBackward0>) tensor(21.1000) tensor(79.) tensor(12.)
tensor(72.0261, grad_fn=<MeanBackward0>) tensor(24.0333) tensor(83.) tensor(16.)
tensor(77.2508, grad_fn=<MeanBackward0>) tensor(22.2667) tensor(87.) tensor(17.)
tensor(74.3948, grad_fn=<MeanBackward0>) tensor(20.8667) tensor(90.) tensor(17.)
tensor(80.3233, grad_fn=<MeanBackward0>) tensor(19.5333) tensor(89.) tensor(19.)
tensor(66.9612, grad_fn=<MeanBackward0>) tensor(20.3000) tensor(90.) tensor(18.)
tensor(67.7904, grad_fn=<MeanBackward0>) tensor(17.1000) tensor(44.) tensor(14.)
tensor(58.9168, grad_fn=<MeanBackward0>) tensor(15.8333) tensor(50.) tensor(13.)
tensor(100.9477, grad_fn=<MeanBackward0>) tensor(15.9000) tensor(53.) tensor(14.)
tensor(86.9262, grad_fn=<Me

tensor(85.6857, grad_fn=<MeanBackward0>) tensor(13.8000) tensor(41.) tensor(14.)
tensor(63.9752, grad_fn=<MeanBackward0>) tensor(14.2333) tensor(41.) tensor(14.)
tensor(62.3474, grad_fn=<MeanBackward0>) tensor(14.) tensor(44.) tensor(15.)
tensor(68.1214, grad_fn=<MeanBackward0>) tensor(17.3667) tensor(46.) tensor(17.)
tensor(61.8680, grad_fn=<MeanBackward0>) tensor(20.1333) tensor(52.) tensor(16.)
tensor(61.1626, grad_fn=<MeanBackward0>) tensor(18.2000) tensor(59.) tensor(12.)
tensor(70.8711, grad_fn=<MeanBackward0>) tensor(18.4000) tensor(50.) tensor(11.)
tensor(89.8376, grad_fn=<MeanBackward0>) tensor(19.8000) tensor(58.) tensor(14.)
tensor(89.1244, grad_fn=<MeanBackward0>) tensor(17.9000) tensor(64.) tensor(14.)
tensor(92.8488, grad_fn=<MeanBackward0>) tensor(18.8667) tensor(73.) tensor(14.)
tensor(93.3546, grad_fn=<MeanBackward0>) tensor(17.9000) tensor(66.) tensor(16.)
tensor(94.7736, grad_fn=<MeanBackward0>) tensor(19.4667) tensor(70.) tensor(17.)
tensor(120.9863, grad_fn=<MeanBa

tensor(93.9241, grad_fn=<MeanBackward0>) tensor(15.8333) tensor(51.) tensor(14.)
tensor(82.9312, grad_fn=<MeanBackward0>) tensor(13.3667) tensor(57.) tensor(11.)
tensor(104.3263, grad_fn=<MeanBackward0>) tensor(14.7333) tensor(66.) tensor(13.)
tensor(94.4094, grad_fn=<MeanBackward0>) tensor(15.7000) tensor(75.) tensor(13.)
tensor(100.4583, grad_fn=<MeanBackward0>) tensor(17.5000) tensor(82.) tensor(16.)
tensor(84.5294, grad_fn=<MeanBackward0>) tensor(19.7667) tensor(87.) tensor(17.)
tensor(98.0684, grad_fn=<MeanBackward0>) tensor(20.0667) tensor(89.) tensor(20.)
tensor(64.3866, grad_fn=<MeanBackward0>) tensor(17.0667) tensor(36.) tensor(17.)
tensor(69.7096, grad_fn=<MeanBackward0>) tensor(20.3000) tensor(41.) tensor(17.)
tensor(61.6715, grad_fn=<MeanBackward0>) tensor(22.0333) tensor(47.) tensor(16.)
tensor(102.4411, grad_fn=<MeanBackward0>) tensor(22.1667) tensor(54.) tensor(16.)
tensor(100.1520, grad_fn=<MeanBackward0>) tensor(24.3667) tensor(62.) tensor(17.)
tensor(110.5488, grad_fn

tensor(73.4400, grad_fn=<MeanBackward0>) tensor(17.1000) tensor(50.) tensor(12.)
tensor(72.1355, grad_fn=<MeanBackward0>) tensor(19.5000) tensor(55.) tensor(13.)
tensor(92.7957, grad_fn=<MeanBackward0>) tensor(22.0667) tensor(60.) tensor(16.)
tensor(102.4554, grad_fn=<MeanBackward0>) tensor(21.8000) tensor(65.) tensor(18.)
tensor(84.6510, grad_fn=<MeanBackward0>) tensor(18.4000) tensor(69.) tensor(15.)
tensor(84.8043, grad_fn=<MeanBackward0>) tensor(19.6333) tensor(75.) tensor(18.)
tensor(66.9770, grad_fn=<MeanBackward0>) tensor(21.2000) tensor(81.) tensor(18.)
tensor(59.4451, grad_fn=<MeanBackward0>) tensor(24.2000) tensor(86.) tensor(19.)
tensor(68.3949, grad_fn=<MeanBackward0>) tensor(23.4667) tensor(83.) tensor(18.)
tensor(65.1017, grad_fn=<MeanBackward0>) tensor(25.3000) tensor(89.) tensor(17.)
tensor(92.1525, grad_fn=<MeanBackward0>) tensor(23.1333) tensor(71.) tensor(16.)
tensor(72.8453, grad_fn=<MeanBackward0>) tensor(22.1667) tensor(79.) tensor(13.)
tensor(70.9951, grad_fn=<Me

tensor(72.6948, grad_fn=<MeanBackward0>) tensor(12.9667) tensor(28.) tensor(11.)
tensor(94.1988, grad_fn=<MeanBackward0>) tensor(15.5000) tensor(38.) tensor(14.)
tensor(129.9717, grad_fn=<MeanBackward0>) tensor(16.7333) tensor(41.) tensor(22.)
tensor(120.5198, grad_fn=<MeanBackward0>) tensor(14.4000) tensor(28.) tensor(18.)
tensor(79.1492, grad_fn=<MeanBackward0>) tensor(16.9333) tensor(33.) tensor(18.)
tensor(52.6696, grad_fn=<MeanBackward0>) tensor(15.9000) tensor(41.) tensor(14.)
tensor(71.0208, grad_fn=<MeanBackward0>) tensor(18.7667) tensor(42.) tensor(15.)
tensor(63.0463, grad_fn=<MeanBackward0>) tensor(20.7667) tensor(51.) tensor(14.)
tensor(136.0557, grad_fn=<MeanBackward0>) tensor(22.3333) tensor(56.) tensor(21.)
tensor(111.3180, grad_fn=<MeanBackward0>) tensor(22.9333) tensor(57.) tensor(20.)
tensor(108.1513, grad_fn=<MeanBackward0>) tensor(20.7667) tensor(63.) tensor(17.)
tensor(123.3855, grad_fn=<MeanBackward0>) tensor(20.4333) tensor(68.) tensor(21.)
tensor(90.6544, grad_f

tensor(75.4637, grad_fn=<MeanBackward0>) tensor(25.2000) tensor(90.) tensor(16.)
tensor(90.1167, grad_fn=<MeanBackward0>) tensor(21.9333) tensor(84.) tensor(17.)
tensor(77.3879, grad_fn=<MeanBackward0>) tensor(19.8333) tensor(87.) tensor(17.)
tensor(71.7090, grad_fn=<MeanBackward0>) tensor(18.2667) tensor(48.) tensor(16.)
tensor(61.2386, grad_fn=<MeanBackward0>) tensor(17.4667) tensor(47.) tensor(15.)
tensor(64.2548, grad_fn=<MeanBackward0>) tensor(21.) tensor(52.) tensor(17.)
tensor(75.1416, grad_fn=<MeanBackward0>) tensor(23.1000) tensor(57.) tensor(18.)
tensor(92.0334, grad_fn=<MeanBackward0>) tensor(24.7667) tensor(63.) tensor(18.)
tensor(90.3983, grad_fn=<MeanBackward0>) tensor(25.2333) tensor(69.) tensor(18.)
tensor(104.7944, grad_fn=<MeanBackward0>) tensor(29.3333) tensor(78.) tensor(19.)
tensor(138.2672, grad_fn=<MeanBackward0>) tensor(31.9333) tensor(85.) tensor(24.)
tensor(111.0419, grad_fn=<MeanBackward0>) tensor(29.0667) tensor(89.) tensor(20.)
tensor(76.0610, grad_fn=<Mean

tensor(69.8298, grad_fn=<MeanBackward0>) tensor(18.1667) tensor(80.) tensor(14.)
tensor(93.4668, grad_fn=<MeanBackward0>) tensor(19.8667) tensor(87.) tensor(18.)
tensor(71.7561, grad_fn=<MeanBackward0>) tensor(18.6000) tensor(76.) tensor(18.)
tensor(98.0024, grad_fn=<MeanBackward0>) tensor(21.5667) tensor(78.) tensor(21.)
tensor(80.4850, grad_fn=<MeanBackward0>) tensor(22.4000) tensor(81.) tensor(20.)
tensor(82.2618, grad_fn=<MeanBackward0>) tensor(24.5000) tensor(89.) tensor(19.)
tensor(108.9164, grad_fn=<MeanBackward0>) tensor(22.) tensor(50.) tensor(20.)
tensor(92.5345, grad_fn=<MeanBackward0>) tensor(21.3000) tensor(55.) tensor(18.)
tensor(87.5163, grad_fn=<MeanBackward0>) tensor(21.2000) tensor(63.) tensor(16.)
tensor(88.9972, grad_fn=<MeanBackward0>) tensor(15.0333) tensor(65.) tensor(14.)
tensor(104.0866, grad_fn=<MeanBackward0>) tensor(15.4000) tensor(70.) tensor(16.)
tensor(81.2596, grad_fn=<MeanBackward0>) tensor(14.8000) tensor(78.) tensor(14.)
tensor(79.1916, grad_fn=<MeanB

tensor(63.3679, grad_fn=<MeanBackward0>) tensor(20.5667) tensor(83.) tensor(15.)
tensor(54.2492, grad_fn=<MeanBackward0>) tensor(17.9000) tensor(85.) tensor(13.)
tensor(62.9131, grad_fn=<MeanBackward0>) tensor(17.8333) tensor(50.) tensor(14.)
tensor(77.8852, grad_fn=<MeanBackward0>) tensor(21.2000) tensor(54.) tensor(19.)
tensor(108.6704, grad_fn=<MeanBackward0>) tensor(19.0333) tensor(54.) tensor(20.)
tensor(100.1409, grad_fn=<MeanBackward0>) tensor(19.7333) tensor(59.) tensor(19.)
tensor(120.0890, grad_fn=<MeanBackward0>) tensor(21.5667) tensor(61.) tensor(22.)
tensor(98.4357, grad_fn=<MeanBackward0>) tensor(18.1333) tensor(52.) tensor(20.)
tensor(89.6612, grad_fn=<MeanBackward0>) tensor(19.0333) tensor(58.) tensor(20.)
tensor(88.2739, grad_fn=<MeanBackward0>) tensor(18.0667) tensor(62.) tensor(19.)
tensor(81.2176, grad_fn=<MeanBackward0>) tensor(17.6000) tensor(62.) tensor(18.)
tensor(67.3760, grad_fn=<MeanBackward0>) tensor(18.8333) tensor(66.) tensor(14.)
tensor(81.8074, grad_fn=<

tensor(101.5682, grad_fn=<MeanBackward0>) tensor(20.5000) tensor(77.) tensor(20.)
tensor(91.8177, grad_fn=<MeanBackward0>) tensor(18.7000) tensor(83.) tensor(18.)
tensor(77.0235, grad_fn=<MeanBackward0>) tensor(19.3000) tensor(87.) tensor(16.)
tensor(72.1867, grad_fn=<MeanBackward0>) tensor(17.4000) tensor(88.) tensor(16.)
tensor(50.9981, grad_fn=<MeanBackward0>) tensor(16.0667) tensor(35.) tensor(13.)
tensor(58.1335, grad_fn=<MeanBackward0>) tensor(18.8333) tensor(42.) tensor(15.)
tensor(90.6382, grad_fn=<MeanBackward0>) tensor(23.9000) tensor(51.) tensor(21.)
tensor(111.8281, grad_fn=<MeanBackward0>) tensor(27.0333) tensor(59.) tensor(22.)
tensor(83.0812, grad_fn=<MeanBackward0>) tensor(19.7333) tensor(56.) tensor(16.)
tensor(83.2646, grad_fn=<MeanBackward0>) tensor(23.1667) tensor(63.) tensor(15.)
tensor(78.9495, grad_fn=<MeanBackward0>) tensor(25.4333) tensor(67.) tensor(15.)
tensor(90.1639, grad_fn=<MeanBackward0>) tensor(25.4667) tensor(74.) tensor(18.)
tensor(79.6557, grad_fn=<M

tensor(86.1212, grad_fn=<MeanBackward0>) tensor(17.5000) tensor(51.) tensor(15.)
tensor(59.9084, grad_fn=<MeanBackward0>) tensor(18.1000) tensor(53.) tensor(15.)
tensor(61.1286, grad_fn=<MeanBackward0>) tensor(21.6000) tensor(55.) tensor(17.)
tensor(71.6239, grad_fn=<MeanBackward0>) tensor(25.3000) tensor(62.) tensor(20.)
tensor(97.6405, grad_fn=<MeanBackward0>) tensor(26.6000) tensor(70.) tensor(21.)
tensor(73.3611, grad_fn=<MeanBackward0>) tensor(26.0333) tensor(77.) tensor(17.)
tensor(83.3652, grad_fn=<MeanBackward0>) tensor(24.4333) tensor(85.) tensor(15.)
tensor(79.2435, grad_fn=<MeanBackward0>) tensor(22.8667) tensor(90.) tensor(14.)
tensor(82.9294, grad_fn=<MeanBackward0>) tensor(20.7333) tensor(76.) tensor(15.)
tensor(91.8985, grad_fn=<MeanBackward0>) tensor(23.2333) tensor(85.) tensor(16.)
tensor(70.7424, grad_fn=<MeanBackward0>) tensor(21.1333) tensor(89.) tensor(14.)
tensor(72.7155, grad_fn=<MeanBackward0>) tensor(18.8667) tensor(88.) tensor(12.)
tensor(54.4700, grad_fn=<Mea

tensor(73.8830, grad_fn=<MeanBackward0>) tensor(20.9333) tensor(44.) tensor(18.)
tensor(63.8555, grad_fn=<MeanBackward0>) tensor(19.0667) tensor(52.) tensor(13.)
tensor(94.2174, grad_fn=<MeanBackward0>) tensor(20.8667) tensor(62.) tensor(16.)
tensor(81.5666, grad_fn=<MeanBackward0>) tensor(15.9667) tensor(58.) tensor(13.)
tensor(80.7550, grad_fn=<MeanBackward0>) tensor(16.4000) tensor(62.) tensor(14.)
tensor(85.1237, grad_fn=<MeanBackward0>) tensor(16.9667) tensor(68.) tensor(14.)
tensor(98.7365, grad_fn=<MeanBackward0>) tensor(18.4000) tensor(74.) tensor(17.)
tensor(101.0642, grad_fn=<MeanBackward0>) tensor(21.) tensor(79.) tensor(21.)
tensor(76.7617, grad_fn=<MeanBackward0>) tensor(17.8333) tensor(86.) tensor(17.)
tensor(81.4408, grad_fn=<MeanBackward0>) tensor(18.3000) tensor(89.) tensor(18.)
tensor(48.0487, grad_fn=<MeanBackward0>) tensor(12.9667) tensor(36.) tensor(14.)
tensor(47.7348, grad_fn=<MeanBackward0>) tensor(14.7667) tensor(33.) tensor(13.)
tensor(62.3964, grad_fn=<MeanBa

tensor(72.8051, grad_fn=<MeanBackward0>) tensor(18.8000) tensor(40.) tensor(16.)
tensor(67.5621, grad_fn=<MeanBackward0>) tensor(18.4000) tensor(45.) tensor(15.)
tensor(98.6327, grad_fn=<MeanBackward0>) tensor(20.0667) tensor(51.) tensor(19.)
tensor(127.1807, grad_fn=<MeanBackward0>) tensor(23.) tensor(59.) tensor(21.)
tensor(93.7188, grad_fn=<MeanBackward0>) tensor(20.9333) tensor(66.) tensor(17.)
tensor(98.2834, grad_fn=<MeanBackward0>) tensor(23.1000) tensor(76.) tensor(17.)
tensor(100.2476, grad_fn=<MeanBackward0>) tensor(24.0333) tensor(80.) tensor(20.)
tensor(98.2463, grad_fn=<MeanBackward0>) tensor(23.9333) tensor(86.) tensor(20.)
tensor(91.3646, grad_fn=<MeanBackward0>) tensor(22.) tensor(84.) tensor(18.)
tensor(78.3662, grad_fn=<MeanBackward0>) tensor(18.6333) tensor(82.) tensor(14.)
tensor(61.1392, grad_fn=<MeanBackward0>) tensor(16.1000) tensor(86.) tensor(12.)
tensor(64.3029, grad_fn=<MeanBackward0>) tensor(13.2667) tensor(37.) tensor(11.)
tensor(41.3080, grad_fn=<MeanBackw

tensor(108.6109, grad_fn=<MeanBackward0>) tensor(19.9000) tensor(78.) tensor(20.)
tensor(84.5190, grad_fn=<MeanBackward0>) tensor(21.0667) tensor(85.) tensor(17.)
tensor(70.0964, grad_fn=<MeanBackward0>) tensor(20.2667) tensor(90.) tensor(15.)
tensor(66.1654, grad_fn=<MeanBackward0>) tensor(18.0667) tensor(83.) tensor(12.)
tensor(62.1589, grad_fn=<MeanBackward0>) tensor(21.0333) tensor(86.) tensor(16.)
tensor(72.2324, grad_fn=<MeanBackward0>) tensor(22.1667) tensor(89.) tensor(17.)
tensor(96.1718, grad_fn=<MeanBackward0>) tensor(20.5000) tensor(86.) tensor(17.)
tensor(86.4571, grad_fn=<MeanBackward0>) tensor(21.1000) tensor(89.) tensor(16.)
tensor(81.1602, grad_fn=<MeanBackward0>) tensor(16.3333) tensor(47.) tensor(12.)
tensor(70.5495, grad_fn=<MeanBackward0>) tensor(16.7000) tensor(55.) tensor(12.)
tensor(102.4985, grad_fn=<MeanBackward0>) tensor(18.9000) tensor(60.) tensor(17.)
tensor(92.5162, grad_fn=<MeanBackward0>) tensor(19.5333) tensor(63.) tensor(15.)
tensor(81.9421, grad_fn=<M

tensor(76.1167, grad_fn=<MeanBackward0>) tensor(16.5333) tensor(81.) tensor(14.)
tensor(63.8249, grad_fn=<MeanBackward0>) tensor(16.2000) tensor(87.) tensor(13.)
tensor(62.0244, grad_fn=<MeanBackward0>) tensor(14.1000) tensor(29.) tensor(12.)
tensor(76.9648, grad_fn=<MeanBackward0>) tensor(14.2000) tensor(32.) tensor(14.)
tensor(80.8806, grad_fn=<MeanBackward0>) tensor(15.4333) tensor(32.) tensor(14.)
tensor(83.8508, grad_fn=<MeanBackward0>) tensor(16.0667) tensor(33.) tensor(17.)
tensor(85.9894, grad_fn=<MeanBackward0>) tensor(18.2667) tensor(36.) tensor(17.)
tensor(100.8548, grad_fn=<MeanBackward0>) tensor(19.6000) tensor(36.) tensor(21.)
tensor(101.0510, grad_fn=<MeanBackward0>) tensor(20.0333) tensor(36.) tensor(20.)
tensor(105.7844, grad_fn=<MeanBackward0>) tensor(18.5000) tensor(41.) tensor(18.)
tensor(135.7457, grad_fn=<MeanBackward0>) tensor(20.2000) tensor(45.) tensor(20.)
tensor(108.3906, grad_fn=<MeanBackward0>) tensor(18.9667) tensor(49.) tensor(19.)
tensor(118.5451, grad_f

tensor(109.6459, grad_fn=<MeanBackward0>) tensor(20.3667) tensor(49.) tensor(18.)
tensor(112.9275, grad_fn=<MeanBackward0>) tensor(18.4667) tensor(49.) tensor(19.)
tensor(92.6970, grad_fn=<MeanBackward0>) tensor(17.0667) tensor(53.) tensor(15.)
tensor(79.1853, grad_fn=<MeanBackward0>) tensor(17.) tensor(55.) tensor(13.)
tensor(98.8153, grad_fn=<MeanBackward0>) tensor(17.7333) tensor(60.) tensor(15.)
tensor(111.0697, grad_fn=<MeanBackward0>) tensor(16.5667) tensor(33.) tensor(17.)
tensor(99.5757, grad_fn=<MeanBackward0>) tensor(16.8667) tensor(33.) tensor(16.)
tensor(101.4604, grad_fn=<MeanBackward0>) tensor(16.3667) tensor(33.) tensor(17.)
tensor(105.0397, grad_fn=<MeanBackward0>) tensor(16.3000) tensor(33.) tensor(18.)
tensor(66.8034, grad_fn=<MeanBackward0>) tensor(14.7333) tensor(32.) tensor(14.)
tensor(59.6491, grad_fn=<MeanBackward0>) tensor(18.2333) tensor(37.) tensor(16.)
tensor(70.2269, grad_fn=<MeanBackward0>) tensor(21.3000) tensor(45.) tensor(19.)
tensor(80.4623, grad_fn=<Me

tensor(57.4222, grad_fn=<MeanBackward0>) tensor(16.8333) tensor(88.) tensor(16.)
tensor(39.5191, grad_fn=<MeanBackward0>) tensor(13.1000) tensor(30.) tensor(12.)
tensor(44.1081, grad_fn=<MeanBackward0>) tensor(15.4667) tensor(39.) tensor(12.)
tensor(74.8510, grad_fn=<MeanBackward0>) tensor(18.8333) tensor(44.) tensor(14.)
tensor(130.5756, grad_fn=<MeanBackward0>) tensor(20.5000) tensor(51.) tensor(20.)
tensor(114.6228, grad_fn=<MeanBackward0>) tensor(16.1667) tensor(60.) tensor(18.)
tensor(130.3698, grad_fn=<MeanBackward0>) tensor(16.1667) tensor(66.) tensor(18.)
tensor(108.2883, grad_fn=<MeanBackward0>) tensor(17.7000) tensor(73.) tensor(17.)
tensor(97.6682, grad_fn=<MeanBackward0>) tensor(19.9000) tensor(81.) tensor(17.)
tensor(96.4437, grad_fn=<MeanBackward0>) tensor(20.3667) tensor(89.) tensor(20.)
tensor(92.7127, grad_fn=<MeanBackward0>) tensor(18.0333) tensor(86.) tensor(19.)
tensor(79.7281, grad_fn=<MeanBackward0>) tensor(17.7000) tensor(35.) tensor(21.)
tensor(77.9615, grad_fn=

tensor(73.7203, grad_fn=<MeanBackward0>) tensor(18.4000) tensor(68.) tensor(14.)
tensor(70.0992, grad_fn=<MeanBackward0>) tensor(18.9333) tensor(75.) tensor(14.)
tensor(79.3479, grad_fn=<MeanBackward0>) tensor(18.6000) tensor(80.) tensor(16.)
tensor(108.1691, grad_fn=<MeanBackward0>) tensor(18.9000) tensor(88.) tensor(20.)
tensor(83.7982, grad_fn=<MeanBackward0>) tensor(20.6667) tensor(90.) tensor(22.)
tensor(93.4758, grad_fn=<MeanBackward0>) tensor(17.9333) tensor(37.) tensor(19.)
tensor(70.6327, grad_fn=<MeanBackward0>) tensor(17.3333) tensor(39.) tensor(15.)
tensor(85.2264, grad_fn=<MeanBackward0>) tensor(18.3667) tensor(46.) tensor(13.)
tensor(63.4703, grad_fn=<MeanBackward0>) tensor(20.3333) tensor(54.) tensor(11.)
tensor(84.6843, grad_fn=<MeanBackward0>) tensor(22.6333) tensor(62.) tensor(15.)
tensor(91.5464, grad_fn=<MeanBackward0>) tensor(22.) tensor(69.) tensor(17.)
tensor(104.8118, grad_fn=<MeanBackward0>) tensor(20.) tensor(71.) tensor(15.)
tensor(83.0174, grad_fn=<MeanBackw

tensor(72.0481, grad_fn=<MeanBackward0>) tensor(19.7667) tensor(39.) tensor(18.)
tensor(61.1676, grad_fn=<MeanBackward0>) tensor(22.6667) tensor(46.) tensor(16.)
tensor(85.0331, grad_fn=<MeanBackward0>) tensor(22.6333) tensor(54.) tensor(14.)
tensor(88.5736, grad_fn=<MeanBackward0>) tensor(22.1000) tensor(62.) tensor(14.)
tensor(93.0240, grad_fn=<MeanBackward0>) tensor(19.4000) tensor(70.) tensor(13.)
tensor(129.4355, grad_fn=<MeanBackward0>) tensor(21.9000) tensor(77.) tensor(18.)
tensor(126.5996, grad_fn=<MeanBackward0>) tensor(20.9667) tensor(86.) tensor(19.)
tensor(116.8006, grad_fn=<MeanBackward0>) tensor(17.6667) tensor(83.) tensor(18.)
tensor(115.9774, grad_fn=<MeanBackward0>) tensor(18.7667) tensor(91.) tensor(20.)
tensor(98.3714, grad_fn=<MeanBackward0>) tensor(15.2000) tensor(30.) tensor(17.)
tensor(63.9321, grad_fn=<MeanBackward0>) tensor(15.3667) tensor(30.) tensor(14.)
tensor(59.3410, grad_fn=<MeanBackward0>) tensor(15.7333) tensor(36.) tensor(12.)
tensor(76.6918, grad_fn=

tensor(58.8384, grad_fn=<MeanBackward0>) tensor(13.6333) tensor(25.) tensor(15.)
tensor(73.5243, grad_fn=<MeanBackward0>) tensor(14.6333) tensor(30.) tensor(18.)
tensor(69.6051, grad_fn=<MeanBackward0>) tensor(16.3667) tensor(38.) tensor(16.)
tensor(103.5752, grad_fn=<MeanBackward0>) tensor(18.2000) tensor(40.) tensor(17.)
tensor(103.8054, grad_fn=<MeanBackward0>) tensor(18.) tensor(46.) tensor(16.)
tensor(86.9976, grad_fn=<MeanBackward0>) tensor(18.4333) tensor(54.) tensor(14.)
tensor(90.8388, grad_fn=<MeanBackward0>) tensor(18.1333) tensor(60.) tensor(14.)
tensor(95.0464, grad_fn=<MeanBackward0>) tensor(20.8333) tensor(66.) tensor(15.)
tensor(74.6404, grad_fn=<MeanBackward0>) tensor(20.2667) tensor(70.) tensor(14.)
tensor(87.5412, grad_fn=<MeanBackward0>) tensor(18.9333) tensor(73.) tensor(17.)
tensor(66.5449, grad_fn=<MeanBackward0>) tensor(16.5333) tensor(76.) tensor(14.)
tensor(57.3789, grad_fn=<MeanBackward0>) tensor(17.4333) tensor(82.) tensor(12.)
tensor(85.3095, grad_fn=<MeanB

tensor(65.1452, grad_fn=<MeanBackward0>) tensor(15.5333) tensor(34.) tensor(16.)
tensor(57.1333, grad_fn=<MeanBackward0>) tensor(18.1667) tensor(41.) tensor(14.)
tensor(101.6532, grad_fn=<MeanBackward0>) tensor(20.1667) tensor(49.) tensor(16.)
tensor(99.6752, grad_fn=<MeanBackward0>) tensor(21.0667) tensor(55.) tensor(18.)
tensor(85.3715, grad_fn=<MeanBackward0>) tensor(17.0333) tensor(54.) tensor(14.)
tensor(124.9527, grad_fn=<MeanBackward0>) tensor(18.1667) tensor(62.) tensor(14.)
tensor(102.7223, grad_fn=<MeanBackward0>) tensor(20.4333) tensor(69.) tensor(18.)
tensor(99.7650, grad_fn=<MeanBackward0>) tensor(21.4000) tensor(76.) tensor(19.)
tensor(107.5688, grad_fn=<MeanBackward0>) tensor(22.8000) tensor(82.) tensor(22.)
tensor(93.1501, grad_fn=<MeanBackward0>) tensor(20.4333) tensor(87.) tensor(17.)
tensor(55.7651, grad_fn=<MeanBackward0>) tensor(15.3667) tensor(88.) tensor(13.)
tensor(52.6013, grad_fn=<MeanBackward0>) tensor(15.1333) tensor(40.) tensor(12.)
tensor(40.0278, grad_fn=

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



tensor(142.9411, grad_fn=<MeanBackward0>) tensor(24.3667) tensor(55.) tensor(22.)
tensor(138.7303, grad_fn=<MeanBackward0>) tensor(25.4667) tensor(62.) tensor(23.)
tensor(120.6536, grad_fn=<MeanBackward0>) tensor(25.2333) tensor(68.) tensor(22.)
tensor(104.2625, grad_fn=<MeanBackward0>) tensor(22.8333) tensor(74.) tensor(21.)
tensor(83.4097, grad_fn=<MeanBackward0>) tensor(20.2667) tensor(77.) tensor(18.)
tensor(64.3424, grad_fn=<MeanBackward0>) tensor(15.7667) tensor(77.) tensor(13.)
tensor(48.7908, grad_fn=<MeanBackward0>) tensor(16.0333) tensor(43.) tensor(12.)
tensor(62.3202, grad_fn=<MeanBackward0>) tensor(16.) tensor(35.) tensor(13.)
tensor(67.1568, grad_fn=<MeanBackward0>) tensor(15.3000) tensor(36.) tensor(14.)
tensor(66.8211, grad_fn=<MeanBackward0>) tensor(14.) tensor(39.) tensor(13.)
tensor(54.4914, grad_fn=<MeanBackward0>) tensor(14.8000) tensor(43.) tensor(11.)
tensor(91.5737, grad_fn=<MeanBackward0>) tensor(16.5333) tensor(48.) tensor(18.)
tensor(106.1763, grad_fn=<MeanBa

tensor(80.5232, grad_fn=<MeanBackward0>) tensor(17.3333) tensor(90.) tensor(15.)
tensor(90.1151, grad_fn=<MeanBackward0>) tensor(15.8667) tensor(85.) tensor(16.)
tensor(87.8127, grad_fn=<MeanBackward0>) tensor(14.2333) tensor(26.) tensor(15.)
tensor(67.7706, grad_fn=<MeanBackward0>) tensor(13.8667) tensor(28.) tensor(14.)
tensor(68.2166, grad_fn=<MeanBackward0>) tensor(14.5667) tensor(28.) tensor(14.)
tensor(52.9267, grad_fn=<MeanBackward0>) tensor(14.8667) tensor(34.) tensor(13.)
tensor(79.7781, grad_fn=<MeanBackward0>) tensor(17.9667) tensor(39.) tensor(17.)
tensor(101.3489, grad_fn=<MeanBackward0>) tensor(19.3000) tensor(47.) tensor(19.)
tensor(125.6785, grad_fn=<MeanBackward0>) tensor(15.5667) tensor(55.) tensor(15.)
tensor(106.0609, grad_fn=<MeanBackward0>) tensor(17.7000) tensor(64.) tensor(15.)
tensor(92.9980, grad_fn=<MeanBackward0>) tensor(19.3333) tensor(70.) tensor(16.)
tensor(119.2865, grad_fn=<MeanBackward0>) tensor(20.7000) tensor(75.) tensor(21.)
tensor(85.5602, grad_fn=

tensor(75.8488, grad_fn=<MeanBackward0>) tensor(14.3333) tensor(58.) tensor(11.)
tensor(55.0115, grad_fn=<MeanBackward0>) tensor(15.4667) tensor(62.) tensor(11.)
tensor(78.8136, grad_fn=<MeanBackward0>) tensor(17.9000) tensor(70.) tensor(15.)
tensor(93.7016, grad_fn=<MeanBackward0>) tensor(19.2333) tensor(74.) tensor(20.)
tensor(79.8235, grad_fn=<MeanBackward0>) tensor(20.8000) tensor(79.) tensor(21.)
tensor(96.3598, grad_fn=<MeanBackward0>) tensor(22.4000) tensor(85.) tensor(23.)
tensor(87.3448, grad_fn=<MeanBackward0>) tensor(21.1333) tensor(88.) tensor(20.)
tensor(112.1846, grad_fn=<MeanBackward0>) tensor(17.7667) tensor(51.) tensor(18.)
tensor(82.4224, grad_fn=<MeanBackward0>) tensor(19.4333) tensor(56.) tensor(15.)
tensor(118.7460, grad_fn=<MeanBackward0>) tensor(18.3000) tensor(51.) tensor(16.)
tensor(128.5021, grad_fn=<MeanBackward0>) tensor(18.9333) tensor(56.) tensor(17.)
tensor(122.7807, grad_fn=<MeanBackward0>) tensor(18.6000) tensor(62.) tensor(20.)
tensor(89.4544, grad_fn=

tensor(114.4043, grad_fn=<MeanBackward0>) tensor(17.9333) tensor(48.) tensor(18.)
tensor(109.6263, grad_fn=<MeanBackward0>) tensor(16.8000) tensor(46.) tensor(18.)
tensor(128.5475, grad_fn=<MeanBackward0>) tensor(16.5667) tensor(52.) tensor(16.)
tensor(68.0849, grad_fn=<MeanBackward0>) tensor(16.8333) tensor(56.) tensor(15.)
tensor(68.4282, grad_fn=<MeanBackward0>) tensor(15.3667) tensor(31.) tensor(15.)
tensor(56.4868, grad_fn=<MeanBackward0>) tensor(14.6000) tensor(31.) tensor(12.)
tensor(104.1518, grad_fn=<MeanBackward0>) tensor(18.4000) tensor(34.) tensor(19.)
tensor(74.6330, grad_fn=<MeanBackward0>) tensor(19.0667) tensor(41.) tensor(16.)
tensor(104.3993, grad_fn=<MeanBackward0>) tensor(21.2667) tensor(48.) tensor(18.)
tensor(119.6654, grad_fn=<MeanBackward0>) tensor(23.4000) tensor(55.) tensor(20.)
tensor(118.5029, grad_fn=<MeanBackward0>) tensor(21.1667) tensor(63.) tensor(19.)
tensor(122.3708, grad_fn=<MeanBackward0>) tensor(22.8333) tensor(69.) tensor(20.)
tensor(100.6373, gra

tensor(80.4024, grad_fn=<MeanBackward0>) tensor(27.3333) tensor(50.) tensor(15.)
tensor(112.8941, grad_fn=<MeanBackward0>) tensor(29.2667) tensor(60.) tensor(21.)
tensor(108.8939, grad_fn=<MeanBackward0>) tensor(26.3333) tensor(67.) tensor(18.)
tensor(97.8220, grad_fn=<MeanBackward0>) tensor(21.) tensor(75.) tensor(14.)
tensor(59.1063, grad_fn=<MeanBackward0>) tensor(19.0333) tensor(71.) tensor(11.)
tensor(97.3626, grad_fn=<MeanBackward0>) tensor(16.9000) tensor(77.) tensor(11.)
tensor(76.6978, grad_fn=<MeanBackward0>) tensor(16.1667) tensor(84.) tensor(12.)
tensor(65.5964, grad_fn=<MeanBackward0>) tensor(12.6333) tensor(27.) tensor(12.)
tensor(73.2303, grad_fn=<MeanBackward0>) tensor(13.4333) tensor(34.) tensor(14.)
tensor(68.7992, grad_fn=<MeanBackward0>) tensor(14.5333) tensor(40.) tensor(13.)
tensor(89.0398, grad_fn=<MeanBackward0>) tensor(17.8000) tensor(44.) tensor(18.)
tensor(72.0824, grad_fn=<MeanBackward0>) tensor(16.0667) tensor(29.) tensor(18.)
tensor(66.2851, grad_fn=<MeanB

tensor(90.2489, grad_fn=<MeanBackward0>) tensor(18.4000) tensor(64.) tensor(17.)
tensor(79.4226, grad_fn=<MeanBackward0>) tensor(16.7667) tensor(64.) tensor(16.)
tensor(117.0273, grad_fn=<MeanBackward0>) tensor(14.6333) tensor(71.) tensor(15.)
tensor(85.9130, grad_fn=<MeanBackward0>) tensor(14.7333) tensor(78.) tensor(17.)
tensor(58.5108, grad_fn=<MeanBackward0>) tensor(14.7667) tensor(82.) tensor(15.)
tensor(57.0456, grad_fn=<MeanBackward0>) tensor(16.8333) tensor(89.) tensor(14.)
tensor(74.7256, grad_fn=<MeanBackward0>) tensor(16.6000) tensor(38.) tensor(14.)
tensor(42.7344, grad_fn=<MeanBackward0>) tensor(17.5000) tensor(38.) tensor(13.)
tensor(53.7847, grad_fn=<MeanBackward0>) tensor(19.2333) tensor(44.) tensor(11.)
tensor(114.6157, grad_fn=<MeanBackward0>) tensor(22.9333) tensor(51.) tensor(18.)
tensor(100.1893, grad_fn=<MeanBackward0>) tensor(22.1000) tensor(58.) tensor(17.)
tensor(108.6990, grad_fn=<MeanBackward0>) tensor(23.2000) tensor(66.) tensor(17.)
tensor(104.5809, grad_fn

tensor(75.8988, grad_fn=<MeanBackward0>) tensor(15.7000) tensor(57.) tensor(14.)
tensor(69.7639, grad_fn=<MeanBackward0>) tensor(15.8667) tensor(64.) tensor(15.)
tensor(77.7375, grad_fn=<MeanBackward0>) tensor(18.6333) tensor(71.) tensor(16.)
tensor(72.2389, grad_fn=<MeanBackward0>) tensor(19.3667) tensor(75.) tensor(16.)
tensor(71.6752, grad_fn=<MeanBackward0>) tensor(18.1333) tensor(79.) tensor(15.)
tensor(68.7432, grad_fn=<MeanBackward0>) tensor(20.8333) tensor(83.) tensor(15.)
tensor(82.0767, grad_fn=<MeanBackward0>) tensor(21.5333) tensor(89.) tensor(19.)
tensor(90.8045, grad_fn=<MeanBackward0>) tensor(14.4333) tensor(76.) tensor(16.)
tensor(74.2610, grad_fn=<MeanBackward0>) tensor(16.7000) tensor(80.) tensor(16.)
tensor(64.2039, grad_fn=<MeanBackward0>) tensor(17.7333) tensor(82.) tensor(15.)
tensor(69.2642, grad_fn=<MeanBackward0>) tensor(19.7333) tensor(88.) tensor(16.)
tensor(84.6422, grad_fn=<MeanBackward0>) tensor(16.2000) tensor(46.) tensor(16.)
tensor(91.4421, grad_fn=<Mea

tensor(98.9331, grad_fn=<MeanBackward0>) tensor(24.2333) tensor(50.) tensor(18.)
tensor(113.6993, grad_fn=<MeanBackward0>) tensor(20.8333) tensor(58.) tensor(18.)
tensor(117.9063, grad_fn=<MeanBackward0>) tensor(23.2333) tensor(64.) tensor(20.)
tensor(127.0592, grad_fn=<MeanBackward0>) tensor(21.8000) tensor(69.) tensor(18.)
tensor(107.1076, grad_fn=<MeanBackward0>) tensor(24.9000) tensor(78.) tensor(18.)
tensor(107.5538, grad_fn=<MeanBackward0>) tensor(23.4333) tensor(83.) tensor(19.)
tensor(119.4069, grad_fn=<MeanBackward0>) tensor(21.7667) tensor(82.) tensor(21.)
tensor(90.7344, grad_fn=<MeanBackward0>) tensor(19.9333) tensor(84.) tensor(18.)
tensor(82.8157, grad_fn=<MeanBackward0>) tensor(22.3000) tensor(89.) tensor(19.)
tensor(63.5640, grad_fn=<MeanBackward0>) tensor(21.7667) tensor(80.) tensor(16.)
tensor(76.4528, grad_fn=<MeanBackward0>) tensor(20.4667) tensor(82.) tensor(18.)
tensor(102.1535, grad_fn=<MeanBackward0>) tensor(17.6667) tensor(84.) tensor(15.)
tensor(82.7782, grad_

tensor(71.7270, grad_fn=<MeanBackward0>) tensor(14.7333) tensor(45.) tensor(12.)
tensor(67.8044, grad_fn=<MeanBackward0>) tensor(13.8667) tensor(47.) tensor(12.)
tensor(102.4850, grad_fn=<MeanBackward0>) tensor(13.9000) tensor(51.) tensor(14.)
tensor(77.5254, grad_fn=<MeanBackward0>) tensor(13.3333) tensor(26.) tensor(14.)
tensor(55.1069, grad_fn=<MeanBackward0>) tensor(14.9333) tensor(28.) tensor(13.)
tensor(55.6860, grad_fn=<MeanBackward0>) tensor(18.8000) tensor(35.) tensor(15.)
tensor(82.1869, grad_fn=<MeanBackward0>) tensor(19.7667) tensor(39.) tensor(18.)
tensor(101.2934, grad_fn=<MeanBackward0>) tensor(20.4667) tensor(46.) tensor(18.)
tensor(106.3957, grad_fn=<MeanBackward0>) tensor(20.0667) tensor(48.) tensor(16.)
tensor(134.6854, grad_fn=<MeanBackward0>) tensor(20.1667) tensor(58.) tensor(17.)
tensor(118.0330, grad_fn=<MeanBackward0>) tensor(19.7333) tensor(59.) tensor(17.)
tensor(85.3684, grad_fn=<MeanBackward0>) tensor(21.1333) tensor(66.) tensor(15.)
tensor(122.6849, grad_f

tensor(87.8963, grad_fn=<MeanBackward0>) tensor(20.9333) tensor(83.) tensor(18.)
tensor(68.3362, grad_fn=<MeanBackward0>) tensor(23.5667) tensor(88.) tensor(20.)
tensor(58.5507, grad_fn=<MeanBackward0>) tensor(19.5000) tensor(48.) tensor(17.)
tensor(71.6812, grad_fn=<MeanBackward0>) tensor(23.3333) tensor(54.) tensor(19.)
tensor(65.9583, grad_fn=<MeanBackward0>) tensor(20.9667) tensor(50.) tensor(16.)
tensor(86.2180, grad_fn=<MeanBackward0>) tensor(21.) tensor(56.) tensor(16.)
tensor(109.0151, grad_fn=<MeanBackward0>) tensor(24.6667) tensor(63.) tensor(18.)
tensor(134.1599, grad_fn=<MeanBackward0>) tensor(22.7667) tensor(66.) tensor(22.)
tensor(110.6635, grad_fn=<MeanBackward0>) tensor(20.2667) tensor(72.) tensor(18.)
tensor(104.3605, grad_fn=<MeanBackward0>) tensor(17.8667) tensor(79.) tensor(14.)
tensor(71.5838, grad_fn=<MeanBackward0>) tensor(19.2333) tensor(87.) tensor(14.)
tensor(71.4536, grad_fn=<MeanBackward0>) tensor(17.2667) tensor(91.) tensor(18.)
tensor(46.5261, grad_fn=<Mea

tensor(138.8281, grad_fn=<MeanBackward0>) tensor(21.2333) tensor(68.) tensor(19.)
tensor(116.2208, grad_fn=<MeanBackward0>) tensor(20.2667) tensor(75.) tensor(17.)
tensor(103.2984, grad_fn=<MeanBackward0>) tensor(21.2333) tensor(84.) tensor(17.)
tensor(114.3069, grad_fn=<MeanBackward0>) tensor(19.1667) tensor(84.) tensor(17.)
tensor(127.1077, grad_fn=<MeanBackward0>) tensor(20.0333) tensor(91.) tensor(21.)
tensor(89.8185, grad_fn=<MeanBackward0>) tensor(15.7333) tensor(46.) tensor(16.)
tensor(80.5540, grad_fn=<MeanBackward0>) tensor(15.9333) tensor(50.) tensor(14.)
tensor(63.7551, grad_fn=<MeanBackward0>) tensor(15.5667) tensor(54.) tensor(12.)
tensor(54.5947, grad_fn=<MeanBackward0>) tensor(16.8667) tensor(57.) tensor(14.)
tensor(48.7580, grad_fn=<MeanBackward0>) tensor(17.0667) tensor(41.) tensor(13.)
tensor(61.8882, grad_fn=<MeanBackward0>) tensor(17.3667) tensor(47.) tensor(12.)
tensor(92.4006, grad_fn=<MeanBackward0>) tensor(19.9667) tensor(55.) tensor(18.)
tensor(120.6818, grad_f

tensor(91.7459, grad_fn=<MeanBackward0>) tensor(20.1333) tensor(71.) tensor(15.)
tensor(112.6743, grad_fn=<MeanBackward0>) tensor(19.3000) tensor(77.) tensor(17.)
tensor(93.5621, grad_fn=<MeanBackward0>) tensor(19.7667) tensor(80.) tensor(17.)
tensor(96.0755, grad_fn=<MeanBackward0>) tensor(20.3667) tensor(84.) tensor(20.)
tensor(103.5331, grad_fn=<MeanBackward0>) tensor(19.8000) tensor(90.) tensor(19.)
tensor(74.6447, grad_fn=<MeanBackward0>) tensor(17.6333) tensor(38.) tensor(17.)
tensor(75.3458, grad_fn=<MeanBackward0>) tensor(18.6333) tensor(41.) tensor(18.)
tensor(112.0490, grad_fn=<MeanBackward0>) tensor(20.7000) tensor(46.) tensor(18.)
tensor(88.9212, grad_fn=<MeanBackward0>) tensor(19.2333) tensor(52.) tensor(18.)
tensor(84.8260, grad_fn=<MeanBackward0>) tensor(17.2000) tensor(61.) tensor(14.)
tensor(82.3787, grad_fn=<MeanBackward0>) tensor(18.0333) tensor(51.) tensor(13.)
tensor(91.5025, grad_fn=<MeanBackward0>) tensor(16.2000) tensor(59.) tensor(16.)
tensor(106.7968, grad_fn=

tensor(109.4938, grad_fn=<MeanBackward0>) tensor(20.5667) tensor(80.) tensor(20.)
tensor(105.5522, grad_fn=<MeanBackward0>) tensor(18.7333) tensor(84.) tensor(17.)
tensor(90.8081, grad_fn=<MeanBackward0>) tensor(16.9333) tensor(58.) tensor(15.)
tensor(66.8358, grad_fn=<MeanBackward0>) tensor(17.7667) tensor(64.) tensor(14.)
tensor(74.9174, grad_fn=<MeanBackward0>) tensor(19.0667) tensor(65.) tensor(17.)
tensor(60.1731, grad_fn=<MeanBackward0>) tensor(19.3667) tensor(71.) tensor(13.)
tensor(80.9997, grad_fn=<MeanBackward0>) tensor(18.4000) tensor(74.) tensor(15.)
tensor(90.9187, grad_fn=<MeanBackward0>) tensor(19.1333) tensor(79.) tensor(15.)
tensor(77.0391, grad_fn=<MeanBackward0>) tensor(20.2333) tensor(84.) tensor(14.)
tensor(79.9776, grad_fn=<MeanBackward0>) tensor(22.0333) tensor(91.) tensor(14.)
tensor(60.9612, grad_fn=<MeanBackward0>) tensor(17.3667) tensor(74.) tensor(12.)
tensor(81.3479, grad_fn=<MeanBackward0>) tensor(19.) tensor(78.) tensor(14.)
tensor(72.4869, grad_fn=<MeanB

tensor(119.4979, grad_fn=<MeanBackward0>) tensor(21.9333) tensor(49.) tensor(18.)
tensor(103.8695, grad_fn=<MeanBackward0>) tensor(18.8667) tensor(49.) tensor(16.)
tensor(127.9175, grad_fn=<MeanBackward0>) tensor(19.8667) tensor(57.) tensor(18.)
tensor(197.5295, grad_fn=<MeanBackward0>) tensor(21.6667) tensor(63.) tensor(21.)
tensor(129.2680, grad_fn=<MeanBackward0>) tensor(19.9333) tensor(69.) tensor(22.)
tensor(98.4383, grad_fn=<MeanBackward0>) tensor(19.3333) tensor(73.) tensor(18.)
tensor(76.4583, grad_fn=<MeanBackward0>) tensor(18.1333) tensor(78.) tensor(16.)
tensor(79.2781, grad_fn=<MeanBackward0>) tensor(20.8667) tensor(81.) tensor(16.)
tensor(112.1189, grad_fn=<MeanBackward0>) tensor(20.6000) tensor(88.) tensor(17.)
tensor(67.7814, grad_fn=<MeanBackward0>) tensor(18.9333) tensor(90.) tensor(16.)
tensor(78.1419, grad_fn=<MeanBackward0>) tensor(16.1000) tensor(55.) tensor(13.)
tensor(82.5383, grad_fn=<MeanBackward0>) tensor(20.0333) tensor(60.) tensor(13.)
tensor(86.7146, grad_f

tensor(98.9021, grad_fn=<MeanBackward0>) tensor(23.3000) tensor(52.) tensor(17.)
tensor(107.3513, grad_fn=<MeanBackward0>) tensor(26.) tensor(62.) tensor(18.)
tensor(106.1775, grad_fn=<MeanBackward0>) tensor(22.6000) tensor(70.) tensor(17.)
tensor(95.3635, grad_fn=<MeanBackward0>) tensor(24.3000) tensor(77.) tensor(17.)
tensor(87.8707, grad_fn=<MeanBackward0>) tensor(22.1000) tensor(83.) tensor(16.)
tensor(76.5961, grad_fn=<MeanBackward0>) tensor(19.) tensor(81.) tensor(13.)
tensor(88.5669, grad_fn=<MeanBackward0>) tensor(17.0333) tensor(87.) tensor(16.)
tensor(109.3133, grad_fn=<MeanBackward0>) tensor(15.9000) tensor(30.) tensor(19.)
tensor(97.4089, grad_fn=<MeanBackward0>) tensor(16.3667) tensor(30.) tensor(19.)
tensor(71.5623, grad_fn=<MeanBackward0>) tensor(17.6000) tensor(30.) tensor(17.)
tensor(61.3876, grad_fn=<MeanBackward0>) tensor(18.3667) tensor(36.) tensor(16.)
tensor(73.9441, grad_fn=<MeanBackward0>) tensor(19.9667) tensor(42.) tensor(16.)
tensor(87.3337, grad_fn=<MeanBack

tensor(76.6353, grad_fn=<MeanBackward0>) tensor(19.3000) tensor(77.) tensor(16.)
tensor(90.2143, grad_fn=<MeanBackward0>) tensor(17.3333) tensor(79.) tensor(14.)
tensor(58.0372, grad_fn=<MeanBackward0>) tensor(13.9667) tensor(61.) tensor(11.)
tensor(59.4835, grad_fn=<MeanBackward0>) tensor(15.9000) tensor(65.) tensor(12.)
tensor(72.9195, grad_fn=<MeanBackward0>) tensor(17.9000) tensor(48.) tensor(16.)
tensor(81.2594, grad_fn=<MeanBackward0>) tensor(19.9667) tensor(48.) tensor(17.)
tensor(107.0797, grad_fn=<MeanBackward0>) tensor(22.) tensor(55.) tensor(22.)
tensor(85.5781, grad_fn=<MeanBackward0>) tensor(19.6333) tensor(48.) tensor(16.)
tensor(100.0280, grad_fn=<MeanBackward0>) tensor(21.) tensor(56.) tensor(15.)
tensor(112.8149, grad_fn=<MeanBackward0>) tensor(23.5667) tensor(61.) tensor(19.)
tensor(98.6582, grad_fn=<MeanBackward0>) tensor(17.5667) tensor(68.) tensor(15.)
tensor(87.5940, grad_fn=<MeanBackward0>) tensor(17.4000) tensor(74.) tensor(14.)
tensor(78.1380, grad_fn=<MeanBack

tensor(35.5537, grad_fn=<MeanBackward0>) tensor(16.5000) tensor(72.) tensor(11.)
tensor(89.8415, grad_fn=<MeanBackward0>) tensor(19.1667) tensor(77.) tensor(12.)
tensor(71.6982, grad_fn=<MeanBackward0>) tensor(21.3667) tensor(82.) tensor(16.)
tensor(101.1496, grad_fn=<MeanBackward0>) tensor(25.3333) tensor(87.) tensor(20.)
tensor(124.2178, grad_fn=<MeanBackward0>) tensor(20.7333) tensor(60.) tensor(19.)
tensor(85.4628, grad_fn=<MeanBackward0>) tensor(20.4000) tensor(67.) tensor(14.)
tensor(93.0558, grad_fn=<MeanBackward0>) tensor(24.0333) tensor(75.) tensor(16.)
tensor(99.1048, grad_fn=<MeanBackward0>) tensor(24.2667) tensor(83.) tensor(18.)
tensor(99.7855, grad_fn=<MeanBackward0>) tensor(24.4667) tensor(88.) tensor(20.)
tensor(91.3570, grad_fn=<MeanBackward0>) tensor(22.9667) tensor(88.) tensor(18.)
tensor(71.4538, grad_fn=<MeanBackward0>) tensor(16.4667) tensor(84.) tensor(12.)
tensor(67.3128, grad_fn=<MeanBackward0>) tensor(15.7000) tensor(85.) tensor(13.)
tensor(83.7360, grad_fn=<M

tensor(65.2058, grad_fn=<MeanBackward0>) tensor(15.6000) tensor(32.) tensor(15.)
tensor(64.1412, grad_fn=<MeanBackward0>) tensor(17.1667) tensor(36.) tensor(16.)
tensor(66.4904, grad_fn=<MeanBackward0>) tensor(20.3333) tensor(44.) tensor(17.)
tensor(132.0375, grad_fn=<MeanBackward0>) tensor(21.3667) tensor(52.) tensor(18.)
tensor(101.9078, grad_fn=<MeanBackward0>) tensor(19.8667) tensor(57.) tensor(17.)
tensor(124.0525, grad_fn=<MeanBackward0>) tensor(20.9000) tensor(64.) tensor(17.)
tensor(97.1219, grad_fn=<MeanBackward0>) tensor(20.7000) tensor(72.) tensor(15.)
tensor(73.9089, grad_fn=<MeanBackward0>) tensor(18.4667) tensor(80.) tensor(13.)
tensor(107.3268, grad_fn=<MeanBackward0>) tensor(20.4000) tensor(89.) tensor(17.)
tensor(91.5447, grad_fn=<MeanBackward0>) tensor(18.9667) tensor(77.) tensor(16.)
tensor(70.9870, grad_fn=<MeanBackward0>) tensor(18.0333) tensor(82.) tensor(14.)
tensor(96.7027, grad_fn=<MeanBackward0>) tensor(20.3000) tensor(87.) tensor(18.)
tensor(73.6044, grad_fn=

tensor(61.8289, grad_fn=<MeanBackward0>) tensor(18.4333) tensor(80.) tensor(16.)
tensor(176.6409, grad_fn=<MeanBackward0>) tensor(17.4000) tensor(85.) tensor(16.)
tensor(59.8805, grad_fn=<MeanBackward0>) tensor(17.8667) tensor(90.) tensor(13.)
tensor(68.8324, grad_fn=<MeanBackward0>) tensor(16.3667) tensor(44.) tensor(12.)
tensor(84.2102, grad_fn=<MeanBackward0>) tensor(19.2667) tensor(49.) tensor(14.)
tensor(98.8872, grad_fn=<MeanBackward0>) tensor(19.7667) tensor(57.) tensor(15.)
tensor(113.2776, grad_fn=<MeanBackward0>) tensor(21.3333) tensor(64.) tensor(19.)
tensor(106.1274, grad_fn=<MeanBackward0>) tensor(21.1000) tensor(70.) tensor(20.)
tensor(77.9701, grad_fn=<MeanBackward0>) tensor(17.6667) tensor(75.) tensor(14.)
tensor(79.4404, grad_fn=<MeanBackward0>) tensor(19.5000) tensor(80.) tensor(16.)
tensor(83.2703, grad_fn=<MeanBackward0>) tensor(22.2333) tensor(86.) tensor(18.)
tensor(98.2500, grad_fn=<MeanBackward0>) tensor(23.5000) tensor(88.) tensor(20.)
tensor(81.2792, grad_fn=<

tensor(97.8725, grad_fn=<MeanBackward0>) tensor(20.8333) tensor(42.) tensor(20.)
tensor(92.4252, grad_fn=<MeanBackward0>) tensor(19.9333) tensor(49.) tensor(17.)
tensor(100.7776, grad_fn=<MeanBackward0>) tensor(20.) tensor(57.) tensor(15.)
tensor(67.1282, grad_fn=<MeanBackward0>) tensor(17.3000) tensor(65.) tensor(12.)
tensor(100.1377, grad_fn=<MeanBackward0>) tensor(18.7000) tensor(73.) tensor(14.)
tensor(98.4704, grad_fn=<MeanBackward0>) tensor(20.9667) tensor(79.) tensor(15.)
tensor(102.2732, grad_fn=<MeanBackward0>) tensor(17.9000) tensor(83.) tensor(15.)
tensor(97.9122, grad_fn=<MeanBackward0>) tensor(19.4667) tensor(91.) tensor(15.)
tensor(76.4080, grad_fn=<MeanBackward0>) tensor(14.4000) tensor(35.) tensor(14.)
tensor(68.7108, grad_fn=<MeanBackward0>) tensor(15.0333) tensor(35.) tensor(14.)
tensor(78.9277, grad_fn=<MeanBackward0>) tensor(16.6000) tensor(35.) tensor(17.)
tensor(52.5886, grad_fn=<MeanBackward0>) tensor(16.1333) tensor(35.) tensor(14.)
tensor(62.3826, grad_fn=<Mean

tensor(66.3192, grad_fn=<MeanBackward0>) tensor(11.4333) tensor(23.) tensor(13.)
tensor(56.6054, grad_fn=<MeanBackward0>) tensor(13.5333) tensor(26.) tensor(13.)
tensor(57.1218, grad_fn=<MeanBackward0>) tensor(14.2333) tensor(26.) tensor(12.)
tensor(68.0263, grad_fn=<MeanBackward0>) tensor(14.1667) tensor(29.) tensor(14.)
tensor(66.2264, grad_fn=<MeanBackward0>) tensor(16.1667) tensor(34.) tensor(15.)
tensor(68.3275, grad_fn=<MeanBackward0>) tensor(14.8333) tensor(38.) tensor(14.)
tensor(75.4104, grad_fn=<MeanBackward0>) tensor(14.9667) tensor(42.) tensor(14.)
tensor(82.6199, grad_fn=<MeanBackward0>) tensor(17.3000) tensor(45.) tensor(15.)
tensor(60.1912, grad_fn=<MeanBackward0>) tensor(15.3333) tensor(30.) tensor(14.)
tensor(53.3688, grad_fn=<MeanBackward0>) tensor(15.7000) tensor(36.) tensor(14.)
tensor(75.7993, grad_fn=<MeanBackward0>) tensor(17.4000) tensor(42.) tensor(17.)
tensor(71.9356, grad_fn=<MeanBackward0>) tensor(16.2000) tensor(42.) tensor(15.)
tensor(73.9530, grad_fn=<Mea

tensor(53.7326, grad_fn=<MeanBackward0>) tensor(18.4333) tensor(87.) tensor(12.)
tensor(51.0428, grad_fn=<MeanBackward0>) tensor(18.1667) tensor(39.) tensor(13.)
tensor(71.3928, grad_fn=<MeanBackward0>) tensor(22.4000) tensor(44.) tensor(19.)
tensor(69.6560, grad_fn=<MeanBackward0>) tensor(20.9333) tensor(52.) tensor(17.)
tensor(72.2576, grad_fn=<MeanBackward0>) tensor(22.6333) tensor(60.) tensor(16.)
tensor(74.8554, grad_fn=<MeanBackward0>) tensor(24.6667) tensor(69.) tensor(18.)
tensor(88.3453, grad_fn=<MeanBackward0>) tensor(26.4667) tensor(74.) tensor(20.)
tensor(89.0051, grad_fn=<MeanBackward0>) tensor(21.4000) tensor(76.) tensor(17.)
tensor(93.5174, grad_fn=<MeanBackward0>) tensor(20.8333) tensor(85.) tensor(16.)
tensor(77.9137, grad_fn=<MeanBackward0>) tensor(16.8333) tensor(86.) tensor(14.)
tensor(71.2116, grad_fn=<MeanBackward0>) tensor(13.5667) tensor(27.) tensor(13.)
tensor(69.5014, grad_fn=<MeanBackward0>) tensor(14.8667) tensor(27.) tensor(13.)
tensor(84.2666, grad_fn=<Mea

tensor(88.6710, grad_fn=<MeanBackward0>) tensor(18.6333) tensor(32.) tensor(20.)
tensor(78.3503, grad_fn=<MeanBackward0>) tensor(22.1333) tensor(37.) tensor(20.)
tensor(90.8454, grad_fn=<MeanBackward0>) tensor(20.3000) tensor(45.) tensor(19.)
tensor(100.8676, grad_fn=<MeanBackward0>) tensor(15.4000) tensor(51.) tensor(12.)
tensor(79.0045, grad_fn=<MeanBackward0>) tensor(17.6333) tensor(57.) tensor(12.)
tensor(84.8274, grad_fn=<MeanBackward0>) tensor(15.7333) tensor(62.) tensor(11.)
tensor(94.5505, grad_fn=<MeanBackward0>) tensor(18.0667) tensor(71.) tensor(14.)
tensor(90.6183, grad_fn=<MeanBackward0>) tensor(16.3333) tensor(78.) tensor(12.)
tensor(90.2206, grad_fn=<MeanBackward0>) tensor(17.4333) tensor(85.) tensor(15.)
tensor(75.0285, grad_fn=<MeanBackward0>) tensor(16.4333) tensor(35.) tensor(14.)
tensor(82.5166, grad_fn=<MeanBackward0>) tensor(17.3333) tensor(44.) tensor(16.)
tensor(60.8627, grad_fn=<MeanBackward0>) tensor(17.1667) tensor(51.) tensor(15.)
tensor(70.5747, grad_fn=<Me

In [None]:
#Deconvolution.py
#used for deconvolution of CNP history

def Deconvolute(model,cnp,Chr,CNV,End):
    '''
    Deconvolution samples the maximum action in a greedy way
    model:the trained Q-learning model
    cnp: the input CNP, shape: Number of CNP,1,44 (#Chr),50 (#regions for one chromosome)
    Chr,CNV,END: output tensor,shape: Number of CNP, maximum length of history
    output: for Chr: -1 indicates WGD, 0 indicates no action, 1~44 the chromosome
            for CNV: only valid if Chr is not -1 or 0
                     indicates the starting point (CNV//2) and the type of CNV (CNV%2==1 for gain and CNV%2==0 for loss)
            for End: only valid if Chr is not -1 or 0
                     indicates the end point for a CNV.
    '''
    max_step=int(Chr.shape[1])
    for i in range(cnp.shape[0]):
        current_cnp=cnp[i:(i+1)]
        step=0
        while(step<max_step):
            #it is also possible to manually set the switch if deemed necessary
            sigma=model.switch(current_cnp)
            res_chrom=model.Chrom_model(current_cnp,sigma)
            #find the chromosome with the maximum probability
            val,temp_Chr=res_chrom.max(1)
            temp_Chr=int(temp_Chr)
            Chr[i][step]=temp_Chr+1
            #WGD
            if (not torch.any(current_cnp-2*torch.floor(current_cnp/2)>0.5)) and torch.any(current_cnp>0.5):
                sigma_wgd=model.switch(torch.floor(current_cnp/2))
                res_chrom_wgd=model.Chrom_model(torch.floor(current_cnp/2),sigma_wgd)
                val_wgd,temp=res_chrom_wgd.max(1)
                if val_wgd>=val:
                    val=val_wgd
                    Chr[i][step]=-1
            #special action END
            val_end=torch.sum(torch.abs(current_cnp-1))*math.log(single_loci_loss)
            if val_end>=val:
                val=val_end
                Chr[i][step]=0
                break
            #if WGD
            if Chr[i][step]< -0.5:
                current_cnp=torch.floor(current_cnp/2)
            #if not WGD or END
            elif Chr[i][step]>0.5:
                #find best CNV
                chrom=current_cnp[:,0,temp_Chr,:]
                CNV[i][step]=model.CNV.find_one_cnv(chrom,sigma)
                cnv_temp=int(CNV[i][step]%2)
                start_temp=int(CNV[i][step]//2)
                #find best End
                chrom_new=chrom.clone()
                chrom_new[:,start_temp:]=chrom_new[:,start_temp:]+(cnv_temp-0.5)*2
                End[i][step]=model.End.find_one_end(chrom,chrom_new,sigma,start_temp,cnv_temp)
                #updata cnp
                print(chrom)
                print(start_temp,End[i][step],cnv_temp)
                
                current_cnp[:,0,temp_Chr,start_temp:int(End[i][step])]=current_cnp[:,0,temp_Chr,start_temp:int(End[i][step])]+(cnv_temp-0.5)*2
                
            step=step+1
            
    return Chr,CNV,End
    




if __name__=="__main__":
    model = Q_learning()
    model=Q_model
    #model.load_state_dict(torch.load(PATH))
    model.eval()
    counter_global=torch.randint(10000,(1,))[0]
    #loading simulated data
    state=Simulate_data(batch_size=1,Number_of_step=5)
    Chr=torch.zeros(state.shape[0],20)
    CNV=torch.zeros(state.shape[0],20)
    End=torch.zeros(state.shape[0],20)
    state_copy=state.clone()
    #deconvolution
    Chr,CNV,End=Deconvolute(Q_model,state,Chr,CNV,End)
    print(Chr)
    #check if the model picks the correct chromosome
    print(CNV)
    print(End)