# **KFOLD 학습**

In [None]:
# optimizer 뒤에
scheduler = ReduceLROnPlateau(optimizer, patience=3, factor=0.5, mode='max', verbose=True)

kfold_n = 5
datas = split_df(train_df)
for fold in range(1, kfold_n + 1):
    args.kfold = fold

    # 하나의 fold만 사용
    train_data = datas[args.kfold - 1][0]
    valid_data = datas[args.kfold - 1][1]

    train_dataset = ArtDataset(train_data, transform)
    valid_dataset = ArtDataset(valid_data, transform)


    train_loader = DataLoader(train_dataset,
                              batch_size = hyper_parameters.batch_size,
                              shuffle = True,
                              num_workers = 1)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=hyper_parameters.batch_size,
                              shuffle=False,
                              num_workers=1)

    early_stopping_counter = 0
    best_acc = -1
    patience = 5
    for epoch in range(hyper_parameters.epoch):

      train_len = len(train_data)
      valid_len = len(valid_data)

      train_acc = train(model, hyper_parameters, train_loader, train_len)
      valid_acc = validate(model, hyper_parameters, valid_loader, valid_len)

      print("[Epoch {}] Train ACC : {}, Valid ACC : {}".format(epoch, train_acc, valid_acc))

      if valid_acc > best_acc:
          best_acc = valid_acc
          early_stopping_counter = 0

          if args.model == 'timm':
              save_name = f"fold{fold}_{str(best_acc.item())[:4]}"
          else:
              save_name = f"fold{fold}_{str(best_acc.item())[:4]}"

          torch.save(model, os.path.join(PROJECT_PATH, save_name))
          print(f'model saved! {save_name}')

      else:
          early_stopping_counter += 1
          if early_stopping_counter >= patience:
              print(f'EarlyStopping counter: {early_stopping_counter} out of {patience}')
              break

      # scheduler
      scheduler.step(best_acc)

# **KFOLD 추론**

In [None]:
models = [
          "fold1_0.98",
          "fold2_0.98",
          "fold3_0.97",
          "fold4_0.97",
          "fold5_0.95"]

models = [torch.load(os.path.join(PROJECT_PATH, model)).eval() for model in models]

In [None]:
batch_size = 2
test_dataset = ArtDataset(test_df, transform)
test_loader = DataLoader(test_dataset,
                         shuffle=False,
                         num_workers=1,
                         batch_size=batch_size)

In [None]:
answers = []
for images, labels in test_loader:
    images = images.to(hyper_parameters.device)
    labels = labels.to(hyper_parameters.device)

    predicts = torch.zeros(images.size(0), 7)
    for model in models:
        outputs = model(images)
        outputs = F.softmax(outputs.cpu(), dim=1)
        predicts += outputs

    # prediction들의 average 계산
    predict_avg = predicts / len(models)

    _, preds = torch.max(predict_avg, 1)

    answers.extend(list(preds.cpu().numpy()))