In [3]:
import sys
import os

pwd = os.getcwd()
python_path = pwd[: pwd.rfind("/")]
sys.path.append(python_path)

In [4]:
import torch
import lightning
import jupyter_black
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from data import GenericDataset
from typing import List, Any
from lightning.pytorch.loggers import TensorBoardLogger
from progress.bar import ChargingBar

jupyter_black.load()

In [5]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.empty_cache()
torch.set_float32_matmul_precision("medium")

In [6]:
def get_dataloder(size: int, batch_size: int):
    substrates_dataset = GenericDataset(
        dir_path="/srv/data/raw_support_data",
        transform=transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Grayscale(),
                transforms.ToTensor(),
                transforms.Normalize(mean=0, std=1),
                transforms.CenterCrop(size=(800, 1280)),
                transforms.Resize((size, size)),
            ]
        ),
    )
    return DataLoader(dataset=substrates_dataset, batch_size=batch_size)

In [7]:
class DataModule(lightning.LightningDataModule):
    def __init__(
        self,
        size: int,
        batch_size: int = 128,
        num_workers: int = 0,
        shuffle: bool = False,
    ) -> None:
        super().__init__()
        self.save_hyperparameters(ignore=["size"])
        self.transform = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Grayscale(),
                transforms.ToTensor(),
                transforms.Normalize(mean=0, std=1),
                transforms.CenterCrop(size=(800, 1280)),
                transforms.Resize((size, size)),
            ]
        )

    def setup(self, stage: str) -> None:
        if stage == "fit" or stage is None:
            self.train = GenericDataset(
                dir_path="/srv/data/raw_support_data",
                transform=self.transform,
            )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train, **self.hparams)

In [8]:
class Pixel_norm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, a: torch.Tensor):
        b = a / torch.sqrt(torch.sum(a**2, dim=1, keepdim=True) + 10e-8)
        return b


class Generator(nn.Module):
    def __init__(
        self,
    ) -> None:
        super().__init__()
        self.in_channels = 512
        self.out_channels = 512
        self.model = nn.Sequential(*self._initial_block())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def grow(self, size: int) -> None:
        self._decrease_channels(size)
        self.model.extend(self._block(self.in_channels, self.out_channels))
        if size == 512:
            self.model.append(nn.Conv2d(16, 1, kernel_size=1))
        self.model = self.model.to(device)

    def _decrease_channels(self, size: int):
        if size > 16:
            self.out_channels //= 2
        if size > 32:
            self.in_channels //= 2

    def _initial_block(self) -> List[Any]:
        return [
            *self._conv(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=4,
                padding=3,
            ),
            *self._conv(
                in_channels=self.out_channels,
                out_channels=self.out_channels,
                kernel_size=3,
            ),
        ]

    def _block(self, in_channels: int, out_channels: int) -> List[Any]:
        return [
            nn.Upsample(scale_factor=2, mode="nearest"),
            *self._conv(in_channels, out_channels),
            *self._conv(out_channels, out_channels),
        ]

    def _conv(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        padding: int = 1,
    ) -> List[Any]:
        return [
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                padding=padding,
            ),
            nn.LeakyReLU(0.2),
            Pixel_norm(),
        ]

In [9]:
class Discriminator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.in_channels = 512
        self.out_channels = 512
        self.model = nn.Sequential(*self._initial_block())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def grow(self, size: int) -> None:
        self._decrease_channels(size)
        self.model = self._block(self.in_channels, self.out_channels) + self.model
        if size == 512:
            self.model = self._end_block() + self.model
        self.model = self.model.to(device)

    def _decrease_channels(self, size: int):
        if size > 16:
            self.in_channels //= 2
        if size > 32:
            self.out_channels //= 2

    def _initial_block(self) -> List[Any]:
        return [
            *self._conv(self.in_channels, self.out_channels, kernel_size=3),
            *self._conv(self.out_channels, self.out_channels, kernel_size=4, stride=3),
            nn.Flatten(),
            nn.Linear(512, 1),
        ]

    def _block(self, in_channels: int, out_channels: int):
        return nn.Sequential(
            *self._conv(in_channels, out_channels),
            *self._conv(out_channels, out_channels),
            nn.AvgPool2d(kernel_size=2),
        )

    def _end_block(self):
        return nn.Sequential(
            nn.Sequential(nn.Conv2d(1, 16, kernel_size=1)), nn.LeakyReLU(0.2)
        )

    def _conv(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        padding: int = 1,
        stride: int = 1,
    ) -> List[Any]:
        return [
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                padding=padding,
                stride=stride,
            ),
            nn.LeakyReLU(0.2),
        ]

