In [1]:
import argparse
import logging
import os
import sys

import numpy as np
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm

from eval import eval_net
from unet import UNet

from torch.utils.tensorboard import SummaryWriter
#from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split

In [2]:
from unet import UNet

import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.utils as utils
import torchvision.datasets as dataset
import torchvision.transforms as transforms
from torch.utils.data import Dataset

from pathlib import Path

import numpy as np
import h5py

In [3]:
class BasicDataset(Dataset):
    
    def __init__(self, file_name, scale=1):
        
        self.hdf5_file_name = file_name
        self.scale = scale
        
        # load dataset
        self.images, self.labels = self.read_hdf5(self.hdf5_file_name)
        
        assert 0 < scale <= 1, 'Scale must be between 0 and 1'
        logging.info(f'Creating dataset with {len(self.images)} examples')
    
    def __len__(self):
        
        return len(self.images)
    
    @classmethod
    def preprocess(cls, img_nd):

        if len(img_nd.shape) == 2:
            img_nd = np.expand_dims(img_nd, axis=2)

        # HWC to CHW
        img_trans = img_nd.transpose((2, 0, 1))
        if img_trans.max() > 1:
            img_trans = img_trans / 255

        return img_trans
    
    def __getitem__(self, i):
        
        mask = self.labels[i]
        img = self.images[i]

        #assert img.size == mask.size, \
        #    f'Image and mask should be the same size, but are {img[:,:,1].size} and {mask.size}'

        img = self.preprocess(img)
        mask = self.preprocess(mask)

        return {
            'image': torch.from_numpy(img).type(torch.FloatTensor),
            'mask': torch.from_numpy(mask).type(torch.FloatTensor)
        }
    
    def read_hdf5(self, hdf5_file_name):
        """ Reads image from HDF5.
            Parameters:
            ---------------
            num_images   number of images to read

            Returns:
            ----------
            images      images array, (N, 32, 32, 3) to be stored
            labels      associated meta data, int label (N, 1)
        """
        images, labels = [], []

        # Open the HDF5 file
        file = h5py.File(hdf5_file_name, "r+")

        images = np.array(file["/ortho"]).astype("uint8")
        labels = np.array(file["/ground_truth"]).astype("uint8")

        return images, labels

In [4]:
dir_checkpoint = 'checkpoints/'
ds_file_name  = "/media/philipp/ed7d22ba-5a3b-4d31-bf6c-6add6e106b3d/test/256x256/1m/dataset_256.hdf5";
#dataset = BasicDataset(ds_file_name)

