# Hyperparameters

In [1]:
CFG = 'cfgs/brain/brain_pospoolxyz.yaml'#'cfgs/brain/brain_pointwisemlp.yaml'
IS_CONF = True
DATAFOLDER = 'BrainData'
FINE_TUNE = False
IS_EXPERIMENT = True
DATA_POSTFIX = '_exp_grey'
DEVICE = 0

#For early stopping
EXP_NAME = '3'
PATIENCE = 20

# For fine-tuning
PRETRAINED_MODEL_PATH = None

#For loss
LOSS_TYPE = 'BCE'
IS_KUNI = False
KUNI_AGG = 'mean'
KUNI_LAM = 1
IS_SEP_LOSS = False


# Importing libraries

In [2]:
if IS_EXPERIMENT:
    from comet_ml import Experiment

    experiment = Experiment(
        api_key="1cDG73F9830XhuYfWn4JJ2JEH",
        project_name="project-dl-bia",
        workspace="rukubrakov",
    )

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/rukubrakov/project-dl-bia/d97ddb9e2f1b43d1b8813a36a1b117ef



In [3]:
from utils.fcd import *
warnings.filterwarnings("ignore")

In [4]:
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
torch.cuda.set_device(DEVICE)
torch.cuda.current_device()

0

In [5]:
config = config_seting(CFG)

# Data processing, train and validation

In [6]:
es = []
test_ious = []
test_lossess = []
test_dices = []
test_rocs = []
test_tprs = []
test_fprs = []

train_loader, test_loader, train_labels = get_loader(batch_size = config.batch_size,num_points = config.num_points, 
                                   data_post = DATA_POSTFIX, 
                                   datafolder = DATAFOLDER )

total_1class = np.sum([np.sum(labels) for labels in train_labels])
total = np.sum([labels.shape[0] for labels in train_labels])
weight = total_1class / (total - total_1class)
WEIGHTS = [weight, 1]

model, criterion = build_multi_part_segmentation(config,
                                                 WEIGHTS,
                                                 LOSS_TYPE,
                                                 is_kuni = IS_KUNI,
                                                 kuni_agg = KUNI_AGG,
                                                 kuni_lam = KUNI_LAM,
                                                 is_sep_loss = IS_SEP_LOSS
                                                )
if FINE_TUNE:
    model.load_state_dict(torch.load(PRETRAINED_MODEL_PATH))
model.cuda()
criterion.cuda()

n_data = len(train_loader.dataset)
print(f"length of training dataset: {n_data}")
n_data = len(test_loader.dataset)
print(f"length of testing dataset: {n_data}")


if config.optimizer == 'sgd':
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=config.batch_size * dist.get_world_size() / 16 * config.base_learning_rate,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)
elif config.optimizer == 'adam':
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config.base_learning_rate,
                                 weight_decay=config.weight_decay)
elif config.optimizer == 'adamW':
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=config.base_learning_rate,
                                  weight_decay=config.weight_decay)
else:
    raise NotImplementedError(f"Optimizer {config.optimizer} not supported")
scheduler = get_scheduler(optimizer, len(train_loader), config)

test_iou = []
test_losses = []
test_dice = []
test_roc = []
test_tpr = []
test_fpr = []
minimal_loss = 1e8
early_stopping = EarlyStopping(patience=PATIENCE, verbose=True, 
                               path = f'../pytorch/checkpoints/{EXP_NAME}.pth')

for epoch in tqdm(range(config.start_epoch, config.epochs + 1)):

    tic = time.time()
    loss, opt, roc = train(epoch, train_loader, model, criterion, optimizer, scheduler,
                           config,is_kuni = IS_KUNI,is_sep_loss = IS_SEP_LOSS)
    if IS_EXPERIMENT:
        experiment.log_metric('roc_train',roc, epoch = epoch)
        experiment.log_metric('lr',optimizer.param_groups[0]['lr'], epoch = epoch)

    tmp = validate(epoch, test_loader, model, criterion, config, num_votes=1, 
                   is_conf = IS_CONF,is_kuni = IS_KUNI,is_sep_loss = IS_SEP_LOSS)
    if IS_CONF:
        loss_test,acc, msIoU, mIoU,confs, roc,opt, tpr,fpr,dice_score = tmp
    else:
        loss_test,acc, msIoU, mIoU, roc,opt,tpr,fpr,dice_score = tmp


    if IS_EXPERIMENT:    
        experiment.log_metric('optimal_cutoff_test',opt, epoch = epoch)
        experiment.log_metric('tpr_test',tpr, epoch = epoch)
        experiment.log_metric('fpr_test',fpr, epoch = epoch)
        experiment.log_metric('roc_test',roc, epoch = epoch)
        experiment.log_metric('dice_score_test',dice_score, epoch = epoch)

        if IS_SEP_LOSS:
            experiment.log_metric('loss_train',loss[0]+loss[1], epoch = epoch)
            experiment.log_metric('loss_train_base',loss[0], epoch = epoch)
            experiment.log_metric('loss_train_kuni',loss[1], epoch = epoch)
            experiment.log_metric('loss_test',loss_test[0]+loss_test[1], epoch = epoch)
            experiment.log_metric('loss_test_base',loss_test[0], epoch = epoch)
            experiment.log_metric('loss_test_kuni',loss_test[1], epoch = epoch)
        else:
            experiment.log_metric('loss_train',loss, epoch = epoch)
            experiment.log_metric('loss_test',loss_test, epoch = epoch)
        experiment.log_metric('accuracy_test',acc, epoch = epoch)
        experiment.log_metric('IoU_test',msIoU, epoch = epoch)
    if IS_EXPERIMENT:
        if IS_CONF:
            experiment.log_confusion_matrix(title=f"Test confusion epoch = {epoch}", matrix = confs, labels = ['No FCD', 'FCD'])
    if IS_SEP_LOSS:   
        test_losses.append(loss_test[0])
    else:
        test_losses.append(loss_test)
    test_iou.append(msIoU)
    test_roc.append(roc)
    test_dice.append(dice_score)
    test_tpr.append(tpr)
    test_fpr.append(fpr)

    if IS_SEP_LOSS:
        early_stopping(loss_test[0], model)
    else:
        early_stopping(loss_test, model)

    if early_stopping.early_stop:
        print("Early stopping")
        break

