In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("..")

from src.data.prepare_data import *
from src.models.model import *
import time
import random
from sklearn.model_selection import KFold
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data.sampler import WeightedRandomSampler
from tqdm import tqdm

In [3]:
random.seed(2718)
np.random.seed(2718)
torch.manual_seed(2718)
SEED = 2718

In [4]:
train = pd.read_csv("../data/internal/train.csv")
test = pd.read_csv("../data/internal/test.csv")
sub = pd.read_csv("../data/internal/sample_submission.csv")

In [17]:
config = {
    'INPUT_DIR'      : '',
    'MODEL'          : 'alexnet',
    'SIZE'           : 128,
    'BATCH_SIZE'     : 128,
    'NUM_FOLDS'      : 3,
    'NUM_EPOCHS'     : 10,
    'FREEZED_EPOCHS' : 3,
    'LEARNING_RATE'  : 1e-3,
    'EARLY_STOPPING' : 3,
    'UNIFORM_AUGMENT': True,
    'TTA'            : 3,
    'NUM_WORKERS'    : 16,
    'DEVICE'         : 'cpu'
}

In [19]:
t = time.time()
predictions = pd.DataFrame()
transform = ImageTransform(config['SIZE'], config['UNIFORM_AUGMENT'])

skf = KFold(n_splits=config['NUM_FOLDS'], shuffle=True, random_state=SEED)
for i, (idxT,idxV) in enumerate(skf.split(np.arange(15))):
    t_fold = time.time()
    tr = train.loc[train.tfrecord.isin(idxT)]
    va = train.loc[train.tfrecord.isin(idxV)]
    tr.reset_index(drop=True, inplace=True)
    va.reset_index(drop=True, inplace=True)

    # create datasets
    dataset_train = MelanomaDataset("../data/internal/train", tr, transform=transform, phase='train')
    dataset_valid = MelanomaDataset("../data/internal/train", va, transform=transform, phase='valid')
    
    # load a pretrained model
    net = load_model(config['MODEL'], 2)

    # define a loss function
    criterion = nn.CrossEntropyLoss()

    # define an optimizer
    optimizer = optim.Adam(net.parameters(), lr=config['LEARNING_RATE'])

    # define a scheduler
    scheduler = ReduceLROnPlateau(optimizer=optimizer, mode='max', patience=2, factor=0.2)

    # create a sampler
    class_sample_count = np.array([len(np.where(tr['target'] == t)[0]) for t in np.unique(tr['target'])])
    weight = 1. / class_sample_count
    samples_weight = np.array([weight[t] for t in tr['target']])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

    # train the network
    print(f"---- fold: {i + 1} ------------")
    train_model(
        f"{config['MODEL']}_{i + 1}",
        dataset_train,
        dataset_valid,
        config['BATCH_SIZE'],
        net,
        criterion,
        optimizer,
        scheduler,
        config['NUM_EPOCHS'],
        config['FREEZED_EPOCHS'],
        config['INPUT_DIR'],
        config['NUM_WORKERS'],
        sampler,
        config['DEVICE'],
        config['EARLY_STOPPING']
    )

    # predict on test dataset
    test['target'] = 0
    dataset_test = MelanomaDataset("../data/internal/test", test, transform=transform, phase='test')
    predictions = get_predictions(dataset_test, 
                                  config["BATCH_SIZE"], 
                                  net, 
                                  config["TTA"], 
                                  predictions, 
                                  config["DEVICE"])
    predictions.to_csv(f'../submissions/{config["MODEL"]}_fold{i+1}.csv')
    print(f"fold took {round(time.time() - t_fold, 2)}")
    
# output
sub['target'] = predictions.mean(axis=1)
sub.to_csv(f"../submissions/submission{config['MODEL']}.csv", index=False)
print(f"total time: {round(time.time() - t, 4)}")

  0%|          | 0/171 [00:00<?, ?it/s]

---- fold: 1 ------------


100%|██████████| 171/171 [02:17<00:00,  1.24it/s]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 1, loss_train: 0.9565, loss_valid: 0.3698, auc_valid: 0.7127, saved: True, 197.9558sec


100%|██████████| 171/171 [02:15<00:00,  1.26it/s]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 2, loss_train: 0.5129, loss_valid: 0.3258, auc_valid: 0.7129, saved: True, 196.5378sec


100%|██████████| 171/171 [02:15<00:00,  1.26it/s]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 3, loss_train: 0.4967, loss_valid: 0.3986, auc_valid: 0.7112, saved: False, 195.0844sec


100%|██████████| 171/171 [06:33<00:00,  2.30s/it]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 4, loss_train: 0.4401, loss_valid: 0.2136, auc_valid: 0.7541, saved: True, 452.9181sec


100%|██████████| 171/171 [06:21<00:00,  2.23s/it]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 5, loss_train: 0.3585, loss_valid: 0.1937, auc_valid: 0.7444, saved: False, 440.3519sec


