In [1]:
import os
import torch
import shutil
import zipfile
import itertools
import torchvision
from PIL import Image
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset


In [2]:
# Data
def load_data():
    # Download data if not present
    os.makedirs("data", exist_ok=True)
    if not os.path.exists("data/summer2winter_yosemite.zip"):
        print("Downloading data...")
        os.system(
            "wget http://efrosgans.eecs.berkeley.edu/cyclegan/datasets/summer2winter_yosemite.zip --directory-prefix=data"
        )
    if len(os.listdir("data")) == 1:
        print("Extracting data...")
        with zipfile.ZipFile("data/summer2winter_yosemite.zip", "r") as zip_ref:
            zip_ref.extractall("data")


class UnpairedImageDataset(Dataset):
    def __init__(self, root_A_dir, root_B_dir):
        self.transform = transforms.Compose(
            [
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]
        )
        self.files_A = sorted(
            [os.path.join(root_A_dir, x) for x in os.listdir(root_A_dir)]
        )
        self.files_B = sorted(
            [os.path.join(root_B_dir, x) for x in os.listdir(root_B_dir)]
        )

    def __len__(self):
        return 10#min(len(self.files_A), len(self.files_B))

    def __getitem__(self, index):
        img_A = self.transform(Image.open(self.files_A[index]))
        img_B = self.transform(Image.open(self.files_B[index]))
        return img_A, img_B


