# MDN for Classification

In [1]:
import math
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as TD
from torch.autograd import Variable
from collections import OrderedDict
%matplotlib inline
%config InlineBackend.figure_format='retina'
np.set_printoptions(precision=3)
torch.set_printoptions(precision=3)
print ("PyTorch version:[%s]."%(torch.__version__))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print ("device:[%s]."%(device))

PyTorch version:[1.7.0+cu101].
device:[cuda:0].


### Helper functions

In [2]:
# Codes copied from 'https://github.com/sksq96/pytorch-summary/tree/master/torchsummary' 
def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None):
    if dtypes == None:
        dtypes = [torch.FloatTensor]*len(input_size)
    summary_str = ''
    def register_hook(module):
        def hook(module, input, output):
            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)

            m_key = "%s-%i" % (class_name, module_idx + 1)
            summary[m_key] = OrderedDict()
            summary[m_key]["input_shape"] = list(input[0].size())
            summary[m_key]["input_shape"][0] = batch_size
            if isinstance(output, (list, tuple)):
                summary[m_key]["output_shape"] = [
                    [-1] + list(o.size())[1:] for o in output
                ]
            else:
                summary[m_key]["output_shape"] = list(output.size())
                summary[m_key]["output_shape"][0] = batch_size

            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params

        if (
            not isinstance(module, nn.Sequential)
            and not isinstance(module, nn.ModuleList)
        ):
            hooks.append(module.register_forward_hook(hook))

    # multiple inputs to the network
    if isinstance(input_size, tuple):
        input_size = [input_size]

    # batch_size of 2 for batchnorm
    x = [torch.rand(2, *in_size).type(dtype).to(device=device)
         for in_size, dtype in zip(input_size, dtypes)]

    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook
    model.apply(register_hook)

    # make a forward pass
    # print(x.shape)
    model(*x)

    # remove these hooks
    for h in hooks:
        h.remove()

    summary_str += "----------------------------------------------------------------" + "\n"
    line_new = "{:>20}  {:>25} {:>15}".format(
        "Layer (type)", "Output Shape", "Param #")
    summary_str += line_new + "\n"
    summary_str += "================================================================" + "\n"
    total_params = 0
    total_output = 0
    trainable_params = 0
    for layer in summary:
        # input_shape, output_shape, trainable, nb_params
        line_new = "{:>20}  {:>25} {:>15}".format(
            layer,
            str(summary[layer]["output_shape"]),
            "{0:,}".format(summary[layer]["nb_params"]),
        )
        total_params += summary[layer]["nb_params"]
        summary_str += line_new + "\n"
    # return summary
    return summary_str,summary
print ("Done.")

Done.


##  $\color{yellow}{\text{Mixture Logits Network (MLN) }}$ 
- 
`Cross Entropy Loss`
$ \mathcal{L}_{\text{CE}} = 
    -\sum_{d=1}^{D} y_d \log(\hat{\mu}_d)
$
where $y \in [0,1]^d$ is the target and $\hat{\mu} \in \mathbb{S}^d$ is the prediction result.
- `Weighted CE Loss`
$ \mathcal{L}_{\text{WCE}} = 
    -
    \sum_{k=1}^{K}
        \hat{\pi}_k
        \sum_{d=1}^{D} y_d \log(\hat{\mu}_d)
$
where $\hat{\pi}$, $\hat{\mu}$, and $y$ are mixture weights,
output predicitons, and labels, respectively. 
- 
`Gal Loss`
$
    \mathcal{L}_{\text{Gal}} 
    = \log \frac{1}{T} \sum_{t}
        \exp \left(
            \hat{x}_{t,c} - \log \sum_{c'} \exp \hat{x}_{t,c'}
            \right)
$
where $\hat{x_t} = f^{W} + \sigma^{W}\epsilon_t, ~ \epsilon_t \sim \mathcal{N}(0,I)$.
- 
`Mixture of Attenuated CE Loss`
$ \mathcal{L}_{\text{MACE}} 
    =
    -
    \sum_{k=1}^{K}
        \hat{\pi}_k
        \sum_{d=1}^{D}
        \frac
            {y_d \log(\hat{\mu}_{d,k})}
            {\hat{\sigma}_{d,k} + \sigma_{\text{min}}}
$
where $\sigma_{\text{min}}=1.0$ is the minimum standard deviation.


