In [1]:
import torch
import matplotlib.pyplot as plt
import pickle
import pandas as pd

from functions import *
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_util import RAdam,trainable_parameter

In [2]:
sz = 256
batch_size = 8
epochs = 10
clip = 1.0
lr = 3e-4
encoder_str = 'efficientnet-b2'
decoder_channels=(256, 128, 64, 32, 16)
decoder_repeats=(2,2,2,2,2)
opt_level="O1"

In [3]:
images_dir = "../Data/pickles_"+str(sz)+"/images/"
masks_dir = r"../Data/pickles_"+str(sz)+"/masks/"

In [4]:
imageId = pd.read_csv('../Data/stage_2_train.csv')['ImageId'].values.tolist()
np.random.shuffle(imageId)

In [5]:
# setup train/val
imageId_val = imageId[:2400]
imageId_train = imageId[2400:]

In [6]:
if sz in preprocessing_dict:
    dataset_train = dataset(imageId_train,images_dir,masks_dir,transform,preprocessing_dict[sz])
else:
    dataset_train = dataset(imageId_train,images_dir,masks_dir,transform)
    dataset_train._cal_preprocessing()
dataset_val = dataset(imageId_val,images_dir,masks_dir,preprocessing=dataset_train.preprocessing)

train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=4)

In [45]:
# imgs,masks = next(iter(train_loader))

In [9]:
encoder = EfficientNet_encoder.from_pretrained(encoder_str)
decoder = EfficientNet_decoder(sz,encoder_channels[encoder_str],\
                               decoder_channels=decoder_channels,decoder_repeats=decoder_repeats)
model = Unet(encoder,decoder).to('cuda')
paras = trainable_parameter(model)
opt = RAdam(paras,lr=lr,weight_decay=1e-2)
scheduler = ReduceLROnPlateau(opt, 'min',factor=0.5,patience=5,min_lr=1e-05)
model, opt = amp.initialize(model, opt, opt_level=opt_level)

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [11]:
model,bestOpt,bestWeight = train(opt,model,epochs,train_loader,\
                                 valid_loader,paras,clip,scheduler=scheduler)

epoch:0, train_loss: +1.423, val_loss: +1.467

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
epoch:1, train_loss: +0.852, val_loss: +0.755

epoch:2, train_loss: +0.707, val_loss: +0.612

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
epoch:3, train_loss: +0.667, val_loss: +0.603

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0
epoch:4, train_loss: +0.658, val_loss: +0.604

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0
epoch:5, train_loss: +0.641, val_loss: +0.639

epoch:6, train_loss: +0.635, val_loss: +0.597

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0
epoch:7, train_loss: +0.631, val_loss: +0.583

epoch:8, train_loss: +0.617, val_loss: +0.565

epoch:9, train_loss: 

In [13]:
checkpoint = {
    'model': model.state_dict(),
    'opt': opt.state_dict(),
    'amp': amp.state_dict()
}
torch.save(checkpoint, '../Model/'+str(sz)+'_'+encoder_str+'_'+str(decoder_channels)+'_'+str(decoder_repeats)+'.pt')

In [None]:
# checkpoint = torch.load('amp_checkpoint.pt')
# model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# model.load_state_dict(checkpoint['model'])
# optimizer.load_state_dict(checkpoint['optimizer'])
# amp.load_state_dict(checkpoint['amp'])