## Steps:

1. MRIs with different resolution obtained from different instrument or imaging parameters needs to be resampled to
    an unform voxel spacing
2. The apply Data augmentation to increase variability and number of training sample
3. Make sure all volumes have same dimension
4. Training model

### Step 1: MRIs with different resolutions obtained from different instruments or imaging parameters need to be resampled to an unform voxel spacing

#### Change the value of following parameters: 

volume_path="Input path wehere your training dataset is stored"

out_path="Output path wehere you want to store resampled volumes"

new_spacing = Resolution that you choose to resample all MRIs if you have various resolution

size=Input volume dimension that you choose to train your model. Make sure your available gpu memory support it and it encompasess all MRIs. 




In [None]:
import numpy as np
import SimpleITK as sitk
from medpy.io import load,save
import os, shutil, glob
    
def resample_volume(volume_path,ismask, new_spacing,size ,out_path,basename):
    if ismask==1:
        interpolator = sitk.sitkNearestNeighbor
    else:
        interpolator = sitk.sitkBSpline
  
    volume = sitk.ReadImage(volume_path) # read and cast to float32
    original_spacing = volume.GetSpacing()
    original_size = volume.GetSize()
    original_direction=volume.GetDirection()
    original_origin=volume.GetOrigin()
    offset_origin=np.subtract(new_spacing, original_spacing)/2
    
    
    new_size = [int(round(osz*ospc/nspc)) for osz,ospc,nspc in zip(original_size, original_spacing, new_spacing)]
    #print(new_size)
    resampled_img=sitk.Resample(volume, new_size, sitk.Transform(), interpolator,
                         original_origin+offset_origin, new_spacing, original_direction, 0,
                         volume.GetPixelID())
    if(os.path.exists(out_path+"res.nii.gz")):
        os.remove(out_path+"res.nii.gz")
    sitk.WriteImage(resampled_img,out_path+"res.nii.gz")
    img,h=load(out_path+"res.nii.gz")
    if(ismask==1):
        img[img>0]=1
    rezized_image=np.zeros(size,dtype=np.uint16)
    print(img.shape,rezized_image.shape,volume.GetOrigin(),original_spacing)
    img_shapes=img.shape
    center=np.divide(rezized_image.shape,2)
    start_pos=np.divide(img_shapes,2)
    start_pos=center-start_pos
    rezized_image[int(start_pos[0]):int(start_pos[0])+img_shapes[0],int(start_pos[1]):int(start_pos[1])+img_shapes[1],int(start_pos[2]):int(start_pos[2])+img_shapes[2]]=img
    if (not os.path.exists(out_path+"inex_train/"+basename)):
        os.makedirs(out_path+"inex_train/"+basename)
    
    filename=os.path.basename(volume_path)
    save(rezized_image,out_path+"inex_train/"+basename+"/"+filename,hdr=h,use_compression=False)


volume_path="/research/sharedresources/cbi/data_exchange/zakhagrp/presentations/DeepBrainIPP_dataset/tmp_inexvivo/"
out_path="/research/sharedresources/cbi/data_exchange/zakhagrp/presentations/DeepBrainIPP_dataset/resampled/"
new_spacing = [0.06, 0.06, 0.48]
size=(448,448,48)
#size=(256,224,288)
#size=(448,448,384)
#size=(320,320,48)

for files in glob.glob(volume_path+"*"):
    #imag_path=files
    basename=os.path.basename(files)
    print(basename)
    imag_path=files+"/"+basename+"_brain.nii.gz"
    seg_path=files+"/"+basename+"_seg.nii.gz"
    original_image,h = load(imag_path) 
    mask,m= load(seg_path) 
    save(original_image,imag_path,hdr=h, use_compression=False)
    save(mask,seg_path,hdr=h, use_compression=False)
    resample_volume(imag_path,0,new_spacing,size,out_path,basename)
    resample_volume(seg_path,1,new_spacing,size,out_path,basename)


### Step 2:The apply Data augmentation to increase variability and number of training sample

#### Change the value of following parameters:
path= Location where you stored resampled volume (output of previous step)

aug_path= location where you want to store augmented sample

In [None]:
import imageio
import numpy as np
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import numpy as np
import pandas as pd
import cv2
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
import matplotlib.pyplot as plt
import skimage.io as io
from medpy.io import load,save
import numpy, imageio
import glob, os
import shutil
import itertools
import random
import scipy.ndimage as ndimage
from skimage.measure import label,regionprops
import random