In [3]:
class MixturesOfLogits(nn.Module):
    """
        Mixture of Logits 
    """
    def __init__(self,
                 in_dim  = 64,  # input feature dimension
                 y_dim   = 10,  # output dimension
                 k       = 5,   # number of mixtures
                 sig_min = 1,   # minimum sigma
                 sig_max = None # maximum signa 
                 ):
        super(MixturesOfLogits,self).__init__()
        self.in_dim   = in_dim
        self.y_dim    = y_dim
        self.k        = k
        self.sig_min  = sig_min
        self.sig_max  = sig_max
        self.fc_pi    = nn.Linear(self.in_dim,self.k)
        self.fc_mu    = nn.Linear(self.in_dim,self.k*self.y_dim)
        self.fc_sigma = nn.Linear(self.in_dim,self.k*self.y_dim)

    def forward(self,x):
        pi_logit = self.fc_pi(x) # [N x K]
        pi       = torch.softmax(pi_logit,dim=1) # [N x K]
        mu       = self.fc_mu(x) # [N x KD]
        mu       = torch.reshape(mu,(-1,self.k,self.y_dim)) # [N x K x D]
        sigma    = self.fc_sigma(x) # [N x KD]
        sigma    = torch.reshape(sigma,(-1,self.k,self.y_dim)) # [N x K x D]
        if self.sig_max is None:
            sigma = self.sig_min + torch.exp(sigma) # [N x K x D]
        else:
            sigma = self.sig_min + (self.sig_max-self.sig_min)*torch.sigmoid(sigma) # [N x K x D]
        return pi,mu,sigma

