In [None]:
import os
import multiprocessing
import torch
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST
import pandas as pd
import numpy as np

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64
MAX_EPOCHS = 10
KAGGLE_FILE = "../input/digit-recognizer/test.csv"


class PermutationOperation(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # print('Shape before permutation', x.shape)
        out = x.permute(0, 2, 3, 1)
        # print('Shape after permutation', out.shape)
        return out

# Based on M7 from https://arxiv.org/pdf/2008.10400v2.pdf
class SimpleMNISTCNN(LightningModule):
    def __init__(self, data_dir=PATH_DATASETS, learning_rate=2e-4):

        super().__init__()

        # Set our init args as class attributes
        self.data_dir = data_dir
        self.learning_rate = learning_rate

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=48,
                kernel_size=7,
                bias=False,
            ),
            nn.BatchNorm2d(48),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=48,
                out_channels=96,
                kernel_size=7,
                bias=False,
            ),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=96,
                out_channels=144,
                kernel_size=7,
                bias=False,
            ),
            nn.BatchNorm2d(144),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=144,
                out_channels=192,
                kernel_size=7,
                bias=False,
            ),
            nn.BatchNorm2d(192),
            nn.ReLU(),
            PermutationOperation(),
            nn.Flatten(),
            nn.Linear(3072, self.num_classes, bias=False),
            nn.BatchNorm1d(10)
        )

        self.accuracy = Accuracy(num_classes=10)

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.accuracy(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.accuracy, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
        return [optimizer], [{"scheduler": scheduler, "monitor": "val_loss"}]

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True,
                               transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(
                self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE, num_workers=multiprocessing.cpu_count())

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE, num_workers=multiprocessing.cpu_count())

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE, num_workers=multiprocessing.cpu_count())


if __name__ == '__main__':
    model = SimpleMNISTCNN()
    trainer = Trainer(
        gpus=AVAIL_GPUS,
        max_epochs=MAX_EPOCHS,
        progress_bar_refresh_rate=20,
    )
    trainer.fit(model)
    test = pd.read_csv(KAGGLE_FILE)
    tensor = torch.from_numpy(np.array(test))
    print(tensor.shape)
    tensor = tensor.reshape((-1, 1, 28, 28)).type(dtype=torch.float32)
    print(tensor.shape)
    # print(test)
    results = None
    with torch.no_grad():
        inputs_to_predict = torch.split(tensor, BATCH_SIZE)
        for input_to_predict in inputs_to_predict:
            pred = model(input_to_predict)
            if results == None:
                results = pred
            else:
                results = torch.cat((results, pred), 0)
            print(results.shape)

    print(results.shape)
    np_pred = results.cpu().detach().numpy()
    y_pred = np.argmax(np_pred, axis=1)
    # creates de submission array, (28000,2)
    # First column = ImageId, Second Column = Label
    # ImageID follows the order of the Test File order
    submission = pd.DataFrame(
        {'ImageId': [i for i in range(1, len(np_pred)+1)], 'Label': y_pred})
    # saves the submission file
    submission.to_csv('submission.csv', index=False)
    submission