In [3]:
# Generator, adopts the UNET architecture
class Generator(nn.Module):
    def __init__(self, emb_d: int, n_down_blocks: int, n_res_blocks: int) -> None:
        super().__init__()
        self.in_conv = nn.Sequential(
            nn.Conv2d(3, emb_d, 7, 1, 3),
            nn.BatchNorm2d(emb_d),
            nn.ReLU(),
            nn.Dropout(),
        )

        self.down_blocks = nn.ModuleList()
        for _ in range(n_down_blocks):
            self.down_blocks.append(DownBlock(emb_d, emb_d * 2))
            emb_d *= 2
        self.res_blocks = nn.Sequential(*[ResBlock(emb_d) for _ in range(n_res_blocks)])

        self.up_blocks = nn.ModuleList()
        for _ in range(n_down_blocks):
            self.up_blocks.append(UpBlock(emb_d*2, emb_d//2))
            emb_d //= 2

        self.out_conv = nn.Sequential(nn.Conv2d(emb_d, 3, 3, 1, 1), nn.Tanh())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.in_conv(x)
        # Down
        latents = []
        for down_block in self.down_blocks:
            x = down_block(x)
            latents.append(x)
        # Res
        x = self.res_blocks(x)
        # Up
        latents.reverse()
        for i, up_block in enumerate(self.up_blocks):
            x = up_block(torch.cat([x, latents[i]], dim=1))
        x = self.out_conv(x)
        return x


class DownBlock(nn.Module):
    def __init__(self, in_d, out_d) -> None:
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_d, out_d, 3, 2, 1),
            nn.InstanceNorm2d(out_d),
            nn.LeakyReLU(),
            nn.Dropout(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)


class ResBlock(nn.Module):
    def __init__(self, in_d) -> None:
        super().__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(in_d, in_d, 3, 1, 1),
            nn.BatchNorm2d(in_d),
            nn.ReLU(),
            nn.Dropout(),
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(in_d, in_d, 3, 1, 1),
            nn.BatchNorm2d(in_d),
            nn.ReLU(),
            nn.Dropout(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        res = x
        x = self.conv_1(x)
        x = self.conv_2(x)
        return x + res


class UpBlock(nn.Module):
    def __init__(self, in_d, out_d) -> None:
        super().__init__()
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(in_d, out_d, 3, 2, 1, 1),
            nn.InstanceNorm2d(out_d),
            nn.ReLU(),
            nn.Dropout(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)


In [4]:
# Discriminator, adopts the PatchGAN architecture
class Discriminator(nn.Module):
    def __init__(self, emb_d: int = 128) -> None:
        super().__init__()
        self.in_conv = nn.Sequential(
            nn.Conv2d(3, emb_d, 4, 2, 1),
            nn.ReLU(),
        )
        self.layers = nn.ModuleList()
        for i in range(4):
            self.layers.append(
                nn.Sequential(
                    nn.Conv2d(emb_d, emb_d*2, 4, 2, 1),
                    nn.InstanceNorm2d(emb_d*2),
                    nn.LeakyReLU(0.2, inplace=True),
                )
            )
            emb_d *= 2
        self.out_conv = nn.Sequential(nn.Conv2d(emb_d, 1, 3, 1, 1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.in_conv(x)
        for layer in self.layers:
            x = layer(x)
        x = self.out_conv(x)
        x = torch.flatten(x, 1, -1)
        return x


In [5]:
class Trainer:
    def __init__(
        self,
        n_ckpt_steps: int,
        n_log_steps: int,
        epochs: int,
        # Data parameters
        data_dir: str,
        batch_size: int,
        # Optimiser parameters
        lr: float,
        beta_1: float,
        beta_2: float,
        # Model parameters
        emb_d: int,
        n_down_blocks: int,
        n_res_blocks: int,
        flush_prev_logs: bool = True,
    ):
        """
        n_ckpt_steps: Saves a checkpoint every n_ckpt_steps
        epochs: Number of epochs to train for
        data_dir: Directory containing the data, must contain trainA and trainB folders
        batch_size: Batch size
        lr: Learning rate
        beta_1: Beta 1 for Adam optimiser
        beta_2: Beta 2 for Adam optimiser
        """
        if flush_prev_logs:
            shutil.rmtree("results/logs", ignore_errors=True)
        os.makedirs("results/logs", exist_ok=True)
        torch.cuda.empty_cache()
        # Converts original image to stylised image
        self.generator_A2B = Generator(emb_d, n_down_blocks, n_res_blocks)
        # Converts stylised image back to original image
        self.generator_B2A = Generator(emb_d, n_down_blocks, n_res_blocks)
        # Discriminator for original image
        self.discriminator_A = Discriminator()
        # Discriminator for stylised image
        self.discriminator_B = Discriminator()
        # Optimisers
        self.optim_G = torch.optim.Adam(
            itertools.chain(
                self.generator_A2B.parameters(), self.generator_B2A.parameters()
            ),
            lr=lr,
            betas=(beta_1, beta_2),
        )
        self.optim_D = torch.optim.Adam(
            itertools.chain(
                self.discriminator_A.parameters(), self.discriminator_B.parameters()
            ),
            lr=lr,
            betas=(beta_1, beta_2),
        )
        # Dataloaders
        assert os.path.exists(
            f"{data_dir}/trainA"
        ), f"Data directory {data_dir}/trainA does not exist"
        assert os.path.exists(
            f"{data_dir}/trainB"
        ), f"Data directory {data_dir}/trainB does not exist"
        assert os.path.exists(
            f"{data_dir}/testA"
        ), f"Data directory {data_dir}/testA does not exist"
        assert os.path.exists(
            f"{data_dir}/testB"
        ), f"Data directory {data_dir}/testB does not exist"
        train_dataset = UnpairedImageDataset(f"{data_dir}/trainA", f"{data_dir}/trainB")
        self.train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
        )
        test_dataset = UnpairedImageDataset(f"{data_dir}/testA", f"{data_dir}/testB")
        self.test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=True,
        )
        # Hyperparameters
        self.n_log_steps = n_log_steps
        self.n_ckpt_steps = n_ckpt_steps
        self.epochs = epochs
        self.batch_size = batch_size
        # Cast to device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.generator_A2B.to(self.device)
        self.generator_B2A.to(self.device)
        self.discriminator_A.to(self.device)
        self.discriminator_B.to(self.device)
        # Init variables
        self.global_step = 0
        # self.plotlosses = PlotLosses()

    def to_device(self, *args):
        return [arg.to(self.device) for arg in args]

    def run(self):
        for e in range(self.epochs):
            self.train_iter(e)

    def train_iter(self, epoch: int):
        # Set model to training mode
        self.generator_A2B.train()
        self.generator_B2A.train()
        self.discriminator_A.train()
        self.discriminator_B.train()
        # Train on a batch of images
        for i, (img_A, img_B) in enumerate(
            tqdm(self.train_loader, desc=f"Epoch: {epoch}", leave=False)
        ):
            img_A, img_B = self.to_device(img_A, img_B)
            # Train the generators
            self.optim_G.zero_grad()
            loss_G = self.compute_generator_loss(img_A, img_B)
            loss_G.backward()
            self.optim_G.step()
            # Train the discriminators
            self.optim_D.zero_grad()
            loss_D_A = self.compute_discriminator_loss(
                img_A, self.generator_B2A(img_B), self.discriminator_A
            )
            loss_D_B = self.compute_discriminator_loss(
                img_B, self.generator_A2B(img_A), self.discriminator_B
            )
            loss_D = loss_D_A + loss_D_B
            loss_D.backward()
            self.optim_D.step()
            # Log the losses
            if self.global_step % self.n_log_steps == 0:
                #     self.plotlosses.update(
                #         {"loss_G": loss_G.item(), "loss_D_A": loss_D_A.item(), "loss_D_B": loss_D_B.item()}
                #     )
                #     self.plotlosses.send()
                self.log_image(
                    self.generator_A2B(img_A), f"{self.global_step}_img_A_generated.png"
                )
                self.log_image(img_A, f"{self.global_step}_img_A.png")

            # Increment the global step
            self.global_step += 1
            if self.global_step % self.n_ckpt_steps == 0:
                self.save_checkpoint()

    def compute_generator_loss(self, img_A, img_B) -> torch.Tensor:
        """
        Compute the generator loss for CycleGAN
        this is the loss for both generators ie: A -> B and B -> A
        """
        # A -> B
        # Generate the stylised image
        generated_img_B = self.generator_A2B(img_A)
        # Reconstruct the original image
        generated_img_A = self.generator_B2A(generated_img_B)
        # Pass generated image through discriminator
        pred_D_B = self.discriminator_B(generated_img_B)
        # Compute the losses
        recon_loss = F.l1_loss(generated_img_A, img_A)  # Reconstruction loss
        adv_loss = F.mse_loss(pred_D_B, torch.ones_like(pred_D_B))  # Adversarial loss
        # B -> A
        generated_img_A = self.generator_B2A(img_B)
        generated_img_B = self.generator_A2B(generated_img_A)
        pred_D_A = self.discriminator_A(generated_img_A)
        recon_loss += F.l1_loss(generated_img_B, img_B)
        adv_loss += F.mse_loss(pred_D_A, torch.ones_like(pred_D_A))
        # Return the total loss
        return recon_loss + adv_loss

    def compute_discriminator_loss(
        self,
        real_img: torch.Tensor,
        generated_img: torch.Tensor,
        discriminator: nn.Module,
    ) -> torch.Tensor:
        """
        Computes the loss for each discriminator (A & B)
        this consists of the loss for the real images and the fake images
        """
        real_D = discriminator(real_img)
        fake_D = discriminator(generated_img)
        real_loss = F.mse_loss(real_D, torch.ones_like(real_D))
        fake_loss = F.mse_loss(fake_D, torch.zeros_like(fake_D))
        # Compute the losses
        return real_loss + fake_loss

    def save_checkpoint(self):
        torch.save(
            {
                "global_step": self.global_step,
                "generator_A2B": self.generator_A2B.state_dict(),
                "generator_B2A": self.generator_B2A.state_dict(),
                "discriminator_A": self.discriminator_A.state_dict(),
                "discriminator_B": self.discriminator_B.state_dict(),
                "optim_G": self.optim_G.state_dict(),
                    "optim_D": self.optim_D.state_dict(),
            },
            f"results/checkpoint_{self.global_step}.pth",
        )

    def load_checkpoint(self, path: str):
        checkpoint = torch.load(path)
        self.global_step = checkpoint["global_step"]
        self.generator_A2B.load_state_dict(checkpoint["generator_A2B"])
        self.generator_B2A.load_state_dict(checkpoint["generator_B2A"])
        self.discriminator_A.load_state_dict(checkpoint["discriminator_A"])
        self.discriminator_B.load_state_dict(checkpoint["discriminator_B"])
        self.optim_G.load_state_dict(checkpoint["optim_G"])
        self.optim_D.load_state_dict(checkpoint["optim_D"])

    def log_image(self, img: torch.Tensor, name: str):
        torchvision.utils.save_image(
            img,
            "results/logs" + name,
            normalize=True,
            range=(-1, 1),
        )


In [6]:
load_data()
trainer = Trainer(
    n_ckpt_steps=1000,
    n_log_steps=1000,
    epochs=10000,
    data_dir="data/summer2winter_yosemite",
    batch_size=4,
    lr=1e-4,
    beta_1=0.9,
    beta_2=0.999,
    emb_d=64,
    n_down_blocks=4,
    n_res_blocks=1,
    flush_prev_logs=True,
)

# Check model code
# Check train code


In [1]:
trainer.run()

NameError: name 'trainer' is not defined