In [1]:
import os

from typing import Tuple, Dict, List, Any

import lightning as L
import numpy as np
import torch
from torch import nn
from torch.optim import lr_scheduler, Optimizer
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchmetrics.functional.classification import multiclass_accuracy

In [2]:
class MLP_Block(nn.Module):
    """Building block for MLP-based models."""

    def __init__(
        self, hidden_size: int, activation: nn.Module, depth: int
    ) -> None:
        """Initialization of the MLP block.

        Args:
            hidden_size: Number of neurons in the linear layer.
            activation: Activation function.
            depth: Number of MLP blocks (linear layer with activation).
        """
        super(MLP_Block, self).__init__()
        layers = []
        for _ in range(depth):
            linear = nn.Linear(hidden_size, hidden_size)
            layers.append(linear)
            layers.append(activation)
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        """Propagates the input through the MLP block.

        Args:
            x: Input.

        Returns:
            Output of the network.
        """
        return self.layers(x)


class MLP(nn.Module):
    def __init__(
        self,
        input_shape: Tuple[int],
        output_shape: Tuple[int],
        hidden_factor: int = 1,
        depth: int = 1,
    ) -> None:
        """Initialization of the multi-layer perceptron.

        Args:
            input_shape: Shape of the input.
            output_shape: Shape of the output.
            hidden_factor: Factor for multiplying with input length to
                determine the number of neurons in each hidden layer.
                Defaults to 1.
            depth: Number of hidden layers. Defaults to 1.
        """
        super().__init__()
        self.input_shape = input_shape
        self.output_shape = output_shape
        input_len = int(np.prod(input_shape))
        output_len = int(np.prod(output_shape))
        hidden_size = int(input_len * hidden_factor)

        self.layers = nn.ModuleList(
            [   
                nn.Flatten(),
                nn.Linear(input_len, hidden_size),  # Input layer
                MLP_Block(hidden_size, nn.ReLU(), depth),
                nn.Linear(hidden_size, output_len),  # Output layer
            ]
        )

        self.layers = nn.Sequential(*self.layers)
    
    def forward(self, x):
        """Propagates the input through the MLP block.

        Args:
            x: Input.

        Returns:
            Output of the network.
        """
        return self.layers(x)