class MixtureLogitNetwork(nn.Module):
    def __init__(self,
                 name='mln',
                 x_dim   = [1,28,28], # iput dimension 
                 k_size  = 3,         # kernel size
                 c_dims  = [32,64],   # channel dimensions for conv layer(s)
                 p_sizes = [2,2],     # pooling sizes
                 h_dims  = [128],     # hidden dimensions for dense layer(s)
                 y_dim   = 10,        # output dimension
                 USE_BN  = True,      # whether to use batch norm   
                 k       = 5,         # number of mixtures
                 sig_min = 1,         # $\sigma_{min}$
                 sig_max = None,      # $\sigma_{max}$
                 mu_min  = -3,        # minimum $\mu$ while initializing bias 
                 mu_max  = +3,        # maximum $\mu$ while initializing bias 
                 ):
        super(MixtureLogitNetwork,self).__init__()
        self.name    = name
        self.x_dim   = x_dim
        self.k_size  = k_size
        self.c_dims  = c_dims
        self.p_sizes = p_sizes
        self.h_dims  = h_dims
        self.y_dim   = y_dim
        self.USE_BN  = USE_BN
        self.k       = k
        self.sig_min = sig_min
        self.sig_max = sig_max
        self.mu_min  = mu_min
        self.mu_max  = mu_max

        # Build graph
        self.build_graph()

        # Initialize parameters        
        self.init_param() 

    def build_graph(self):
        self.layers = []
        # Conv layers
        prev_c_dim = self.x_dim[0] # input channel 
        for (c_dim,p_size) in zip(self.c_dims,self.p_sizes):
            self.layers.append(
                nn.Conv2d(
                    in_channels  = prev_c_dim,
                    out_channels = c_dim,
                    kernel_size  = self.k_size,
                    stride       = (1,1),
                    padding      = self.k_size//2
                    ) # conv
                )
            if self.USE_BN:
                self.layers.append(
                    nn.BatchNorm2d(num_features=c_dim)
                )
            self.layers.append(nn.ReLU())
            self.layers.append(
                nn.MaxPool2d(kernel_size=(p_size,p_size),stride=(p_size,p_size))
                )
            # self.layers.append(nn.Dropout2d(p=0.1))  # p: to be zero-ed
            prev_c_dim = c_dim 
        # Dense layers
        self.layers.append(nn.Flatten())
        p_prod = np.prod(self.p_sizes)
        prev_h_dim = prev_c_dim*(self.x_dim[1]//p_prod)*(self.x_dim[2]//p_prod)
        for h_dim in self.h_dims:
            self.layers.append(
                nn.Linear(
                    in_features  = prev_h_dim,
                    out_features = h_dim,
                    bias         = True
                    )
                )
            self.layers.append(nn.ReLU(True))  # activation
            self.layers.append(nn.Dropout2d(p=0.1))  # p: to be zero-ed
            prev_h_dim = h_dim
        # Final mixture of logits layer
        mol = MixturesOfLogits(
            in_dim  = prev_h_dim,  
            y_dim   = self.y_dim, 
            k       = self.k,
            sig_min = self.sig_min,
            sig_max = self.sig_max
        )
        self.layers.append(mol)

        # Concatanate all layers
        self.net = nn.Sequential()
        for l_idx,layer in enumerate(self.layers):
            layer_name = "%s_%02d"%(type(layer).__name__.lower(),l_idx)
            self.net.add_module(layer_name,layer)

    def init_param(self): 
        for m in self.modules():
            if isinstance(m,nn.Conv2d): # init conv
                nn.init.kaiming_normal_(m.weight)
                nn.init.zeros_(m.bias)
            if isinstance(m,nn.Linear): # lnit dense
                nn.init.kaiming_normal_(m.weight)
                nn.init.zeros_(m.bias)
        """
        Heuristic: fc_mu.bias ~ Uniform(mu_min,mu_max)
        """
        self.layers[-1].fc_mu.bias.data.uniform_(self.mu_min,self.mu_max)

    def forward(self,x):
        return self.net(x)

# Instantiate mixture of logits layer 
M = MixtureLogitNetwork(
    name='mln',x_dim=[1,28,28],k_size=3,c_dims=[32,64],p_sizes=[2,2],
    h_dims=[128],y_dim=10,USE_BN=True,
    k=3,sig_min=1,sig_max=None,
    mu_min=-3,mu_max =+3).to(device)
print ("Done.")

Done.


##  $\color{yellow}{\text{Loss function}}$ 
`Mixture of Attenuated CE Loss`
$ \mathcal{L}_{\text{MACE}} 
    =
    \sum_{k=1}^{K}
        \hat{\pi}_k
        \sum_{d=1}^{D}
        \frac
            {-y_d \log(\hat{\mu}_{d,k})}
            {\hat{\sigma}_{d,k} 
            }
    + 
    \frac{1}{D}
    \sum_{d=1}^{D}
    \sum_{k=1}^{K}
    \hat{\pi}_k \hat{\sigma}_{d,k}
$

In [4]:
def np2tc(x_np): return torch.from_numpy(x_np).float().to(device)
def tc2np(x_tc): return x_tc.detach().cpu().numpy()

def mdn_gather(pi,mu,sigma):
    """
    pi:     [N x K]
    mu:     [N x K x D]
    sigma:  [N x K x D]
    """
    max_idx = torch.argmax(pi,dim=1) # [N]
    idx_gather = max_idx.unsqueeze(dim=-1).repeat(1,mu.shape[2]).unsqueeze(1) # [N x 1 x D]
    mu_sel = torch.gather(mu,dim=1,index=idx_gather).squeeze(dim=1) # [N x D]
    sigma_sel = torch.gather(sigma,dim=1,index=idx_gather).squeeze(dim=1) # [N x D]
    out = {'max_idx':max_idx,'idx_gather':idx_gather,
           'mu_sel':mu_sel,'sigma_sel':sigma_sel}
    return out

def mace_loss(pi,mu,sigma,target,alea_weight=1.0):
    """
    Mixture of attenuated CE loss
        pi:      [N x K]
        mu:      [N x K x D]
        sigma:   [N x K x D]
        target:  [N x D]
    """
    # softmax \mu
    mu_hat = torch.softmax(mu,dim=2) # logit to prob [N x K x D]
    log_mu_hat = torch.log(mu_hat+1e-5) # [N x K x D]
    
    # Expanded \pi 
    pi_usq = torch.unsqueeze(pi,2) # [N x K x 1]
    pi_exp = pi_usq.expand_as(sigma) # [N x K x D]

    # Expanded target
    target_usq =  torch.unsqueeze(target,1) # [N x 1 x D]
    target_exp =  target_usq.expand_as(sigma) # [N x K x D]

    # Loss
    # ce_loss_exp = -target_exp*log_mu_hat # [N x K x D]
    ce_loss_exp = -target_exp*log_mu_hat # [N x K x D]
    atte_ce = ce_loss_exp / sigma # attenuated CE loss [N x K x D]
    waces = torch.sum(torch.mul(pi_exp,atte_ce),dim=1) # weighted attenuated CE loss [N x D]
    wace = torch.mean(waces,dim=1) # N
    aleas = alea_weight*torch.sum(pi_exp*sigma,dim=1)# aleatoric uncertainty [N x D]
    alea = torch.mean(aleas,dim=1) # [N]

    # Accumulate loss 
    loss = wace + alea # [N]

    # Average loss
    wace_avg = torch.mean(wace) # [1]
    alea_avg = torch.mean(alea) # [1]
    loss_avg = torch.mean(loss) # [1]


    out = {'mu_hat':mu_hat,'log_mu_hat':log_mu_hat,
           'pi_usq':pi_usq,'pi_exp':pi_exp,
           'sigma':sigma,
           'target_usq':target_usq,'target_exp':target_exp,
           'ce_loss_exp':ce_loss_exp,'atte_ce':atte_ce,
           'waces':waces,'wace':wace,
           'aleas':aleas,'alea':alea,
           'loss':loss,
           'wace_avg':wace_avg,'alea_avg':alea_avg,'loss_avg':loss_avg}
    return out

def gmm_uncertainties(pi, mu, sigma):
    # Compute Epistemic Uncertainty
    M = 0.1
    # pi = torch.softmax(M*pi,1) # (optional) heuristics 
    pi_usq = torch.unsqueeze(pi,2) # [N x K x 1]
    pi_exp = pi_usq.expand_as(sigma) # [N x K x D]

    # For classification problems, we use softmax(mu) instead of me
    mu = torch.softmax(mu,dim=2) # logit to prob [N x K x D]

    mu_avg = torch.sum(torch.mul(pi_exp,mu),dim=1).unsqueeze(1) # [N x 1 x D]
    mu_exp = mu_avg.expand_as(mu) # [N x K x D]
    mu_diff_sq = torch.square(mu-mu_exp) # [N x K x D]
    epis_unct = torch.sum(torch.mul(pi_exp,mu_diff_sq), dim=1)  # [N x D]

    # Compute Aleatoric Uncertainty
    alea_unct = torch.sum(torch.mul(pi_exp,sigma), dim=1)  # [N x D]

    # Sqaure root 
    epis_unct = torch.sqrt(epis_unct) # [N x D]
    alea_unct = torch.sqrt(alea_unct) # [N x D]
    epis_unct_avg = torch.mean(epis_unct,dim=1) # [N]
    alea_unct_avg = torch.mean(alea_unct,dim=1) # [N]

    # Out
    unct_out = {'epis_unct':epis_unct,'alea_unct':alea_unct,
           'epis_unct_avg':epis_unct_avg,'alea_unct_avg':alea_unct_avg}

    return unct_out
    
# Demo run to check the loss 
M = MixtureLogitNetwork(
    name='mln',x_dim=[1,28,28],k_size=3,c_dims=[32,64],p_sizes=[2,2],
    h_dims=[128],y_dim=10,USE_BN=True).to(device)

x_np = np.random.rand(2,1,28,28)
x_tc = np2tc(x_np)
pi_tc,mu_tc,sigma_tc = M.forward(x_tc) # forward path of MLN
target_tc = F.one_hot(torch.randint(low=0,high=10,size=(2,)),num_classes=10).to(device) # random one-hot
out = mace_loss(pi_tc,mu_tc,sigma_tc,target_tc) # mixture of CE 
unct_out = gmm_uncertainties(pi_tc,mu_tc,sigma_tc)

print ('pi_tc:         %s'%(tc2np(target_tc).shape,))
print ('mu_tc:         %s'%(tc2np(mu_tc).shape,))
print ('sigma_tc:      %s'%(tc2np(sigma_tc).shape,))
print ('target_tc:     %s'%(tc2np(target_tc).shape,))
print ('=>')
print ('mu_hat:        %s'%(tc2np(out['mu_hat']).shape,))
print ('log_mu_hat:    %s'%(tc2np(out['log_mu_hat']).shape,))
print ('pi_usq:        %s'%(tc2np(out['pi_usq']).shape,))
print ('pi_exp:        %s'%(tc2np(out['pi_exp']).shape,))
print ('target_usq:    %s'%(tc2np(out['target_usq']).shape,))
print ('target_exp:    %s'%(tc2np(out['target_exp']).shape,))
print ('ce_loss_exp:   %s'%(tc2np(out['ce_loss_exp']).shape,))
print ('atte_ce:       %s'%(tc2np(out['atte_ce']).shape,))
print ('waces:         %s'%(tc2np(out['waces']).shape,))
print ('wace:          %s'%(tc2np(out['wace']).shape,))
print ('aleas:         %s'%(tc2np(out['aleas']).shape,))
print ('alea:          %s'%(tc2np(out['alea']).shape,))
print ('loss:          %s'%(tc2np(out['loss']).shape,))

pi_tc:         (2, 10)
mu_tc:         (2, 5, 10)
sigma_tc:      (2, 5, 10)
target_tc:     (2, 10)
=>
mu_hat:        (2, 5, 10)
log_mu_hat:    (2, 5, 10)
pi_usq:        (2, 5, 1)
pi_exp:        (2, 5, 10)
target_usq:    (2, 1, 10)
target_exp:    (2, 5, 10)
ce_loss_exp:   (2, 5, 10)
atte_ce:       (2, 5, 10)
waces:         (2, 10)
wace:          (2,)
aleas:         (2, 10)
alea:          (2,)
loss:          (2,)


### Summarize the model

In [5]:
summary_str,summary = summary_string(M,input_size=(1,28,28),device=device)
print (summary_str)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 28, 28]             320
       BatchNorm2d-2           [-1, 32, 28, 28]              64
              ReLU-3           [-1, 32, 28, 28]               0
         MaxPool2d-4           [-1, 32, 14, 14]               0
            Conv2d-5           [-1, 64, 14, 14]          18,496
       BatchNorm2d-6           [-1, 64, 14, 14]             128
              ReLU-7           [-1, 64, 14, 14]               0
         MaxPool2d-8             [-1, 64, 7, 7]               0
           Flatten-9                 [-1, 3136]               0
           Linear-10                  [-1, 128]         401,536
             ReLU-11                  [-1, 128]               0
        Dropout2d-12                  [-1, 128]               0
           Linear-13                    [-1, 5]             645
           Linear-14                   