path="/research/sharedresources/cbi/data_exchange/zakhagrp/presentations/DeepBrainIPP_dataset/resampled/inex_train/"
aug_path="/research/sharedresources/cbi/data_exchange/zakhagrp/presentations/DeepBrainIPP_dataset/augmented_volume/"
if not os.path.exists(aug_path):
    os.makedirs(aug_path)
else:
    shutil.rmtree(aug_path)
    os.makedirs(aug_path)

#define the augmentation amtrix
seq = iaa.Sequential([
    iaa.Multiply((0.9,1.3 ), per_channel=0.2),
    iaa.AllChannelsCLAHE(),
    iaa.Fliplr(1), # horizontally flip 50% of all images
    iaa.Flipud(1), # vertically flip 20% of all images
    iaa.Dropout([0.01, 0.05]),      # drop 5% or 20% of all pixels
    #iaa.Sharpen(alpha=(0, 1.0), lightness=(0.4, 3.5)),       # sharpen the image
    #affine transformation
    #iaa.PiecewiseAffine(scale=(0.01, 0.05)),
    
    #PiecewiseAffine and elastic deformation
    iaa.PiecewiseAffine(scale=(0.01, 0.05)),
    iaa.ElasticTransformation(alpha=30, sigma=9),
    iaa.ElasticTransformation(alpha=40, sigma=10),  # apply water effect (affects segmaps)
    iaa.ElasticTransformation(alpha=50, sigma=11),  # apply water effect (affects segmaps)
    iaa.ElasticTransformation(alpha=(2.5,3.0), sigma=1),
    iaa.ElasticTransformation(alpha=(4.0,5.0), sigma=2.0),  # apply water effect (affects segmaps)
    iaa.ElasticTransformation(alpha=(8.0,11.0), sigma=3.0),  # apply water effect (affects segmaps)
    iaa.ElasticTransformation(alpha=(13.0,15.0), sigma=5),  # apply water effect (affects segmaps)
    iaa.ElasticTransformation(alpha=(20.0,25.0), sigma=8),  # apply water effect (affects segmaps)
    iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05*255), per_channel=0.5),
    iaa.PiecewiseAffine(scale=(0.02, 0.07)) ,
    iaa.GaussianBlur(sigma=(1.5)),
    iaa.GaussianBlur(sigma=(1.0)),
    iaa.GaussianBlur(sigma=(0.8)),
   
    iaa.Affine(rotate=(-20, -15),
            translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05),"z": (-0.05, 0.05)},
            scale={"x": (0.8), "y": (0.8),"z":(0.8)}, order=3,
            ),
    
    iaa.Affine(rotate=(-15, -10),
            translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05),"z": (-0.05, 0.05)},
            scale={"x": (0.9), "y": (0.9),"z":(0.9)}, order=3,
            ),
    iaa.Affine(rotate=(-10, -5),
            translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05),"z": (-0.05, 0.05)},
            scale={"x": (1.2), "y": (1.2),"z":(1.2)}, order=3,
            ),
    
    iaa.Affine(rotate=(5, 10),
            translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05),"z": (-0.05, 0.05)},
            scale={"x": (1.3), "y": (1.3),"z":(1.3)}, order=3,
            ),
    iaa.Affine(rotate=(10, 15),
            translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05),"z": (-0.05, 0.05)},
            scale={"x": (0.9), "y": (0.9),"z":(0.9)}, order=3,
            ),
     iaa.Affine(rotate=(15, 20),
            translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05),"z": (-0.05, 0.05)},
            scale={"x": (0.8), "y": (0.8),"z":(0.8)}, order=3,
            ),
    
              
], random_order=False)

#change data orientations
def permute_data(data, key):
    """
    Permutes the given data according to the specification of the given key. Input data
    must be of shape (n_modalities, x, y, z).

    Input key is a tuple: (rotate_y, rotate_z), flip_x, flip_y, flip_z, transpose)

    As an example, ((0, 1), 0, 1, 0, 1) represents a permutation in which the data is
    rotated 90 degrees around the z-axis, then reversed on the y-axis, and then
    transposed.
    """
    data = np.copy(data)
    (rotate_y, rotate_z), flip_x, flip_y, flip_z, transpose = key

    if rotate_y != 0:
        data = np.rot90(data, rotate_y, axes=(0, 2))
    if rotate_z != 0:
        data = np.rot90(data, rotate_z, axes=(1, 2))
    if flip_x:
        data = data[:, ::-1]
    if flip_y:
        data = data[:, :, ::-1]
    if flip_z:
        data = data[:, :, ::-1]
    return data