In [10]:
def repeat_in_channels(x: torch.Tensor, channels: int) -> torch.Tensor:
    temp = np.array([x.cpu().numpy() for _ in range(channels)])
    shape = [*temp.shape]
    shape = (shape[1], channels, *shape[-2:])
    return torch.tensor(temp.reshape(shape)).to(device)


class PGGAN(nn.Module):
    def __init__(
        self,
    ) -> None:
        super().__init__()
        self.automatic_optimization = False

        self.D = Discriminator()
        self.G = Generator()
        self.size = 4

    def configure_optimizers(self, D_lr: float, G_lr: float):
        D_optim = optim.Adam(self.D.parameters(), lr=D_lr, betas=(0.5, 0.99))
        G_optim = optim.Adam(self.G.parameters(), lr=G_lr, betas=(0.5, 0.99))
        return [D_optim, G_optim]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.G(x)

    def grow(self, size: int) -> None:
        self.G.grow(size)
        self.D.grow(size)
        self.size = size

    def training_step(
        self,
        batch: torch.Tensor,
        batch_idx: int,
        D_optim_frequency: int = 1,
        G_optim_frequency: int = 1
    ):
        x = repeat_in_channels(
            batch, self.G.out_channels if self.size < 1024 else 1
        ).to(device)
        D_optim, G_optim = self.optimizers()  # type: ignore

        if batch_idx % D_optim_frequency == 0:
            z = torch.rand((len(batch), 512, 1, 1)).to(device)
            loss = -(torch.mean(self.D(x)) - torch.mean(self.D(self.G(z))))
            self.on_training_step_end(D_optim, loss)

        if batch_idx % G_optim_frequency == 0:
            z = torch.rand((len(batch), 512, 1, 1)).to(device)
            loss = -torch.mean(self.D(self.G(z)))
            self.on_training_step_end(G_optim, loss)

    def on_training_step_end(self, optim, loss: torch.Tensor) -> None:
        optim.zero_grad()
        loss.backward()
        optim.step()

In [12]:
class HParams:
    def __init__(self, size: int, batch_size: int, epochs: int):
        self.size = size
        self.batch_size = batch_size
        self.epochs = epochs


schedule = [
    HParams(size=4, batch_size=16, epochs=1),
    HParams(size=8, batch_size=16, epochs=1),
    HParams(size=16, batch_size=16, epochs=1),
    HParams(size=32, batch_size=16, epochs=1),
    HParams(size=64, batch_size=8, epochs=1),
    HParams(size=128, batch_size=4, epochs=1),
    HParams(size=256, batch_size=2, epochs=1),
    HParams(size=512, batch_size=1, epochs=1),
    HParams(size=1024, batch_size=1, epochs=1),
]

logger = TensorBoardLogger(save_dir="../../logs/bacterias_pggan")
pggan = PGGAN()
pggan.configure_optimizers(D_lr=0.0001, G_lr=0.002)

for hparams in schedule:
    dataloader = get_dataloder(hparams.size, batch_size=hparams.batch_size)
    for i, batch in enumerate(dataloader):
        pggan.training_step(batch, i)
    pggan.grow(hparams.size)



RuntimeError: PGGAN is not attached to a `Trainer`.

In [None]:
BATCH_SIZE = 1
SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024]
EPOCHS = 50
LR, BETAS = 0.001, (0.0, 0.99)

G = Generator().to(device)
D = Discriminator().to(device)

G_optim = optim.Adam(G.parameters(), lr=LR, betas=BETAS)
D_optim = optim.Adam(D.parameters(), lr=LR, betas=BETAS)

tbl = TensorBoardLogger(save_dir="../../logs/substrates_pggan")
tbl.log_hyperparams(
    {
        "epochs": EPOCHS,
        "optimizer": "Adam",
        "lr": LR,
        "b1": BETAS[0],
        "b2": BETAS[1],
        "batch_size": BATCH_SIZE,
    }
)

for size in SIZES:
    dataloader = get_dataloder(size, BATCH_SIZE)
    for epoch in range(EPOCHS):
        for _, batch in enumerate(dataloader):
            z = torch.rand((len(batch), 512, 1, 1)).to(device)
            x = repeat_in_channels(batch, G.out_channels if size < 1024 else 1).to(
                device
            )

            D_loss = -(torch.mean(D(x)) - torch.mean(D(G(z))))
            on_train_step_end(D_optim, D_loss)
            G_loss = -torch.mean(D(G(z)))
            on_train_step_end(G_optim, G_loss)

    G.grow(size)
    D.grow(size)



KeyboardInterrupt: 