### Check parameters

In [6]:
n_param = 0
for p_idx,(param_name,param) in enumerate(M.named_parameters()):
    if param.requires_grad:
        param_numpy = param.detach().cpu().numpy() # to numpy array 
        n_param += len(param_numpy.reshape(-1))
        print ("[%02d] name:[%s] shape:[%s]."%(p_idx,param_name,param_numpy.shape))
        print ("     first 3 values:%s"%(param_numpy.reshape(-1)[:3]))
print ("Total number of parameters:[%s]."%(format(n_param,',d')))

[00] name:[net.conv2d_00.weight] shape:[(32, 1, 3, 3)].
     first 3 values:[-0.481 -0.531 -1.441]
[01] name:[net.conv2d_00.bias] shape:[(32,)].
     first 3 values:[0. 0. 0.]
[02] name:[net.batchnorm2d_01.weight] shape:[(32,)].
     first 3 values:[1. 1. 1.]
[03] name:[net.batchnorm2d_01.bias] shape:[(32,)].
     first 3 values:[0. 0. 0.]
[04] name:[net.conv2d_04.weight] shape:[(64, 32, 3, 3)].
     first 3 values:[-0.028  0.012 -0.068]
[05] name:[net.conv2d_04.bias] shape:[(64,)].
     first 3 values:[0. 0. 0.]
