In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.autograd import Variable
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
from tqdm import trange
from PIL import Image
import argparse
import easydict

%matplotlib inline
custom_style = {'axes.labelcolor': 'white',
                'xtick.color': 'white',
                'ytick.color': 'white'}
sns.set_style("darkgrid", rc=custom_style)
sns.set_context("notebook")
plt.style.use('dark_background')
plt.rcParams["font.size"] = 18

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
from torchvision.datasets import MNIST
from torch.utils.data import Dataset
from torch.utils.data import sampler

# PytorchでMNISTを学習

### 参考サイト
https://github.com/pytorch/examples/blob/master/mnist/main.py

In [3]:
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--label_batch_size', type=int, default=32, metavar='N',
#                         help='input batch size for training (default: 32)')
#     parser.add_argument('--unlabel_batch_size', type=int, default=128, metavar='N',
#                         help='input batch size for training (default: 32)')
#     parser.add_argument('--test_batch_size', type=int, default=1000, metavar='N',
#                         help='input batch size for testing (default: 1000)')
#     parser.add_argument('--epoch', type=int, default=10, metavar='N',
#                         help='number of iterations to train (default: 10)')
#     parser.add_argument('--iters', type=int, default=400, metavar='N',
#                         help='number of iterations to train (default: 10000)')
#     parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
#                         help='learning rate (default: 0.01)')
#     parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
#                         help='SGD momentum (default: 0.9)')
#     parser.add_argument('--alpha', type=float, default=1.0, metavar='ALPHA',
#                         help='regularization coefficient (default: 0.01)')
#     parser.add_argument('--xi', type=float, default=10.0, metavar='XI',
#                         help='hyperparameter of VAT (default: 0.1)')
#     parser.add_argument('--eps', type=float, default=1.0, metavar='EPS',
#                         help='hyperparameter of VAT (default: 1.0)')
#     parser.add_argument('--ip', type=int, default=1, metavar='IP',
#                         help='hyperparameter of VAT (default: 1)')
#     parser.add_argument('--workers', type=int, default=8, metavar='W',
#                         help='number of CPU')
#     parser.add_argument('--seed', type=int, default=123, metavar='S',
#                         help='random seed (default: 1)')
#     parser.add_argument('--log-interval', type=int, default=100, metavar='N',
#                         help='how many batches to wait before logging training status')
#     parser.add_argument('--gpu', type=int, default=1, metavar='W',
#                         help='number of CPU')
#     args = parser.parse_args()

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        #畳み込み層を定義する
        #引数は順番に、サンプル数、チャネル数、フィルタのサイズ
        self.conv1 = nn.Conv2d(1, 10, kernel_size=3)
        #フィルタのサイズは正方形であればタプルではなく整数でも可（8行目と10行目は同じ意味）
        self.conv2 = nn.Conv2d(10, 20, kernel_size=3)
        #全結合層を定義する
        #fc1の第一引数は、チャネル数*最後のプーリング層の出力のマップのサイズ=特徴量の数
        
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(500, 500)
        self.fc2 = nn.Linear(500, 500)
        self.fc3 = nn.Linear(500, 10)
        
    def forward(self, x):
        #入力→畳み込み層1→活性化関数(ReLU)→プーリング層1(2*2)→出力
        # input 28 x 28 x 1
        # conv1 28 x 28 x 1 -> 24 x 24 x 10
        # max_pool(kernel2) 12 x 12 x 10
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, (2,2) )
        
        #入力→畳み込み層2→活性化関数(ReLU)→プーリング層2(2*2)→出力
        # conv2 12 x 12 x 10 -> 8 x 8 x 20
        # max_pool(kernel2) -> 4 x 4 x 20
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        
        x = self.conv2_drop(x)
        # output layer
        #x = x.view(-1, self.num_flat_features(x))
        # self.num_flat_featuresで特徴量の数を算出
        # flatten 4 x 4 x 20 = 320
        x = x.view(-1, self.num_flat_features(x))
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
#         x = F.log_softmax(x, dim=1)
        
        return x
    
    def num_flat_features(self, x):
        #Conv2dは入力を4階のテンソルとして保持する(サンプル数*チャネル数*縦の長さ*横の長さ)
        #よって、特徴量の数を数える時は[1:]でスライスしたものを用いる
        size = x.size()[1:] ## all dimensions except the batch dimension
        #特徴量の数=チャネル数*縦の長さ*横の長さを計算する
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


In [5]:
class TrainLogger(object):
    
    def __init__(self, out):
        try:
            os.makedirs(out)
        except OSError:
            pass
        self.file = open(os.path.join(out, 'log'), 'w')
        self.logs = []
        
    def write(self, log):
        ## write log
        tqdm.write(log)
        tqdm.write(log, file=self.file)
        self.logs.append(log)
        
    def state_dict(self):
        ## returns the state of the loggers
        return {'logs': self.logs}
    
    def load_state_dict(self, state_dict):
        ## load the logger state
        self.logs = state_dict['logs']
        #write logs
        tqdm.write(self.logs[-1])
        for log in self.logs:
            tqdm.write(log, file=self.file)
            

In [6]:
def checkpoint(net, optimizer, epoch, logger, out):
    filename = os.path.join(out, 'epoch-{}'.format(epoch))
    torch.save({'epoch': epoch + 1, 'logger': logger.state_dict()}, filename + '.iter')
    torch.save(net.state_dict(), filename + 'model')
    torch.save(optimizer.state_dict(), filename + 'state')

