In [3]:
import torch
import torch.nn.functional as F
from torch.autograd import Variable

x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1) torch 不能处理一维数据，因此要用torch.unsqueeze

x = Variable(x) #torch运算都要用Variable


class Net(torch.nn.Module): #继承主模块):
    def __init__(self, n_features, n_hidden, n_output):   #搭建层所需要的信息, 每层信息,参数加下面层的层数信息
        super(Net, self).__init__()  #继承torch.nn.Module，官方操作
        self.hidden = torch.nn.Linear(n_features, n_hidden)   #self.hidden点后面为层名字，后面为输入输出维度
        self.prediction = torch.nn.Linear(n_hidden, n_output)
    def forward(self, x):  #前向传播过程 x参数为输入
        x = F.relu(self.hidden(x))      #定义输入前向传播过程  F.relu由F定义的relu是个功能，不会在打印层中显示
        x = self.prediction(x)
        return x     #得到最终返回值

net = Net(1, 10, 1)  #输入参数为Net类所规定的参数（__init__()中定义的输入参数）
print(net)   #打印网络层

#优化器
optmizer = torch.optm.SGD(net.parameters(), lr=0.5)  #定义优化器，优化的是net的parameters
loss_func = torch.nn.MseLoss()   #定义怎么计算误差

for t in range(100):    #迭代
    prediction = net(x)
    
    loss = loss_func(prediction, y) #要预测值在前，label在后
    '''
    优化步骤
    '''
    optimizer.zero_grad()  #net.parameters()所有参数梯度变为0
    loss.backward() #所有参数计算梯度
    optimizer.step() #optimizr优化parameters
    

Net(
  (hidden): Linear(in_features=1, out_features=10, bias=True)
  (prediction): Linear(in_features=10, out_features=1, bias=True)
)


In [5]:
#分类
# 假数据
n_data = torch.ones(100, 2)         # 数据的基本形态
x0 = torch.normal(2*n_data, 1)      # 类型0 x data (tensor), shape=(100, 2)
y0 = torch.zeros(100)               # 类型0 y data (tensor), shape=(100, )
x1 = torch.normal(-2*n_data, 1)     # 类型1 x data (tensor), shape=(100, 1)
y1 = torch.ones(100)                # 类型1 y data (tensor), shape=(100, )

# 注意 x, y 数据的数据形式是一定要像下面一样 (torch.cat 是在合并数据)
x = torch.cat((x0, x1), 0).type(torch.FloatTensor)  # FloatTensor = 32-bit floating
y = torch.cat((y0, y1), ).type(torch.LongTensor)    # LongTensor = 64-bit integer 
#注意分类问题中一定要将标签值类型设为longtensor
#x为floattensor

net = Net(2, 10, 2)  #假如是2分类问题，输出为对应类别维度(相当于one-hot)
print(net)
loss_func = torch.nn.CrossEntropyLoss()   #定义怎么计算误差,CrossEntropyLoss在计算softmax，即每一类概率

'''
比如标签值[0,0,1], loss值【0.1，0.1，0.8】
'''


Net(
  (hidden): Linear(in_features=2, out_features=10, bias=True)
  (prediction): Linear(in_features=10, out_features=2, bias=True)
)


'\n比如标签值[0,0,1], loss值【0.1，0.1，0.8】\n'

In [7]:
#第二种搭建网络方法,这种方法没有层的自定义名字
net2 = torch.nn.Sequential(
    torch.nn.Linear(2, 10),
    torch.nn.ReLU(),   # torch.nn.ReLU()这种定义方法定义的是一个类，因此在打印层的时候会出现
    torch.nn.Linear(10,2))
print(net2)

Sequential(
  (0): Linear(in_features=2, out_features=10, bias=True)
  (1): ReLU()
  (2): Linear(in_features=10, out_features=2, bias=True)
)


In [8]:
#保存
torch.save(net, 'net.pkl')  #完整保存.第一个参数是要保存的网络，第二个是名字
torch.save(net.state_dict(), 'net_parameters.pkl') #只保留整个parameters

#提取网络
net = torch.load('net.pkl')