In [6]:
def train_net(net,
              device,
              epochs=5,
              batch_size=1,
              lr=0.001,
              val_percent=0.1,
              save_cp=True,
              img_scale=1):
    
    dataset = BasicDataset(ds_file_name)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)
    
    #writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
    global_step = 0

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images scaling:  {img_scale}
    ''')

    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                #writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % (n_train // (10 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        #writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
                        #writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
                    val_score = eval_net(net, val_loader, device)
                    scheduler.step(val_score)
                    #writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)

                    if net.n_classes > 1:
                        logging.info('Validation cross entropy: {}'.format(val_score))
                        #writer.add_scalar('Loss/test', val_score, global_step)
                    else:
                        logging.info('Validation Dice Coeff: {}'.format(val_score))
                        #writer.add_scalar('Dice/test', val_score, global_step)

                    #writer.add_images('images', imgs, global_step)
                    #if net.n_classes == 1:
                        #writer.add_images('masks/true', true_masks, global_step)
                        #writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')

    #writer.close()

In [8]:
def get_args():
    '''
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,
                        help='Number of epochs', dest='epochs')
    parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,
                        help='Batch size', dest='batchsize')
    parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.0001,
                        help='Learning rate', dest='lr')
    parser.add_argument('-f', '--load', dest='load', type=str, default=False,
                        help='Load model from a .pth file')
    parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5,
                        help='Downscaling factor of the images')
    parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')
    '''
    args = {'epochs':3, 'batchsize':15, 'lr':0.0001, 'loadfile':False, 'scale':1, 'val':10}

    return args

In [9]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
args = get_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')

# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
#   - For 1 class and background, use n_classes=1
#   - For 2 classes, use n_classes=1
#   - For N > 2 classes, use n_classes=N
net = UNet(n_channels=4, n_classes=1, bilinear=True)
logging.info(f'Network:\n'
             f'\t{net.n_channels} input channels\n'
             f'\t{net.n_classes} output channels (classes)\n'
             f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

if args['loadfile']:
    net.load_state_dict(
        torch.load(args['loadfile'], map_location=device)
    )
    logging.info(f'Model loaded')

net.to(device=device)
# faster convolutions, but more memory
# cudnn.benchmark = True

try:
    train_net(net=net,
              epochs=args['epochs'],
              batch_size=args['batchsize'],
              lr=args['lr'],
              device=device,
              img_scale=args['scale'],
              val_percent=args['val'] / 100)
except KeyboardInterrupt:
    torch.save(net.state_dict(), 'INTERRUPTED.pth')
    logging.info('Saved interrupt')
    try:
        sys.exit(0)
    except SystemExit:
        os._exit(0)

INFO: Using device cuda
INFO: Network:
	4 input channels
	1 output channels (classes)
	Bilinear upscaling
INFO: Creating dataset with 20928 examples
INFO: Starting training:
        Epochs:          3
        Batch size:      15
        Learning rate:   0.0001
        Training size:   18836
        Validation size: 2092
        Checkpoints:     True
        Device:          cuda
        Images scaling:  1
    
Epoch 1/3:  10%|▉         | 1875/18836 [00:35<05:14, 53.98img/s, loss (batch)=0.0129]
Validation round:   0%|          | 0/139 [00:00<?, ?batch/s][A
Validation round:   1%|          | 1/139 [00:00<00:45,  3.01batch/s][A
Validation round:   2%|▏         | 3/139 [00:00<00:35,  3.83batch/s][A
Validation round:   4%|▎         | 5/139 [00:00<00:28,  4.75batch/s][A
Validation round:   5%|▌         | 7/139 [00:00<00:23,  5.69batch/s][A
Validation round:   6%|▋         | 9/139 [00:01<00:19,  6.63batch/s][A
Validation round:   8%|▊         | 11/139 [00:01<00:17,  7.50batch/s][A
Val

Validation round:  13%|█▎        | 18/139 [00:01<00:12,  9.77batch/s][A
Validation round:  14%|█▍        | 20/139 [00:02<00:11, 10.03batch/s][A
Validation round:  16%|█▌        | 22/139 [00:02<00:11, 10.35batch/s][A
Validation round:  17%|█▋        | 24/139 [00:02<00:10, 10.57batch/s][A
Validation round:  19%|█▊        | 26/139 [00:02<00:10, 10.72batch/s][A
Validation round:  20%|██        | 28/139 [00:02<00:10, 10.84batch/s][A
Validation round:  22%|██▏       | 30/139 [00:03<00:10, 10.90batch/s][A
Validation round:  23%|██▎       | 32/139 [00:03<00:09, 10.91batch/s][A
Validation round:  24%|██▍       | 34/139 [00:03<00:09, 10.86batch/s][A
Validation round:  26%|██▌       | 36/139 [00:03<00:09, 10.84batch/s][A
Validation round:  27%|██▋       | 38/139 [00:03<00:09, 10.77batch/s][A
Validation round:  29%|██▉       | 40/139 [00:03<00:09, 10.75batch/s][A
Validation round:  30%|███       | 42/139 [00:04<00:09, 10.74batch/s][A
Validation round:  32%|███▏      | 44/139 [00:04<00

Validation round:  68%|██████▊   | 95/139 [00:09<00:04, 10.51batch/s][A
Validation round:  70%|██████▉   | 97/139 [00:09<00:04, 10.33batch/s][A
Validation round:  71%|███████   | 99/139 [00:09<00:03, 10.24batch/s][A
Validation round:  73%|███████▎  | 101/139 [00:09<00:03, 10.33batch/s][A
Validation round:  74%|███████▍  | 103/139 [00:09<00:03, 10.42batch/s][A
Validation round:  76%|███████▌  | 105/139 [00:10<00:03, 10.31batch/s][A
Validation round:  77%|███████▋  | 107/139 [00:10<00:03, 10.23batch/s][A
Validation round:  78%|███████▊  | 109/139 [00:10<00:02, 10.32batch/s][A
Validation round:  80%|███████▉  | 111/139 [00:10<00:02, 10.37batch/s][A
Validation round:  81%|████████▏ | 113/139 [00:10<00:02, 10.48batch/s][A
Validation round:  83%|████████▎ | 115/139 [00:11<00:02, 10.47batch/s][A
Validation round:  84%|████████▍ | 117/139 [00:11<00:02, 10.50batch/s][A
Validation round:  86%|████████▌ | 119/139 [00:11<00:01, 10.40batch/s][A
Validation round:  87%|████████▋ | 121/13

Validation round:  13%|█▎        | 18/139 [00:01<00:12,  9.84batch/s][A
Validation round:  14%|█▍        | 20/139 [00:02<00:11, 10.06batch/s][A
Validation round:  16%|█▌        | 22/139 [00:02<00:11, 10.34batch/s][A
Validation round:  17%|█▋        | 24/139 [00:02<00:10, 10.56batch/s][A
Validation round:  19%|█▊        | 26/139 [00:02<00:10, 10.71batch/s][A
Validation round:  20%|██        | 28/139 [00:02<00:10, 10.83batch/s][A
Validation round:  22%|██▏       | 30/139 [00:03<00:09, 10.92batch/s][A
Validation round:  23%|██▎       | 32/139 [00:03<00:09, 10.95batch/s][A
Validation round:  24%|██▍       | 34/139 [00:03<00:09, 10.97batch/s][A
Validation round:  26%|██▌       | 36/139 [00:03<00:09, 10.97batch/s][A
Validation round:  27%|██▋       | 38/139 [00:03<00:09, 10.95batch/s][A
Validation round:  29%|██▉       | 40/139 [00:03<00:09, 10.92batch/s][A
Validation round:  30%|███       | 42/139 [00:04<00:08, 10.89batch/s][A
Validation round:  32%|███▏      | 44/139 [00:04<00

Validation round:  68%|██████▊   | 95/139 [00:08<00:04, 10.99batch/s][A
Validation round:  70%|██████▉   | 97/139 [00:09<00:03, 10.94batch/s][A
Validation round:  71%|███████   | 99/139 [00:09<00:03, 10.94batch/s][A
Validation round:  73%|███████▎  | 101/139 [00:09<00:03, 10.93batch/s][A
Validation round:  74%|███████▍  | 103/139 [00:09<00:03, 10.92batch/s][A
Validation round:  76%|███████▌  | 105/139 [00:09<00:03, 10.92batch/s][A
Validation round:  77%|███████▋  | 107/139 [00:09<00:02, 10.93batch/s][A
Validation round:  78%|███████▊  | 109/139 [00:10<00:02, 10.96batch/s][A
Validation round:  80%|███████▉  | 111/139 [00:10<00:02, 11.02batch/s][A
Validation round:  81%|████████▏ | 113/139 [00:10<00:02, 11.05batch/s][A
Validation round:  83%|████████▎ | 115/139 [00:10<00:02, 11.08batch/s][A
Validation round:  84%|████████▍ | 117/139 [00:10<00:01, 11.11batch/s][A
Validation round:  86%|████████▌ | 119/139 [00:11<00:01, 11.11batch/s][A
Validation round:  87%|████████▋ | 121/13

Validation round:  17%|█▋        | 23/139 [00:02<00:11, 10.52batch/s][A
Validation round:  18%|█▊        | 25/139 [00:02<00:10, 10.63batch/s][A
Validation round:  19%|█▉        | 27/139 [00:02<00:10, 10.72batch/s][A
Validation round:  21%|██        | 29/139 [00:02<00:10, 10.75batch/s][A
Validation round:  22%|██▏       | 31/139 [00:03<00:10, 10.79batch/s][A
Validation round:  24%|██▎       | 33/139 [00:03<00:09, 10.82batch/s][A
Validation round:  25%|██▌       | 35/139 [00:03<00:09, 10.84batch/s][A
Validation round:  27%|██▋       | 37/139 [00:03<00:09, 10.86batch/s][A
Validation round:  28%|██▊       | 39/139 [00:03<00:09, 10.95batch/s][A
Validation round:  29%|██▉       | 41/139 [00:03<00:08, 10.99batch/s][A
Validation round:  31%|███       | 43/139 [00:04<00:08, 10.99batch/s][A
Validation round:  32%|███▏      | 45/139 [00:04<00:08, 11.03batch/s][A
Validation round:  34%|███▍      | 47/139 [00:04<00:08, 11.07batch/s][A
Validation round:  35%|███▌      | 49/139 [00:04<00

Validation round:  71%|███████   | 99/139 [00:09<00:03, 10.63batch/s][A
Validation round:  73%|███████▎  | 101/139 [00:09<00:03, 10.61batch/s][A
Validation round:  74%|███████▍  | 103/139 [00:09<00:03, 10.69batch/s][A
Validation round:  76%|███████▌  | 105/139 [00:10<00:03, 10.77batch/s][A
Validation round:  77%|███████▋  | 107/139 [00:10<00:02, 10.80batch/s][A
Validation round:  78%|███████▊  | 109/139 [00:10<00:02, 10.81batch/s][A
Validation round:  80%|███████▉  | 111/139 [00:10<00:02, 10.84batch/s][A
Validation round:  81%|████████▏ | 113/139 [00:10<00:02, 10.85batch/s][A
Validation round:  83%|████████▎ | 115/139 [00:10<00:02, 10.93batch/s][A
Validation round:  84%|████████▍ | 117/139 [00:11<00:02, 10.98batch/s][A
Validation round:  86%|████████▌ | 119/139 [00:11<00:01, 11.01batch/s][A
Validation round:  87%|████████▋ | 121/139 [00:11<00:01, 11.03batch/s][A
Validation round:  88%|████████▊ | 123/139 [00:11<00:01, 11.07batch/s][A
Validation round:  90%|████████▉ | 125/

Validation round:  12%|█▏        | 17/139 [00:01<00:12,  9.45batch/s][A
Validation round:  13%|█▎        | 18/139 [00:02<00:12,  9.61batch/s][A
Validation round:  14%|█▍        | 20/139 [00:02<00:11,  9.93batch/s][A
Validation round:  16%|█▌        | 22/139 [00:02<00:11, 10.18batch/s][A
Validation round:  17%|█▋        | 24/139 [00:02<00:11, 10.39batch/s][A
Validation round:  19%|█▊        | 26/139 [00:02<00:10, 10.50batch/s][A
Validation round:  20%|██        | 28/139 [00:02<00:10, 10.58batch/s][A
Validation round:  22%|██▏       | 30/139 [00:03<00:10, 10.66batch/s][A
Validation round:  23%|██▎       | 32/139 [00:03<00:10, 10.66batch/s][A
Validation round:  24%|██▍       | 34/139 [00:03<00:09, 10.75batch/s][A
Validation round:  26%|██▌       | 36/139 [00:03<00:09, 10.81batch/s][A
Validation round:  27%|██▋       | 38/139 [00:03<00:09, 10.82batch/s][A
Validation round:  29%|██▉       | 40/139 [00:04<00:09, 10.84batch/s][A
Validation round:  30%|███       | 42/139 [00:04<00

Validation round:  67%|██████▋   | 93/139 [00:08<00:04, 11.05batch/s][A
Validation round:  68%|██████▊   | 95/139 [00:08<00:03, 11.04batch/s][A
Validation round:  70%|██████▉   | 97/139 [00:09<00:03, 10.98batch/s][A
Validation round:  71%|███████   | 99/139 [00:09<00:03, 10.96batch/s][A
Validation round:  73%|███████▎  | 101/139 [00:09<00:03, 10.93batch/s][A
Validation round:  74%|███████▍  | 103/139 [00:09<00:03, 10.90batch/s][A
Validation round:  76%|███████▌  | 105/139 [00:09<00:03, 10.89batch/s][A
Validation round:  77%|███████▋  | 107/139 [00:10<00:02, 10.85batch/s][A
Validation round:  78%|███████▊  | 109/139 [00:10<00:02, 10.84batch/s][A
Validation round:  80%|███████▉  | 111/139 [00:10<00:02, 10.87batch/s][A
Validation round:  81%|████████▏ | 113/139 [00:10<00:02, 10.86batch/s][A
Validation round:  83%|████████▎ | 115/139 [00:10<00:02, 10.86batch/s][A
Validation round:  84%|████████▍ | 117/139 [00:10<00:02, 10.90batch/s][A
Validation round:  86%|████████▌ | 119/139

Validation round:  15%|█▌        | 21/139 [00:02<00:11, 10.22batch/s][A
Validation round:  17%|█▋        | 23/139 [00:02<00:11, 10.47batch/s][A
Validation round:  18%|█▊        | 25/139 [00:02<00:10, 10.59batch/s][A
Validation round:  19%|█▉        | 27/139 [00:02<00:10, 10.68batch/s][A
Validation round:  21%|██        | 29/139 [00:02<00:10, 10.73batch/s][A
Validation round:  22%|██▏       | 31/139 [00:03<00:10, 10.77batch/s][A
Validation round:  24%|██▎       | 33/139 [00:03<00:09, 10.80batch/s][A
Validation round:  25%|██▌       | 35/139 [00:03<00:09, 10.83batch/s][A
Validation round:  27%|██▋       | 37/139 [00:03<00:09, 10.83batch/s][A
Validation round:  28%|██▊       | 39/139 [00:03<00:09, 10.84batch/s][A
Validation round:  29%|██▉       | 41/139 [00:04<00:09, 10.84batch/s][A
Validation round:  31%|███       | 43/139 [00:04<00:08, 10.92batch/s][A
Validation round:  32%|███▏      | 45/139 [00:04<00:08, 10.98batch/s][A
Validation round:  34%|███▍      | 47/139 [00:04<00

Validation round:  65%|██████▍   | 90/139 [00:08<00:04,  9.96batch/s][A
Validation round:  65%|██████▌   | 91/139 [00:08<00:04,  9.91batch/s][A
Validation round:  66%|██████▌   | 92/139 [00:09<00:04,  9.90batch/s][A
Validation round:  68%|██████▊   | 94/139 [00:09<00:04, 10.21batch/s][A
Validation round:  69%|██████▉   | 96/139 [00:09<00:04, 10.45batch/s][A
Validation round:  71%|███████   | 98/139 [00:09<00:03, 10.62batch/s][A
Validation round:  72%|███████▏  | 100/139 [00:09<00:03, 10.73batch/s][A
Validation round:  73%|███████▎  | 102/139 [00:09<00:03, 10.81batch/s][A
Validation round:  75%|███████▍  | 104/139 [00:10<00:03, 10.79batch/s][A
Validation round:  76%|███████▋  | 106/139 [00:10<00:03, 10.71batch/s][A
Epoch 2/3:  49%|████▉     | 9285/18836 [03:54<02:57, 53.88img/s, loss (batch)=1.07e-5]
Validation round:  79%|███████▉  | 110/139 [00:10<00:02, 10.24batch/s][A
Validation round:  81%|████████  | 112/139 [00:10<00:02, 10.13batch/s][A
Validation round:  82%|████████

Validation round:   9%|▉         | 13/139 [00:01<00:15,  8.06batch/s][A
Validation round:  11%|█         | 15/139 [00:01<00:14,  8.67batch/s][A
Validation round:  12%|█▏        | 17/139 [00:01<00:13,  9.23batch/s][A
Validation round:  14%|█▎        | 19/139 [00:02<00:12,  9.69batch/s][A
Validation round:  15%|█▌        | 21/139 [00:02<00:11, 10.03batch/s][A
Validation round:  17%|█▋        | 23/139 [00:02<00:11, 10.30batch/s][A
Validation round:  18%|█▊        | 25/139 [00:02<00:10, 10.53batch/s][A
Validation round:  19%|█▉        | 27/139 [00:02<00:10, 10.69batch/s][A
Validation round:  21%|██        | 29/139 [00:02<00:10, 10.81batch/s][A
Validation round:  22%|██▏       | 31/139 [00:03<00:09, 10.91batch/s][A
Validation round:  24%|██▎       | 33/139 [00:03<00:09, 10.97batch/s][A
Validation round:  25%|██▌       | 35/139 [00:03<00:09, 10.97batch/s][A
Validation round:  27%|██▋       | 37/139 [00:03<00:09, 11.01batch/s][A
Validation round:  28%|██▊       | 39/139 [00:03<00

Validation round:  53%|█████▎    | 74/139 [00:07<00:06, 10.45batch/s][A
Validation round:  55%|█████▍    | 76/139 [00:07<00:05, 10.51batch/s][A
Validation round:  56%|█████▌    | 78/139 [00:07<00:05, 10.67batch/s][A
Validation round:  58%|█████▊    | 80/139 [00:08<00:05, 10.79batch/s][A
Validation round:  59%|█████▉    | 82/139 [00:08<00:05, 10.87batch/s][A
Validation round:  60%|██████    | 84/139 [00:08<00:05, 10.92batch/s][A
Validation round:  62%|██████▏   | 86/139 [00:08<00:04, 10.96batch/s][A
Validation round:  63%|██████▎   | 88/139 [00:08<00:04, 10.92batch/s][A
Validation round:  65%|██████▍   | 90/139 [00:08<00:04, 10.88batch/s][A
Validation round:  66%|██████▌   | 92/139 [00:09<00:04, 10.86batch/s][A
Validation round:  68%|██████▊   | 94/139 [00:09<00:04, 10.86batch/s][A
Validation round:  69%|██████▉   | 96/139 [00:09<00:03, 10.82batch/s][A
Validation round:  71%|███████   | 98/139 [00:09<00:03, 10.82batch/s][A
Validation round:  72%|███████▏  | 100/139 [00:09<0

Validation round:   1%|▏         | 2/139 [00:00<00:34,  4.02batch/s][A
Validation round:   3%|▎         | 4/139 [00:00<00:27,  4.94batch/s][A
Validation round:   4%|▍         | 6/139 [00:00<00:22,  5.93batch/s][A
Validation round:   6%|▌         | 8/139 [00:00<00:19,  6.89batch/s][A
Validation round:   7%|▋         | 10/139 [00:01<00:16,  7.78batch/s][A
Validation round:   9%|▊         | 12/139 [00:01<00:14,  8.55batch/s][A
Validation round:  10%|█         | 14/139 [00:01<00:13,  9.17batch/s][A
Validation round:  12%|█▏        | 16/139 [00:01<00:12,  9.66batch/s][A
Validation round:  13%|█▎        | 18/139 [00:01<00:12, 10.06batch/s][A
Validation round:  14%|█▍        | 20/139 [00:02<00:11, 10.36batch/s][A
Validation round:  16%|█▌        | 22/139 [00:02<00:11, 10.57batch/s][A
Validation round:  17%|█▋        | 24/139 [00:02<00:10, 10.68batch/s][A
Validation round:  19%|█▊        | 26/139 [00:02<00:10, 10.73batch/s][A
Validation round:  20%|██        | 28/139 [00:02<00:10,

Validation round:  54%|█████▍    | 75/139 [00:07<00:05, 10.83batch/s][A
Validation round:  55%|█████▌    | 77/139 [00:07<00:05, 10.84batch/s][A
Validation round:  57%|█████▋    | 79/139 [00:07<00:05, 10.82batch/s][A
Validation round:  58%|█████▊    | 81/139 [00:07<00:05, 10.81batch/s][A
Validation round:  60%|█████▉    | 83/139 [00:07<00:05, 10.82batch/s][A
Validation round:  61%|██████    | 85/139 [00:08<00:05, 10.77batch/s][A
Validation round:  63%|██████▎   | 87/139 [00:08<00:04, 10.80batch/s][A
Validation round:  64%|██████▍   | 89/139 [00:08<00:04, 10.87batch/s][A
Validation round:  65%|██████▌   | 91/139 [00:08<00:04, 10.95batch/s][A
Validation round:  67%|██████▋   | 93/139 [00:08<00:04, 10.97batch/s][A
Validation round:  68%|██████▊   | 95/139 [00:08<00:03, 11.00batch/s][A
Validation round:  70%|██████▉   | 97/139 [00:09<00:03, 11.01batch/s][A
Validation round:  71%|███████   | 99/139 [00:09<00:03, 11.03batch/s][A
Validation round:  73%|███████▎  | 101/139 [00:09<0

Validation round:   2%|▏         | 3/139 [00:00<00:35,  3.80batch/s][A
Validation round:   4%|▎         | 5/139 [00:00<00:28,  4.70batch/s][A
Validation round:   5%|▌         | 7/139 [00:00<00:23,  5.67batch/s][A
Validation round:   6%|▋         | 9/139 [00:01<00:19,  6.61batch/s][A
Validation round:   8%|▊         | 11/139 [00:01<00:17,  7.49batch/s][A
Validation round:   9%|▉         | 13/139 [00:01<00:15,  8.26batch/s][A
Validation round:  11%|█         | 15/139 [00:01<00:13,  8.91batch/s][A
Validation round:  12%|█▏        | 17/139 [00:01<00:12,  9.41batch/s][A
Validation round:  14%|█▎        | 19/139 [00:02<00:12,  9.81batch/s][A
Validation round:  15%|█▌        | 21/139 [00:02<00:11, 10.10batch/s][A
Validation round:  17%|█▋        | 23/139 [00:02<00:11, 10.32batch/s][A
Validation round:  18%|█▊        | 25/139 [00:02<00:10, 10.53batch/s][A
Validation round:  19%|█▉        | 27/139 [00:02<00:10, 10.68batch/s][A
Validation round:  21%|██        | 29/139 [00:02<00:10,

Validation round:  55%|█████▌    | 77/139 [00:07<00:05, 10.56batch/s][A
Validation round:  57%|█████▋    | 79/139 [00:07<00:05, 10.56batch/s][A
Validation round:  58%|█████▊    | 81/139 [00:07<00:05, 10.71batch/s][A
Validation round:  60%|█████▉    | 83/139 [00:08<00:05, 10.82batch/s][A
Validation round:  61%|██████    | 85/139 [00:08<00:04, 10.88batch/s][A
Validation round:  63%|██████▎   | 87/139 [00:08<00:04, 10.94batch/s][A
Validation round:  64%|██████▍   | 89/139 [00:08<00:04, 10.98batch/s][A
Validation round:  65%|██████▌   | 91/139 [00:08<00:04, 10.97batch/s][A
Validation round:  67%|██████▋   | 93/139 [00:09<00:04, 10.94batch/s][A
Validation round:  68%|██████▊   | 95/139 [00:09<00:04, 10.90batch/s][A
Validation round:  70%|██████▉   | 97/139 [00:09<00:03, 10.89batch/s][A
Validation round:  71%|███████   | 99/139 [00:09<00:03, 10.89batch/s][A
Validation round:  73%|███████▎  | 101/139 [00:09<00:03, 10.89batch/s][A
Validation round:  74%|███████▍  | 103/139 [00:09<

Validation round:   4%|▎         | 5/139 [00:00<00:27,  4.94batch/s][A
Validation round:   5%|▌         | 7/139 [00:00<00:22,  5.92batch/s][A
Validation round:   6%|▋         | 9/139 [00:01<00:18,  6.88batch/s][A
Validation round:   8%|▊         | 11/139 [00:01<00:16,  7.76batch/s][A
Validation round:   9%|▉         | 13/139 [00:01<00:14,  8.54batch/s][A
Validation round:  11%|█         | 15/139 [00:01<00:13,  9.16batch/s][A
Validation round:  12%|█▏        | 17/139 [00:01<00:12,  9.62batch/s][A
Validation round:  14%|█▎        | 19/139 [00:01<00:12,  9.96batch/s][A
Validation round:  15%|█▌        | 21/139 [00:02<00:11, 10.20batch/s][A
Validation round:  17%|█▋        | 23/139 [00:02<00:11, 10.38batch/s][A
Validation round:  18%|█▊        | 25/139 [00:02<00:10, 10.51batch/s][A
Validation round:  19%|█▉        | 27/139 [00:02<00:10, 10.62batch/s][A
Validation round:  21%|██        | 29/139 [00:02<00:10, 10.69batch/s][A
Validation round:  22%|██▏       | 31/139 [00:03<00:10

Validation round:  58%|█████▊    | 81/139 [00:07<00:05, 10.82batch/s][A
Validation round:  60%|█████▉    | 83/139 [00:07<00:05, 10.87batch/s][A
Validation round:  61%|██████    | 85/139 [00:08<00:04, 10.93batch/s][A
Validation round:  63%|██████▎   | 87/139 [00:08<00:04, 10.96batch/s][A
Validation round:  64%|██████▍   | 89/139 [00:08<00:04, 11.00batch/s][A
Validation round:  65%|██████▌   | 91/139 [00:08<00:04, 11.02batch/s][A
Validation round:  67%|██████▋   | 93/139 [00:08<00:04, 11.01batch/s][A
Validation round:  68%|██████▊   | 95/139 [00:08<00:03, 11.02batch/s][A
Validation round:  70%|██████▉   | 97/139 [00:09<00:03, 11.04batch/s][A
Validation round:  71%|███████   | 99/139 [00:09<00:03, 11.06batch/s][A
Validation round:  73%|███████▎  | 101/139 [00:09<00:03, 11.05batch/s][A
Validation round:  74%|███████▍  | 103/139 [00:09<00:03, 11.06batch/s][A
Validation round:  76%|███████▌  | 105/139 [00:09<00:03, 11.05batch/s][A
Validation round:  77%|███████▋  | 107/139 [00:0

Validation round:   6%|▋         | 9/139 [00:01<00:20,  6.30batch/s][A
Validation round:   8%|▊         | 11/139 [00:01<00:17,  7.23batch/s][A
Validation round:   9%|▉         | 13/139 [00:01<00:15,  8.07batch/s][A
Validation round:  11%|█         | 15/139 [00:01<00:14,  8.79batch/s][A
Validation round:  12%|█▏        | 17/139 [00:01<00:13,  9.38batch/s][A
Validation round:  14%|█▎        | 19/139 [00:02<00:12,  9.81batch/s][A
Validation round:  15%|█▌        | 21/139 [00:02<00:11, 10.14batch/s][A
Validation round:  17%|█▋        | 23/139 [00:02<00:11, 10.40batch/s][A
Validation round:  18%|█▊        | 25/139 [00:02<00:10, 10.59batch/s][A
Validation round:  19%|█▉        | 27/139 [00:02<00:10, 10.73batch/s][A
Validation round:  21%|██        | 29/139 [00:02<00:10, 10.84batch/s][A
Validation round:  22%|██▏       | 31/139 [00:03<00:09, 10.90batch/s][A
Validation round:  24%|██▎       | 33/139 [00:03<00:09, 10.88batch/s][A
Validation round:  25%|██▌       | 35/139 [00:03<00:

Validation round:  52%|█████▏    | 72/139 [00:07<00:06, 10.18batch/s][A
Validation round:  53%|█████▎    | 74/139 [00:07<00:06, 10.22batch/s][A
Validation round:  55%|█████▍    | 76/139 [00:07<00:06, 10.31batch/s][A
Validation round:  56%|█████▌    | 78/139 [00:07<00:05, 10.33batch/s][A
Validation round:  58%|█████▊    | 80/139 [00:08<00:05, 10.36batch/s][A
Validation round:  59%|█████▉    | 82/139 [00:08<00:05, 10.38batch/s][A
Validation round:  60%|██████    | 84/139 [00:08<00:05, 10.50batch/s][A
Validation round:  62%|██████▏   | 86/139 [00:08<00:05, 10.48batch/s][A
Validation round:  63%|██████▎   | 88/139 [00:08<00:04, 10.57batch/s][A
Validation round:  65%|██████▍   | 90/139 [00:08<00:04, 10.57batch/s][A
Validation round:  66%|██████▌   | 92/139 [00:09<00:04, 10.62batch/s][A
Validation round:  68%|██████▊   | 94/139 [00:09<00:04, 10.66batch/s][A
Validation round:  69%|██████▉   | 96/139 [00:09<00:04, 10.62batch/s][A
Validation round:  71%|███████   | 98/139 [00:09<00

NameError: name 'writer' is not defined

In [7]:
args = {'epochs':3, 'lr':0.001}

In [13]:
args['lr']

0.001

In [None]:
batch_size = 15
img_size = 256
lr = 0.001
epoch = 10

img_dir = Path("./maps/")
img_data = dataset.ImageFolder(root=img_dir, transform = transforms.Compose([
                                            transforms.Resize(size=img_size),
                                            transforms.CenterCrop(size=(img_size,img_size*2)),
                                            transforms.ToTensor(),
                                            ]))
img_batch = data.DataLoader(img_data, batch_size=batch_size,
                            shuffle=True)