In [None]:
import tensorflow as tf
import torch
torch.cuda.is_available()
tf.test.gpu_device_name()

'/device:GPU:0'

In [None]:
!nvidia-smi


Tue May 23 15:54:58 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P0    24W / 300W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!pwd

/content


In [None]:
ls

critich.pth.tar  genh.pth.tar  [0m[01;34mimages_dry[0m/  [01;34msaved_images[0m/
criticz.pth.tar  genz.pth.tar  [01;34mimages_wet[0m/


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

import os
%cd /content/gdrive/MyDrive/CycleGAN

cwd = os.getcwd()
print(cwd)

Mounted at /content/gdrive
/content/gdrive/MyDrive/CycleGAN
/content/gdrive/MyDrive/CycleGAN


# train.py

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
%cd /content/gdrive/MyDrive/CycleGAN

import torch
import sys
from PIL import Image
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image
import albumentations as A
from torch.utils.data import Dataset
import os, random, numpy as np
from albumentations.pytorch import ToTensorV2



DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

TRAIN_DIR = "."
VAL_DIR = "."
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
LAMBDA_IDENTITY = 1
LAMBDA_CYCLE = 4
NUM_WORKERS = 2
NUM_EPOCHS = 20
LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_GEN_H = "genh.pth.tar"
CHECKPOINT_GEN_Z = "genz.pth.tar"
CHECKPOINT_CRITIC_H = "critich.pth.tar"
CHECKPOINT_CRITIC_Z = "criticz.pth.tar"

transforms = A.Compose(
    [
        A.Resize(width=408, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={"image0": "image"},
)

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    print("Loading successful")


def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class DryWetDataset(Dataset):
    def __init__(self, root_wet, root_dry, transform=None, **kwargs):
        self.root_wet = root_wet
        self.root_dry = root_dry
        self.transform = transform

        self.wet_images = os.listdir(root_wet)
        self.dry_images = os.listdir(root_dry)
        self.length_dataset = max(len(self.wet_images), len(self.dry_images))
        self.wet_len = len(self.wet_images)
        self.dry_len = len(self.dry_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        wet_img = self.wet_images[index % self.wet_len]
        dry_img = self.dry_images[index % self.dry_len]

        wet_path = os.path.join(self.root_wet, wet_img)
        dry_path = os.path.join(self.root_dry, dry_img)

        wet_img = np.array(Image.open(wet_path).convert("RGB"))
        dry_img = np.array(Image.open(dry_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=wet_img, image0=dry_img)
            wet_img = augmentations["image"]
            dry_img = augmentations["image0"]

        return wet_img, dry_img

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity(),
        )

    def forward(self, x):
        return self.conv(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)


class Generator(nn.Module):
    def __init__(self, img_channels, num_features=64, num_residuals=22):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect",),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(
                    num_features, num_features * 2, kernel_size=3, stride=2, padding=1
                ),
                ConvBlock(num_features * 2, num_features * 4, kernel_size=3, stride=2, padding=1,),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features * 4, num_features * 2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1,),
                ConvBlock(num_features * 2, num_features * 1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1,),
            ]
        )

        self.last = nn.Conv2d(num_features * 1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect",)

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect",),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect",),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:  # skip first feature
            layers.append(
                Block(in_channels, feature, stride=1 if feature == features[-1] else 2)
            )
            in_channels = feature
        layers.append(
            nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect",)
        )
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

def train_fn(
    disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
    H_reals = 0
    H_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, (wet, dry) in enumerate(loop):
        wet = wet.to(DEVICE)
        dry = dry.to(DEVICE)

        # Train Discriminators H and Z
        with torch.cuda.amp.autocast():
            fake_dry = gen_H(wet)
            D_H_real = disc_H(dry)
            D_H_fake = disc_H(fake_dry.detach())
            H_reals += D_H_real.mean().item()
            H_fakes += D_H_fake.mean().item()
            D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_wet = gen_Z(dry)
            D_Z_real = disc_Z(wet)
            D_Z_fake = disc_Z(fake_wet.detach())
            D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
            D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
            D_Z_loss = D_Z_real_loss + D_Z_fake_loss

            # put it togethor
            D_loss = (D_H_loss + D_Z_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators H and Z
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_H_fake = disc_H(fake_dry)
            D_Z_fake = disc_Z(fake_wet)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

            # cycle loss
            cycle_wet = gen_Z(fake_dry)
            cycle_dry = gen_H(fake_wet)
            cycle_wet_loss = l1(wet, cycle_wet)
            cycle_dry_loss = l1(dry, cycle_dry)

            # identity loss
            identity_wet = gen_Z(wet)
            identity_dry = gen_H(dry)
            identity_wet_loss = l1(wet, identity_wet)  # faster without, but better output
            identity_dry_loss = l1(dry, identity_dry)

            # add all togethor
            G_loss = (
                loss_G_Z + loss_G_H
                + cycle_wet_loss * LAMBDA_CYCLE
                + cycle_dry_loss * LAMBDA_CYCLE
                + identity_dry_loss * LAMBDA_IDENTITY
                + identity_wet_loss * LAMBDA_IDENTITY
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 4 == 0:
            #save_image(fake_dry * 0.5 + 0.5, f"saved_images/dry_{idx}.png")
            save_image(fake_wet * 0.5 + 0.5, f"saved_images_morespray/wet_{idx}.png")
            #save_image(fake_wet * 0.5 + 0.5, f"saved_images_intense/wet_{idx}.png")
            #save_image(fake_wet * 0.5 + 0.5, f"saved_images_combined/wet_{idx}.png")


        loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))


def main():
    disc_H = Discriminator(in_channels=3).to(DEVICE)
    disc_Z = Discriminator(in_channels=3).to(DEVICE)
    gen_Z = Generator(img_channels=3, num_residuals=9).to(DEVICE)
    gen_H = Generator(img_channels=3, num_residuals=9).to(DEVICE)
    opt_disc = optim.Adam(list(disc_H.parameters()) + list(disc_Z.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999),)
    opt_gen = optim.Adam(list(gen_Z.parameters()) + list(gen_H.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999),)

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if LOAD_MODEL:
        #%cd checkpoint_combined
        %cd checkpoint_more
        #%cd checkpoint_intense

        load_checkpoint(
            CHECKPOINT_GEN_H, gen_H, opt_gen, LEARNING_RATE,)
        load_checkpoint(
            CHECKPOINT_GEN_Z, gen_Z, opt_gen, LEARNING_RATE,)
        load_checkpoint(
            CHECKPOINT_CRITIC_H, disc_H, opt_disc, LEARNING_RATE,)
        load_checkpoint(
            CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, LEARNING_RATE,)

        %cd ..

    dataset = DryWetDataset(
        root_dry=TRAIN_DIR+"/images_dry",
        #root_wet=TRAIN_DIR+"/images_wet_combined",
        #root_wet=TRAIN_DIR+"/images_wet_intense",
        root_wet=TRAIN_DIR+"/images_wet_morespray",
        transform=transforms,
    )
    #val_dataset = DryWetDataset(root_clear="cyclegan_test/dry1", root_wet="cyclegan_test/wet1",transform=transforms,)
    #val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, pin_memory=True,)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True,)
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        print("Epoch: ", epoch)
        train_fn(disc_H,
            disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler,)


        if SAVE_MODEL:
            #%cd checkpoint_combined
            %cd checkpoint_more
            #%cd checkpoint_intense
            save_checkpoint(gen_H, opt_gen, filename=CHECKPOINT_GEN_H)
            save_checkpoint(gen_Z, opt_gen, filename=CHECKPOINT_GEN_Z)
            save_checkpoint(disc_H, opt_disc, filename=CHECKPOINT_CRITIC_H)
            save_checkpoint(disc_Z, opt_disc, filename=CHECKPOINT_CRITIC_Z)
            %cd ..



if __name__ == "__main__":
    main()


# Neuer Abschnitt