In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("..")

from src.data.prepare_data import *
from src.models.model import *
import time
import random
from sklearn.model_selection import KFold
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data.sampler import WeightedRandomSampler
from tqdm import tqdm

In [3]:
random.seed(2718)
np.random.seed(2718)
torch.manual_seed(2718)
SEED = 2718

In [4]:
train = pd.read_csv("../data/internal/train.csv")
test = pd.read_csv("../data/internal/test.csv")
train_ext = pd.read_csv('../data/external/train.csv')
sub = pd.read_csv("../data/internal/sample_submission.csv")

train_ext['tfrecord'] += 20
train = pd.concat([train, train_ext], axis=0, ignore_index=True)

In [5]:
config = {
    'INPUT_DIR'      : '',
    'MODEL'          : 'alexnet',
    'SIZE'           : 128,
    'BATCH_SIZE'     : 32,
    'NUM_FOLDS'      : 3,
    'NUM_EPOCHS'     : 20,
    'FREEZED_EPOCHS' : 3,
    'LEARNING_RATE'  : 1e-3,
    'EARLY_STOPPING' : 20,
    'UNIFORM_AUGMENT': True,
    'TTA'            : 3,
    'NUM_WORKERS'    : 16,
    'DEVICE'         : 'cpu'
}

In [6]:
t = time.time()
predictions = pd.DataFrame()
transform = ImageTransform(config['SIZE'], config['UNIFORM_AUGMENT'])

skf = KFold(n_splits=config['NUM_FOLDS'], shuffle=True, random_state=SEED)
for i, (idxT,idxV) in enumerate(skf.split(np.arange(15))):
    t_fold = time.time()
    tr = train.loc[train.tfrecord.isin(idxT) | (train.tfrecord >= 20)]
    va = train.loc[train.tfrecord.isin(idxV)]
    tr.reset_index(drop=True, inplace=True)
    va.reset_index(drop=True, inplace=True)

    # create datasets
    dataset_train = MelanomaDataset("../data/internal/train", 
                                    tr, 
                                    transform=transform, 
                                    phase='train', 
                                    external_base_dir='../data/external/train')
    dataset_valid = MelanomaDataset("../data/internal/train", va, transform=transform, phase='valid')
    
    # load a pretrained model
    net = load_model(config['MODEL'], 2)

    # define a loss function
    criterion = nn.CrossEntropyLoss()

    # define an optimizer
    optimizer = optim.Adam(net.parameters(), lr=config['LEARNING_RATE'])

    # define a scheduler
    scheduler = ReduceLROnPlateau(optimizer=optimizer, mode='max', patience=2, factor=0.2)

    # create a sampler
    sampler = create_weighted_random_sampler(tr)

    # train the network
    print(f"---- fold: {i + 1} ------------")
    train_model(
        f"{config['MODEL']}_{i + 1}",
        dataset_train,
        dataset_valid,
        config['BATCH_SIZE'],
        net,
        criterion,
        optimizer,
        scheduler,
        config['NUM_EPOCHS'],
        config['FREEZED_EPOCHS'],
        config['INPUT_DIR'],
        config['NUM_WORKERS'],
        sampler,
        config['DEVICE'],
        config['EARLY_STOPPING']
    )

    # predict on test dataset
    test['target'] = 0
    dataset_test = MelanomaDataset("../data/internal/test", test, transform=transform, phase='test')
    tta_time = time.time()
    predictions = get_predictions(dataset_test, 
                                  config["BATCH_SIZE"], 
                                  net, 
                                  config["TTA"], 
                                  predictions, 
                                  config["DEVICE"])
    print(f"TTA took {round(time.time() - tta_time, 2)}")
    predictions.to_csv(f'../submissions/{config["MODEL"]}_fold{i+1}.csv')
    print(f"fold took {round(time.time() - t_fold, 2)}")
    
# output
sub['target'] = predictions.mean(axis=1)
sub.to_csv(f"../submissions/submission{config['MODEL']}.csv", index=False)
print(f"total time: {round(time.time() - t, 4)}")

  0%|          | 0/1471 [00:00<?, ?it/s]

---- fold: 1 ------------


100%|██████████| 1471/1471 [06:29<00:00,  3.77it/s]
  0%|          | 0/1471 [00:00<?, ?it/s]

epoch: 1, loss_train: 0.7134, loss_valid: 0.3265, auc_valid: 0.6973, saved: True, 461.3580sec


100%|██████████| 1471/1471 [06:26<00:00,  3.81it/s]
  0%|          | 0/1471 [00:00<?, ?it/s]

epoch: 2, loss_train: 0.6688, loss_valid: 0.4668, auc_valid: 0.7240, saved: True, 457.5376sec


100%|██████████| 1471/1471 [06:26<00:00,  3.81it/s]
  0%|          | 0/1471 [00:00<?, ?it/s]

epoch: 3, loss_train: 0.6583, loss_valid: 0.1878, auc_valid: 0.7041, saved: False, 456.8285sec


100%|██████████| 1471/1471 [25:31<00:00,  1.04s/it]
  0%|          | 0/1471 [00:00<?, ?it/s]

epoch: 4, loss_train: 0.4921, loss_valid: 0.1371, auc_valid: 0.6773, saved: False, 1599.7361sec


100%|██████████| 1471/1471 [27:06<00:00,  1.11s/it]
  0%|          | 0/1471 [00:00<?, ?it/s]

epoch: 5, loss_train: 0.4489, loss_valid: 0.1742, auc_valid: 0.3867, saved: False, 1693.9566sec


 94%|█████████▍| 1380/1471 [26:39<01:45,  1.16s/it]


KeyboardInterrupt: 