In [1]:
import torch
import torchvision
import numpy as np
import sys

In [2]:
batch_size =256
num_workers = 4
mnist_train = torchvision.datasets.FashionMNIST(root='./Datasets',
               train=True,download=True,transform=torchvision.transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='./Datasets',
               train=False,download=True,transform=torchvision.transforms.ToTensor())
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=True,num_workers=num_workers)
print(mnist_train[0][0].size())  #公司自动下载的是对的，家里自动下载的是错的，所以网上自己找了资源放在相同的地方

torch.Size([1, 28, 28])


In [3]:
num_inputs = 784
num_outputs = 10

w = torch.randn(num_inputs,num_outputs)
#w = torch.FloatTensor(np.random.normal(0,1,(num_inputs,num_outputs)))
b = torch.zeros(num_outputs)
print(w.dtype,b.dtype)
print(w.size(),b.size())
w.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True)
params = [w,b]
print(params[0].data.size())
print(w.grad)

torch.float32 torch.float32
torch.Size([784, 10]) torch.Size([10])
torch.Size([784, 10])
None


In [4]:
X = torch.tensor([[1,2,3],[4,5,6]])
print(X.size())
print(X.sum(0,keepdim=True).size())
print(X.sum(1,keepdim=True).size())

torch.Size([2, 3])
torch.Size([1, 3])
torch.Size([2, 1])


In [5]:
def softmax(X):
    X_exp = X.exp()  #X还要用，不能在原地计算  原地计算是exp_() 
    #=赋值在函数结束后不会改变变量值，exp_会彻底改变
    partition = X_exp.sum(dim=1,keepdim=True)
    #print(X_exp.size(),partition.size())
    return X_exp/partition   #广播
X = torch.rand(2, 5)
X_prob = softmax(X)
print(X_prob, X_prob.sum(dim=1))
print(X_prob.sum())

tensor([[0.2424, 0.2582, 0.1414, 0.1493, 0.2087],
        [0.1984, 0.1437, 0.1494, 0.2555, 0.2530]]) tensor([1., 1.])
tensor(2.0000)


In [6]:
def net(X):
    return softmax(torch.mm(X.view(-1,num_inputs),w)+b)
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
#print(y.size(),y.view(-1,1).size())
y_hat.gather(1, y.view(-1, 1))
#print((y_hat.argmax(dim=1)==y).float().mean().item())

def cross_entropy(y_hat,y):
    return -torch.log(y_hat.gather(1,y.view(-1,1)))

def accuracy(y_hat,y):
    return (y_hat.argmax(dim=1)==y).float().mean().item()

def evaluate_accuracy(data_iter,net):
    acc_sum,n=0.0,0
    for X,y in data_iter:
        acc_sum += (net(X).argmax(dim=1)==y).float().sum().item()
        n +=y.shape[0]
        print(acc_sum,n)
        break
    return acc_sum/n
#print(evaluate_accuracy(train_iter,net))

def sgd(params,lr,batch_size):
    for param in params:
        param.data -= lr *param.grad/batch_size  #tensor.data就是数据

In [7]:
epochs = 10
lr = 0.1
for epoch in range(1,epochs+1):
    train_l_sum,train_acc_sum,n=0.0,0.0,0
    for X,y in train_iter:
        y_hat = net(X)
        loss = cross_entropy(y_hat,y).sum()
        loss.backward()
        sgd([w,b],lr,batch_size)
        w.grad.data.zero_()   #w.grad在backward之前为None,要先backword()
        b.grad.data.zero_()
        train_l_sum +=loss.item()
        train_acc_sum +=(y_hat.argmax(dim=1)==y).sum().item()
        n += y.shape[0]
    test_acc = evaluate_accuracy(test_iter,net)
    print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
         %(epoch,train_l_sum/n,train_acc_sum/n,test_acc))  
    #下的数据库test居然不对，而且训练越训越糟，应该是数据库的问题

148.0 256
epoch 1, loss 3.9116, train acc 0.431, test acc 0.578
162.0 256
epoch 2, loss 1.9144, train acc 0.621, test acc 0.633
176.0 256
epoch 3, loss 1.5771, train acc 0.672, test acc 0.688
167.0 256
epoch 4, loss 1.4020, train acc 0.699, test acc 0.652
193.0 256
epoch 5, loss 1.2852, train acc 0.718, test acc 0.754
196.0 256
epoch 6, loss 1.2039, train acc 0.732, test acc 0.766
194.0 256
epoch 7, loss 1.1418, train acc 0.741, test acc 0.758
191.0 256
epoch 8, loss 1.0908, train acc 0.749, test acc 0.746
189.0 256
epoch 9, loss 1.0504, train acc 0.755, test acc 0.738
193.0 256
epoch 10, loss 1.0147, train acc 0.761, test acc 0.754


In [8]:
from matplotlib import pyplot as plt
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]
def show_fashion_mnist(images,labels):
    figs,axs = plt.subplots(1,len(images),figsize=(12,12))
    for ax,img,lbl in zip(axs,images,labels):
        ax.imshow(img.view(28,28))
        ax.set_title(lbl)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False) 

X,y = iter(train_iter).next()
X,y = iter(train_iter).next()
true_labels = get_fashion_mnist_labels(y)  #y.numpy()不是必须的
pred_labels = get_fashion_mnist_labels(net(X).argmax(dim=1))
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]

show_fashion_mnist(X[0:9],titles[0:9])    