In [1]:
import os
from google.colab import drive
drive.mount('/content/gdrive/', force_remount=True)

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive/


In [2]:
os.chdir('/content/gdrive/My Drive/final_files/translation')
!ls

data		    helpers.py	      __pycache__	       translator.py
data_0_123	    networks.py       translated_patches       unit_trainer.py
gan_training.ipynb  new_configs.yaml  translate_patches.ipynb


In [3]:
!pip3 install tensorboardX

Collecting tensorboardX
[?25l  Downloading https://files.pythonhosted.org/packages/af/0c/4f41bcd45db376e6fe5c619c01100e9b7531c55791b7244815bac6eac32c/tensorboardX-2.1-py2.py3-none-any.whl (308kB)
[K     |█                               | 10kB 7.7MB/s eta 0:00:01[K     |██▏                             | 20kB 2.9MB/s eta 0:00:01[K     |███▏                            | 30kB 3.6MB/s eta 0:00:01[K     |████▎                           | 40kB 3.9MB/s eta 0:00:01[K     |█████▎                          | 51kB 3.4MB/s eta 0:00:01[K     |██████▍                         | 61kB 3.8MB/s eta 0:00:01[K     |███████▍                        | 71kB 4.2MB/s eta 0:00:01[K     |████████▌                       | 81kB 4.5MB/s eta 0:00:01[K     |█████████▌                      | 92kB 4.8MB/s eta 0:00:01[K     |██████████▋                     | 102kB 4.7MB/s eta 0:00:01[K     |███████████▊                    | 112kB 4.7MB/s eta 0:00:01[K     |████████████▊                   | 122kB 4.7

In [4]:
from helpers import data_loaders, output_sub_folders, write_loss, load_config, display_images, Timer
from unit_trainer import UNIT_Trainer
import torch.backends.cudnn as cudnn
import torch
try:
    from itertools import izip as zip
except ImportError: 
    pass
import os
import sys
import tensorboardX
import time

In [5]:
def main_training(opts):
    """ Main GAN training """
    cudnn.benchmark = True
    # Setup configs
    config = load_config(opts.config)
    epochs = config['epochs']
    # number of images to display during training
    display_size = config['display_size']

    # Setup model and data loader
    trainer = UNIT_Trainer(config)
    trainer.cuda()
    # domain 'a' contains healthy patches, 'b' contains unhealthy patches 
    train_loader_a, train_loader_b, test_loader_a, test_loader_b = data_loaders(config, opts.data_root)
    #trian/test patches to display during training
    train_display_images_a = torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).cuda()
    train_display_images_b = torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_a = torch.stack([test_loader_a.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_b = torch.stack([test_loader_b.dataset[i] for i in range(display_size)]).cuda()

    # Setup logger and output folders
    config_name = os.path.splitext(os.path.basename(opts.config))[0]
    train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + '/logs', config_name))
    timestr = time.strftime("%Y%m%d-%H%M")
    output_directory = os.path.join(opts.output_path + '/outputs', 'translation_{}_{}/'.format(config_name, timestr))
    checkpoint_directory, image_directory = output_sub_folders(output_directory)

    # Start training
    epoch = 0
    while True:
        for i, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)):
            images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach()

            with Timer("Time: %f"):
                # Main training code: update discriminator and then generator
                trainer.dis_update(images_a, images_b, config)
                trainer.gen_update(images_a, images_b, config)
                trainer.update_learning_rate()
                # measure time spent
                torch.cuda.synchronize()

            # Sample and display a few patches during training (saved in output directory)
            if (epoch + 1) % config['image_display_freq'] == 0:
                with torch.no_grad():
                    #translate the selected patches
                    test_image_outputs = trainer.sample_translate(test_display_images_a, test_display_images_b)
                    train_image_outputs = trainer.sample_translate(train_display_images_a, train_display_images_b)
                #display (save) the patches
                display_images(test_image_outputs, display_size, image_directory, 'test_%08d' % (epoch + 1))
                display_images(train_image_outputs, display_size, image_directory, 'train_%08d' % (epoch + 1))

            # Save network weights
            if (epoch + 1) % config['save_model_freq'] == 0:
                trainer.save(checkpoint_directory, epoch)

            # Save training stats in log file
            if (epoch + 1) % config['log_freq'] == 0:
                print("Iteration: %07d/%07d" % (epoch + 1, epochs))
                #add loss in tensorboard
                write_loss(epoch, trainer, train_writer)

            epoch += 1
            if epoch >= epochs:
                sys.exit('Finish training')

    return 

In [6]:
# training options
class options:
    #path to the config yaml file
    config = 'new_configs.yaml' 
    #path to the dataset
    data_root = 'data' 
    #path to the outputs
    output_path = 'data' 


In [8]:
opts = options()

# Main training
main_training(opts)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Time: 0.194028
Time: 0.188492
Time: 0.200214
Time: 0.186347
Iteration: 0095460/0100000
Time: 0.192608
Time: 0.191599
Time: 0.198875
Time: 0.188406
Time: 0.193589
Time: 0.189515
Time: 0.195085
Time: 0.194941
Time: 0.193424
Time: 0.199918
Iteration: 0095470/0100000
Time: 0.196553
Time: 0.200662
Time: 0.196229
Time: 0.192242
Time: 0.197051
Time: 0.193489
Time: 0.190558
Time: 0.201852
Time: 0.187171
Time: 0.203225
Iteration: 0095480/0100000
Time: 0.189604
Time: 0.193407
Time: 0.189831
Time: 0.193943
Time: 0.188239
Time: 0.195463
Time: 0.190405
Time: 0.195683
Time: 0.189867
Time: 0.199645
Iteration: 0095490/0100000
Time: 0.189723
Time: 0.193737
Time: 0.189879
Time: 0.196576
Time: 0.187449
Time: 0.201418
Time: 0.189979
Time: 0.193687
Time: 0.193188
Time: 0.197097
Iteration: 0095500/0100000
Time: 0.195359
Time: 0.197692
Time: 0.186945
Time: 0.192208
Time: 0.193287
Time: 0.197907
Time: 0.195150
Time: 0.193862
Time: 0.193413
Time:

SystemExit: ignored

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


ADDITIONAL EXPERIMENTS:

1. IF NOT ONLY SEVERE SAMPLES (DOMAIN B CONTAINS IMG1, 2, 3)

In [None]:
# training options
class options:
    #path to the config yaml file
    config = 'new_configs.yaml' 
    #path to the dataset
    data_root = 'data_0_123' 
    #path to the outputs
    output_path = 'data_0_123' 

opts = options()

# Main training
main_training(opts)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Time: 0.173966
Time: 0.171028
Time: 0.165036
Time: 0.159848
Iteration: 0095460/0100000
Time: 0.174592
Time: 0.158923
Time: 0.170306
Time: 0.171196
Time: 0.168635
Time: 0.163666
Time: 0.164917
Time: 0.161048
Time: 0.172331
Time: 0.171583
Iteration: 0095470/0100000
Time: 0.164581
Time: 0.160253
Time: 0.164961
Time: 0.160735
Time: 0.169586
Time: 0.170240
Time: 0.167606
Time: 0.168262
Time: 0.165839
Time: 0.160109
Iteration: 0095480/0100000
Time: 0.173449
Time: 0.170880
Time: 0.164809
Time: 0.163083
Time: 0.172115
Time: 0.162384
Time: 0.166371
Time: 0.169190
Time: 0.166683
Time: 0.164335
Iteration: 0095490/0100000
Time: 0.163925
Time: 0.160019
Time: 0.177219
Time: 0.169801
Time: 0.164320
Time: 0.158610
Time: 0.166420
Time: 0.158977
Time: 0.169006
Time: 0.168420
Iteration: 0095500/0100000
Time: 0.175545
Time: 0.162740
Time: 0.161831
Time: 0.162347
Time: 0.172648
Time: 0.166001
Time: 0.164701
Time: 0.159638
Time: 0.169266
Time:

SystemExit: ignored

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


ADDITIONAL EXPERIMENTS

USED MEAN AND STD COMPUTED FROM THE TRAINING SET INSTEAD OF 0.5

DATE: JULY 9

In [None]:
# training options
class options:
    #path to the config yaml file
    config = 'new_configs.yaml' 
    #path to the dataset
    data_root = 'data' 
    #path to the outputs
    output_path = 'data' 

opts = options()

# Main training
main_training(opts)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Time: 0.194404
Time: 0.185497
Time: 0.185857
Time: 0.183830
Iteration: 0095460/0100000
Time: 0.188698
Time: 0.184376
Time: 0.191289
Time: 0.185995
Time: 0.186003
Time: 0.186492
Time: 0.187979
Time: 0.187976
Time: 0.191421
Time: 0.186649
Iteration: 0095470/0100000
Time: 0.191018
Time: 0.185156
Time: 0.189295
Time: 0.187482
Time: 0.191480
Time: 0.191237
Time: 0.198406
Time: 0.192151
Time: 0.199731
Time: 0.191901
Iteration: 0095480/0100000
Time: 0.192563
Time: 0.187572
Time: 0.189068
Time: 0.183532
Time: 0.187681
Time: 0.184953
Time: 0.189423
Time: 0.185541
Time: 0.187222
Time: 0.186871
Iteration: 0095490/0100000
Time: 0.187379
Time: 0.188811
Time: 0.186288
Time: 0.185232
Time: 0.200468
Time: 0.185563
Time: 0.186342
Time: 0.184815
Time: 0.187945
Time: 0.184180
Iteration: 0095500/0100000
Time: 0.187626
Time: 0.184246
Time: 0.190568
Time: 0.186242
Time: 0.187793
Time: 0.192426
Time: 0.187620
Time: 0.187880
Time: 0.187181
Time:

SystemExit: ignored

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
