In [13]:
import os
from tqdm import tqdm_notebook as tqdm
from typing import Tuple, Dict, List, Any

import lightning as L
import numpy as np
import torch
from torch import nn
from torch.optim import lr_scheduler, Optimizer
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchmetrics.functional.classification import multiclass_accuracy

In [None]:
class MLP_Block(nn.Module):
    """Building block for MLP-based models."""

    def __init__(
        self, hidden_size: int, activation: nn.Module, depth: int
    ) -> None:
        """Initialization of the MLP block.

        Args:
            hidden_size: Number of neurons in the linear layer.
            activation: Activation function.
            depth: Number of MLP blocks (linear layer with activation).
        """
        super(MLP_Block, self).__init__()
        layers = []
        for _ in range(depth):
            linear = nn.Linear(hidden_size, hidden_size)
            layers.append(linear)
            layers.append(activation)
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        """Propagates the input through the MLP block.

        Args:
            x: Input.

        Returns:
            Output of the network.
        """
        return self.layers(x)


class MLP(nn.Module):
    def __init__(
        self,
        input_shape: Tuple[int],
        output_shape: Tuple[int],
        hidden_factor: int = 1,
        depth: int = 1,
    ) -> None:
        """Initialization of the multi-layer perceptron.

        Args:
            input_shape: Shape of the input.
            output_shape: Shape of the output.
            hidden_factor: Factor for multiplying with input length to
                determine the number of neurons in each hidden layer.
                Defaults to 1.
            depth: Number of hidden layers. Defaults to 1.
        """
        super().__init__()
        self.input_shape = input_shape
        self.output_shape = output_shape
        input_len = int(np.prod(input_shape))
        output_len = int(np.prod(output_shape))
        hidden_size = int(input_len * hidden_factor)

        self.layers = nn.ModuleList(
            [   
                nn.Flatten(),
                nn.Linear(input_len, hidden_size),  # Input layer
                MLP_Block(hidden_size, nn.ReLU(), depth),
                nn.Linear(hidden_size, output_len),  # Output layer
            ]
        )

        self.layers = nn.Sequential(*self.layers)
    
    def forward(self, x):
        """Propagates the input through the MLP block.

        Args:
            x: Input.

        Returns:
            Output of the network.
        """
        return self.layers(x)

