## Information Bottleneck in MNIST
이전에 짰던 코드는 단순히 10bit -> 1bit으로 Distortion할 때의 문제인데
이제 이미지 데이터에 대해서 한 번 해보고자 함 !

### Reference
https://github.com/shalomma/PytorchBottleneck

## Dataset (MNIST)

In [None]:
from dataset import MNIST
import numpy as np
import torch
from torch.utils.data import DataLoader
from random import seed

from network import FeedForward
from train_mnist import Train, TrainConfig
from plotter import Plotter



In [None]:
np.random.seed(1234)
seed(1234)
torch.manual_seed(1234)

data = dict()
data['train'] = MNIST('./dataset', train=True, download=True, randomize=False)
data['test'] = MNIST('./dataset', train=False)

loader = dict()
loader['train'] = torch.utils.data.DataLoader(data['train'], batch_size=60000, shuffle=False)
loader['test'] = torch.utils.data.DataLoader(data['test'], batch_size=10000, shuffle=False)

# setup
input_size = 28 * 28
output_size = 10
hidden_sizes = [784, 1024, 1024, 20, 20, 20, 10]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'to device: {device}')
net = FeedForward(input_size, hidden_sizes, output_size).to(device)

criterion = torch.nn.CrossEntropyLoss(reduction='sum')
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

cfg = TrainConfig(net, criterion, optimizer)
train = Train(cfg)
train.epochs = 4000
train.mi_cycle = 20
train.run(loader)
train.dump()

plot = Plotter(train)
plot.plot_losses()
plot.plot_accuracy()
plot.plot_info_plan('train')
plot.plot_info_plan('test')