In [9]:
import torch
from torch import nn
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import numpy as np
from torch.nn import functional as F
from datetime import datetime
from torchvision import transforms as tfs

In [10]:
def data_tf(x):
    data_aug = tfs.Compose([
        tfs.RandomHorizontalFlip(),
        tfs.ColorJitter(brightness = 0.5, contrast = 0.5, hue = 0.2),
        tfs.ToTensor()
    ])
    x = data_aug(x)
    return x

In [11]:
train_set = MNIST('./data', train = True, transform = data_tf, download = True)
train_data = DataLoader(train_set, batch_size = 64, shuffle = True)
test_set = MNIST('./data', train = False, transform = data_tf, download = True)
test_data = DataLoader(test_set, batch_size = 128, shuffle = False)

In [12]:
class conv_net(nn.Module):
    def __init__(self):
        super(conv_net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 5)
        self.batch_norm1 = nn.BatchNorm2d(16)
        self.maxpool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 8, 5)
        self.batch_norm2 = nn.BatchNorm2d(8)
        self.maxpool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128, 64)
        self.fc2 = nn.Linear(64, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.batch_norm1(x)
        x = F.relu(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.batch_norm2(x)
        x = F.relu(x)
        x = self.maxpool2(x)
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

In [13]:
def set_learning_rate(optimizer, lr):
    for param_groups in optimizer.param.groups:
        param_groups['lr'] = lr

In [14]:
def train(net, train_data, test_data, epoch, optimizer, criterion):
    train_losses = []
    test_losses = []
    if torch.cuda.is_available():
        net = net.cuda()
    prev_time = datetime.now()
    for epoch in range(epoch):
        if epoch == 15:
            set_learning_rate(optimizer, 0.001)
        train_loss = 0
        net = net.train()
        for im, labels in train_data:
            if torch.cuda.is_available():
                im = im.cuda()
                labels = labels.cuda()
            
            output = net(im)
            loss = criterion(output, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.data.item()
        cur_time = datetime.now()
        h,remainder = divmod((cur_time - prev_time).seconds, 3600)
        m,s = divmod(remainder, 60)
        time_str = 'Time:%02d:%02d:%02d'%(h, m, s)
        
        test_loss = 0
        test_acc = 0
        net = net.eval()
        for im, labels in test_data:
            if torch.cuda.is_available():
                im = im.cuda()
                labels = labels.cuda()
            output = net(im)
            loss = criterion(im, labels)
            test_loss += loss.data.item()
            test_acc += get_acc(output, label)
        epoch_str = 'epoch %d, train_loss:%f, test_loss:%f,'%(epoch, train_loss / len(train_data), test_loss/len(test_data))
        prev_time = cur_time
        train_losses.append(train_loss / len(train_data))
        test_losses.append(test_loss / len(test_data))
        print(epoch_str + time_str)

In [15]:
net = conv_net()
optimizer = torch.optim.Adam(net.parameters(), lr = 0.01)
criterion = nn.CrossEntropyLoss()
print(net)

conv_net(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
  (batch_norm1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 8, kernel_size=(5, 5), stride=(1, 1))
  (batch_norm2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=128, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=10, bias=True)
)


In [16]:
train(net, train_data, test_data, 25, optimizer, criterion)

RuntimeError: CUDA error: invalid device function