# MLP探索
- 探索削减精度对预测的影响
- 探索参数成倍变化对预测的影响
- 探索训练时使用批标准化，预测时不使用批标准化对预测的影响

In [72]:
import torch as pt
import torchvision as ptv
import numpy as np
train_set = ptv.datasets.MNIST("../../pytorch_database/mnist/train",train=True,transform=ptv.transforms.ToTensor(),download=True)
test_set = ptv.datasets.MNIST("../../pytorch_database/mnist/test",train=False,transform=ptv.transforms.ToTensor(),download=True)
train_dataset = pt.utils.data.DataLoader(train_set,batch_size=100)
test_dataset = pt.utils.data.DataLoader(test_set,batch_size=100)

def AccuarcyCompute(pred,label):
    pred = pred.cpu().data.numpy()
    label = label.cpu().data.numpy()
    test_np = (np.argmax(pred,1) == label)
    test_np = np.float32(test_np)
    return np.mean(test_np)

class MLP(pt.nn.Module):
    def __init__(self):
        super(MLP,self).__init__()
        self.fc1 = pt.nn.Linear(784,512)
        self.norm1 = pt.nn.BatchNorm1d(512,momentum=0.1,affine=False)
        self.fc2 = pt.nn.Linear(512,128)
        self.norm2 = pt.nn.BatchNorm2d(128,momentum=0.1,affine=False)
        self.fc3 = pt.nn.Linear(128,10)
        
    def forward(self,din):
        din = din.view(-1,28*28)
        dout = pt.nn.functional.relu(self.norm1(self.fc1(din)))
        dout = pt.nn.functional.relu(self.norm2(self.fc2(dout)))
        return pt.nn.functional.softmax(self.fc3(dout))
model_norm = MLP().cuda()
print(model_norm)

optimizer = pt.optim.SGD(model_norm.parameters(),lr=0.01,momentum=0.9)
lossfunc = pt.nn.CrossEntropyLoss().cuda()
for x in range(2):
    for i,data in enumerate(train_dataset):
    
        optimizer.zero_grad()
    
        (inputs,labels) = data
        inputs = pt.autograd.Variable(inputs).cuda()
        labels = pt.autograd.Variable(labels).cuda()
    
        outputs = model_norm(inputs)
    
        loss = lossfunc(outputs,labels)
        loss.backward()
    
        optimizer.step()
    
        if i % 200 == 0:
            print(i,":",AccuarcyCompute(outputs,labels))

MLP (
  (fc1): Linear (784 -> 512)
  (norm1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=False)
  (fc2): Linear (512 -> 128)
  (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=False)
  (fc3): Linear (128 -> 10)
)
0 : 0.11
200 : 0.9
400 : 0.96
0 : 0.96
200 : 0.96
400 : 0.97


In [73]:
optimizer.zero_grad()
accuarcy_list = []
for i,(inputs,labels) in enumerate(test_dataset):
    inputs = pt.autograd.Variable(inputs).cuda()
    labels = pt.autograd.Variable(labels).cuda()
    outputs = model_norm(inputs)
    accuarcy_list.append(AccuarcyCompute(outputs,labels))
print(sum(accuarcy_list) / len(accuarcy_list))
pt.save(model_norm.state_dict(),"../../pytorch_model/mlp/explore_params/mlp_params.pt")

0.965700003505


## 削减精度对网络的影响

In [51]:
mlp_low = MLP().cuda()
mlp_low.load_state_dict(pt.load("../../pytorch_model/mlp/explore_params/mlp_params.pt"))
print(mlp_low)
for name,f in mlp_low.named_parameters():
    f.data = f.data * 100
    f.data = f.data.int().float()
    f.data = f.data / 100
#     print(name,f)
accuarcy_list = []
for i,(inputs,labels) in enumerate(test_dataset):
    inputs = pt.autograd.Variable(inputs).cuda()
    labels = pt.autograd.Variable(labels).cuda()
    outputs = mlp_low(inputs)
    accuarcy_list.append(AccuarcyCompute(outputs,labels))
print(sum(accuarcy_list) / len(accuarcy_list))

MLP (
  (fc1): Linear (784 -> 512)
  (norm1): BatchNorm1d(512, eps=1e-05, momentum=0.5, affine=True)
  (fc2): Linear (512 -> 128)
  (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.5, affine=True)
  (fc3): Linear (128 -> 10)
)
0.962100004554