def crop_volume_2(img,mask,thresholds):
    tmp_mask=np.zeros_like(mask)
    tmp_mask[mask==1.0]=1
    center_mass=ndimage.measurements.center_of_mass(mask)   
    x=int(center_mass[0]-thresholds[0])
    y=int(center_mass[1]-thresholds[1])+thresholds[3]
    z=int(center_mass[2]-thresholds[2])
    if x<0:
        x=0
    if y<0:
        y=0
    if z<0:
        z=0
    
    mask=mask[x:x+2*thresholds[0],y:y+2*thresholds[1],z:z+2*thresholds[2]]
    img=img[x:x+2*thresholds[0],y:y+2*thresholds[1],z:z+2*thresholds[2]]
    return img,mask

def crop_volume(img,mask,thresholds):
    tmp_mask=np.zeros((thresholds[0],thresholds[1],thresholds[2]),dtype=mask.dtype)
    tmp_img=np.zeros((thresholds[0],thresholds[1],thresholds[2]),dtype=img.dtype)
    
    center_mass=ndimage.measurements.center_of_mass(img)   
    x=int(center_mass[0]-thresholds[0])
    y=int(center_mass[1]-thresholds[1])+thresholds[3]
    z=int(center_mass[2]-thresholds[2])
    if x<0:
        x=0
    if y<0:
        y=0
    if z<0:
        z=0
    
    tmp_mask[:img.shape[0],:img.shape[1],:img.shape[2]]=mask
    tmp_img[:img.shape[0],:img.shape[1],:img.shape[2]]=img
    return tmp_img,tmp_mask
def rotate_volume(img,seg_img, axes,prefix,img_header,seg_header,thresholds):
        
        deg=[-20,-15,-10,-5,5,10,15,20]
        i=0
        for angl in deg:

            rotated_img=ndimage.interpolation.rotate(img,angle=angl,axes=axes,order=5, reshape=False, mode="constant");
            #rotated_seg=seg_img
            #rotated_seg=ndimage.filters.median_filter(seg_img, mode="constant", size=(5,5,3))
            #rotated_seg[rotated_seg>0]=2**15
            rotated_seg=ndimage.interpolation.rotate(rotated_seg,angle=angl,axes=axes,reshape=False,order=1,mode="constant");
            rotated_seg[rotated_seg>0.0]=1.0
            rotated_seg[rotated_seg<=0.0]=0.0
            #rotated_seg=ndimage.filters.median_filter(rotated_seg, mode="constant", size=(5,5,3))
            os.makedirs(aug_path+dirname+prefix+"_"+str(i))
            rotated_img,rotated_seg=crop_volume(rotated_img,rotated_seg,thresholds)
            save(rotated_img,aug_path+dirname+prefix+"_"+str(i)+"/"+dirname+prefix+"_"+str(i)+"_brain.nii.gz",hdr=img_header,use_compression=False)
            save(rotated_seg,aug_path+dirname+prefix+"_"+str(i)+"/"+dirname+prefix+"_"+str(i)+"_seg.nii.gz",hdr=seg_header,use_compression=False)
            i=i+1

def zoom_image(img,seg_img,prefix,img_header,seg_header):
    factors=[0.8,0.9,1.1,1.2]
    i=0
    for factor in factors:
        rotated_img=ndimage.zoom(img,zoom=factor,mode='nearest', order=5);
        rotated_seg=ndimage.zoom(seg_img,zoom=factor,mode='nearest', order=0);
        os.makedirs(aug_path+dirname+prefix+"_"+str(i))
        save(rotated_img,aug_path+dirname+prefix+"_"+str(i)+"/"+dirname+prefix+"_"+str(i)+"_brain.nii.gz",hdr=img_header,use_compression=False)
        save(rotated_seg,aug_path+dirname+prefix+"_"+str(i)+"/"+dirname+prefix+"_"+str(i)+"_seg.nii.gz",hdr=seg_header,use_compression=False)
        i=i+1
    

      
def apply_augmentation(seq_picked,image,segmap,dirname,i,img_header,seg_header, thresholds):
    """ apply augmentation on each image and mask
    """
    dist_image, dist_mask= seq_picked(image=image, segmentation_maps=segmap)
    os.makedirs(aug_path+dirname+"_"+str(i))
    #print(dist_mask.shape)
    dist_image,dist_mask=crop_volume(dist_image,dist_mask.get_arr(), thresholds)
    save(dist_mask,aug_path+dirname+"_"+str(i)+"/"+dirname+"_"+str(i)+"_seg.nii.gz",hdr=seg_header,use_compression=False)
    save(dist_image,aug_path+dirname+"_"+str(i)+"/"+dirname+"_"+str(i)+"_brain.nii.gz",hdr=img_header,use_compression=False)
    
    return dist_image