In [3]:
class LitClassificationModel(L.LightningModule):

    def __init__(
        self,
        net: str,
        lr: float,
        num_classes: int,
        criterion,
        optimizer_class,
        step_size,
        scheduler_class,
        ) -> None:
        """Initialization of the custom Lightning Module.

        Args:
            model: Neural network model name.
            config: Neural network model and training config.
        """
        super().__init__()
        self.lr = lr
        self.num_classes = num_classes
        self.criterion = criterion
        self.optimizer_class = optimizer_class
        self.step_size = step_size
        self.scheduler_class = scheduler_class
        self.net = net

    def configure_optimizers(
        self,
    ) -> Tuple[Optimizer, lr_scheduler.LRScheduler]:
        """Configures the optimizer and scheduler based on the learning rate
            and step size.

        Returns:
            Configured optimizer and scheduler.
        """
        optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
        scheduler = self.scheduler_class(optimizer, self.step_size)
        return [optimizer], [scheduler]

    def infer_batch(
        self, batch: Dict[str, dict]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Propagate given batch through the Lightning Module.

        Args:
            batch: Batch containing the subjects.

        Returns:
            Model output and corresponding ground truth.
        """
        x, y = batch
        y_hat = self.net(x)
        return y_hat, y

    def training_step(self, batch: Dict[str, dict], batch_idx: int) -> float:
        """Infer batch on training data, log metrics and retrieve loss.

        Args:
            batch: Batch containing the subjects.
            batch_idx: Number displaying index of this batch.

        Returns:
            Calculated loss.
        """
        y_hat, y = self.infer_batch(batch)

        # Calculate loss
        loss = self.criterion(y_hat, y)

        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        # this is the test loop
        y_hat, y = self.infer_batch(batch)
        loss = self.criterion(y_hat, y)
        acc =  multiclass_accuracy(y_hat, y, num_classes=self.num_classes)
        self.log('test_loss', loss, prog_bar=True)
        self.log('acc', acc, prog_bar=True)


In [81]:
class SimpleFreqSpace(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, img):
        return torch.fft.rfft2(img)


class SimpleComplex2Vec(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x):
        return torch.view_as_real(x)

class BaseDataModule(L.LightningDataModule):
    
    def __init__(self):
        super().__init__()

    def train_dataloader(self) -> torch.utils.data.DataLoader:
        """Creates Dataloader for training phase.

        Returns:
            Dataloader for training phase.
        """
        return torch.utils.data.DataLoader(
            self.train_set, self.batch_size
        )

    def val_dataloader(self) -> torch.utils.data.DataLoader:
        """Creates Dataloader for validation phase.

        Returns:
            Dataloader for validation phase.
        """
        return torch.utils.data.DataLoader(
            self.val_set, self.batch_size
        )


class ImageNetDataModule(BaseDataModule):
    def __init__(self, data_dir: str, input_domain: str, batch_size: int = 32) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.input_domain = input_domain
        self.batch_size = batch_size

    def setup(self):
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        traindir = os.path.join(self.data_dir, 'train')
        valdir = os.path.join(self.data_dir, 'val')

        if self.input_domain == 'freq':
            domain_transfrom = [SimpleFreqSpace(), SimpleComplex2Vec()]
        else:
            domain_transform = [lambda x: x]

        self.train_set = ImageFolder(
            traindir,
            transforms.Compose(
                [
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                    *domain_transfrom,
                ]
            ),
        )

        self.val_set = ImageFolder(
            valdir,
            transforms.Compose(
                [
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                    *domain_transfrom,
                ]
            ),
        )


class MNISTDataModule(BaseDataModule):
    
    def __init__(self, input_domain: str, batch_size: int = 32) -> None:
        super().__init__()
        self.input_domain = input_domain
        self.batch_size = batch_size
    
    def prepare_data(self):
        # download
        datasets.MNIST(
            root="MNIST", download=True, train=False)
        datasets.MNIST(
            root="MNIST", download=True, train=False)

    def setup(self, stage: str):
        tensor_transform = transforms.ToTensor()
        
        if self.input_domain == 'freq':
            domain_transfrom = [SimpleFreqSpace(), SimpleComplex2Vec()]
        else:
            domain_transform = [lambda x: x]

        self.test_set = datasets.MNIST(
            root="MNIST", download=True, train=False, transform= transforms.Compose([tensor_transform ]))
        
        # use 20% of training data for validation
        train_set_size = int(len(train_set) * 0.8)
        valid_set_size = len(train_set) - train_set_size
        
        self.train_set, self.val_set = torch.utils.data.random_split(
            train_set, [train_set_size, valid_set_size], generator=seed)
        
        self.test_set = datasets.MNIST(
            root="MNIST", download=True, train=False, transform=transform)


In [82]:
class MNISTDataModule2(L.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def prepare_data(self):
        # download
        datasets.MNIST(self.data_dir, train=True, download=True)
        datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = torch.utils.data.random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = datasets.MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict":
            self.mnist_predict = datasets.MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)

In [83]:
from collections import  defaultdict
class FindSumPairs:

    def __init__(self, nums1: List[int], nums2: List[int]):
        self.dict1 = defaultdict(int)
        for num in nums1:
            self.dict1[num] += 1
        self.dict2 = defaultdict(int)
        for num in nums2:
            self.dict2[num] += 1
        self.nums2 = nums2

    def add(self, index: int, val: int) -> None:
        print(self.dict2)
        self.dict2[self.nums2[index] + val] += 1
        self.dict2[self.nums2[index]] -= 1 

    def count(self, tot: int) -> int:
        print(self.dict2)
        ret = 0
        for num, count in self.dict1.items():
            ret += (count * self.dict2[tot - num])
        return ret

In [84]:
transform = transforms.ToTensor()

train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)

# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = torch.utils.data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
train_loader = DataLoader(train_set, batch_size=200)

In [91]:
net = MLP([28, 28], [10])
criterion = torch.nn.CrossEntropyLoss()
lr = 0.0001
num_classes = 10
step_size = 300
scheduler_class = torch.optim.lr_scheduler.StepLR
optimizer_class = torch.optim.Adam
lit_model = LitClassificationModel(
    net, lr, num_classes, criterion,
    optimizer_class, step_size, scheduler_class
)
datamodule = MNISTDataModule2('pixel')

In [92]:
trainer = L.Trainer(max_epochs=6)
trainer.fit(model=lit_model, datamodule=datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to pixel/MNIST/raw/train-images-idx3-ubyte.gz



  0%|                                                                                                                                                                                          | 0/9912422 [00:00<?, ?it/s][A
  1%|██▏                                                                                                                                                                     | 131072/9912422 [00:00<00:08, 1118112.55it/s][A
  6%|██████████▌                                                                                                                                                             | 622592/9912422 [00:00<00:02, 3182846.36it/s][A
 10%|████████████████▊                                                                                                                                                        | 983040/9912422 [00:01<00:13, 680625.39it/s][A
 12%|████████████████████▌                                                                                 

Extracting pixel/MNIST/raw/train-images-idx3-ubyte.gz to pixel/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to pixel/MNIST/raw/train-labels-idx1-ubyte.gz



  0%|                                                                                                                                                                                            | 0/28881 [00:00<?, ?it/s][A
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 240154.66it/s][A


Extracting pixel/MNIST/raw/train-labels-idx1-ubyte.gz to pixel/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to pixel/MNIST/raw/t10k-images-idx3-ubyte.gz



  0%|                                                                                                                                                                                          | 0/1648877 [00:00<?, ?it/s][A
  2%|███▍                                                                                                                                                                      | 32768/1648877 [00:00<00:08, 193498.32it/s][A
  6%|██████████▏                                                                                                                                                               | 98304/1648877 [00:00<00:07, 199538.25it/s][A
 12%|████████████████████▏                                                                                                                                                    | 196608/1648877 [00:00<00:04, 313899.15it/s][A
 16%|██████████████████████████▊                                                                           

Extracting pixel/MNIST/raw/t10k-images-idx3-ubyte.gz to pixel/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to pixel/MNIST/raw/t10k-labels-idx1-ubyte.gz



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 320834.80it/s][A

  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | net       | MLP              | 1.2 M 
-----------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.955     Total estimated model params size (MB)


Extracting pixel/MNIST/raw/t10k-labels-idx1-ubyte.gz to pixel/MNIST/raw

Epoch 5: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1719/1719 [00:34<00:00, 49.64it/s, v_num=46, train_loss=0.0136]

`Trainer.fit` stopped: `max_epochs=6` reached.


Epoch 5: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1719/1719 [00:34<00:00, 49.60it/s, v_num=46, train_loss=0.0136]


In [93]:
trainer.test(model=lit_model, dataloaders=DataLoader(test_set))

Testing DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:28<00:00, 355.74it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
           acc              0.9682000279426575
        test_loss           0.34421759843826294
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.34421759843826294, 'acc': 0.9682000279426575}]