In [248]:
import os

import numpy as np
import pandas as pd

from PIL import Image, ImageChops
from matplotlib import pyplot as plt

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [249]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, -1)

In [250]:
Net = Network().cuda()
optimizer = optim.Adagrad(Net.parameters(), lr=0.01)
loss_fun = nn.CrossEntropyLoss().cuda()

In [290]:
raw_data = pd.read_csv('data/train.csv')
target = torch.tensor(raw_data['label'].to_numpy(), dtype=torch.int32)
data = torch.tensor(raw_data.iloc[:,1:].to_numpy().reshape((42000, 1, 28, 28)), dtype=torch.int32) / 255.0

In [284]:
trans = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomAffine(degrees=(-10, 10), translate=(0, 0.1), scale=(0.9, 1.1)),
    transforms.ToTensor()
])

In [291]:
DATA = data
TARGET = target

len_data = len(data)
for step in range(5):
    new_data = torch.zeros((len_data, 1, 28, 28))
    for i, tensor in enumerate(data):
        new_data[i] = trans(tensor)
    DATA = torch.cat([DATA, new_data])
    TARGET = torch.cat([TARGET, target])
data = DATA
target = TARGET

0
1
2
3
4


In [293]:
indices = np.arange(0, len(data))
np.random.shuffle(indices)
count = int(len(indices) * 0.9)
train_idx = indices[:count]
test_idx = indices[count:]

X_train, y_train = data[train_idx], target[train_idx]
X_test, y_test = data[test_idx], target[test_idx]