[06] name:[net.batchnorm2d_05.weight] shape:[(64,)].
     first 3 values:[1. 1. 1.]
[07] name:[net.batchnorm2d_05.bias] shape:[(64,)].
     first 3 values:[0. 0. 0.]
[08] name:[net.linear_09.weight] shape:[(128, 3136)].
     first 3 values:[ 0.023  0.041 -0.054]
[09] name:[net.linear_09.bias] shape:[(128,)].
     first 3 values:[0. 0. 0.]
[10] name:[net.mixturesoflogits_12.fc_pi.weight] shape:[(5, 128)].
     first 3 values:[-0.03   0.151  0.103]
[11] name:[net.mixturesoflogit

### Demo forward path

In [7]:
# Demo instantiate
M = MixtureLogitNetwork(
    name='mln',x_dim=[1,28,28],k_size=3,c_dims=[32,64],p_sizes=[2,2],
    h_dims=[128],y_dim=10,USE_BN=True,
    k=3,sig_min=1,sig_max=None,
    mu_min=-3,mu_max =+3).to(device)
# Demo forward path 
x_np = np.random.rand(2,1,28,28)
x_tc = np2tc(x_np)
pi_tc,mu_tc,sigma_tc = M.forward(x_tc) # forward path of MLN
pi_np,mu_np,sigma_np = tc2np(pi_tc),tc2np(mu_tc),tc2np(sigma_tc)
out = mdn_gather(pi_tc,mu_tc,sigma_tc)
mu_sel_np = tc2np(out['mu_sel'])
print ('x_np:      %s'%(x_np.shape,))
print ('=>')
print ('pi_np:     %s'%(pi_np.shape,)) # [N x K]
print ('mu_np:     %s'%(mu_np.shape,)) # [N x K x D]
print ('sigma_np:  %s'%(sigma_np.shape,)) # [N x K x D]
print ('=>')
print ('mu_sel_np: %s'%(mu_sel_np.shape,)) # [N x D]

x_np:      (2, 1, 28, 28)
=>
pi_np:     (2, 3)
mu_np:     (2, 3, 10)
sigma_np:  (2, 3, 10)
=>
mu_sel_np: (2, 10)


### Dataset

In [8]:
from torchvision import datasets,transforms
mnist_train = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)
mnist_test = datasets.MNIST(root='./data/',train=False,transform=transforms.ToTensor(),download=True)
mnist_train.targets = mnist_train.targets # manipulate train labels
BATCH_SIZE = 64
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=BATCH_SIZE,shuffle=True,num_workers=1)
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=BATCH_SIZE,shuffle=True,num_workers=1)
print ("Done.")

Done.


### Evaluation function