#若保存的为参数，则提取过程为先定义网络结构
net2 = torch.nn.Sequential(
    torch.nn.Linear(2, 10),
    torch.nn.ReLU(),   # torch.nn.ReLU()这种定义方法定义的是一个类，因此在打印层的时候会出现
    torch.nn.Linear(10,2))
net2.load_state_dict(torch.load('net_parameters.pkl'))

  "type " + obj.__name__ + ". It won't be checked "


RuntimeError: Error(s) in loading state_dict for Sequential:
	Missing key(s) in state_dict: "0.weight", "0.bias", "2.weight", "2.bias". 
	Unexpected key(s) in state_dict: "hidden.weight", "hidden.bias", "prediction.weight", "prediction.bias". 

In [16]:
#批训练
import torch.utils.data as Data
torch.manual_seed(1)   #种子

BATCH_SIZE = 5
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)

'''
形成tersor数据库
x 放特征数据
y 放标签数据
'''
torch_dataset = Data.TensorDataset(x, y) #形成tersor数据库

'''
loader 使训练变成一批一批
'''
loader = Data.DataLoader(
        dataset=torch_dataset,    #传进去数据库
        batch_size=BATCH_SIZE,    #多少批
        shuffle=True,             #是否打乱
        num_workers=2)        #多线程提取数据

for epoch in range(3):   #每一轮
    for step, (batch_x, batch_y) in enumerate(loader):   #每一批一批
        b_x = Variable(batch_x)    #转为Variable
        b_y = Variable(batch_y)
        #training.....   
        pass


In [23]:
import torch.nn as nn
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(  # input shape (1, 28, 28)
            nn.Conv2d(
                in_channels=1,      # input height
                out_channels=16,    # n_filters
                kernel_size=5,      # filter size
                stride=1,           # filter movement/step
                padding=2,      # 如果想要 con2d 出来的图片长宽没有变化, padding=(kernel_size-1)/2 当 stride=1
            ),      # output shape (16, 28, 28)
            nn.ReLU(),    # activation
            nn.MaxPool2d(kernel_size=2),    # 在 2x2 空间里向下采样, output shape (16, 14, 14)
        )
        self.conv2 = nn.Sequential(  # input shape (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),  # output shape (32, 14, 14)
            nn.ReLU(),  # activation
            nn.MaxPool2d(2),  # output shape (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)   # fully connected layer, output 10 classes

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)     #维度（batch，32，7，7）

       # 下一步要自己进行展平，之后才能传入到Linear层
        x = x.view(x.size(0), -1)   # 展平多维的卷积图成 (batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output
net = CNN()
print(net)


CNN(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (out): Linear(in_features=1568, out_features=10, bias=True)
)


In [28]:
import torch
from torch import nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torchvision


torch.manual_seed(1)    # reproducible

# Hyper Parameters
EPOCH = 1           # 训练整批数据多少次, 为了节约时间, 我们只训练一次
BATCH_SIZE = 64
TIME_STEP = 28      # rnn 时间步数 / 图片高度  ！考虑多少个时间点
INPUT_SIZE = 28     # rnn 每步输入值 / 图片每行像素    ！每个时间点传入多少数据
LR = 0.01           # learning rate
DOWNLOAD_MNIST = True  # 如果你已经下载好了mnist数据就写上 Fasle


# Mnist 手写数字
train_data = torchvision.datasets.MNIST(
    root='./mnist/',    # 保存或者提取位置
    train=True,  # this is training data
    transform=torchvision.transforms.ToTensor(),    # 转换 PIL.Image or numpy.ndarray 成
                                                    # torch.FloatTensor (C x H x W), 训练的时候 normalize 成 [0.0, 1.0] 区间
    download=DOWNLOAD_MNIST,          # 没下载就下载, 下载了就不用再下了
)

Using downloaded and verified file: ./mnist/MNIST/raw/train-images-idx3-ubyte.gz
Extracting ./mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw
Using downloaded and verified file: ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw


EOFError: Compressed file ended before the end-of-stream marker was reached

In [29]:
# Mnist 手写数字
train_data = torchvision.datasets.MNIST(
    root='./mnist_/',    # 保存或者提取位置
    train=True,  # this is training data
    transform=torchvision.transforms.ToTensor(),    # 转换 PIL.Image or numpy.ndarray 成
                                                    # torch.FloatTensor (C x H x W), 训练的时候 normalize 成 [0.0, 1.0] 区间
    download=DOWNLOAD_MNIST,          # 没下载就下载, 下载了就不用再下了
)



