In [3]:
import torch
import torchvision
import tqdm
import pandas as pd
import pprint
import itertools
import os
import pylab as plt
import time
import numpy as np

from src import models
from src import datasets


import argparse

from torch.utils.data import sampler
from torch.utils.data.sampler import RandomSampler
from torch.backends import cudnn
from torch.nn import functional as F
from torch.utils.data import DataLoader

In [7]:
def train(exp_dict, savedir, datadir, reset=False, num_workers=0):
    os.makedirs(savedir, exist_ok=True)
    # Dataset
    # ==================
    # train set
    train_set = datasets.get_dataset(dataset_dict=exp_dict["dataset"],
                                     split="train",
                                     datadir=datadir,
                                     exp_dict=exp_dict,
                                     dataset_size=exp_dict['dataset_size'])
    # val set
    val_set = datasets.get_dataset(dataset_dict=exp_dict["dataset"],
                                   split="val",
                                   datadir=datadir,
                                   exp_dict=exp_dict,
                                   dataset_size=exp_dict['dataset_size'])

    val_sampler = torch.utils.data.SequentialSampler(val_set)
    val_loader = DataLoader(val_set,
                            sampler=val_sampler,
                            batch_size=1,
                            num_workers=num_workers)
    # Model
    # ==================
    model = models.get_model(model_dict=exp_dict['model'],
                             exp_dict=exp_dict,
                             train_set=train_set).cuda()

    # model.opt = optimizers.get_optim(exp_dict['opt'], model)
    model_path = os.path.join(savedir, "model.pth")
    score_list_path = os.path.join(savedir, "score_list.pkl")

    # Train & Val
    # ==================
    train_sampler = torch.utils.data.RandomSampler(
        train_set, replacement=True, num_samples=2*len(val_set))

    train_loader = DataLoader(train_set,
                              sampler=train_sampler,
                              batch_size=exp_dict["batch_size"], 
                              drop_last=True, num_workers=num_workers)
    
    best_val = 100000
    epochs_without_improvement = 0
    for e in range(exp_dict['max_epoch']):
        # Validate only at the start of each cycle
        score_dict = {}

        # Train the model
        train_dict = model.train_on_loader(train_loader)
        print(f'train_dict: {train_dict}')

        # Validate and Visualize the model
        val_dict = model.val_on_loader(val_loader, 
                        savedir_images=os.path.join(savedir, "images"),
                        n_images=3)
        print(f'val_dict: {val_dict}')  
        if(val_dict['val_mae'] < best_val):
            print('Better validation')
            best_val = val_dict['val_mae']   
            epochs_without_improvement = 0
        elif(epochs_without_improvement > 3):
            print('No improvement for 4 epochs')
            break
        else:
            epochs_without_improvement += 1
        score_dict.update(val_dict)
        # model.vis_on_loader(
        #     vis_loader, savedir=os.path.join(savedir, "images"))

        # Get new score_dict
        score_dict.update(train_dict)
        score_dict["epoch"] = len(score_list)

        # Add to score_list and save checkpoint
        score_list += [score_dict]

        # Report & Save
        score_df = pd.DataFrame(score_list)
        print("\n", score_df.tail(), "\n")
        '''
        hu.torch_save(model_path, model.get_state_dict())
        hu.save_pkl(score_list_path, score_list)
        '''
        print("Checkpoint Saved: %s" % savedir)

        # Save Best Checkpoint
        '''
        if e == 0 or (score_dict.get("val_score", 0) > score_df["val_score"][:-1].fillna(0).max()):
            hu.save_pkl(os.path.join(
                savedir, "score_list_best.pkl"), score_list)
            hu.torch_save(os.path.join(savedir, "model_best.pth"),
                          model.get_state_dict())
            print("Saved Best: %s" % savedir)
        '''

    print('Experiment completed et epoch %d' % e)

