In [None]:
import torch
import numpy as np
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import time
from utils.logger import log
from scipy import ndimage
import os
from engine.train_loop import train_3d_cnn

import spinecnn_CONFIG
from datasets.brains18 import BrainS18Dataset 
from model_zoo.builder import generate_model

In [None]:
# Config; also makes sure they exist in spinecnn_CONFIG.py
RNG_SEED = spinecnn_CONFIG.RNG_SEED
LEARNING_RATE = spinecnn_CONFIG.LEARNING_RATE
CHECKPOINT_PATH = spinecnn_CONFIG.CHECKPOINT_PATH
BATCH_SIZE = spinecnn_CONFIG.BATCH_SIZE
NUM_WORKERS = spinecnn_CONFIG.NUM_WORKERS
PIN_MEMORY = spinecnn_CONFIG.PIN_MEMORY
DATA_FOLDER = spinecnn_CONFIG.DATA_FOLDER
TRAINING_DATA_LIST = spinecnn_CONFIG.TRAINING_DATA_LIST
SAVE_FOLDER = spinecnn_CONFIG.SAVE_FOLDER
NUM_EPOCHS = spinecnn_CONFIG.NUM_EPOCHS
SAVE_INTERVALS = spinecnn_CONFIG.SAVE_INTERVALS
NEW_LAYER_NAMES = spinecnn_CONFIG.NEW_LAYER_NAMES
INPUT_W = spinecnn_CONFIG.INPUT_W
INPUT_H = spinecnn_CONFIG.INPUT_H
INPUT_D = spinecnn_CONFIG.INPUT_D
PRETRAIN_PATH = spinecnn_CONFIG.PRETRAIN_PATH
MODEL_TYPE = spinecnn_CONFIG.MODEL_TYPE
MODEL_DEPTH = spinecnn_CONFIG.MODEL_DEPTH
RESNET_SHORTCUT = spinecnn_CONFIG.RESNET_SHORTCUT
NO_CUDA = spinecnn_CONFIG.NO_CUDA
NUM_CLASSES = spinecnn_CONFIG.NUM_CLASSES
GPU_ID = spinecnn_CONFIG.GPU_ID

In [None]:
# getting model
torch.manual_seed(RNG_SEED)
model, parameters = generate_model(MODEL_TYPE, MODEL_DEPTH,
                INPUT_W, INPUT_H, INPUT_D, NUM_CLASSES,
                NEW_LAYER_NAMES, PRETRAIN_PATH,
                RESNET_SHORTCUT, NO_CUDA, GPU_ID,
                model_phase='train')
print(model)

In [None]:
# optimizer
params = [
            { 'params': parameters['base_parameters'], 'lr': LEARNING_RATE }, 
            { 'params': parameters['new_parameters'], 'lr': LEARNING_RATE*100 }
            ]
optimizer = torch.optim.SGD(params, momentum=0.9, weight_decay=1e-3)   
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

In [None]:
# train from checkpoint
if CHECKPOINT_PATH:
    if os.path.isfile(CHECKPOINT_PATH):
        print("=> loading checkpoint '{}'".format(CHECKPOINT_PATH))
        checkpoint = torch.load(CHECKPOINT_PATH)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
            .format(CHECKPOINT_PATH, checkpoint['epoch']))

# getting data 
training_dataset = BrainS18Dataset(DATA_FOLDER, TRAINING_DATA_LIST, INPUT_D, INPUT_H, INPUT_W, phase='train')
data_loader = DataLoader(training_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

In [None]:
train_3d_cnn(data_loader, model, optimizer, scheduler, 
        NUM_EPOCHS, SAVE_INTERVALS, SAVE_FOLDER, NO_CUDA) 