0it [00:00, ?it/s][A[A

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_/MNIST/raw/train-images-idx3-ubyte.gz




  0%|          | 0/9912422 [00:00<?, ?it/s][A[A

  0%|          | 16384/9912422 [00:01<07:16, 22663.52it/s][A[A

  0%|          | 40960/9912422 [00:01<05:46, 28464.83it/s][A[A

  1%|          | 57344/9912422 [00:01<04:52, 33665.67it/s][A[A

  1%|          | 65536/9912422 [00:02<08:36, 19055.35it/s][A[A

  1%|          | 90112/9912422 [00:03<07:24, 22077.32it/s][A[A

  1%|          | 106496/9912422 [00:04<07:04, 23093.92it/s][A[A

  1%|          | 114688/9912422 [00:04<07:08, 22858.62it/s][A[A

  1%|          | 122880/9912422 [00:06<15:51, 10287.89it/s][A[A

  1%|▏         | 131072/9912422 [00:08<25:49, 6312.17it/s] [A[A

  1%|▏         | 139264/9912422 [00:09<23:53, 6817.08it/s][A[A

  1%|▏         | 147456/9912422 [00:11<27:05, 6008.63it/s][A[A

  2%|▏         | 155648/9912422 [00:13<31:22, 5181.59it/s][A[A

  2%|▏         | 163840/9912422 [00:17<45:41, 3556.45it/s][A[A

  2%|▏         | 172032/9912422 [00:19<45:21, 3579.17it/s][A[A

  2%|▏         | 18

 11%|█         | 1073152/9912422 [03:00<17:32, 8396.90it/s][A[A

 11%|█         | 1081344/9912422 [03:01<18:31, 7946.81it/s][A[A

 11%|█         | 1089536/9912422 [03:02<17:44, 8287.78it/s][A[A

 11%|█         | 1097728/9912422 [03:04<19:44, 7439.76it/s][A[A

 11%|█         | 1105920/9912422 [03:04<17:07, 8574.36it/s][A[A

 11%|█         | 1114112/9912422 [03:05<18:00, 8140.59it/s][A[A

 11%|█▏        | 1122304/9912422 [03:06<16:07, 9084.66it/s][A[A

 11%|█▏        | 1130496/9912422 [03:08<23:29, 6232.49it/s][A[A

 11%|█▏        | 1138688/9912422 [03:09<21:44, 6723.43it/s][A[A

 12%|█▏        | 1146880/9912422 [03:14<39:05, 3736.72it/s][A[A

 12%|█▏        | 1155072/9912422 [03:14<30:28, 4789.48it/s][A[A

 12%|█▏        | 1163264/9912422 [03:15<26:56, 5410.95it/s][A[A

 12%|█▏        | 1171456/9912422 [03:16<22:07, 6585.51it/s][A[A

 12%|█▏        | 1179648/9912422 [03:18<24:15, 6000.14it/s][A[A

 12%|█▏        | 1187840/9912422 [03:19<22:09, 6564.17it/s][A

 21%|██▏       | 2129920/9912422 [05:53<15:47, 8212.31it/s][A[A

 22%|██▏       | 2138112/9912422 [05:54<13:43, 9445.19it/s][A[A

 22%|██▏       | 2146304/9912422 [05:54<12:45, 10147.50it/s][A[A

 22%|██▏       | 2154496/9912422 [05:57<20:04, 6441.02it/s] [A[A

 22%|██▏       | 2162688/9912422 [05:59<25:40, 5031.48it/s][A[A

 22%|██▏       | 2170880/9912422 [06:00<24:45, 5212.80it/s][A[A

 22%|██▏       | 2179072/9912422 [06:04<34:43, 3711.48it/s][A[A

 22%|██▏       | 2187264/9912422 [06:06<34:12, 3763.27it/s][A[A

 22%|██▏       | 2195456/9912422 [06:07<26:57, 4771.68it/s][A[A

 22%|██▏       | 2203648/9912422 [06:08<21:53, 5867.47it/s][A[A

 22%|██▏       | 2211840/9912422 [06:08<17:10, 7469.11it/s][A[A

 22%|██▏       | 2220032/9912422 [06:08<13:36, 9415.42it/s][A[A

 22%|██▏       | 2228224/9912422 [06:09<14:39, 8739.94it/s][A[A

 23%|██▎       | 2244608/9912422 [06:10<11:37, 10998.14it/s][A[A

 23%|██▎       | 2252800/9912422 [06:11<11:11, 11407.66it/s

 33%|███▎      | 3227648/9912422 [08:45<32:18, 3447.70it/s][A[A

 33%|███▎      | 3235840/9912422 [08:46<28:26, 3913.43it/s][A[A

 33%|███▎      | 3244032/9912422 [08:47<24:46, 4486.70it/s][A[A

 33%|███▎      | 3252224/9912422 [08:49<24:23, 4551.02it/s][A[A

 33%|███▎      | 3260416/9912422 [08:51<24:10, 4587.52it/s][A[A

 33%|███▎      | 3268608/9912422 [08:51<19:21, 5721.72it/s][A[A

 33%|███▎      | 3276800/9912422 [08:52<16:52, 6556.17it/s][A[A

 33%|███▎      | 3284992/9912422 [08:53<14:02, 7865.62it/s][A[A

 33%|███▎      | 3293184/9912422 [08:53<10:56, 10077.48it/s][A[A

 33%|███▎      | 3301376/9912422 [08:55<14:07, 7801.76it/s] [A[A

 33%|███▎      | 3309568/9912422 [08:55<12:39, 8696.54it/s][A[A

 33%|███▎      | 3317760/9912422 [08:56<12:02, 9131.74it/s][A[A

 34%|███▎      | 3325952/9912422 [08:57<11:07, 9874.29it/s][A[A

 34%|███▎      | 3334144/9912422 [08:58<14:20, 7646.93it/s][A[A

 34%|███▎      | 3342336/9912422 [09:03<28:49, 3799.76it/s]

 43%|████▎     | 4243456/9912422 [10:55<13:13, 7145.77it/s][A[A

 43%|████▎     | 4251648/9912422 [10:55<11:29, 8209.55it/s][A[A

 43%|████▎     | 4259840/9912422 [10:57<14:02, 6711.28it/s][A[A

 43%|████▎     | 4268032/9912422 [11:00<17:55, 5246.07it/s][A[A

 43%|████▎     | 4276224/9912422 [11:01<15:51, 5922.37it/s][A[A

 43%|████▎     | 4284416/9912422 [11:03<20:03, 4675.06it/s][A[A

 43%|████▎     | 4292608/9912422 [11:04<17:24, 5379.39it/s][A[A

 43%|████▎     | 4300800/9912422 [11:05<14:27, 6465.42it/s][A[A

 43%|████▎     | 4308992/9912422 [11:05<11:02, 8456.40it/s][A[A

 44%|████▎     | 4317184/9912422 [11:07<12:40, 7359.96it/s][A[A

 44%|████▎     | 4333568/9912422 [11:07<10:24, 8931.39it/s][A[A

 44%|████▍     | 4341760/9912422 [11:08<08:19, 11161.09it/s][A[A

 44%|████▍     | 4349952/9912422 [11:08<06:51, 13505.92it/s][A[A

 44%|████▍     | 4358144/9912422 [11:08<05:50, 15847.49it/s][A[A

 44%|████▍     | 4366336/9912422 [11:09<06:11, 14939.80it/s

 53%|█████▎    | 5267456/9912422 [12:58<06:19, 12244.36it/s][A[A

 53%|█████▎    | 5275648/9912422 [12:58<05:58, 12929.13it/s][A[A

 53%|█████▎    | 5283840/9912422 [12:59<06:09, 12528.66it/s][A[A

 53%|█████▎    | 5292032/9912422 [12:59<05:12, 14806.55it/s][A[A

 53%|█████▎    | 5300224/9912422 [13:01<09:39, 7954.52it/s] [A[A

 54%|█████▎    | 5308416/9912422 [13:02<09:19, 8227.55it/s][A[A

 54%|█████▎    | 5316608/9912422 [13:05<13:34, 5645.28it/s][A[A

 54%|█████▎    | 5324800/9912422 [13:05<10:14, 7459.57it/s][A[A

 54%|█████▍    | 5332992/9912422 [13:07<11:06, 6868.17it/s][A[A

 54%|█████▍    | 5341184/9912422 [13:08<12:59, 5862.45it/s][A[A

 54%|█████▍    | 5349376/9912422 [13:12<18:03, 4210.02it/s][A[A

 54%|█████▍    | 5357568/9912422 [13:12<14:22, 5279.38it/s][A[A

 54%|█████▍    | 5365760/9912422 [13:13<10:49, 6999.46it/s][A[A

 54%|█████▍    | 5373952/9912422 [13:14<10:23, 7284.20it/s][A[A

 54%|█████▍    | 5382144/9912422 [13:14<08:33, 8814.20it/

 63%|██████▎   | 6291456/9912422 [15:44<13:04, 4617.87it/s][A[A

 64%|██████▎   | 6299648/9912422 [15:47<15:29, 3885.96it/s][A[A

 64%|██████▎   | 6307840/9912422 [15:50<15:52, 3784.22it/s][A[A

 64%|██████▎   | 6316032/9912422 [15:51<13:11, 4544.82it/s][A[A

 64%|██████▍   | 6324224/9912422 [15:53<14:44, 4057.04it/s][A[A

 64%|██████▍   | 6332416/9912422 [15:54<11:29, 5191.69it/s][A[A

 64%|██████▍   | 6340608/9912422 [15:54<09:05, 6552.33it/s][A[A

 64%|██████▍   | 6348800/9912422 [15:55<07:53, 7520.14it/s][A[A

 64%|██████▍   | 6356992/9912422 [15:55<06:53, 8601.20it/s][A[A

 64%|██████▍   | 6365184/9912422 [15:56<05:59, 9856.87it/s][A[A

 64%|██████▍   | 6373376/9912422 [15:58<08:56, 6599.46it/s][A[A

 64%|██████▍   | 6381568/9912422 [16:00<09:57, 5907.03it/s][A[A

 64%|██████▍   | 6389760/9912422 [16:01<09:03, 6480.60it/s][A[A

 65%|██████▍   | 6397952/9912422 [16:01<06:53, 8502.42it/s][A[A

 65%|██████▍   | 6406144/9912422 [16:02<06:52, 8502.91it/s][A

 74%|███████▎  | 7290880/9912422 [18:01<04:25, 9886.74it/s] [A[A

 74%|███████▎  | 7299072/9912422 [18:02<03:57, 10985.98it/s][A[A

 74%|███████▎  | 7307264/9912422 [18:02<03:12, 13508.82it/s][A[A

 74%|███████▍  | 7315456/9912422 [18:03<03:33, 12188.01it/s][A[A

 74%|███████▍  | 7323648/9912422 [18:04<03:54, 11054.64it/s][A[A

 74%|███████▍  | 7331840/9912422 [18:05<04:09, 10341.48it/s][A[A

 74%|███████▍  | 7340032/9912422 [18:07<05:36, 7644.37it/s] [A[A

 74%|███████▍  | 7348224/9912422 [18:07<04:57, 8618.06it/s][A[A

 74%|███████▍  | 7356416/9912422 [18:08<05:14, 8125.38it/s][A[A

 74%|███████▍  | 7364608/9912422 [18:12<08:52, 4785.86it/s][A[A

 74%|███████▍  | 7372800/9912422 [18:16<12:33, 3369.17it/s][A[A

 74%|███████▍  | 7380992/9912422 [18:17<10:17, 4096.30it/s][A[A

 75%|███████▍  | 7389184/9912422 [18:18<09:41, 4342.73it/s][A[A

 75%|███████▍  | 7397376/9912422 [18:22<12:39, 3313.46it/s][A[A

 75%|███████▍  | 7405568/9912422 [18:24<11:14, 3716.79i

 84%|████████▎ | 8290304/9912422 [20:51<05:26, 4972.21it/s][A[A

 84%|████████▎ | 8298496/9912422 [20:51<04:05, 6578.21it/s][A[A

 84%|████████▍ | 8306688/9912422 [20:54<05:21, 4996.11it/s][A[A

 84%|████████▍ | 8314880/9912422 [20:55<04:20, 6122.06it/s][A[A

 84%|████████▍ | 8323072/9912422 [20:55<03:24, 7757.25it/s][A[A

 84%|████████▍ | 8331264/9912422 [20:59<06:08, 4287.99it/s][A[A

 84%|████████▍ | 8339456/9912422 [21:00<05:51, 4476.26it/s][A[A

 84%|████████▍ | 8347648/9912422 [21:02<05:13, 4986.90it/s][A[A

 84%|████████▍ | 8355840/9912422 [21:02<04:13, 6140.32it/s][A[A

 84%|████████▍ | 8364032/9912422 [21:03<03:48, 6771.70it/s][A[A

 84%|████████▍ | 8372224/9912422 [21:04<03:19, 7703.29it/s][A[A

 85%|████████▍ | 8380416/9912422 [21:09<06:45, 3779.80it/s][A[A

 85%|████████▍ | 8388608/9912422 [21:10<05:58, 4247.80it/s][A[A

 85%|████████▍ | 8396800/9912422 [21:13<06:37, 3815.58it/s][A[A

 85%|████████▍ | 8404992/9912422 [21:14<05:59, 4194.81it/s][A

 94%|█████████▍| 9322496/9912422 [23:16<02:45, 3556.67it/s][A[A

 94%|█████████▍| 9330688/9912422 [23:20<03:12, 3014.73it/s][A[A

 94%|█████████▍| 9338880/9912422 [23:20<02:28, 3863.92it/s][A[A

 94%|█████████▍| 9347072/9912422 [23:22<02:11, 4285.80it/s][A[A

 94%|█████████▍| 9355264/9912422 [23:23<01:54, 4857.18it/s][A[A

 94%|█████████▍| 9363456/9912422 [23:26<02:28, 3700.95it/s][A[A

 95%|█████████▍| 9371648/9912422 [23:27<01:50, 4896.70it/s][A[A

 95%|█████████▍| 9379840/9912422 [23:34<03:30, 2533.85it/s][A[A

 95%|█████████▍| 9388032/9912422 [23:37<03:21, 2598.00it/s][A[A

 95%|█████████▍| 9396224/9912422 [23:37<02:34, 3343.86it/s][A[A

 95%|█████████▍| 9404416/9912422 [23:38<01:54, 4455.13it/s][A[A

 95%|█████████▍| 9412608/9912422 [23:44<03:07, 2669.84it/s][A[A

 95%|█████████▌| 9420800/9912422 [23:48<03:20, 2448.21it/s][A[A

 95%|█████████▌| 9428992/9912422 [23:48<02:30, 3203.03it/s][A[A

 95%|█████████▌| 9437184/9912422 [23:49<01:51, 4252.16it/s][A

Extracting ./mnist_/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_/MNIST/raw





0it [00:00, ?it/s][A[A[A

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_/MNIST/raw/train-labels-idx1-ubyte.gz





  0%|          | 0/28881 [00:00<?, ?it/s][A[A[A


 57%|█████▋    | 16384/28881 [00:02<00:01, 11408.56it/s][A[A[A


 85%|████████▌ | 24576/28881 [00:03<00:00, 10301.08it/s][A[A[A


32768it [00:03, 10380.05it/s]                           [A[A[A


0it [00:00, ?it/s][A[A[A

Extracting ./mnist_/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_/MNIST/raw/t10k-images-idx3-ubyte.gz





  0%|          | 0/1648877 [00:02<?, ?it/s][A[A[A


  0%|          | 8192/1648877 [00:03<02:29, 11008.84it/s][A[A[A


  1%|          | 16384/1648877 [00:05<03:24, 7995.71it/s][A[A[A


  2%|▏         | 32768/1648877 [00:06<03:02, 8845.17it/s][A[A[A


  2%|▏         | 40960/1648877 [00:06<02:25, 11066.10it/s][A[A[A


  3%|▎         | 49152/1648877 [00:07<02:23, 11183.22it/s][A[A[A


  3%|▎         | 57344/1648877 [00:07<01:57, 13524.00it/s][A[A[A


  4%|▍         | 65536/1648877 [00:08<01:58, 13388.87it/s][A[A[A


  4%|▍         | 73728/1648877 [00:08<01:41, 15487.95it/s][A[A[A


  5%|▍         | 81920/1648877 [00:10<02:36, 10000.36it/s][A[A[A

9920512it [26:09, 6668.03it/s]                             [A[A


  5%|▌         | 90112/1648877 [00:10<02:26, 10647.14it/s][A[A[A


  6%|▌         | 98304/1648877 [00:12<03:02, 8504.69it/s] [A[A[A


  6%|▋         | 106496/1648877 [00:14<04:07, 6243.16it/s][A[A[A


  7%|▋         | 114688/1648877 [00:15

 58%|█████▊    | 958464/1648877 [02:12<01:23, 8304.78it/s][A[A[A


 59%|█████▊    | 966656/1648877 [02:13<01:07, 10088.09it/s][A[A[A


 59%|█████▉    | 974848/1648877 [02:13<01:01, 11043.75it/s][A[A[A


 60%|█████▉    | 983040/1648877 [02:14<01:09, 9566.16it/s] [A[A[A


 60%|██████    | 991232/1648877 [02:16<01:19, 8220.95it/s][A[A[A


 61%|██████    | 999424/1648877 [02:16<01:11, 9063.93it/s][A[A[A


 61%|██████    | 1007616/1648877 [02:17<01:11, 8982.13it/s][A[A[A


 62%|██████▏   | 1015808/1648877 [02:19<01:33, 6762.68it/s][A[A[A


 62%|██████▏   | 1024000/1648877 [02:20<01:17, 8024.27it/s][A[A[A


 63%|██████▎   | 1032192/1648877 [02:22<01:31, 6727.80it/s][A[A[A


 63%|██████▎   | 1040384/1648877 [02:22<01:18, 7799.95it/s][A[A[A


 64%|██████▎   | 1048576/1648877 [02:23<01:06, 9045.75it/s][A[A[A


 64%|██████▍   | 1056768/1648877 [02:24<01:07, 8796.97it/s][A[A[A


 65%|██████▍   | 1064960/1648877 [02:25<01:15, 7720.55it/s][A[A[A


 65%|████

Extracting ./mnist_/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_/MNIST/raw/t10k-labels-idx1-ubyte.gz






  0%|          | 0/4542 [00:00<?, ?it/s][A[A[A[A



8192it [00:00, 14893.46it/s]            [A[A[A[A

Extracting ./mnist_/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_/MNIST/raw
Processing...
Done!


In [None]:
#rnn用于分类
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()

        self.rnn = nn.LSTM(     # LSTM 效果要比 nn.RNN() 好多了
            input_size=28,      # 图片每行的数据像素点   原始输入（28，28）每个时间点扫描一行，因此input_size=28
            hidden_size=64,     # rnn hidden unit    总共用多少个隐层，与常规相似
            num_layers=1,       # 有几层 RNN layers
            batch_first=True,   # input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size)
        )

        self.out = nn.Linear(64, 10)    # 输出层   输入为lstm的最后隐层，输出为类别

    def forward(self, x):
        # x shape (batch, time_step, input_size)
        # r_out shape (batch, time_step, output_size)
        # h_n shape (n_layers, batch, hidden_size)   LSTM 有两个 hidden states, h_n 是分线, h_c 是主线
        # h_c shape (n_layers, batch, hidden_size)
        r_out, (h_n, h_c) = self.rnn(x, None)   # None 表示 hidden state 会用全0的 state

        # 选取最后一个时间点的 r_out 输出
        # 这里 r_out[:, -1, :] 的值也是 h_n 的值
        out = self.out(r_out[:, -1, :])     #（batch，time step，input）
        return out

rnn = RNN()
print(rnn)
"""
RNN (
  (rnn): LSTM(28, 64, batch_first=True)
  (out): Linear (64 -> 10)
)
"""

In [26]:
#rnn用于回归
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()

        self.rnn = nn.LSTM(     # LSTM 效果要比 nn.RNN() 好多了
            input_size=1,    
            hidden_size=64,     # rnn hidden unit    总共用多少个隐层，与常规相似
            num_layers=1,       # 有几层 RNN layers
            batch_first=True,   # input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size)
        )

        self.out = nn.Linear(64, 1)    # 输出层   输入为lstm的最后隐层，输出为类别

    def forward(self, x):
        # x shape (batch, time_step, input_size)
        # r_out shape (batch, time_step, output_size)
        # h_n shape (n_layers, batch, hidden_size)   LSTM 有两个 hidden states, h_n 是分线, h_c 是主线
        # h_c shape (n_layers, batch, hidden_size)
        r_out, (h_n, h_c) = self.rnn(x, None)   # None 表示 hidden state 会用全0的 state

        # 选取最后一个时间点的 r_out 输出
        # 这里 r_out[:, -1, :] 的值也是 h_n 的值
        out = self.out(r_out[:, -1, :])     #（batch，time step，input）
        return out

rnn = RNN()
print(rnn)

RNN(
  (rnn): LSTM(1, 64, batch_first=True)
  (out): Linear(in_features=64, out_features=1, bias=True)
)


In [None]:
#RNN
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()

        self.rnn = nn.RNN(  # 这回一个普通的 RNN 就能胜任
            input_size=1,
            hidden_size=32,     # rnn hidden unit
            num_layers=1,       # 有几层 RNN layers
            batch_first=True,   # input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size)
        )
        self.out = nn.Linear(32, 1) 

    def forward(self, x, h_state):  # 因为 hidden state 是连续的, 所以我们要一直传递这一个 state
        # x (batch, time_step, input_size)
        # h_state (n_layers, batch, hidden_size)
        # r_out (batch, time_step, output_size)
        r_out, h_state = self.rnn(x, h_state)   # h_state 也要作为 RNN 的一个输入

        outs = []    # 保存所有时间点的预测值
        for time_step in range(r_out.size(1)):    # 对每一个时间点计算 output
            outs.append(self.out(r_out[:, time_step, :]))
        return torch.stack(outs, dim=1), h_state    #因为outs为list因此要将它包为tensor h_state用于下一回输入


rnn = RNN()
print(rnn
      
h_state = None #在没有训练之前，先置0

for step in range(60):
      x = Variable()
      y = Variable
      prediction, h_state = rnn(x, h_state)
      h_state = Variable(h_state.data)   # 关键，h_state要再次包进Variable里

"""
RNN (
  (rnn): RNN(1, 32, batch_first=True)
  (out): Linear (32 -> 1)
)
"""
#另一种定义方式
def forward(self, x, h_state):
    r_out, h_state = self.rnn(x, h_state)
    r_out = r_out.view(-1, 32)
    outs = self.out(r_out)
    return outs.view(-1, 32, TIME_STEP), h_state

In [None]:
#无监督学习 encoder decoder  类似bottleneck y标签值为x
#压缩和解压, 压缩后得到压缩的特征值, 再从压缩的特征值解压成原图片.

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        # 压缩
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 12),
            nn.Tanh(),
            nn.Linear(12, 3),   # 压缩成3个特征, 进行 3D 图像可视化
        )
        # 解压
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.Tanh(),
            nn.Linear(12, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 28*28),
            nn.Sigmoid(),       # 激励函数让输出值在 (0, 1)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded    #主要用encoded 即压缩完的样子

autoencoder = AutoEncoder()

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
loss_func = nn.MSELoss()

for epoch in range(EPOCH):
    for step, (x, b_label) in enumerate(train_loader):
        b_x = x.view(-1, 28*28)   # batch x, shape (batch, 28*28)
        b_y = x.view(-1, 28*28)   # batch y, shape (batch, 28*28)  标签y为x原始值

        encoded, decoded = autoencoder(b_x)

        loss = loss_func(decoded, b_y)      # mean square error
        optimizer.zero_grad()               # clear gradients for this training step
        loss.backward()                     # backpropagation, compute gradients
        optimizer.step()                    # apply gradients

In [None]:
#用GPU加速

#对数据进行cuda移动
x_b = Variable(x).cuda()
#转为cpu格式
pred = x_b.cpu()

#将模块移动
cnn = CNN()
cnn.cuda()



In [None]:
#Batch Normalization 添加在全连接和激励函数之间，将数据分布进行标准化，更好发挥激活函数作用