thresholds=(448,448,48,0)
#thresholds=(128,112,144,0)
#thresholds=(128,112,144,0)

#read all files and apply augmentation
for file in glob.glob(path+"*"):
    """ apply augmentation on each image and mask"""
    dirname=os.path.basename(file)
    print(dirname)
    image_path=file+"/"+dirname+"_brain.nii.gz"
    mask_path=file+"/"+dirname+"_seg.nii.gz"
    image,img_header=load(image_path)
    mask,seg_header=load(mask_path)
    
    #mask=ndimage.filters.median_filter(mask, mode="constant", size=(3,3,3))
    
    segmap = SegmentationMapsOnImage(mask, shape=image.shape)
    seq_all=[]
    
    
    for i in range(len(seq)):
        
        seq_picked=iaa.Sequential([seq[i]])
        apply_augmentation(seq_picked,image,segmap,dirname,i,img_header,seg_header,thresholds)
        seq_all.append(seq[i])

    
        
    '''
    image_scale=ndimage.zoom(image,(1.3,1.3,5),order=3)
    mask_scale=ndimage.zoom(mask,(1.3,1.3,5),order=3)
    segmap = SegmentationMapsOnImage(mask_scale, shape=image_scale.shape)
    sel_indx=[2,3, 4,5,6,7,8,9]
    for i in sel_indx:
        seq_picked=iaa.Sequential([seq[1],seq[i]])
        ret_img=apply_augmentation(seq_picked,image_scale,segmap,dirname,i+100)
    '''
    
   
    #seq_picked=iaa.Sequential(seq_all)
    #apply_augmentation(seq_picked,image,segmap,dirname,len(dict_aug)-1)
    #image=exposure.equalize_adapthist(image)
    
    #zoom_image(image,mask,"zoom",img_header,seg_header)
    #rotate_volume(image,mask, (0,2),"roty",img_header,seg_header,thresholds)
    #rotate_volume(image,mask, (1,2),"rotx",img_header,seg_header,thresholds)
    
    os.makedirs(aug_path+dirname)
    image,mask=crop_volume(image,mask,thresholds)
    save(image,aug_path+dirname+"/"+dirname+"_brain.nii.gz",hdr=img_header, use_compression=False)
    save(mask,aug_path+dirname+"/"+dirname+"_seg.nii.gz",hdr=seg_header, use_compression=False)
    
    '''
    keys=set(itertools.product(itertools.combinations_with_replacement(range(2), 2), range(2), range(2), range(2), range(2)))
    key=random.choice(list(keys))
    dist_image = permute_data(image,key)
    dist_mask = permute_data(mask,key)
    
    os.makedirs(aug_path+dirname+"rotx")
    save(dist_image,aug_path+dirname+"rotx/"+dirname+"rotx_brain.nii.gz")
    save(dist_mask,aug_path+dirname+"rotx/"+dirname+"rotx_seg.nii.gz")
    '''   



### Step 3: Check all volumes have same dimension

In [7]:
#check incorrect shape or corrupted data
import glob, sys, os
import numpy as np
from medpy.io import load,save
path="/research/sharedresources/cbi/data_exchange/zakhagrp/presentations/DeepBrainIPP_dataset/augmented_volume/"
for files in glob.glob(path+"*"):
    basename=os.path.basename(files)
    img,h=load(files+"/"+basename+"_brain.nii.gz")
    seg,h=load(files+"/"+basename+"_seg.nii.gz")
    #print(seg.shape)
    shape=(448,448,48)
    if (len(seg.shape)>2):
        
        if(img.shape[0]!=shape[0] or img.shape[1]!=shape[1] or img.shape[2]!=shape[2] or img.shape[0]!=seg.shape[0] or img.shape[1]!=seg.shape[1]  or img.shape[2]!=seg.shape[2] or np.max(seg)!=1 or np.min(seg)!=0 or img.shape[2]!=seg.shape[2] or np.max(img)==0 ):
            print("corrupted",basename)
    else:
        print(basename)

### Step : Training model

#### Parameters: There are several parameters need to find or set. Follow the draft manuscript for various parameter to build models
file_path= Path where you stored augmented sample. Outcome of previous step

data_path=Path where you want to store models




In [None]:

import os, sys
import glob
sys.path.append('/lustre_scratch/sandbox/salam/cnn/3DUnetCNN/')
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"
from unet3d.data import write_data_to_file, open_data_file
from unet3d.generator import get_training_and_validation_generators
from unet3d.model import isensee2017_model
from unet3d.training import load_old_model, train_model
from keras.utils.vis_utils import plot_model
import tensorflow as tf
print(tf.__version__)
import keras

