In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter
from PIL import Image

In [2]:
%matplotlib inline

In [3]:
batch_size = 32
learnning_rate = 1e-3
num_epochs = 20

In [4]:
train_dataset = datasets.MNIST('./datas', train=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST('./datas', train=False, transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
class Cnn(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Cnn, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_dim, 6, 3, stride=1, padding=1), # b 6 28 28
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2), # b 6 14 14
        )
        
        self.deconv1 = nn.ConvTranspose2d(6, 1, 3, padding=1)
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 16, 3, stride=1, padding=1), # b 16 14 14
            nn.ReLU(True),
            nn.MaxPool2d(2, 2), # b 16 7 7
        )
        
        self.deconv2 = nn.ConvTranspose2d(16, 1, 3, padding=1)
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(16, 16, 3, stride=1), # b 16 5 5
            nn.ReLU(True),
        )
        
        self.deconv3 = nn.ConvTranspose2d(16, 1, 3, padding=1)
        
        
        self.fc = nn.Sequential(
            nn.Linear(400, 200),
            nn.Linear(200, 100),
            nn.Linear(100, out_dim),
        )
        
    def forward(self, x):
        out1 = self.conv1(x)
        out2 = self.conv2(out1)
        out3 = self.conv3(out2)
        
        deconv1 = self.deconv1(out1)
        deconv2 = self.deconv2(out2)
        deconv3 = self.deconv3(out3)
        
        out3 = out3.view(out3.size(0), -1)
        return self.fc(out3), deconv1, deconv2, deconv3
    

In [24]:
model = Cnn(1, 10).cuda()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learnning_rate)

In [25]:
writer = SummaryWriter('./log/cnn2')

In [26]:
for epoch in range(num_epochs):
    
    running_loss = .0
    running_acc = .0
    for i, data in enumerate(train_loader, 1):
        img, label = data
        img = img.cuda()
        label = label.cuda()
        
        out, deconv1, deconv2, deconv3 = model(img)
        loss = criterion(out, label)
        
        running_loss += loss.item() * img.size(0)
        
        _, pred = torch.max(out, 1)
        running_acc += (pred == label).sum().item()
        
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        step = epoch * len(train_loader) + i
        accuracy = (pred == label).float().mean()
        
        writer.add_scalar('loss', loss.item(), step)
        writer.add_scalar('accuracy', accuracy, step)
        writer.add_image('images', torchvision.utils.make_grid(img), step)
        
#         torchvision.utils.save_image(torchvision.utils.make_grid(img), 'xxx.jpg')
#         torchvision.utils.save_image(torchvision.utils.make_grid(deconv1 + img.mean()), 'xxx1.jpg')
#         torchvision.utils.save_image(torchvision.utils.make_grid(deconv2 + img.mean()), 'xxx2.jpg')
#         torchvision.utils.save_image(torchvision.utils.make_grid(deconv3 + img.mean()), 'xxx3.jpg')
#         torchvision.utils.save_image(torchvision.utils.make_grid(img), 'xxx.jpg')
#         torchvision.utils.save_image(torchvision.utils.make_grid(deconv1.abs_()), 'xxx1.jpg')
#         torchvision.utils.save_image(torchvision.utils.make_grid(deconv2.abs_()), 'xxx2.jpg')
#         torchvision.utils.save_image(torchvision.utils.make_grid(deconv3.abs_()), 'xxx3.jpg')

        writer.add_image('deconv1', torchvision.utils.make_grid(deconv1, normalize=True, scale_each=True).data.cpu(), step)
        writer.add_image('deconv2', torchvision.utils.make_grid(deconv2, normalize=True, scale_each=True).data.cpu(), step)
        writer.add_image('deconv3', torchvision.utils.make_grid(deconv3, normalize=True, scale_each=True).data.cpu(), step)
            
        if i % 100 == 0:
            for tag, value in model.named_parameters():
                if tag.startswith('deconv'):
                    continue
                tag = tag.replace('.', '/')
                writer.add_histogram(tag, value.cpu().data.numpy(), step)
                writer.add_histogram(tag + '/grad', value.grad.cpu().data.numpy(), step)
        
        if i % 500 == 0:
            print 'Epoch: [{}/{}], Loss: {:.6f}, Acc: {:.6f}'.format(epoch + 1, num_epochs, \
                                                                    running_loss / (img.size(0) * i), \
                                                                    running_acc / (img.size(0) * i))
            
    print 'Finish {} Epoch, Loss: {:.6f}, Acc: {:.6f}'.format(epoch + 1, \
                                                             running_loss / len(train_dataset), \
                                                             running_acc / len(train_dataset))
        
    model.eval()
    eval_loss = .0
    eval_acc = .0
    for data in test_loader:
        img, label = data
        img = img.cuda()
        label = label.cuda()
        
        out, deconv1, deconv2, deconv3 = model(img)
        loss = criterion(out, label)
        
        eval_loss += loss.item() * img.size(0)
        
        _, pred = torch.max(out, 1)
        eval_acc += (pred == label).sum().item()
        
    print 'Eval Loss: {:.6f}, Eval Acc: {:.6f}'.format(eval_loss / len(test_dataset), eval_acc / len(test_dataset))
    
    model.train()

Epoch: [1/3], Loss: 0.465617, Acc: 0.844812
Epoch: [1/3], Loss: 0.299766, Acc: 0.901500
Epoch: [1/3], Loss: 0.231596, Acc: 0.924479
Finish 1 Epoch, Loss: 0.202567, Acc: 0.934300
Eval Loss: 0.057321, Eval Acc: 0.980200
Epoch: [2/3], Loss: 0.071739, Acc: 0.977875
Epoch: [2/3], Loss: 0.072790, Acc: 0.977062
Epoch: [2/3], Loss: 0.069864, Acc: 0.978458
Finish 2 Epoch, Loss: 0.068024, Acc: 0.978833
Eval Loss: 0.056475, Eval Acc: 0.981400
Epoch: [3/3], Loss: 0.048726, Acc: 0.984062
Epoch: [3/3], Loss: 0.051383, Acc: 0.983563
Epoch: [3/3], Loss: 0.053005, Acc: 0.983229
Finish 3 Epoch, Loss: 0.053957, Acc: 0.983267
Eval Loss: 0.051872, Eval Acc: 0.984000


In [27]:
writer.close()

In [None]:
torch.save(model.state_dict(), './ser/cnn2.pth')

In [None]:
# grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
#                      normalize=normalize, range=range, scale_each=scale_each)
# ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
# im = Image.fromarray(ndarr)
# im.save(filename)

#####

# def log_images(tag, images, step):
#     im_summaries = []
#     for nr, img in enumerate(images):
#         s = StringIO.StringIO()
#         plt.imsave(s, img, format='png')

#         img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
#                                        height=img.shape[0],
#                                        width=img.shape[1])
#         im_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, nr),
#                                                  image=img_sum))

#     summary = tf.Summary(value=im_summaries)
#     writer.add_summary(summary, step)


# def log_images(tag, images, step):
#     im_summaries = []
#     for nr, img in enumerate(images):
#         s = StringIO.StringIO()
#         plt.imsave(s, img, format='png')

#         img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
#                                        height=img.height,
#                                        width=img.width)
#         im_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, nr),
#                                                  image=img_sum))

#     summary = tf.Summary(value=im_summaries)
#     writer.file_writer.add_summary(summary, step)

# def log_image(tag, image, step):
    
#     Summary(value=[Summary.Value(tag=tag, image=image)])
    
#     pass


# toimg_trans = transforms.ToPILImage()