# MNSIT RNN test

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.datasets as dset
import torchvision.transforms as T
import numpy as np
import timeit
from torchnet import meter
%matplotlib inline

import torch.nn.functional as F

# 载入数据

In [5]:
set_train = dset.MNIST(r'../MNIST', train=True, transform=T.ToTensor(), download=False)
loader_train = DataLoader(set_train, batch_size=512)
set_test = dset.MNIST(r'../MNIST', train=False, transform=T.ToTensor(),download=False)
loader_test = DataLoader(set_test, batch_size=512)

### 看下大小

In [6]:
print("训练集大小：",set_train.train_data.size())
print("训练集标签：",set_train.train_labels.size())
print("测试集大小：",set_test.test_data.size())
print("测试集标签：",set_test.test_labels.size())
type(set_train[0]) #数据集的索引是tuple

训练集大小： torch.Size([60000, 28, 28])
训练集标签： torch.Size([60000])
测试集大小： torch.Size([10000, 28, 28])
测试集标签： torch.Size([10000])


tuple

# SelfAttn

In [16]:
class SelfAttn(nn.Module):
    """
    自注意力层, softmax(x*x^T) * x
    输入：x(batch,seq,dim)
    """

    def __init__(self):
        super(SelfAttn, self).__init__()
        pass

    def forward(self, x):
        w = torch.bmm(x, x.permute(0, 2, 1))
        w = F.softmax(w, dim=2)
        return torch.bmm(w,x)
    
self_attn = SelfAttn()
x = torch.arange(9).view(1,3,3)
print(x)
x = Variable(x)
self_attn(x)


(0 ,.,.) = 
  0  1  2
  3  4  5
  6  7  8
[torch.FloatTensor of size 1x3x3]



Variable containing:
(0 ,.,.) = 
  5.9996  6.9996  7.9996
  6.0000  7.0000  8.0000
  6.0000  7.0000  8.0000
[torch.FloatTensor of size (1,3,3)]

In [17]:
class LayerNormalization(nn.Module):
    ''' 
    从transformer中复制过来的Layer normalization 模块
    对倒数第一个维度执行归一化
    输入的shape是[batch_size, seq_len, channel]
    '''

    def __init__(self, d_hid, eps=1e-3):
        super(LayerNormalization, self).__init__()

        self.eps = eps
        self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
        self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)

    def forward(self, z):
        if z.size(1) == 1:
            return z

        mu = torch.mean(z, keepdim=True, dim=-1)
        sigma = torch.std(z, keepdim=True, dim=-1)
        ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
        ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)

        return ln_out
    
class inConv(nn.Module):
    def __init__(self):
        super(inConv, self).__init__()
        pass
    def forward(self, x):
        return x.permute(0,2,1)
    
class outConv(nn.Module):
    def __init__(self):
        super(outConv, self).__init__()
        pass
    def forward(self, x):
        return x.permute(0,2,1)

# 构造模型

In [18]:
class Net(nn.Module):
    def __init__(self, h_dim = 256):
        super(Net, self).__init__()
        self.subnet = nn.Sequential(
            LayerNormalization(28),
            inConv(),
            nn.Conv1d(28,h_dim,1),
            outConv(),
            SelfAttn(),
            
            #LayerNormalization(h_dim),
            #inConv(),
            #nn.Conv1d(h_dim,h_dim,1),
            #outConv(),
            #SelfAttn(),
            
            #LayerNormalization(h_dim),
            #inConv(),
            #nn.Conv1d(h_dim,h_dim,1),
            #outConv(),
            #SelfAttn(),
        )
        self.fc = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(28*h_dim,10),
        )
    def forward(self, x):
        n = x.size(0)
        out = self.subnet(x)
        #print(out.shape)
        out = out.view(n,-1)
        #print(out.shape)
        out = self.fc(out)
        return out
    

net =Net().cuda()

x = torch.randn(64, 28, 28).cuda() # batch,seq,dim
x_var = Variable(x) 

net(x_var).shape

torch.Size([64, 10])

