In [1]:
import os
import torch
import torch.optim as optim
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import time

from dataset import Agulhas
from models.CycleGAN import Discriminator, Generator
from joint_transforms import Transform

from utils import save_checkpoint, load_checkpoint, save_examples, csv_writer

In [2]:
MODEL="CycleGAN"
LEARNING_RATE=1.0e-5
NUM_EPOCHS=30
INPUT_SIZE=256
BATCH_SIZE=2
NUM_WORKERS=0
LAMBDA_CYCLE=10
LAMBDA_IDENTITY=0
BETAS=[0.5, 0.999]
MOMENTUM=0.0
WEIGHT_DECAY=0.0
POWER=0.0

LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10

SCRATCH_BUCKET = os.environ['SCRATCH_BUCKET']
SNAPSHOT_DIR = os.path.join(MODEL, 'snapshots')
print(SNAPSHOT_DIR)

USE_CHECKPOINT = False
RESTORE_FROM = os.path.join(MODEL, 'snapshots', 'epoch-0')

CycleGAN/snapshots


In [3]:
def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse):
    
    H_reals = 0
    H_fakes = 0
    
    logger = dict()
    running_disc_loss = 0.0
    running_gen_loss = 0.0
    for idx, (zebra, horse) in enumerate(loader):
        zebra = zebra.cuda()
        horse = horse.cuda()

        # Train Discriminators H and Z

        fake_horse = gen_H(zebra)
        D_H_real = disc_H(horse)
        D_H_fake = disc_H(fake_horse.detach())
        # H_reals += D_H_real.mean().item()
        # H_fakes += D_H_fake.mean().item()
        D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
        D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
        D_H_loss = D_H_real_loss + D_H_fake_loss

        fake_zebra = gen_Z(horse)
        D_Z_real = disc_Z(zebra)
        D_Z_fake = disc_Z(fake_zebra.detach())
        D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
        D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
        D_Z_loss = D_Z_real_loss + D_Z_fake_loss

        # put it togethor
        D_loss = (D_H_loss + D_Z_loss) / 2
        
        opt_disc.zero_grad()
        D_loss.backward()
        opt_disc.step()

        # Train Generators H and Z
        # adversarial loss for both generators
        D_H_fake = disc_H(fake_horse)
        D_Z_fake = disc_Z(fake_zebra)
        loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
        loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

        # cycle loss
        cycle_zebra = gen_Z(fake_horse)
        cycle_horse = gen_H(fake_zebra)
        cycle_zebra_loss = l1(zebra, cycle_zebra)
        cycle_horse_loss = l1(horse, cycle_horse)

        # identity loss (remove these for efficiency if you set lambda_identity=0)
        identity_zebra = gen_Z(zebra)
        identity_horse = gen_H(horse)
        identity_zebra_loss = l1(zebra, identity_zebra)
        identity_horse_loss = l1(horse, identity_horse)

        # add all togethor
        G_loss = (
            loss_G_Z
            + loss_G_H
            + cycle_zebra_loss * LAMBDA_CYCLE
            + cycle_horse_loss * LAMBDA_CYCLE
            + identity_horse_loss * LAMBDA_IDENTITY
            + identity_zebra_loss * LAMBDA_IDENTITY
        )

        opt_gen.zero_grad()
        G_loss.backward()
        opt_gen.step()
        
        # Statistics
        D_loss = D_loss.detach()
        G_loss = G_loss.detach()
        running_disc_loss += D_loss.item()
        running_gen_loss += G_loss.item()
        

        if idx % 200 == 0:
            print(f"Discriminator Loss: {D_loss.item():.5f}, Generator Loss: {G_loss.item():.5f}")
            
    logger['disc_loss'] = running_disc_loss / len(loader)
    logger['gen_loss'] = running_gen_loss / len(loader)

    return logger

def val_loop(dataloader, transform_params, model, saving_path):

    model.eval()
    with torch.no_grad():
        for counter, (x, y) in enumerate(dataloader, 1):

            # GPU deployment
            x = x.cuda()
            y = y.cuda()

            # Compute prediction and loss
            y_fake = model(x)

            save_examples(x, y, y_fake, transform_params, counter, saving_path)
            if counter == 5:
                break

