In [362]:
import torch 
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math

torch.manual_seed(1) #reproducible
EPOCH = 4
BATCH_SIZE = 128

train_data = torchvision.datasets.MNIST(
    root='./mnist', #保存位置
    train=True, #training set
    transform=torchvision.transforms.ToTensor(), #converts a PIL.Image to torch.FloatTensor(C*H*W) in range(0.0,1.0)
    download=True
)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = torchvision.datasets.MNIST(
    root='./MNIST',
    train=False,
    transform=torchvision.transforms.ToTensor()
)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE,shuffle=False)

In [372]:
torch.manual_seed(27)
class CNN(nn.Module):
    def __init__(self,D_in,H,D_out):
        super(CNN, self).__init__()
        self.fc1 = nn.Linear(D_in,H)
        torch.nn.init.normal(self.fc1.weight, mean=0, std=0.01)
        #nn.init.xavier_normal(self.fc1.weight,gain = 1)
        nn.init.constant(self.fc1.bias, 0.1)
        
        self.fc2 = nn.Linear(H,D_out)
        torch.nn.init.normal(self.fc2.weight, mean=0, std=0.01)
        #nn.init.xavier_normal(self.fc2.weight, gain = 1)
        nn.init.constant(self.fc2.bias, 0.1)
       # self.out = nn.Linear(10,10)
        
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x))
        output = x
        #output = self.out(x)
        return output

D_in,H,D_out = 784,10,10
cnn = CNN(D_in,H,D_out)
print(cnn)

para = list(cnn.parameters())
print(len(para))
print(para[0].size())  # conv1's .weight

learning_rate = 0.9
u0 = 1
#loss function:cross-entropy with l2 regularizaiton
loss_func = nn.CrossEntropyLoss()
# inital all using viariables
EMAg = []
EMAg_2 = []
EMAx = []
EMAx_2 = []
EMAxg = []
EMAu = []