In [5]:
exp_dict = {"dataset":
               {'name': 'trancos',
                'transform': 'rgb_normalize'},
            'model':
               {'name': 'lcfcn',
                'base': 'fcn8_vgg16'},
            'batch_size': 8,
            'max_epoch': 10,
            'dataset_size': 
               {'train': 'all',
                'val': 'all'},
            'optimizer': 'adam',
            'lr': 1e-5}
datadir = 'TRANCOS_v3/'

In [9]:
train(
    exp_dict=exp_dict,
    savedir='custom_output',
    datadir=datadir,
    num_workers=1)

  return torch.LongTensor(np.asarray(x))
Training. Loss: 5.6921: 100%|██████████| 105/105 [01:45<00:00,  1.00s/it]
  0%|          | 0/420 [00:00<?, ?it/s]
  0%|          | 0/420 [00:00<?, ?it/s][A

train_dict: {'train_loss': 5.692069660978658}


  return torch.LongTensor(np.asarray(x))
Validating. MAE: 13.0000:   0%|          | 1/420 [00:00<01:01,  6.85it/s]
Validating. MAE: 9.0000:   0%|          | 2/420 [00:00<00:58,  7.16it/s] 
Validating. MAE: 8.0000:   1%|          | 3/420 [00:00<00:51,  8.16it/s]
Validating. MAE: 8.0000:   1%|          | 4/420 [00:00<00:47,  8.70it/s]
Validating. MAE: 8.0000:   2%|▏         | 7/420 [00:00<00:31, 13.17it/s]
Validating. MAE: 8.0000:   2%|▏         | 9/420 [00:00<00:31, 12.89it/s]
Validating. MAE: 8.0000:   3%|▎         | 11/420 [00:00<00:30, 13.22it/s]
Validating. MAE: 8.0000:   3%|▎         | 13/420 [00:01<00:30, 13.35it/s]
Validating. MAE: 8.0000:   4%|▎         | 15/420 [00:01<00:30, 13.34it/s]
Validating. MAE: 8.0000:   4%|▍         | 17/420 [00:01<00:27, 14.40it/s]
Validating. MAE: 8.0000:   5%|▍         | 19/420 [00:01<00:26, 15.29it/s]
Validating. MAE: 8.0000:   5%|▌         | 21/420 [00:01<00:28, 14.12it/s]
Validating. MAE: 8.0000:   5%|▌         | 23/420 [00:01<00:30, 13.19it/s]
V

Validating. MAE: 8.0000:  48%|████▊     | 200/420 [00:15<00:15, 14.14it/s]
Validating. MAE: 8.0000:  48%|████▊     | 202/420 [00:15<00:15, 14.44it/s]
Validating. MAE: 8.0000:  49%|████▊     | 204/420 [00:15<00:15, 14.03it/s]
Validating. MAE: 8.0000:  49%|████▉     | 206/420 [00:15<00:15, 13.47it/s]
Validating. MAE: 8.0000:  50%|████▉     | 208/420 [00:15<00:15, 13.29it/s]
Validating. MAE: 8.0000:  50%|█████     | 210/420 [00:15<00:16, 13.00it/s]
Validating. MAE: 8.0000:  50%|█████     | 212/420 [00:15<00:16, 12.80it/s]
Validating. MAE: 8.0000:  51%|█████     | 214/420 [00:16<00:16, 12.72it/s]
Validating. MAE: 8.0000:  51%|█████▏    | 216/420 [00:16<00:15, 12.76it/s]
Validating. MAE: 8.0000:  52%|█████▏    | 218/420 [00:16<00:14, 13.90it/s]
Validating. MAE: 8.0000:  52%|█████▏    | 220/420 [00:16<00:13, 14.91it/s]
Validating. MAE: 8.0000:  53%|█████▎    | 222/420 [00:16<00:12, 15.68it/s]
Validating. MAE: 8.0000:  53%|█████▎    | 224/420 [00:16<00:11, 16.34it/s]
Validating. MAE: 8.0000: 

val_dict: {'val_mae': 1.7166666666666666, 'val_score': -1.7166666666666666}
Better validation





UnboundLocalError: local variable 'score_list' referenced before assignment