In [None]:
import sys
!{sys.executable} -m pip uninstall tensorflow_datasets --y
!{sys.executable} -m pip install git+git://github.com/sputney13/datasets

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import os
import matplotlib.pyplot as plt
import sklearn

from trainer.models import unet, patchGAN
from trainer.utils import edges_2_shoes, cityscapes, oasbud, training_utils

## Edges2Shoes

In [None]:
train_ds, test_ds = edges_2_shoes.create_dataset(4)

In [None]:
gen = unet.UNet(input_shape = [256, 256, 3], out_channels = 3)
disc = patchGAN.PatchGAN(input_shape = [256, 256, 3])

In [None]:
# Train for 2 epochs just to see a result - increase the 2 -> 15 to see increased performance
trained_gen, trained_disc = training_utils.fit(train_ds, gen, disc, 2, test_ds)

In [None]:
for input_image, target in test_ds.take(3):
    training_utils.generate_images(trained_gen, input_image, target)

## Cityscapes

In [None]:
train_ds, test_ds = cityscapes.create_dataset()

In [None]:
gen = unet.UNet(input_shape = [256, 256, 3], out_channels = 3)
disc = patchGAN.PatchGAN()

In [None]:
# Train for 5 epochs just to see a result - increase the 5 -> 200 to see increased performance
trained_gen, trained_disc = training_utils.fit(train_ds, gen, disc, 5, test_ds)

In [None]:
for input_image, target in test_ds.take(3):
    training_utils.generate_images(trained_gen, input_image, target)

## OASBUD

In [None]:
dataset = tfds.load('oasbud/b_mode') # not yet merged into Tensorflow Datasets, must install @sputney13 git version

#### Train on Whole Set

In [None]:
bmode_ds = process_oasbud_for_gan(dataset)

In [None]:
gen = unet.UNet(input_shape = [1024, 512, 1], out_channels = 1)
disc = patchGAN.PatchGAN(input_shape = [1024, 512, 1])

In [None]:
# Train for 5 epochs just to see a result - increase the 5 -> 120 to see increased performance
trained_gen, trained_disc = training_utils.fit(bmode_ds, gen, disc, 5, LAMBDA = 50)

In [None]:
for input_image, target in bmode_ds.take(3):
    training_utils.generate_images(trained_gen, input_image, target)

#### Train by Class

In [None]:
malignant_ds, benign_ds = process_oasbud_for_gan_by_class(dataset)

Malignant GAN

In [None]:
gen = unet.UNet(input_shape = [1024, 512, 1], out_channels = 1)
disc = patchGAN.PatchGAN(input_shape = [1024, 512, 1])

In [None]:
# Train for 5 epochs just to see a result - increase the 5 -> 120 to see increased performance
mal_gen, mal_disc = training_utils.fit(malignant_ds, gen, disc, 5, LAMBDA = 50)

In [None]:
for input_image, target in malignant_ds.take(3):
    training_utils.generate_images(mal_gen, input_image, target)

Benign GAN

In [None]:
gen = unet.UNet(input_shape = [1024, 512, 1], out_channels = 1)
disc = patchGAN.PatchGAN(input_shape = [1024, 512, 1])

In [None]:
# Train for 5 epochs just to see a result - increase the 5 -> 120 to see increased performance
ben_gen, ben_disc = training_utils.fit(benign_ds, gen, disc, 5, LAMBDA = 50)

In [None]:
for input_image, target in benign_ds.take(3):
    training_utils.generate_images(ben_gen, input_image, target)