In [None]:
%matplotlib inline

import os
import time
import math
import glob
import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import h5py

from pathlib import Path
from torch.utils.data import Dataset, DataLoader, sampler
from torch.cuda.amp import GradScaler, autocast
from PIL import Image

import torch
from torch.nn import functional as F
from torch.autograd import Variable

##############################
from models import deeplab_resnet_hyper
from models import deeplab_xception
from datasets import datasets

from utils import manager as mgr
from utils import metrics
from utils import losses
from utils import log
from utils import img_utils
##############################

nr_classes = 4
nr_channels = 7
exp_name = 'deeplab_std_resnet_dc_256_3_os8_flip_ac12'
DEVICE = "cuda:0"
WEIGHTS_PATH = 'weights/' + exp_name + '/'

backbone = 'resnet' #'xception'#'resnet'

# set device
#device = torch.device(DEVICE if torch.cuda.is_available() else "cpu")
device = "cuda:0"
device_nr = int(device[-1])
#device = "cpu"
print(device)

# batch size (256) 56, (512) 15
batch_size = 24 #40 #56 #15
split={'train':0.7, 'val':0.1, 'test':0.2}
num_workers = 4
pin_memory = True

LR = 0.001
LR_DECAY = 0.995
DECAY_EVERY_N_EPOCHS = 1
N_EPOCHS = 1
start_epoch = 1

In [None]:
torch.cuda.empty_cache()

nr_classes = 4
cr = 48
dsize = 256
ddepth = 3
os = 16
up = 256

# backbone
if exp_name == 'deeplab_std_xception':
    WEIGHTS_FILE = 'weights-14-0.155-0.876.pth'
    backbone = 'xception'  
elif exp_name == 'deeplab_std_resnet':
    WEIGHTS_FILE = 'weights-23-0.147-0.887.pth'
    backbone = 'resnet'

# CR
elif exp_name == 'deeplab_std_resnet_cr32':
    WEIGHTS_FILE = 'weights-22-0.149-0.886.pth'
    backbone = 'resnet'
    cr = 32
elif exp_name == 'deeplab_std_resnet_cr64':
    WEIGHTS_FILE = 'weights-11-0.148-0.885.pth'
    backbone = 'resnet'
    cr = 64

# upsample
elif exp_name == 'deeplab_std_resnet_dc_512_3_up_512':
    WEIGHTS_FILE = 'weights-45-0.149-0.886.pth'
    backbone = 'resnet'
    up = 512
    dsize = 512    
    
# decoder channels
elif exp_name == 'deeplab_std_resnet_dc_128_3':
    WEIGHTS_FILE = 'weights-33-0.151-0.885.pth'
    backbone = 'resnet'
    dsize = 128
elif exp_name == 'deeplab_std_resnet_dc_512_3':
    WEIGHTS_FILE = 'weights-15-0.148-0.886.pth'
    backbone = 'resnet'
    dsize = 512
elif exp_name == 'deeplab_std_resnet_dc_1024_3':
    WEIGHTS_FILE = 'weights-25-0.151-0.883.pth'
    backbone = 'resnet'
    dsize = 1024

# decoder layers
elif exp_name == 'deeplab_std_resnet_dc_256_4':
    WEIGHTS_FILE = 'weights-43-0.149-0.886.pth'
    backbone = 'resnet'
    ddepth = 4
elif exp_name == 'deeplab_std_resnet_dc_256_5':
    WEIGHTS_FILE = 'weights-32-0.150-0.879.pth'
    backbone = 'resnet'
    ddepth = 5
elif exp_name == 'deeplab_std_resnet_dc_512_4':
    WEIGHTS_FILE = 'weights-30-0.148-0.882.pth'
    backbone = 'resnet'
    dsize = 512
    ddepth = 4
    
# OS
elif exp_name == 'deeplab_std_resnet_dc_256_3_os8':
    WEIGHTS_FILE = 'weights-40-0.147-0.881.pth'
    backbone = 'resnet'
    os = 8
    
# input 512x512
elif exp_name == 'deeplab_std_resnet_in_512x512':
    WEIGHTS_FILE = 'weights-19-0.152-0.887.pth'
    backbone = 'resnet'
    
# flip
elif exp_name == 'deeplab_std_resnet_flip':
    WEIGHTS_FILE = 'weights-32-0.137-0.894.pth'
    #WEIGHTS_FILE = 'weights-46-0.134-0.897.pth'
    backbone = 'resnet'
elif exp_name == 'deeplab_std_resnet_dc_256_3_os8_flip':
    WEIGHTS_FILE = 'weights-27-0.135-0.894.pth'
    #WEIGHTS_FILE = 'weights-37-0.130-0.895.pth'
    #WEIGHTS_FILE = 'weights-46-0.129-0.898.pth'
    backbone = 'resnet'
    os = 8
    
# mask
elif exp_name == 'deeplab_std_resnet_dc_256_3_os8_flip_mask_std':
    WEIGHTS_FILE = 'weights-41-0.338-0.907.pth'
    backbone = 'resnet'
    #nr_classes = 3
    os = 8
    
# ac1 & ac12
elif exp_name == 'deeplab_std_resnet_dc_256_3_os8_flip_ac1':
    WEIGHTS_FILE = 'weights-41-0.286-0.839.pth'
    backbone = 'resnet'
    os = 8
elif exp_name == 'deeplab_std_resnet_dc_256_3_os8_flip_ac12':
    WEIGHTS_FILE = 'weights-42-0.273-0.824.pth'
    backbone = 'resnet'
    nr_classes = 5
    os = 8

In [None]:
## Creating the dataset
#path_dataset = "/media/philipp/DATA/dataset/dataset_256_df_177.h5"
path_dataset = "/media/philipp/DATA/dataset/dataset_512_df_177.h5"
#path_dataset = "/media/philipp/DATA/dataset/dataset_mask_512_df_177.h5"
dataset = datasets.ForestDataset(path_dataset, ground_truth='ground_truth_ac12')

if dataset[0][0].shape[1] == 256:
    chunk_size = 1000
else:
    chunk_size = 0

# shuffle and split
train_sampler, val_sampler, test_sampler = dataset.get_sampler(split=split, \
                shuffle_dataset=True, random_seed=399, chunk_size=chunk_size, fold=0)

# dataloader
train_dl = torch.utils.data.DataLoader(dataset, batch_size=batch_size, \
                                sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory)
val_dl = torch.utils.data.DataLoader(dataset, batch_size=batch_size, \
                                sampler=val_sampler, num_workers=num_workers, pin_memory=pin_memory)
test_dl = torch.utils.data.DataLoader(dataset, batch_size=batch_size, \
                                sampler=test_sampler, num_workers=num_workers, pin_memory=pin_memory)
print(len(dataset))
print(len(train_sampler.indices))
print(len(val_sampler.indices))
print(len(test_sampler.indices))