file_path="/research/sharedresources/cbi/data_exchange/zakhagrp/presentations/DeepBrainIPP_dataset/augmented_volume/"
data_path="/research/sharedresources/cbi/data_exchange/zakhagrp/presentations/DeepBrainIPP_dataset/model/"
output_model_path=data_path
config = dict()
config["gpu"]=4
config["image_shape"] = (448, 448, 48)  # This determines what shape the images will be cropped/resampled to.
config["patch_shape"] =None#(128,128,128)#None  # switch to None to train on the whole image
config["labels"] = (1)  # the label numbers on the input image
config["n_base_filters"] = 16  # these are doubled after each downsampling
config["n_labels"] =1# len(config["labels"])
config["all_modalities"] = ["brain"]  # set for the brats data
config["training_modalities"] = config["all_modalities"]  # change this if you want to only use some of the modalities
config["nb_channels"] = len(config["training_modalities"])
if "patch_shape" in config and config["patch_shape"] is not None:
    config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"]))
else:
    config["input_shape"] = tuple([config["nb_channels"]] + list(config["image_shape"]))
config["truth_channel"] = config["nb_channels"]
config["deconvolution"] = False  # if False, will use upsampling instead of deconvolution

config["batch_size"] =1
config["validation_batch_size"] =1
config["n_epochs"] = 500  # cutoff the training after this many epochs
config["patience"] = 15  # learning rate will be reduced after this many epochs if the validation loss is not improving
config["early_stop"] = 100  # training will be stopped after this many epochs without the validation loss improving
config["initial_learning_rate"] = 5e-5
config["learning_rate_drop"] = 0.7  # factor by which the learning rate will be reduced
config["validation_split"] = 0.8  # portion of the data that will be used for training
config["flip"] = False  # augments the data by randomly flipping an axis during
config["permute"] = False  # data shape must be a cube. Augments the data by permuting in various directions
config["distort"] = False  # switch to None if you want no distortion
config["augment"] = config["flip"] or config["distort"]
config["patch_overlap"] = 0  # if > 0, during training, validation patches will be overlapping
config["training_patch_start_offset"] = (16, 16, 16)  # randomly offset the first patch index by up to this offset
config["skip_blank"] = True  # if True, then patches without any target will be skipped

config["data_file"] = os.path.abspath(data_path+"brain_data.h5")
config["model_file"] = os.path.abspath(output_model_path+"brain_unet_model-{epoch:02d}.h5")
config["training_file"] = os.path.abspath(data_path+"brain_training_ids.pkl")
config["validation_file"] = os.path.abspath(data_path+"brain_validation_ids.pkl")
config["overwrite"] = False  # If True, will previous files. If False, will use previously written files.


def fetch_mouse_2020_files(modalities, group="Training", include_truth=True, return_subject_ids=False):
    training_data_files = list()
    subject_ids = list()
    modalities = list(modalities)
    if include_truth:
        modalities = modalities + ["seg"]
    #print(os.path.join(os.path.dirname(files_dir), "data", "*{0}*", "*{0}*").format(group))
    for subject_dir in glob.glob(file_path+"/*"):
        subject_id = os.path.basename(subject_dir)
        
        subject_ids.append(subject_id)
        subject_files = list()
        for modality in modalities:
            subject_files.append(os.path.join(subject_dir, subject_id + "_" + modality + ".nii.gz"))
            #print(os.path.join(subject_dir, subject_id + "_" + modality + ".nii.gz"))
        training_data_files.append(tuple(subject_files))
    if return_subject_ids:
        return training_data_files, subject_ids
    else:
        return training_data_files


def fetch_training_data_files(return_subject_ids=False):
    return fetch_mouse_2020_files(modalities=config["training_modalities"],include_truth=True, return_subject_ids=return_subject_ids)


def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)

        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    #if not overwrite and os.path.exists(config["model_file"]):
        #model = load_old_model(config["model_file"])
    #else:
        # instantiate new model
    model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
                              initial_learning_rate=config["initial_learning_rate"],
                              n_base_filters=config["n_base_filters"],depth=5,dropout_rate=0.1,gpu=config["gpu"])
    #print(model.summary())
    #tf.keras.utils.plot_model(
    #model, to_file='model.png', show_shapes=False, show_layer_names=True,
    #rankdir='LR', expand_nested=False, dpi=96)

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"],
        patch_overlap=config["patch_overlap"])

    # run training
    history=train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
    return history


history=main(overwrite=config["overwrite"])


### See the loss curve

In [None]:
import matplotlib.pyplot as plt
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()