100%|██████████| 171/171 [06:16<00:00,  2.20s/it]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 6, loss_train: 0.3148, loss_valid: 0.1595, auc_valid: 0.7023, saved: False, 435.5674sec


100%|██████████| 171/171 [06:17<00:00,  2.21s/it]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 7, loss_train: 0.2838, loss_valid: 0.1813, auc_valid: 0.7723, saved: True, 436.6003sec


100%|██████████| 171/171 [06:16<00:00,  2.20s/it]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 8, loss_train: 0.2537, loss_valid: 0.1730, auc_valid: 0.6736, saved: False, 435.4929sec


100%|██████████| 171/171 [06:18<00:00,  2.22s/it]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 9, loss_train: 0.2426, loss_valid: 0.1973, auc_valid: 0.7306, saved: False, 437.4067sec


100%|██████████| 171/171 [06:17<00:00,  2.21s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

epoch: 10, loss_train: 0.2298, loss_valid: 0.1600, auc_valid: 0.7474, saved: False, 436.7642sec


100%|██████████| 3/3 [03:35<00:00, 71.81s/it]


fold took 3881.07


  0%|          | 0/171 [00:00<?, ?it/s]

---- fold: 2 ------------


100%|██████████| 171/171 [02:16<00:00,  1.26it/s]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 1, loss_train: 0.9567, loss_valid: 0.5249, auc_valid: 0.7236, saved: True, 196.9148sec


100%|██████████| 171/171 [02:15<00:00,  1.26it/s]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 2, loss_train: 0.5271, loss_valid: 0.6315, auc_valid: 0.7291, saved: True, 197.2371sec


100%|██████████| 171/171 [02:14<00:00,  1.27it/s]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 3, loss_train: 0.5042, loss_valid: 0.3219, auc_valid: 0.7332, saved: True, 195.6398sec


100%|██████████| 171/171 [06:22<00:00,  2.23s/it]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 4, loss_train: 0.4578, loss_valid: 0.2962, auc_valid: 0.7351, saved: True, 442.0249sec


100%|██████████| 171/171 [06:18<00:00,  2.21s/it]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 5, loss_train: 0.3657, loss_valid: 0.2112, auc_valid: 0.7512, saved: True, 437.5500sec


100%|██████████| 171/171 [06:12<00:00,  2.18s/it]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 6, loss_train: 0.3291, loss_valid: 0.2023, auc_valid: 0.7640, saved: True, 431.5578sec


100%|██████████| 171/171 [06:16<00:00,  2.20s/it]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 7, loss_train: 0.2838, loss_valid: 0.1858, auc_valid: 0.7218, saved: False, 434.6012sec


100%|██████████| 171/171 [06:12<00:00,  2.18s/it]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 8, loss_train: 0.2673, loss_valid: 0.2080, auc_valid: 0.7577, saved: False, 430.9226sec


100%|██████████| 171/171 [06:12<00:00,  2.18s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

epoch: 9, loss_train: 0.2485, loss_valid: 0.1841, auc_valid: 0.7467, saved: False, 431.5983sec


100%|██████████| 3/3 [03:35<00:00, 71.87s/it]


fold took 3414.97


  0%|          | 0/171 [00:00<?, ?it/s]

---- fold: 3 ------------


100%|██████████| 171/171 [02:14<00:00,  1.27it/s]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 1, loss_train: 0.9615, loss_valid: 0.3615, auc_valid: 0.7382, saved: True, 194.0480sec


100%|██████████| 171/171 [02:12<00:00,  1.29it/s]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 2, loss_train: 0.5336, loss_valid: 0.3372, auc_valid: 0.7447, saved: True, 192.9648sec


100%|██████████| 171/171 [02:13<00:00,  1.28it/s]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 3, loss_train: 0.5282, loss_valid: 0.4504, auc_valid: 0.7320, saved: False, 192.4415sec


100%|██████████| 171/171 [06:26<00:00,  2.26s/it]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 4, loss_train: 0.4646, loss_valid: 0.2137, auc_valid: 0.7361, saved: False, 444.9330sec


100%|██████████| 171/171 [06:25<00:00,  2.25s/it]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 5, loss_train: 0.3649, loss_valid: 0.1971, auc_valid: 0.7735, saved: True, 445.3052sec


100%|██████████| 171/171 [06:22<00:00,  2.24s/it]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 6, loss_train: 0.3210, loss_valid: 0.2333, auc_valid: 0.7466, saved: False, 441.1519sec


100%|██████████| 171/171 [06:22<00:00,  2.24s/it]
  0%|          | 0/171 [00:00<?, ?it/s]

epoch: 7, loss_train: 0.2893, loss_valid: 0.1493, auc_valid: 0.7615, saved: False, 441.3694sec


100%|██████████| 171/171 [06:22<00:00,  2.23s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

epoch: 8, loss_train: 0.2642, loss_valid: 0.1545, auc_valid: 0.7718, saved: False, 440.7455sec


100%|██████████| 3/3 [03:35<00:00, 71.77s/it]

fold took 3009.6
total time: 10305.6773





This submission achieved a score of 0.8786