In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter

In [2]:
def get_num_correct(preds,labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [3]:
class Network(torch.nn.Module):
    def __init__(self):
        super(Network,self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = torch.nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        
        self.fc1 = torch.nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = torch.nn.Linear(in_features=120, out_features=60)
        self.out = torch.nn.Linear(in_features=60, out_features=10)
        
    def forward(self,t):
        # 1. input layer
        t = t
        # 2. hidden conv layer
        t = self.conv1(t)
        t = F.relu(t)
        t = F.max_pool2d(t,kernel_size=2, stride=2)
        
        # 3. hidden conv layer 2
        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t,kernel_size=2, stride=2)
        
        # 4. linear layer1
        t = t.reshape(-1,12*4*4)
        t = self.fc1(t)
        t = F.relu(t)
        
        # 5. linear layer 2
        t = self.fc2(t)
        t = F.relu(t)
        
        # 6. output layer
        t = self.out(t)
#         t = F.softmax()
        
        return t

In [4]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

In [11]:
from itertools import product

parameters = dict(
    lr = [.01, .001]
    ,batch_size = [10, 100, 1000]
    ,shuffle = [True, False]
)
param_values = list(parameters.values())
for lr, batch_size, shuffle in product(*param_values):
    print(lr, batch_size, shuffle)

0.01 10 True
0.01 10 False
0.01 100 True
0.01 100 False
0.01 1000 True
0.01 1000 False
0.001 10 True
0.001 10 False
0.001 100 True
0.001 100 False
0.001 1000 True
0.001 1000 False


In [14]:
for lr, batch_size, shuffle in product(*param_values):
    network = Network()

    comment = f' batch_size={batch_size} lr={lr} shuffle={shuffle}'
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=shuffle)
    optimizer = torch.optim.Adam(network.parameters(), lr=lr)


    images,labels = next(iter(train_loader))
    grid = torchvision.utils.make_grid(images)

    tb = SummaryWriter(comment=comment)
    tb.add_image('images',grid)
    tb.add_graph(network,images)

    for epoch in range(10):

        total_loss = 0
        total_correct = 0

        for batch in train_loader:
            images,labels = batch
            preds = network(images)
            loss = F.cross_entropy(preds,labels)   #计算loss
            optimizer.zero_grad()
            loss.backward()    #计算梯度
            optimizer.step()   #更新权重

            total_loss += loss.item() * images.shape[0]
            total_correct += get_num_correct(preds,labels)

        tb.add_scalar('Loss', total_loss,epoch)
        tb.add_scalar('Number Correct', total_correct,epoch)
        tb.add_scalar('Accuracy', total_correct / len(train_set), epoch)

        for name, weight in network.named_parameters():
            tb.add_histogram(name, weight, epoch)
            tb.add_histogram(f'{name}.grad', weight.grad, epoch)

        print('epochs:',epoch,'total_loss:',total_loss,'total_correct:',total_correct)

    tb.close()

epochs: 0 total_loss: 38062.21141083166 total_correct: 45583
epochs: 1 total_loss: 32667.050301930867 total_correct: 47880
epochs: 2 total_loss: 32007.768761581865 total_correct: 48360
epochs: 3 total_loss: 31746.64240897866 total_correct: 48598
epochs: 4 total_loss: 32250.7545022259 total_correct: 48318
epochs: 5 total_loss: 31581.662222502055 total_correct: 48695
epochs: 6 total_loss: 30980.29316665721 total_correct: 48988
epochs: 7 total_loss: 31868.901458398905 total_correct: 48721
epochs: 8 total_loss: 32853.47199866548 total_correct: 48569
epochs: 9 total_loss: 30540.147172547877 total_correct: 49057
epochs: 0 total_loss: 36822.03358098632 total_correct: 46693
epochs: 1 total_loss: 31436.656772602946 total_correct: 48924
epochs: 2 total_loss: 33009.11470523104 total_correct: 48510
epochs: 3 total_loss: 32063.661243410315 total_correct: 49115
epochs: 4 total_loss: 31011.099695358425 total_correct: 49327
epochs: 5 total_loss: 32731.78260494955 total_correct: 48590
epochs: 6 total_l