In [4]:
def main():
    
    since = time.time()
    
    cudnn.enabled = True
    cudnn.benchmark = True
    
    print(f"{MODEL} is deployed on {torch.cuda.get_device_name(0)}")
    
    try:
        os.makedirs(SNAPSHOT_DIR)
    except FileExistsError:
        pass
    
    # Loading model
    disc_H = Discriminator(in_channels=1).cuda()
    disc_Z = Discriminator(in_channels=1).cuda()
    gen_Z = Generator(img_channels=1, num_residuals=9).cuda()
    gen_H = Generator(img_channels=1, num_residuals=9).cuda()
    
    opt_disc = optim.Adam(
        list(disc_H.parameters()) + list(disc_Z.parameters()),
        lr=LEARNING_RATE,
        betas=BETAS,
    )

    opt_gen = optim.Adam(
        list(gen_Z.parameters()) + list(gen_H.parameters()),
        lr=LEARNING_RATE,
        betas=BETAS,
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()
    
    joint_transforms = Transform()
    
    train_dataset = Agulhas(split='train', joint_transform=joint_transforms)
    val_dataset = Agulhas(split='val', joint_transform=joint_transforms)

    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                                  num_workers=NUM_WORKERS, pin_memory=False, drop_last=False)

    val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True,
                                num_workers=NUM_WORKERS, pin_memory=False, drop_last=False)
    
    transform_params = dict()
    transform_params['inputs_mean'] = train_dataset.inps_mean_std[0]
    transform_params['inputs_std'] = train_dataset.inps_mean_std[1]
    transform_params['targets_mean'] = train_dataset.tars_mean_std[0]
    transform_params['targets_std'] = train_dataset.tars_mean_std[1]

    if USE_CHECKPOINT:
        load_checkpoint(
            f'{RESTORE_FROM}/gen_h.pth',
            gen_H,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
            f'{RESTORE_FROM}/gen_z.pth',
            gen_Z,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
            f'{RESTORE_FROM}/disc_h.pth',
            disc_H,
            opt_disc,
            LEARNING_RATE,
        )
        load_checkpoint(
            f'{RESTORE_FROM}/disc_z.pth',
            disc_Z,
            opt_disc,
            LEARNING_RATE,
        )


    for epoch in range(NUM_EPOCHS):
        logger = train_fn(
            disc_H,
            disc_Z,
            gen_Z,
            gen_H,
            train_dataloader,
            opt_disc,
            opt_gen,
            L1,
            mse
        )
        
        logger['epoch'] = epoch
        csv_writer(logger, SNAPSHOT_DIR)
        
        try:
            current_epoch_directory = os.path.join(SNAPSHOT_DIR, f'epoch-{epoch:02d}')
            os.makedirs(os.path.join(current_epoch_directory, 'Z'))
            os.makedirs(os.path.join(current_epoch_directory, 'H'))
        except FileExistsError:
            pass

        val_loop(val_dataloader, transform_params, gen_Z, f'{current_epoch_directory}/Z')
        val_loop(val_dataloader, transform_params, gen_H, f'{current_epoch_directory}/H') 
        
        if epoch % 20 == 0:
            save_checkpoint(gen_H, opt_gen, filename=os.path.join(current_epoch_directory, "gen_h.pth"))
            save_checkpoint(gen_Z, opt_gen, filename=os.path.join(current_epoch_directory, "gen_z.pth"))
            save_checkpoint(disc_H, opt_disc, filename=os.path.join(current_epoch_directory, "disc_h.pth"))
            save_checkpoint(disc_Z, opt_disc, filename=os.path.join(current_epoch_directory, "disc_z.pth"))
        
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')



In [5]:
main()

CycleGAN is deployed on Tesla T4
Discriminator Loss: 0.52753, Generator Loss: 17.37900
Discriminator Loss: 0.35851, Generator Loss: 12.86691
Discriminator Loss: 0.22450, Generator Loss: 6.39027
Discriminator Loss: 0.20615, Generator Loss: 5.70717
Discriminator Loss: 0.26843, Generator Loss: 13.51129
Discriminator Loss: 0.15081, Generator Loss: 5.61358
Discriminator Loss: 0.07458, Generator Loss: 5.54880
Discriminator Loss: 0.10078, Generator Loss: 6.13331
Saving Checkpoint...
Saving Checkpoint...
Saving Checkpoint...
Saving Checkpoint...
Discriminator Loss: 0.14689, Generator Loss: 7.61846
Discriminator Loss: 0.11928, Generator Loss: 4.76870
Discriminator Loss: 0.04901, Generator Loss: 3.30104
Discriminator Loss: 0.09374, Generator Loss: 9.10740
Discriminator Loss: 0.04731, Generator Loss: 7.80068
Discriminator Loss: 0.09533, Generator Loss: 9.12804
Discriminator Loss: 0.06687, Generator Loss: 7.44008
Discriminator Loss: 0.05881, Generator Loss: 5.96836
Discriminator Loss: 0.05224, Gen

KeyboardInterrupt: 