In [None]:
# parent: train_vit_folds.ipynb

In [None]:
from pathlib import Path
from tqdm import tqdm
import math

import numpy as np
import pandas as pd

import torch
from torchvision import transforms

from leaf.dta import LeafDataset, LeafIterableDataset, LeafDataLoader, GetPatches, TransformPatches, RandomGreen, \
    LeafDataLoader, GetRandomCrops, GetRandomResizedCrops, get_leaf_splits
from leaf.model import LeafModel, train_one_epoch, validate_one_epoch, warmup




# Transforms with normalizations for imagenet
data_transforms = {
    'tta-random': transforms.Compose([
        transforms.ToTensor(),
        GetRandomCrops(12, size=224),
        TransformPatches(transforms.RandomHorizontalFlip()),
        transforms.Lambda(lambda patches: torch.stack(patches)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'tta-12': transforms.Compose([
        transforms.ToTensor(),
        GetPatches(800, 600, 224, test_colors=False, include_center=False),
        transforms.Lambda(lambda patches: torch.stack(patches)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}
val_max_samples_per_image = 12
ragged_batches = False

In [None]:
output_dir = Path("/mnt/hdd/leaf-disease-outputs")
output_dir.mkdir(exist_ok=True)
logging_dir = Path("/mnt/hdd/leaf-disease-runs")
logging_dir.mkdir(exist_ok=True)

batch_size = 16
val_batch_size = 2

learning_rate = 1e-4
weight_decay = 0.0
num_epochs = 12
log_steps = 605

num_workers = 4

num_splits = 5
folds = get_leaf_splits("./data/images/labels.csv", num_splits, random_seed=5293)

full_dset = LeafDataset("./data/images", "./data/images/labels.csv")


In [None]:
for fold, (train_idxs, val_idxs) in enumerate(folds):
    if fold == 0 or fold == 1 or fold == 3 or fold == 4:
        continue
        
    torch.cuda.empty_cache()
    
    train_dset = LeafDataset.from_leaf_dataset(full_dset, train_idxs, transform=data_transforms["train"])
    val_dset = LeafDataset.from_leaf_dataset(full_dset, val_idxs, transform=data_transforms["val"])

    train_dataloader = LeafDataLoader(train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers, max_samples_per_image=max_samples_per_image)
    val_dataloader = LeafDataLoader(val_dset, batch_size=val_batch_size, shuffle=False, num_workers=num_workers, max_samples_per_image=val_max_samples_per_image)

    leaf_model = LeafModel("vit_base_patch16_224", model_prefix=f"vit_base_patch16_224_lr1e4_fold{fold}", output_dir=output_dir, logging_dir=logging_dir, ragged_batches=ragged_batches)
    optimizer = Adam(leaf_model.model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate, epochs=num_epochs, steps_per_epoch=len(train_dataloader), )
    leaf_model.update_optimizer_scheduler(optimizer, scheduler)
    
#     if fold == 0:
#         leaf_model.load_checkpoint("vit_base_patch16_224_lr1e4_fold0_12-epochs-11")

#         run_name = f"{leaf_model.model_prefix}_{num_epochs}-epochs"
#         for epoch in range(12, num_epochs+1):
#             epoch_name = f"{run_name}-{epoch}"
#             train_one_epoch(leaf_model, train_dataloader, log_steps=log_steps, val_data_loader=val_dataloader, save_at_log_steps=True, epoch_name=epoch_name)
#             val_loss, val_raw_acc, val_logits_mean_acc, val_probs_mean_acc = validate_one_epoch(leaf_model, val_dataloader)
#             print(f"Validation after epoch {epoch}: loss {val_loss}, raw acc {val_raw_acc}, logits mean acc {val_logits_mean_acc}, probs mean acc {val_probs_mean_acc}")
#             tb_writer = SummaryWriter(log_dir=leaf_model.logging_model_dir / epoch_name)
#             val_step = epoch * len(train_dataloader)
#             tb_writer.add_scalar("loss/val", val_loss, val_step)
#             tb_writer.add_scalar("raw_acc/val", val_raw_acc, val_step)
#             tb_writer.add_scalar("logits_mean_acc/val", val_logits_mean_acc, val_step)
#             tb_writer.add_scalar("probs_mean_acc/val", val_probs_mean_acc, val_step)
#             tb_writer.close()
#             leaf_model.save_checkpoint(f"{epoch_name}", epoch=epoch)

    if fold == 2:
        leaf_model.load_checkpoint("vit_base_patch16_224_lr1e4_fold2_12-epochs-6")

        run_name = f"{leaf_model.model_prefix}_{num_epochs}-epochs"
        for epoch in range(7, num_epochs+1):
            epoch_name = f"{run_name}-{epoch}"
            train_one_epoch(leaf_model, train_dataloader, log_steps=log_steps, val_data_loader=val_dataloader, save_at_log_steps=True, epoch_name=epoch_name)
            val_loss, val_raw_acc, val_logits_mean_acc, val_probs_mean_acc = validate_one_epoch(leaf_model, val_dataloader)
            print(f"Validation after epoch {epoch}: loss {val_loss}, raw acc {val_raw_acc}, logits mean acc {val_logits_mean_acc}, probs mean acc {val_probs_mean_acc}")
            tb_writer = SummaryWriter(log_dir=leaf_model.logging_model_dir / epoch_name)
            val_step = epoch * len(train_dataloader)
            tb_writer.add_scalar("loss/val", val_loss, val_step)
            tb_writer.add_scalar("raw_acc/val", val_raw_acc, val_step)
            tb_writer.add_scalar("logits_mean_acc/val", val_logits_mean_acc, val_step)
            tb_writer.add_scalar("probs_mean_acc/val", val_probs_mean_acc, val_step)
            tb_writer.close()
            leaf_model.save_checkpoint(f"{epoch_name}", epoch=epoch)


In [None]:
#torch.cuda.empty_cache()

In [None]:
if False:
    alt_transform = transforms.Compose([
            transforms.ToTensor(),
            GetPatches(800, 600, 224, test_colors=False, include_center=False),
            transforms.Lambda(lambda patches: torch.stack(patches)),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    alt_val_dset = LeafDataset("./data/val_images", "./data/val_images/labels.csv", transform=alt_transform, tiny=False)
    alt_val_dataloader = LeafDataLoader(alt_val_dset, batch_size=val_batch_size, shuffle=False, num_workers=num_workers, max_samples_per_image=12)

In [None]:
if False:
    run_name = f"{leaf_model.model_prefix}-{num_epochs}-epochs"
    for epoch in range(1, num_epochs+1):
        epoch_name = f"{run_name}-{epoch}"
        leaf_model.load_checkpoint(f"{epoch_name}")
        val_loss, val_raw_acc, val_logits_mean_acc, val_probs_mean_acc = validate_one_epoch(leaf_model, alt_val_dataloader)
        print(f"Validation after epoch {epoch}: loss {val_loss}, raw acc {val_raw_acc}, logits mean acc {val_logits_mean_acc}, probs mean acc {val_probs_mean_acc}")