In [7]:
def test(model, device, test_loader, criterion, logger):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            _, predicted = torch.max(output, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    log = '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, total, 100. * correct / total)
    logger.write(log)

In [8]:
class InfiniteSampler(sampler.Sampler):

    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __iter__(self):
        while True:
            order = np.random.permutation(self.num_samples)
            for i in range(self.num_samples):
                yield order[i]

    def __len__(self):
        return None

In [9]:
class MNISTDataSet(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):

        self.image_dataframe = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_dataframe)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,
                                self.image_dataframe.loc[idx, 'img'])
#         image = io.imread(img_name)
        image = Image.open(img_name)
        image = image.convert('L')
        label = self.image_dataframe.loc[idx, 'label']

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
class VAT(nn.Module):
    def __init__(self, xi=10.0, eps=1.0, ip=1):
        
        super(VAT, self).__init__()
        self.xi = xi
        self.eps = eps
        self.ip = ip
    
    def forward(self, model, x):
        
        ## ラベル無しデータをネットワークに通し、predictを得る
        ## VATのLossのbackpropagationでは、このmodel計算は含めないため
        ## no_gradでwrapする
        with torch.no_grad():
            out = model(x)
            pred = F.cross_entropy(out, dim=1)
        
        ## 累積法を用いてVadvを計算する
        
        d = torch.rand(x.shape).sub(0.5).to(x.device)
        
    
    
    

In [10]:
def train(model, device, label_loader, unlabel_loader, criterion, optimizer, epoch, logger, args):
    model.train()
#     for batch_id, (data, target) in enumerate(train_loader):
    label_iter = iter(label_loader)
    unlabel_iter = iter(unlabel_loader)
    
    for iter_id in range(args.iters):

        label_data, label_target = label_iter.next()
        unlabel_data, _ = unlabel_iter.next()
        
        label_data, label_target = label_data.to(device), label_target.to(device)
        unlabel_data = unlabel_data.to(device)       
        
        optimizer.zero_grad()
        label_output = model(label_data)
        label_loss = criterion(label_output, label_target)
        
        
        
        
        
        
        
        
        label_loss.backward()
        optimizer.step()
        if iter_id % args.log_interval == 0:
            log = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch,
                iter_id * len(label_data), args.iters * len(label_data),
                100. * (iter_id * len(label_data)) /(args.iters * len(label_data)),
                label_loss.item())
            logger.write(log)

In [11]:
def main():
    args = easydict.EasyDict({
        "label_batch_size": 32,
        "unlabel_batch_size": 100,
        "test_batch_size": 1000,
        "epochs": 2,
        "iters": 400,
        "lr": 0.01,
        "momentum": 0.9,
        "alpha": 1.0,
        "xi": 10.0,
        "eps": 1.0,
        "ip": 1,
        "workers": 8,
        "seed": 123,
        "log_interval": 100,
        "gpu": 1
    })
    
   
    no_cuda = False
    out_dir = './result'
    train_label_csv = '../data/mnist/train_label.csv'
    train_unlabel_csv = '../data/mnist/train_unlabel.csv'
    test_csv = '../data/mnist/test_big.csv'
    train_root_dir = '../data/mnist/train'
    test_root_dir = '../data/mnist/test'
    test_interval = 1
    resume_interval = 1
    
    use_cuda = not no_cuda and torch.cuda.is_available()    
    torch.manual_seed(args.seed)
    device = torch.device('cuda:{}'.format(args.gpu) if use_cuda else 'cpu')
    print(device)
    kwargs = {'num_workers':8, 'pin_memory': True} if use_cuda else {}
    
    transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5,))])
    
    trainset_label = MNISTDataSet(train_label_csv, train_root_dir, transform)    
    trainloader_label = torch.utils.data.DataLoader(
        trainset_label, batch_size=args.label_batch_size, shuffle=False,
        sampler=InfiniteSampler(len(trainset_label)),
        **kwargs
    )

    trainset_unlabel = MNISTDataSet(train_unlabel_csv, train_root_dir, transform)    
    trainloader_unlabel = torch.utils.data.DataLoader(
        trainset_unlabel, batch_size=args.unlabel_batch_size, shuffle=False,
        sampler=InfiniteSampler(len(trainset_unlabel)),
        **kwargs
    )

    testset = MNISTDataSet(test_csv, test_root_dir, transform)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=args.test_batch_size, shuffle=False, **kwargs
    )
    
    net = Net().to(device)
    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)
    criterion = nn.CrossEntropyLoss()
    logger = TrainLogger(out_dir)
    
    for epoch in range(1, args.epochs + 1):
        train(net, device, trainloader_label, trainloader_unlabel, criterion, optimizer, epoch, logger, args)
        if epoch % test_interval == 0:
            test(net, device, testloader, criterion, logger)
        if epoch % resume_interval == 0:
            checkpoint(net, optimizer, epoch, logger, out_dir)

In [12]:
main()

cuda:1

Test set: Average loss: 18.7518, Accuracy: 7535/10000 (75%)


Test set: Average loss: 20.5748, Accuracy: 7636/10000 (76%)



In [14]:
%ls ../data/mnist/

[0m[01;34mtest[0m/         test_small.csv  train_big.csv    train_small.csv
test_big.csv  [01;34mtrain[0m/          train_label.csv  train_unlabel.csv