best_ind = np.argmax(np.array(test_tpr)-np.array(test_fpr))
es.append(e)
test_ious.append(np.max(test_iou))
test_lossess.append(np.min(test_losses))
test_dices.append(np.max(test_dice))
test_rocs.append(np.max(test_roc))
test_tprs.append(np.array(test_tpr)[best_ind])
test_fprs.append(np.array(test_fpr)[best_ind])

del model
    

data/BrainData/trainval_data_exp_grey.pkl loaded successfully
data/BrainData/test_data_exp_grey.pkl loaded successfully


  0%|          | 0/1000 [00:00<?, ?it/s]

length of training dataset: 1013
length of testing dataset: 100


  0%|          | 1/1000 [01:22<22:50:34, 82.32s/it]

Validation loss decreased (inf --> 0.483422).  Saving model ...


  0%|          | 2/1000 [02:45<22:59:22, 82.93s/it]

Validation loss decreased (0.483422 --> 0.449246).  Saving model ...


  0%|          | 3/1000 [04:09<23:02:14, 83.18s/it]

Validation loss decreased (0.449246 --> 0.436388).  Saving model ...


  0%|          | 4/1000 [05:33<23:02:43, 83.30s/it]

Validation loss decreased (0.436388 --> 0.430132).  Saving model ...


  0%|          | 5/1000 [06:57<23:03:11, 83.41s/it]

Validation loss decreased (0.430132 --> 0.424701).  Saving model ...


  1%|          | 6/1000 [08:20<23:02:26, 83.45s/it]

Validation loss decreased (0.424701 --> 0.420293).  Saving model ...


  1%|          | 7/1000 [09:44<23:01:37, 83.48s/it]

EarlyStopping counter: 1 out of 20
Validation loss decreased (0.420293 --> 0.413617).  Saving model ...


  1%|          | 9/1000 [12:32<23:00:10, 83.56s/it]

Validation loss decreased (0.413617 --> 0.410014).  Saving model ...
Validation loss decreased (0.410014 --> 0.407984).  Saving model ...


  1%|          | 11/1000 [15:19<22:58:10, 83.61s/it]

Validation loss decreased (0.407984 --> 0.405191).  Saving model ...


  1%|          | 12/1000 [16:43<22:56:48, 83.61s/it]

EarlyStopping counter: 1 out of 20


  1%|▏         | 13/1000 [18:07<22:55:28, 83.62s/it]

EarlyStopping counter: 2 out of 20


  1%|▏         | 14/1000 [19:30<22:54:08, 83.62s/it]

EarlyStopping counter: 3 out of 20


  2%|▏         | 15/1000 [20:54<22:52:59, 83.63s/it]

Validation loss decreased (0.405191 --> 0.402905).  Saving model ...
Validation loss decreased (0.402905 --> 0.402714).  Saving model ...


  2%|▏         | 17/1000 [23:42<22:50:52, 83.67s/it]

Validation loss decreased (0.402714 --> 0.394891).  Saving model ...


  2%|▏         | 18/1000 [25:06<22:49:29, 83.68s/it]

EarlyStopping counter: 1 out of 20


  2%|▏         | 19/1000 [26:29<22:48:10, 83.68s/it]

Validation loss decreased (0.394891 --> 0.393738).  Saving model ...


  2%|▏         | 20/1000 [27:53<22:46:40, 83.67s/it]

EarlyStopping counter: 1 out of 20


  2%|▏         | 21/1000 [29:17<22:45:17, 83.67s/it]

Validation loss decreased (0.393738 --> 0.389280).  Saving model ...


  2%|▏         | 22/1000 [30:40<22:43:59, 83.68s/it]

Validation loss decreased (0.389280 --> 0.389121).  Saving model ...


  2%|▏         | 23/1000 [32:04<22:42:38, 83.68s/it]

Validation loss decreased (0.389121 --> 0.389089).  Saving model ...


  2%|▏         | 24/1000 [33:28<22:41:07, 83.68s/it]

EarlyStopping counter: 1 out of 20


  2%|▎         | 25/1000 [34:51<22:39:41, 83.67s/it]

Validation loss decreased (0.389089 --> 0.387030).  Saving model ...


KeyboardInterrupt: 