In [43]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

In [44]:
input_size = 28 * 28    # 784
num_classes = 10
num_epochs = 3
batch_size = 100
learning_rate = 0.001

In [45]:
train_dataset = torchvision.datasets.MNIST(root='../../data', 
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data', 
                                          train=False, 
                                          transform=transforms.ToTensor())

In [46]:
train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size=batch_size, 
                                          shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,batch_size = batch_size,shuffle = True)

In [47]:
image1, label1 = train_dataset[0]
print(image1.size())
data_iter = iter(train_loader)
images,labels = data_iter.next()
print(images.size())
print(labels.size())
print(train_dataset.__len__())

torch.Size([1, 28, 28])
torch.Size([100, 1, 28, 28])
torch.Size([100])
60000


In [48]:
a = list(image1.size())[1]
print(a)

28


In [49]:
# build the model
model = nn.Linear(list(image1.size())[1] * list(image1.size())[2],num_classes)
# define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr = learning_rate)

In [50]:
# train the model
for epoch in range(0,num_epochs):
    
    for batch_index , data in enumerate(train_loader,0):
        
        
        # get data
        images , labels = data
        
        # currently images (one channel) has dimension as (batch_size,1,input_size[0],input_size[1])
        # we need to reshape it so its dimension is (batch_size,inputisze[0],inputsize[1])
        images = images.reshape(-1, list(image1.size())[1] * list(image1.size())[2])
        
        
        # zero buffer
        optimizer.zero_grad()
        
        # forward pass
        outputs = model(images)
        
        # calculate the loss
        loss = criterion(outputs,labels)
        
        # backward pass
        loss.backward()
        
        # update the weights
        optimizer.step()
        
        # get the loss print
        
        if batch_index % 100 == 99:
            print('at epoch %d, batch %d, the loss is %.4f' % (epoch, batch_index + 1, round(loss.item() ,4)))
            total_loss = 0
        
        

at epoch 0, batch 100, the loss is 2.2381
at epoch 0, batch 200, the loss is 2.1380
at epoch 0, batch 300, the loss is 2.0250
at epoch 0, batch 400, the loss is 1.9213
at epoch 0, batch 500, the loss is 1.8790
at epoch 0, batch 600, the loss is 1.7547
at epoch 1, batch 100, the loss is 1.7298
at epoch 1, batch 200, the loss is 1.6654
at epoch 1, batch 300, the loss is 1.6129
at epoch 1, batch 400, the loss is 1.6108
at epoch 1, batch 500, the loss is 1.5847
at epoch 1, batch 600, the loss is 1.5392
at epoch 2, batch 100, the loss is 1.4461
at epoch 2, batch 200, the loss is 1.3515
at epoch 2, batch 300, the loss is 1.3626
at epoch 2, batch 400, the loss is 1.3521
at epoch 2, batch 500, the loss is 1.2627
at epoch 2, batch 600, the loss is 1.2362


# - Save the Checkpoint

In [51]:
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
    

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

    
# save the checkpoint of the model - which can be used to resume the training
PATH = 'logistic_regression_example.ckpt'
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, PATH)

Model's state_dict:
weight 	 torch.Size([10, 784])
bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1]}]


In [52]:
# TEST THE MODEL
net = nn.Linear(list(image1.size())[1] * list(image1.size())[2],num_classes)
optimizer = optim.SGD(net.parameters(), lr=learning_rate)


checkpoint = torch.load(PATH)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [53]:
with torch.no_grad():
    correct = 0
    total = 0
    
    for batch_num, data in enumerate(test_loader,0):
        images, labels = data
        images = images.reshape(-1, input_size)
        
        # outputs
        outputs = net(images)
        
        # get the maximum
        outputs_m = torch.argmax(outputs.data,1)
        
        correct += sum(outputs_m == labels).item()
        
        total += list(outputs.size())[0]
        
    print('accuracy is %.4f', correct / total)

accuracy is %.4f 0.7992