In [52]:
mlp_low = MLP().cuda()
mlp_low.load_state_dict(pt.load("../../pytorch_model/mlp/explore_params/mlp_params.pt"))
print(mlp_low)
for name,f in mlp_low.named_parameters():
    f.data = f.data * 10
    f.data = f.data.int().float()
    f.data = f.data / 10
#     print(name,f)
accuarcy_list = []
for i,(inputs,labels) in enumerate(test_dataset):
    inputs = pt.autograd.Variable(inputs).cuda()
    labels = pt.autograd.Variable(labels).cuda()
    outputs = mlp_low(inputs)
    accuarcy_list.append(AccuarcyCompute(outputs,labels))
print(sum(accuarcy_list) / len(accuarcy_list))

MLP (
  (fc1): Linear (784 -> 512)
  (norm1): BatchNorm1d(512, eps=1e-05, momentum=0.5, affine=True)
  (fc2): Linear (512 -> 128)
  (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.5, affine=True)
  (fc3): Linear (128 -> 10)
)
0.0974000002816


由上面的比较发现，mlp中的精度削减可能存在一个阈值。以上面为例，削减参数精度为小数点后2位时对结果几乎没有影响；而削减为小数点后1位时，结果已无法接受

## 参数成倍变化对网络是否有影响

In [56]:
mlp_double = MLP().cuda()
mlp_double.load_state_dict(pt.load("../../pytorch_model/mlp/explore_params/mlp_params.pt"))
print(mlp_double)
for name,f in mlp_double.named_parameters():
    f.data = f.data * 2
    f.data = f.data.int().float()
#     print(name,f)
accuarcy_list = []
for i,(inputs,labels) in enumerate(test_dataset):
    inputs = pt.autograd.Variable(inputs).cuda()
    labels = pt.autograd.Variable(labels).cuda()
    outputs = mlp_double(inputs)
    accuarcy_list.append(AccuarcyCompute(outputs,labels))
print(sum(accuarcy_list) / len(accuarcy_list))

MLP (
  (fc1): Linear (784 -> 512)
  (norm1): BatchNorm1d(512, eps=1e-05, momentum=0.5, affine=True)
  (fc2): Linear (512 -> 128)
  (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.5, affine=True)
  (fc3): Linear (128 -> 10)
)
0.0979999999329


由以上看出参数成倍变化可能对结果造成灾难性影响

## 预测时去除批标准化层对结果的影响
### 直接去除批标准化

In [74]:
class MLP_no(pt.nn.Module):
    def __init__(self):
        super(MLP_no,self).__init__()
        self.fc1 = pt.nn.Linear(784,512)
        self.fc2 = pt.nn.Linear(512,128)
        self.fc3 = pt.nn.Linear(128,10)
        
    def forward(self,din):
        din = din.view(-1,28*28)
        dout = pt.nn.functional.relu(self.fc1(din))
        dout = pt.nn.functional.relu(self.fc2(dout))
        return pt.nn.functional.softmax(self.fc3(dout))
model_no = MLP_no().cuda()
print(model_no)

MLP_no (
  (fc1): Linear (784 -> 512)
  (fc2): Linear (512 -> 128)
  (fc3): Linear (128 -> 10)
)


In [75]:
mlp_double = MLP().cuda()
mlp_double.load_state_dict(pt.load("../../pytorch_model/mlp/explore_params/mlp_params.pt"))
param_dict = {}
for name,f in mlp_double.named_parameters():
    param_dict[name] = f
# print(param_dict)
for name,f in model_no.named_parameters():
    if name in param_dict:
        f.data = param_dict[name].data

In [76]:
accuarcy_list = []
for i,(inputs,labels) in enumerate(test_dataset):
    inputs = pt.autograd.Variable(inputs).cuda()
    labels = pt.autograd.Variable(labels).cuda()
    outputs = model_no(inputs)
    accuarcy_list.append(AccuarcyCompute(outputs,labels))
print(sum(accuarcy_list) / len(accuarcy_list))

0.809399998188


可以发现，在预测过程中将批标准化移除后，会有一部分性能损失，但是还没到灾难的地步
### 使用减去平均值的方法代替批标准化