In [9]:
def func_eval(model,data_iter,device):
    with torch.no_grad():
        n_total,n_correct,epis_unct_sum,alea_unct_sum = 0,0,0,0
        model.eval() # evaluate (affects DropOut and BN)
        for batch_in,batch_out in data_iter:
            # Foraward path
            y_trgt = batch_out.to(device)
            pi,mu,sigma = model.forward(batch_in.view(-1,1,28,28).to(device))
            out = mdn_gather(pi,mu,sigma)
            model_pred = out['mu_sel']

            # Uncertainty 
            unct_out = gmm_uncertainties(pi,mu,sigma)
            epis_unct = unct_out['epis_unct'] # [N]
            alea_unct = unct_out['alea_unct'] # [N]
            epis_unct_sum += torch.sum(epis_unct)
            alea_unct_sum += torch.sum(alea_unct)

            # Check
            _,y_pred = torch.max(model_pred,1)
            n_correct += (y_pred==y_trgt).sum().item()
            n_total += batch_in.size(0)
        val_accr = (n_correct/n_total)
        epis_unct_avg = (epis_unct_sum/n_total).detach().cpu().item()
        alea_unct_avg = (alea_unct_sum/n_total).detach().cpu().item()
        model.train() # back to train mode 
        out_eval = {'val_accr':val_accr,
                    'epis_unct_avg':epis_unct_avg,'alea_unct_avg':alea_unct_avg}
    return out_eval
print ("Done")

Done


In [10]:
M.init_param()
train_accr = func_eval(M,train_iter,device)['val_accr']
test_accr = func_eval(M,test_iter,device)['val_accr']
print ("train_accr:[%.3f] test_accr:[%.3f]."%(train_accr,test_accr))

train_accr:[0.169] test_accr:[0.170].


### Train with clean data

In [11]:
np.random.seed(seed=0)
torch.manual_seed(seed=0)
M = MixtureLogitNetwork(
    name='mln',x_dim=[1,28,28],k_size=3,c_dims=[32,64],p_sizes=[2,2],
    h_dims=[128],y_dim=10,USE_BN=False,k=5,
    sig_min=0.01,sig_max=3,
    mu_min=-3,mu_max=+3).to(device)
M.init_param()
optm = optim.Adam(M.parameters(),lr=1e-3,weight_decay=1e-6)
M.train() # train mode

# Re-define the train iterator
mnist_train.targets = mnist_train.targets # manipulate train labels
BATCH_SIZE = 64
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=BATCH_SIZE,shuffle=True,num_workers=1)

EPOCHS,print_every = 10,1
for epoch in range(EPOCHS):
    loss_sum,wace_sum,alea_sum = 0,0,0
    for batch_in,batch_out in train_iter:
        # Forward path
        pi,mu,sigma = M.forward(batch_in.view(-1,1,28,28).to(device)) 
        target = torch.eye(M.y_dim)[batch_out].to(device)
        mace_loss_out = mace_loss(pi,mu,sigma,target,alea_weight=0.5) # mixture of CE 
        loss_out = mace_loss_out['loss_avg']
        wace_out = mace_loss_out['wace_avg']
        alea_out = mace_loss_out['alea_avg']
        # Update 
        optm.zero_grad() # reset gradient 
        loss_out.backward() # backpropagate
        optm.step() # optimizer update
        # Track losses 
        loss_sum += loss_out
        wace_sum += wace_out
        alea_sum += alea_out
    loss_avg = loss_sum/len(train_iter)
    wace_avg = wace_sum/len(train_iter)
    alea_avg = alea_sum/len(train_iter)
    # Print
    if ((epoch%print_every)==0) or (epoch==(EPOCHS-1)):
        train_res = func_eval(M,train_iter,device)
        test_res  = func_eval(M,test_iter,device)
        print ("epoch:[%d] loss:[%.3f]=(wace:%.3f+alea:%.3f) train_accr:[%.3f] test_accr:[%.3f]."%
               (epoch,loss_avg,wace_avg,alea_avg,
                train_res['val_accr'],test_res['val_accr'])) 
        print (" [Train] alea:[%.3f] epis:[%.3f]"%
               (train_res['alea_unct_avg'],train_res['epis_unct_avg']))
        print ("  [Test] alea:[%.3f] epis:[%.3f]"%
               (test_res['alea_unct_avg'],test_res['epis_unct_avg']))

print ("Done")

epoch:[0] loss:[0.152]=(wace:0.069+alea:0.084) train_accr:[0.974] test_accr:[0.974].
 [Train] alea:[2.257] epis:[0.004]
  [Test] alea:[2.244] epis:[0.004]
epoch:[1] loss:[0.068]=(wace:0.033+alea:0.035) train_accr:[0.982] test_accr:[0.981].
 [Train] alea:[2.061] epis:[0.002]
  [Test] alea:[2.014] epis:[0.002]
epoch:[2] loss:[0.048]=(wace:0.022+alea:0.026) train_accr:[0.984] test_accr:[0.983].
 [Train] alea:[1.724] epis:[0.001]
  [Test] alea:[1.700] epis:[0.001]
epoch:[3] loss:[0.040]=(wace:0.018+alea:0.021) train_accr:[0.989] test_accr:[0.987].
 [Train] alea:[1.572] epis:[0.000]
  [Test] alea:[1.555] epis:[0.000]