# 初始化一些参数

In [19]:
# This is a little utility that we'll use to reset the model
# if we want to re-initialize all our parameters
def reset(m):
    if hasattr(m, 'reset_parameters'):
        m.reset_parameters()
        
class Flatten(nn.Module):
    def forward(self, x):
        N, C, H, W = x.size() # 读取 N, C, H, W
        return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image
    
def train(model, loss_fn, optimizer, num_epochs = 5, print_every = 200):
    for epoch in range(num_epochs):
        print('Starting epoch %d / %d' % (epoch + 1, num_epochs))
        check_accuracy(model, loader_test)
        model.train()
        loss_meter = meter.AverageValueMeter()
        for t, (x, y) in enumerate(loader_train):
            x_var = Variable(x.cuda())
            x_var = x_var.squeeze()
            #print(x_var.shape)
            y_var = Variable(y.cuda().long())
            scores = model(x_var)
            
            loss = loss_fn(scores, y_var)
            loss_meter.add(loss.data[0])
                   
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (t + 1) % print_every == 0:
                print("t:{}:loss:{:.3}".format(t,loss_meter.value()[0]))

def check_accuracy(model, loader):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    with torch.no_grad(): 
        for x, y in loader:
            x_var = Variable(x.cuda())
            x_var = x_var.squeeze()
            scores = model(x_var)
            _, preds = scores.data.cpu().max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))

# 训练

In [20]:
loss_fn = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
#optimizer = torch.optim.SGD(net.parameters(), lr=1e-3, momentum=0.9)

#torch.cuda.random.manual_seed(123)
#Net.apply(reset) #重置权值
train(net, loss_fn, optimizer, num_epochs=20)
check_accuracy(net, loader_test)

Starting epoch 1 / 20
Checking accuracy on test set


Got 1416 / 10000 correct (14.16)


Starting epoch 2 / 20
Checking accuracy on test set


Got 9077 / 10000 correct (90.77)


Starting epoch 3 / 20
Checking accuracy on test set


Got 9209 / 10000 correct (92.09)


Starting epoch 4 / 20
Checking accuracy on test set


Got 9276 / 10000 correct (92.76)


Starting epoch 5 / 20
Checking accuracy on test set


Got 9327 / 10000 correct (93.27)


Starting epoch 6 / 20
Checking accuracy on test set


Got 9349 / 10000 correct (93.49)


Starting epoch 7 / 20
Checking accuracy on test set


Got 9351 / 10000 correct (93.51)


Starting epoch 8 / 20
Checking accuracy on test set


Got 9347 / 10000 correct (93.47)


Starting epoch 9 / 20
Checking accuracy on test set


Got 9350 / 10000 correct (93.50)


Starting epoch 10 / 20
Checking accuracy on test set


Got 9356 / 10000 correct (93.56)


Starting epoch 11 / 20
Checking accuracy on test set


Got 9346 / 10000 correct (93.46)


Starting epoch 12 / 20
Checking accuracy on test set


Got 9345 / 10000 correct (93.45)


Starting epoch 13 / 20
Checking accuracy on test set


Got 9357 / 10000 correct (93.57)


Starting epoch 14 / 20
Checking accuracy on test set


Got 9361 / 10000 correct (93.61)


Starting epoch 15 / 20
Checking accuracy on test set


Got 9345 / 10000 correct (93.45)


Starting epoch 16 / 20
Checking accuracy on test set


Got 9350 / 10000 correct (93.50)


Starting epoch 17 / 20
Checking accuracy on test set


Got 9357 / 10000 correct (93.57)


Starting epoch 18 / 20
Checking accuracy on test set


Got 9345 / 10000 correct (93.45)


Starting epoch 19 / 20
Checking accuracy on test set


Got 9351 / 10000 correct (93.51)


Starting epoch 20 / 20
Checking accuracy on test set


Got 9347 / 10000 correct (93.47)


Checking accuracy on test set


Got 9347 / 10000 correct (93.47)


In [None]:
# attention始终达不到很好的效果