In [109]:
class MLP_avg(pt.nn.Module):
    def __init__(self):
        super(MLP_avg,self).__init__()
        self.fc1 = pt.nn.Linear(784,512)
        self.fc2 = pt.nn.Linear(512,128)
        self.fc3 = pt.nn.Linear(128,10)
        
    def forward(self,din):
        din = din.view(-1,28*28)
        dout = self.fc1(din)
        dout.data = self.sub_average(dout)
        dout = pt.nn.functional.relu(dout)
        dout = self.fc2(dout)
        dout.data = self.sub_average(dout)
        dout = pt.nn.functional.relu(dout)
        return pt.nn.functional.softmax(self.fc3(dout))
    
    def sub_average(self,din):
        average = pt.sum(din)
        num = 0
        for i in din.size():
            num += i
        return din.data.sub_(din.data / num)

model_avg = MLP_avg().cuda()
print(model_avg)

mlp_double = MLP().cuda()
mlp_double.load_state_dict(pt.load("../../pytorch_model/mlp/explore_params/mlp_params.pt"))
param_dict = {}
for name,f in mlp_double.named_parameters():
    param_dict[name] = f
# print(param_dict)
for name,f in model_avg.named_parameters():
    if name in param_dict:
        f.data = param_dict[name].data

accuarcy_list = []
for i,(inputs,labels) in enumerate(test_dataset):
    inputs = pt.autograd.Variable(inputs).cuda()
    labels = pt.autograd.Variable(labels).cuda()
    outputs = model_avg(inputs)
    accuarcy_list.append(AccuarcyCompute(outputs,labels))
print(sum(accuarcy_list) / len(accuarcy_list))

MLP_avg (
  (fc1): Linear (784 -> 512)
  (fc2): Linear (512 -> 128)
  (fc3): Linear (128 -> 10)
)
0.809099998474


使用减去平均值的方法代替批标准化对结果没有提升
### 使用normalize函数代替

In [122]:
class MLP_normalize(pt.nn.Module):
    def __init__(self):
        super(MLP_normalize,self).__init__()
        self.fc1 = pt.nn.Linear(784,512)
        self.fc2 = pt.nn.Linear(512,128)
        self.fc3 = pt.nn.Linear(128,10)
        
    def forward(self,din):
        din = din.view(-1,28*28)
        dout = pt.nn.functional.relu(pt.nn.functional.normalize(self.fc1(din),p=2))
        dout = pt.nn.functional.relu(pt.nn.functional.normalize(self.fc2(dout),p=2))
        return pt.nn.functional.softmax(self.fc3(dout))
model_normalize = MLP_normalize().cuda()
print(model_normalize)

optimizer = pt.optim.SGD(model_normalize.parameters(),lr=0.01,momentum=0.9)
lossfunc = pt.nn.CrossEntropyLoss().cuda()

MLP_normalize (
  (fc1): Linear (784 -> 512)
  (fc2): Linear (512 -> 128)
  (fc3): Linear (128 -> 10)
)


In [126]:
for x in range(5):
    for i,data in enumerate(train_dataset):
    
        optimizer.zero_grad()
    
        (inputs,labels) = data
        inputs = pt.autograd.Variable(inputs).cuda()
        labels = pt.autograd.Variable(labels).cuda()
    
        outputs = model_normalize(inputs)
    
        loss = lossfunc(outputs,labels)
        loss.backward()
    
        optimizer.step()
    
        if i % 200 == 0:
            print(i,":",AccuarcyCompute(outputs,labels))

0 : 0.92
200 : 0.86
400 : 0.91
0 : 0.92
200 : 0.86
400 : 0.91
0 : 0.92
200 : 0.86
400 : 0.91
0 : 0.92
200 : 0.86
400 : 0.91
0 : 0.93
200 : 0.86
400 : 0.92


In [127]:
accuarcy_list = []
for i,(inputs,labels) in enumerate(test_dataset):
    inputs = pt.autograd.Variable(inputs).cuda()
    labels = pt.autograd.Variable(labels).cuda()
    outputs = model_normalize(inputs)
    accuarcy_list.append(AccuarcyCompute(outputs,labels))
print(sum(accuarcy_list) / len(accuarcy_list))

0.886699998379


可以发现，使用`torch.nn.functional.normalize()`函数代替批标准化层后，精度有一定下降，性能较直接移除而言稍高

# 结论
- 移除一定的精度MLP网络的性能影响并不大
- 参数成倍上升或下降对性能造成灾难性影响
- 直接或使用易于计算的函数代替批标准化对模型性能造成损失