epoch:[4] loss:[0.033]=(wace:0.015+alea:0.019) train_accr:[0.992] test_accr:[0.989].
 [Train] alea:[1.429] epis:[0.000]
  [Test] alea:[1.404] epis:[0.000]
epoch:[5] loss:[0.028]=(wace:0.013+alea:0.016) train_accr:[0.993] test_accr:[0.989].
 [Train] alea:[1.341] epis:[0.000]
  [Test] alea:[1.319] epis:[0.000]
epoch:[6] loss:[0.030]=(wace:0.014+alea:0.017) train_accr:[0.993] test

### Train with random shuffle noise

In [12]:
np.random.seed(seed=0)
torch.manual_seed(seed=0)
M = MixtureLogitNetwork(
    name='mln',x_dim=[1,28,28],k_size=3,c_dims=[32,64],p_sizes=[2,2],
    h_dims=[128],y_dim=10,USE_BN=False,k=5,
    sig_min=0.01,sig_max=3,
    mu_min=-3,mu_max=+3).to(device)
M.init_param()
optm = optim.Adam(M.parameters(),lr=1e-3,weight_decay=1e-6)
M.train() # train mode

# Re-define the train iterator
mnist_train = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)
n_train = len(mnist_train)
corrupt_rate = 0.5 # random shuffle rate 
n_corrupt = int(n_train*corrupt_rate)
r_idx = np.random.permutation(n_train)[:n_corrupt]
mnist_train.targets[r_idx] = torch.randint(low=0,high=10,size=(n_corrupt,)) # random label 
BATCH_SIZE = 64
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=BATCH_SIZE,shuffle=True,num_workers=1)

EPOCHS,print_every = 10,1
for epoch in range(EPOCHS):
    loss_sum,wace_sum,alea_sum = 0,0,0
    for batch_in,batch_out in train_iter:
        # Forward path
        pi,mu,sigma = M.forward(batch_in.view(-1,1,28,28).to(device)) 
        target = torch.eye(M.y_dim)[batch_out].to(device)
        mace_loss_out = mace_loss(pi,mu,sigma,target,alea_weight=0.5) # mixture of CE 
        loss_out = mace_loss_out['loss_avg']
        wace_out = mace_loss_out['wace_avg']
        alea_out = mace_loss_out['alea_avg']
        # Update 
        optm.zero_grad() # reset gradient 
        loss_out.backward() # backpropagate
        optm.step() # optimizer update
        # Track losses 
        loss_sum += loss_out
        wace_sum += wace_out
        alea_sum += alea_out
    loss_avg = loss_sum/len(train_iter)
    wace_avg = wace_sum/len(train_iter)
    alea_avg = alea_sum/len(train_iter)
    # Print
    if ((epoch%print_every)==0) or (epoch==(EPOCHS-1)):
        train_res = func_eval(M,train_iter,device)
        test_res  = func_eval(M,test_iter,device)
        print ("epoch:[%d] loss:[%.3f]=(wace:%.3f+alea:%.3f) train_accr:[%.3f] test_accr:[%.3f]."%
               (epoch,loss_avg,wace_avg,alea_avg,
                train_res['val_accr'],test_res['val_accr'])) 
        print (" [Train] alea:[%.3f] epis:[%.3f]"%
               (train_res['alea_unct_avg'],train_res['epis_unct_avg']))
        print ("  [Test] alea:[%.3f] epis:[%.3f]"%
               (test_res['alea_unct_avg'],test_res['epis_unct_avg']))

print ("Done")

epoch:[0] loss:[0.619]=(wace:0.297+alea:0.322) train_accr:[0.529] test_accr:[0.957].
 [Train] alea:[7.769] epis:[0.027]
  [Test] alea:[7.765] epis:[0.027]
epoch:[1] loss:[0.589]=(wace:0.289+alea:0.300) train_accr:[0.534] test_accr:[0.964].
 [Train] alea:[7.251] epis:[0.009]
  [Test] alea:[7.240] epis:[0.009]
epoch:[2] loss:[0.583]=(wace:0.287+alea:0.296) train_accr:[0.539] test_accr:[0.972].
 [Train] alea:[7.687] epis:[0.011]
  [Test] alea:[7.693] epis:[0.011]
epoch:[3] loss:[0.578]=(wace:0.285+alea:0.293) train_accr:[0.540] test_accr:[0.972].
 [Train] alea:[7.451] epis:[0.007]
  [Test] alea:[7.443] epis:[0.007]
epoch:[4] loss:[0.573]=(wace:0.283+alea:0.290) train_accr:[0.541] test_accr:[0.969].
 [Train] alea:[7.526] epis:[0.007]
  [Test] alea:[7.531] epis:[0.007]
epoch:[5] loss:[0.566]=(wace:0.279+alea:0.287) train_accr:[0.543] test_accr:[0.967].
 [Train] alea:[7.439] epis:[0.004]
  [Test] alea:[7.440] epis:[0.004]
