In [11]:
import torch
import lightning
from torch import nn
from torch.utils.data import DataLoader

class PeriodicReLU(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.where(x > 0, torch.cos(x)*x, torch.sin(x))
        

class Model(lightning.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.loss = nn.SmoothL1Loss(reduction="mean")
        self.net = nn.Sequential(
            nn.Linear(100, 784),
            PeriodicReLU(),
            nn.Linear(784, 16),
            PeriodicReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid(),
            nn.ReLU()
        )

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        return self.net(inputs)

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.Adam(self.parameters(), lr=0.1)


    def _log_metrics(self, loss: float, step: str) -> None:
        self.log(f"{step}_loss", loss, prog_bar=True, logger=False, on_epoch=True, on_step=False)

    def training_step(self, batch: tuple[torch.Tensor, ...], batch_idx: int) -> torch.Tensor:
        return self._forward(batch, step="train")

    def validation_step(self, batch: tuple[torch.Tensor, ...], batch_idx: int) -> torch.Tensor:
        return self._forward(batch, step="valid")

    def test_step(self, batch: tuple[torch.Tensor, ...], batch_idx: int) -> torch.Tensor:
        return self._forward(batch, step="test")

    def _forward(self, batch: tuple[torch.Tensor, ...], step: str) -> torch.Tensor:
        x, y = batch
        x_hat = self.forward(x)
        loss = self.loss(x_hat, y)
        self._log_metrics(loss=loss, step=step)
        return loss


class TestDataset(torch.utils.data.Dataset):
    def __init__(self):
        super().__init__()

    def __len__(self):
        return 10

    def __getitem__(self, batch_idx):
        return torch.randn(100), torch.randn((1,))


In [12]:
dataloader = DataLoader(
    dataset=TestDataset(),
    batch_size=1,
    num_workers=0,
    shuffle=False,
    pin_memory=True,
    persistent_workers=False,
)

model = Model()

trainer: lightning.Trainer = lightning.Trainer(
    accelerator="gpu",
    num_nodes=1,
    precision=16,
    fast_dev_run=False,
    max_epochs=100,
    min_epochs=1,
    overfit_batches=0,
    log_every_n_steps=100,
    check_val_every_n_epoch=1,
    enable_checkpointing=True,
    enable_progress_bar=True,
    enable_model_summary=True,
    deterministic="warn",
    benchmark=True,
)

trainer.fit(model,  dataloader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type         | Params
--------------------------------------
0 | loss | SmoothL1Loss | 0     
1 | net  | Sequential   | 91.8 K
--------------------------------------
91.8 K    Trainable params
0         Non-trainable params
91.8 K    Total params
0.367     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