for epoch in range(EPOCH):
    running_loss = 0.0
    for i, data in enumerate(train_loader,0):
       # get the inputs
        inputs, labels = data

        # wrap them in Variable
        inputs, labels = Variable(inputs), Variable(labels)

        # zero the parameter gradients
        cnn.zero_grad()

        # forward + backward + optimize
        outputs = cnn(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()
       
        # implementation of cSGD algo
        for index,params in enumerate(cnn.parameters(),0):
            
            #update iteration
            if(i==0 and epoch==0):
                
                EMAg.append(params.grad.data)
                EMAg_2.append(EMAg[index]**2)
                EMAx.append(params.data)
                EMAx_2.append(EMAx[index]**2)
                EMAxg.append(EMAx[index]*EMAg[index])
                EMAu.append(torch.Tensor(params.size()))
                
                EMAu[index] = torch.ones(EMAu[index].size())*u0  
            else:
                #cal beta
                n = params.dim()
                beta = torch.Tensor(params.size())
                if n == 1:
                    for item in range(0,params.size()[0]):
                        if (EMAg_2[index][item]==EMAg[index][item]**2)and(EMAg_2[index][item]==0):
                             beta[item] = 0.9
                        elif (EMAg_2[index][item]!=EMAg[index][item]**2)and(EMAg_2[index][item]==0):
                             beta[item] = 1000
                        else:
                             beta[item] = 0.9+(0.999-0.9)*(EMAg_2[index][item]-EMAg[index][item]**2)/(EMAg_2[index][item])
            
                if n == 2:
                    for item1 in range(0,params.size()[0]):
                        for item2 in range(0,params.size()[1]):
                            if EMAg_2[index][item1,item2]==EMAg[index][item1,item2]**2 and EMAg_2[index][item1,item2]==0:
                                beta[item1,item2] = 0.9
                            elif (EMAg_2[index][item1,item2]!=EMAg[index][item1,item2]**2)and(EMAg_2[index][item1,item2]==0):
                                beta[item1,item2] = 1000
                            else:
                                beta[item1,item2] = 0.9+(0.999-0.9)*(EMAg_2[index][item1,item2]-EMAg[index][item1,item2]**2)/(EMAg_2[index][item1,item2])
        
                #update EMA
                EMAg[index] = (beta)*EMAg[index]+(1-beta)*params.grad.data
                EMAg_2[index] = (beta)*EMAg_2[index] + (1-beta)*(params.grad.data.pow(2))
                EMAx[index] = (beta)*EMAx[index] + (1-beta)*params.data
                EMAx_2[index] = (beta)*EMAx_2[index]+(1-beta)*(params.data.pow(2))
                EMAxg[index] = (beta)*EMAxg[index] + (1-beta)*(params.grad.data*params.data)
                #print("%d"%i)
            
                #cal a,b,sigma,u*
                n = params.dim()
                a = torch.Tensor(params.size())
                b = torch.Tensor(params.size())
                sigma = torch.Tensor(params.size())
                u = torch.Tensor(params.size())
            
            
                if n == 1:
                    for item in range(0,params.size()[0]):
                        #cal a
                        if EMAxg[index][item]==EMAg[index][item]*EMAx[index][item] and (EMAx_2[index][item]-EMAx[index][item]**2)==0:
                             a[item] = 0
                        elif EMAxg[index][item]!=EMAg[index][item]*EMAx[index][item] and (EMAx_2[index][item]-EMAx[index][item]**2)==0:
                             a[item] = 10000
                        else:
                             a[item] = (EMAxg[index][item]-EMAg[index][item]*EMAx[index][item])/(EMAx_2[index][item]-EMAx[index][item]**2)
                        
                        #cal b
                        if EMAg[index][item]==0 and a[item]==0:
                             b[item] = 0
                        elif EMAg[index][item]!=0 and a[item]==0:
                             b[item] = -10000
                        else:
                             b[item] = EMAx[index][item] - EMAg[index][item]/a[item]
                        
                        #cal sigma
                        sigma[item] = EMAg_2[index][item] - EMAg[index][item]**2
                        
                        #cal u*
                        if(a[item]<= 0):
                            u[item] = 1
                        else:
                            if(sigma[item]==0.0 and EMAx[index][item]==b[item]**2):
                                u[item] = 0.0
                            elif(sigma[item]==0.0 and EMAx[index][item]!=b[item]**2):
                                u[item] = 1.0
                            else:
                                u[item] = min(1,a[item]*((EMAx[index][item]-b[item])**2)/(learning_rate*sigma[item])) 
                        
                if n == 2:
                    for item1 in range(0,params.size()[0]):
                        for item2 in range(0,params.size()[1]):
                            #cal a
                            if ((EMAxg[index][item1,item2]-EMAg[index][item1,item2]*EMAx[index][item1,item2])==0):
                                a[item1,item2] = 0
                            elif (EMAx_2[index][item1,item2]-EMAx[index][item1,item2]**2)==0:
                                a[item1,item2] = 10000
                            else:
                                a[item1,item2] =(EMAxg[index][item1,item2]-EMAg[index][item1,item2]*EMAx[index][item1,item2])/(EMAx_2[index][item1,item2]-EMAx[index][item1,item2]**2)            
                                #print(a[item1,item2])
                        
                            #cal b
                            if EMAg[index][item1,item2]==0 and a[item1,item2]==0:
                                 b[item1,item2] = 0
                            elif EMAg[index][item1,item2]!=0 and a[item1,item2]==0:
                                 b[item1,item2] = -10000
                            else:
                                 b[item1,item2] = EMAx[index][item1,item2] - EMAg[index][item1,item2]/a[item1,item2]
                        
                            #cal sigma
                            sigma[item1,item2] = EMAg_2[index][item1,item2] - math.pow(EMAg[index][item1,item2],2)
                        
                            #cal u*
                            if(a[item1,item2]<= 0):
                                u[item1,item2] = 1.0
                            else:
                                if(sigma[item1,item2]==0.0 and EMAx[index][item1,item2]==b[item1,item2]**2):
                                    u[item1,item2] = 0.0
                                elif(sigma[item1,item2]==0.0 and EMAx[index][item1,item2]!=b[item1,item2]**2):
                                    u[item1,item2] = 1.0
                                else:
                                    u[item1,item2] = min(1.0,a[item1,item2]*((EMAx[index][item1,item2]-b[item1,item2])**2)/(learning_rate*sigma[item1,item2]))
                                #print(a[item1,item2],b[item1,item2],sigma[item1,item2])
                                #print(u[item1,item2])
                                #print(a[item1,item2]*((EMAx[index][item1,item2]-b[item1,item2])**2)/(learning_rate*sigma[item1,item2]))
                EMAu[index] = (1-beta)*EMAu[index] + (beta)*u
                #print(EMAu[2])
            
            params.data -= learning_rate *EMAu[index]* params.grad.data
        
        #print("%d %d"%(epoch, i))
        running_loss += loss.data[0]
        if i % 100 == 99:    # print every 2000 mini-batches
            #print(EMAu[0])
            print('[%d, %5d] loss: %.8f' %(epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0
print('Finished Training')

CNN (
  (fc1): Linear (784 -> 10)
  (fc2): Linear (10 -> 10)
)
4
torch.Size([10, 784])
[1,   100] loss: 2.26073040
[1,   200] loss: 2.19000860
[1,   300] loss: 2.16439652
[1,   400] loss: 2.15225426
[2,   100] loss: 2.04023406
[2,   200] loss: 1.96765389
[2,   300] loss: 1.92420876
[2,   400] loss: 1.89311644
[3,   100] loss: 1.87835385
[3,   200] loss: 1.86738326
[3,   300] loss: 1.85825554
[3,   400] loss: 1.85637102
[4,   100] loss: 1.84530401
[4,   200] loss: 1.84382792
[4,   300] loss: 1.84029682
[4,   400] loss: 1.83683700


KeyboardInterrupt: 

In [373]:
correct = 0.0
total = 0.0
for data in test_loader:
    images, labels = data
    outputs = cnn(Variable(images))
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()

print('Accuracy of the network on the 10000 test images: %f' % (
    (correct / total)))

Accuracy of the network on the 10000 test images: 0.629900
