In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from logger import Logger

# Configure device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# MNIST dataset
dataset = torchvision.datasets.MNIST(root='../data', 
                                     download=True, 
                                     train=True, 
                                     transform=transforms.ToTensor())

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset, 
                                          batch_size=100, 
                                          shuffle=True)

In [3]:
class NeuralNet(nn.Module):
    def __init__(self, input_size=784, hidden_size=500, num_classes=10):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)  
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out
    
model = NeuralNet().to(device)
logger = Logger('./logs')

# Loss and optimizer
criterion = nn.CrossEntropyLoss()  
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001) 

W0829 11:31:42.990088 140145486227200 deprecation_wrapper.py:119] From /hpc/home/ephyan/pytorch-tutorial/04_tensorboard/logger.py:12: The name tf.summary.FileWriter is deprecated. Please use tf.compat.v1.summary.FileWriter instead.



In [4]:
# Train the model
# Change to total step, otherwise the accuracy line will fluctuating left and right
num_epoch = 20
for epoch in range(num_epoch):   
    for i, (images, labels) in enumerate(data_loader):
        images = images.reshape(-1, 784).to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute accuracy
        _, predicted = torch.max(outputs, 1)
        accuracy = (labels == predicted.squeeze()).float().mean()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Acc: {:.2f}' 
                   .format(epoch+1, num_epoch, i+1, len(data_loader), loss.item(), accuracy.item()))
            
            # ================================================================== #
            #                        Tensorboard Logging                         #
            # ================================================================== #
            
            # 1. Log scalar values
            info = {'loss': loss.item(), 'accuracy': accuracy.item()}
            
            for tag, value in info.items():
                logger.scalar_summary(tag, value, i+1)
            
            # 2. Log values and gradients of the parameters (histogram summary)
            for tag, value in model.named_parameters():
                tag = tag.replace('.', '/')
                logger.histo_summary(tag, value.data.cpu().numpy(), i+1)
                logger.histo_summary(tag+'/grad', value.grad.data.cpu().numpy(), i+1)
                
            # 3. Log training images
            info = {'image': images.reshape(-1, 28, 28)[:10].cpu().numpy()}
            
            for tag, images in info.items():
                logger.image_summary(tag, images, i+1)


W0829 11:31:44.216355 140145486227200 deprecation_wrapper.py:119] From /hpc/home/ephyan/pytorch-tutorial/04_tensorboard/logger.py:16: The name tf.Summary is deprecated. Please use tf.compat.v1.Summary instead.

W0829 11:31:44.227200 140145486227200 deprecation_wrapper.py:119] From /hpc/home/ephyan/pytorch-tutorial/04_tensorboard/logger.py:46: The name tf.HistogramProto is deprecated. Please use tf.compat.v1.HistogramProto instead.



Epoch [1/20], Step [100/600], Loss: 2.2321, Acc: 0.31
Epoch [1/20], Step [200/600], Loss: 2.1198, Acc: 0.59
Epoch [1/20], Step [300/600], Loss: 1.9656, Acc: 0.73
Epoch [1/20], Step [400/600], Loss: 1.8628, Acc: 0.76
Epoch [1/20], Step [500/600], Loss: 1.7404, Acc: 0.77
Epoch [1/20], Step [600/600], Loss: 1.6362, Acc: 0.76
Epoch [2/20], Step [100/600], Loss: 1.4963, Acc: 0.82
Epoch [2/20], Step [200/600], Loss: 1.3676, Acc: 0.82
Epoch [2/20], Step [300/600], Loss: 1.2708, Acc: 0.81
Epoch [2/20], Step [400/600], Loss: 1.1771, Acc: 0.82
Epoch [2/20], Step [500/600], Loss: 1.1000, Acc: 0.79
Epoch [2/20], Step [600/600], Loss: 1.0305, Acc: 0.82
Epoch [3/20], Step [100/600], Loss: 0.8770, Acc: 0.89
Epoch [3/20], Step [200/600], Loss: 0.9145, Acc: 0.83
Epoch [3/20], Step [300/600], Loss: 0.9315, Acc: 0.78
Epoch [3/20], Step [400/600], Loss: 0.7569, Acc: 0.88
Epoch [3/20], Step [500/600], Loss: 0.7995, Acc: 0.82
Epoch [3/20], Step [600/600], Loss: 0.6767, Acc: 0.89
Epoch [4/20], Step [100/600]