In [None]:
class LitClassificationModel(L.LightningModule):

    def __init__(
        self,
        net: str,
        lr: float,
        num_classes: int,
        criterion,
        optimizer_class,
        step_size,
        scheduler_class,
        ) -> None:
        """Initialization of the custom Lightning Module.

        Args:
            model: Neural network model name.
            config: Neural network model and training config.
        """
        super().__init__()
        self.lr = lr
        self.num_classes = num_classes
        self.criterion = criterion
        self.optimizer_class = optimizer_class
        self.step_size = step_size
        self.scheduler_class = scheduler_class
        self.net = net

    def configure_optimizers(
        self,
    ) -> Tuple[Optimizer, lr_scheduler.LRScheduler]:
        """Configures the optimizer and scheduler based on the learning rate
            and step size.

        Returns:
            Configured optimizer and scheduler.
        """
        optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
        scheduler = self.scheduler_class(optimizer, self.step_size)
        return [optimizer], [scheduler]

    def infer_batch(
        self, batch: Dict[str, dict]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Propagate given batch through the Lightning Module.

        Args:
            batch: Batch containing the subjects.

        Returns:
            Model output and corresponding ground truth.
        """
        x, y = batch
        y_hat = self.net(x)
        return y_hat, y

    def training_step(self, batch: Dict[str, dict], batch_idx: int) -> float:
        """Infer batch on training data, log metrics and retrieve loss.

        Args:
            batch: Batch containing the subjects.
            batch_idx: Number displaying index of this batch.

        Returns:
            Calculated loss.
        """
        y_hat, y = self.infer_batch(batch)

        # Calculate loss
        loss = self.criterion(y_hat, y)

        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        # this is the test loop
        y_hat, y = self.infer_batch(batch)
        loss = self.criterion(y_hat, y)
        acc =  multiclass_accuracy(y_hat, y, num_classes=self.num_classes)
        self.log('test_loss', loss, prog_bar=True)
        self.log('acc', acc, prog_bar=True)


In [None]:
class SimpleFreqSpace(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, img):
        return torch.fft.rfft2(img)


class SimpleComplex2Vec(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x):
        n, m = x.shape[-2], x.shape[-1]
        return torch.cat(
            [
                torch.stack(
                [
                    torch.cat(
                        [
                            x[:, : n // 2 + 1, 0:1].real,
                            x[:, 1 : (n + 1) // 2, 0:1].imag
                        ], 1),
                    torch.cat(
                        [
                            x[:, : n // 2 + 1, m - 1 : m].real,
                            x[:, 1 : (n + 1) // 2, m - 1 : m].imag
                        ], 1)],
                dim=3),
              torch.view_as_real(x[:, :, 1:-1])],
            dim=2)

class BaseDataModule(L.LightningDataModule):
    
    def __init__(self, domain: str):
        super().__init__()
        self.domain = domain

        if self.domain == 'freq':
            self.domain_transform = transforms.Compose([SimpleFreqSpace(), SimpleComplex2Vec()])
        else:
            self.domain_transform = torch.nn.Identity()
        

    def train_dataloader(self) -> torch.utils.data.DataLoader:
        """Creates Dataloader for training phase.

        Returns:
            Dataloader for training phase.
        """
        return torch.utils.data.DataLoader(
            self.train_set, self.batch_size
        )

    def val_dataloader(self) -> torch.utils.data.DataLoader:
        """Creates Dataloader for validation phase.

        Returns:
            Dataloader for validation phase.
        """
        return torch.utils.data.DataLoader(
            self.val_set, self.batch_size
        )


class ImageNetDataModule(BaseDataModule):
    def __init__(self, data_dir: str, input_domain: str, batch_size: int = 32) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.input_domain = input_domain
        self.batch_size = batch_size

    def setup(self):
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        traindir = os.path.join(self.data_dir, 'train')
        valdir = os.path.join(self.data_dir, 'val')

        

        self.train_set = ImageFolder(
            traindir,
            transforms.Compose(
                [
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                    *domain_transfrom,
                ]
            ),
        )

        self.val_set = ImageFolder(
            valdir,
            transforms.Compose(
                [
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                    *domain_transfrom,
                ]
            ),
        )


class MNISTDataModule(BaseDataModule):
    
    def __init__(self, domain: str, batch_size: int = 32) -> None:
        super().__init__(domain=domain)
        self.batch_size = batch_size
    
    def prepare_data(self):
        # download
        datasets.MNIST(root='MNIST', download=True, train=True)
        datasets.MNIST(root='MNIST', download=True, train=False)

    def setup(self, stage: str):
        tensor_transform = transforms.ToTensor()

        self.test_set = datasets.MNIST(
            root='MNIST', download=True, train=False,
            transform= transforms.Compose([tensor_transform, self.domain_transform]))

        data_set = datasets.MNIST(
            root='MNIST', download=True, train=True,
            transform= transforms.Compose([tensor_transform, self.domain_transform]))
        
        # use 20% of training data for validation
        train_set_size = int(len(data_set) * 0.8)
        valid_set_size = len(data_set) - train_set_size
        
        self.train_set, self.val_set = torch.utils.data.random_split(
            data_set, [train_set_size, valid_set_size], generator=seed)

def CFAR10DataModule(BaseDataModule):

    def __init__(self, domain: str, batch_szie: int = 32) -> None:
        super().__init__(domain=domain)
        self.batch_size = batch_size

    def prepare_data(self):
        # download
        datasets.CIFAR10(root='CIFAR10', download=True, train=True)
        datasets.CIFAR10(root='CIFAR10', download=True, train=False)



In [8]:
from collections import  defaultdict
class FindSumPairs:

    def __init__(self, nums1: List[int], nums2: List[int]):
        self.dict1 = defaultdict(int)
        for num in nums1:
            self.dict1[num] += 1
        self.dict2 = defaultdict(int)
        for num in nums2:
            self.dict2[num] += 1
        self.nums2 = nums2

    def add(self, index: int, val: int) -> None:
        print(self.dict2)
        self.dict2[self.nums2[index] + val] += 1
        self.dict2[self.nums2[index]] -= 1 

    def count(self, tot: int) -> int:
        print(self.dict2)
        ret = 0
        for num, count in self.dict1.items():
            ret += (count * self.dict2[tot - num])
        return ret

In [9]:
transform = transforms.ToTensor()

train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)

# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = torch.utils.data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
train_loader = DataLoader(train_set, batch_size=200)

In [10]:
net = MLP([28, 28], [10])
criterion = torch.nn.CrossEntropyLoss()
lr = 0.00001
num_classes = 10
step_size = 300
scheduler_class = torch.optim.lr_scheduler.StepLR
optimizer_class = torch.optim.Adam
lit_model = LitClassificationModel(
    net, lr, num_classes, criterion,
    optimizer_class, step_size, scheduler_class
)
datamodule = MNISTDataModule('freq')

NameError: name 'MLP' is not defined

In [11]:
trainer = L.Trainer(max_epochs=10)
trainer.fit(model=lit_model, datamodule=datamodule)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/reghbai7/repos/domain-exploration/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


NameError: name 'lit_model' is not defined

In [9]:
trainer.test(model=lit_model, dataloaders=DataLoader(test_set))

/Users/reghbai7/repos/pyramidal/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing: |                                                                                                    …

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
           acc              0.0949999988079071
        test_loss           2.3090193271636963
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 2.3090193271636963, 'acc': 0.0949999988079071}]

In [11]:
for image, label in datasets.MNIST(root='MNIST', download=True, train=True):
    transforms.Compose([transforms.ToTensor(), SimpleFreqSpace(), SimpleComplex2Vec()])(image)

In [12]:
transforms.ToTensor()(image).shape

torch.Size([1, 28, 28])

In [41]:
n = 28
m = 28
x = transforms.Compose([SimpleFreqSpace()])(torch.rand([1, n, m]))
x[:, 0, :]

tensor([[386.2591+0.0000j,   3.3026+4.0290j,  -1.0987-14.9681j,
           3.3526+11.1068j,   2.9130-7.9891j,   4.3957-10.9407j,
          -7.0289+0.7846j,  -0.5100-0.4097j,  -2.1512+9.7052j,
          -3.0485-2.7976j,   5.1947-1.9727j,   3.0755+6.0210j,
           2.3637+4.5284j,  -0.9809+11.7598j,   3.5509+0.0000j]])

In [52]:
.shape

torch.Size([1, 28, 14, 2])

In [43]:
 torch.view_as_real(x[:, :, 1:-1]).shape


torch.Size([1, 28, 13, 2])

In [39]:
r[:, : n // 2 + 1, 0:1].imag

tensor([[[ 0.0000],
         [ 7.9913],
         [-5.1613],
         [ 0.2433],
         [-0.7990],
         [ 4.1034],
         [ 4.6483],
         [ 1.1475],
         [-7.2214],
         [ 1.2170],
         [ 3.7070],
         [ 4.6968],
         [-6.3063],
         [-8.0839],
         [ 0.0000]]])