epoch:[6] loss:[0.558]=(wace:0.276+alea:0.282) train_accr:[0.545] test

### Train with random permutation noise

In [13]:
np.random.seed(seed=0)
torch.manual_seed(seed=0)
M = MixtureLogitNetwork(
    name='mln',x_dim=[1,28,28],k_size=3,c_dims=[32,64],p_sizes=[2,2],
    h_dims=[128],y_dim=10,USE_BN=False,k=5,
    sig_min=0.01,sig_max=3,
    mu_min=-3,mu_max=+3).to(device)
M.init_param()
optm = optim.Adam(M.parameters(),lr=1e-3,weight_decay=1e-6)
M.train() # train mode

# Re-define the train iterator
mnist_train = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)
n_train = len(mnist_train)
corrupt_rate = 0.3 # random permutation rate 
targets_bu = mnist_train.targets
permute_targets = [1,2,3,4,5,6,7,8,9,0] # shift label 
for idx in range(10):
    sel_idx = torch.where(targets_bu==idx)[0]
    n_sel   = sel_idx.shape[0]
    corrupt_idx = np.random.permutation(n_sel)[:int(n_sel*corrupt_rate)]
    mnist_train.targets[sel_idx[corrupt_idx]] = permute_targets[idx]
BATCH_SIZE = 64
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=BATCH_SIZE,shuffle=True,num_workers=1)

EPOCHS,print_every = 10,1
for epoch in range(EPOCHS):
    loss_sum,wace_sum,alea_sum = 0,0,0
    for batch_in,batch_out in train_iter:
        # Forward path
        pi,mu,sigma = M.forward(batch_in.view(-1,1,28,28).to(device)) 
        target = torch.eye(M.y_dim)[batch_out].to(device)
        mace_loss_out = mace_loss(pi,mu,sigma,target,alea_weight=0.5) # mixture of CE 
        loss_out = mace_loss_out['loss_avg']
        wace_out = mace_loss_out['wace_avg']
        alea_out = mace_loss_out['alea_avg']
        # Update 
        optm.zero_grad() # reset gradient 
        loss_out.backward() # backpropagate
        optm.step() # optimizer update
        # Track losses 
        loss_sum += loss_out
        wace_sum += wace_out
        alea_sum += alea_out
    loss_avg = loss_sum/len(train_iter)
    wace_avg = wace_sum/len(train_iter)
    alea_avg = alea_sum/len(train_iter)
    # Print
    if ((epoch%print_every)==0) or (epoch==(EPOCHS-1)):
        train_res = func_eval(M,train_iter,device)
        test_res  = func_eval(M,test_iter,device)
        print ("epoch:[%d] loss:[%.3f]=(wace:%.3f+alea:%.3f) train_accr:[%.3f] test_accr:[%.3f]."%
               (epoch,loss_avg,wace_avg,alea_avg,
                train_res['val_accr'],test_res['val_accr'])) 
        print (" [Train] alea:[%.3f] epis:[%.3f]"%
               (train_res['alea_unct_avg'],train_res['epis_unct_avg']))
        print ("  [Test] alea:[%.3f] epis:[%.3f]"%
               (test_res['alea_unct_avg'],test_res['epis_unct_avg']))

print ("Done")

epoch:[0] loss:[0.411]=(wace:0.192+alea:0.219) train_accr:[0.659] test_accr:[0.930].
 [Train] alea:[5.737] epis:[0.022]
  [Test] alea:[5.717] epis:[0.022]
epoch:[1] loss:[0.335]=(wace:0.162+alea:0.174) train_accr:[0.657] test_accr:[0.918].
 [Train] alea:[5.203] epis:[0.012]
  [Test] alea:[5.164] epis:[0.012]
epoch:[2] loss:[0.317]=(wace:0.154+alea:0.163) train_accr:[0.681] test_accr:[0.962].
 [Train] alea:[4.874] epis:[0.005]
  [Test] alea:[4.881] epis:[0.005]
epoch:[3] loss:[0.303]=(wace:0.147+alea:0.156) train_accr:[0.690] test_accr:[0.976].
 [Train] alea:[4.483] epis:[0.002]
  [Test] alea:[4.490] epis:[0.002]
epoch:[4] loss:[0.293]=(wace:0.143+alea:0.150) train_accr:[0.691] test_accr:[0.973].
 [Train] alea:[4.628] epis:[0.003]
  [Test] alea:[4.638] epis:[0.003]
epoch:[5] loss:[0.283]=(wace:0.138+alea:0.145) train_accr:[0.695] test_accr:[0.982].
 [Train] alea:[4.059] epis:[0.001]
  [Test] alea:[4.067] epis:[0.001]
epoch:[6] loss:[0.275]=(wace:0.134+alea:0.141) train_accr:[0.696] test