In [None]:
# parent: train_vanilla_TTA-TWO

In [None]:
from pathlib import Path
from tqdm import tqdm
import math

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torchvision import transforms

from leaf.dta import LeafDataset, LeafIterableDataset, LeafDataLoader, GetPatches, TransformPatches, RandomGreen, LeafDataLoader, GetRandomCrops, GetRandomResizedCrops
from leaf.model import LeafModel, train_one_epoch, validate_one_epoch, warmup
from torch.utils.data import RandomSampler, DataLoader

from torch.optim.lr_scheduler import CyclicLR

from torch.utils.tensorboard import SummaryWriter


# Transforms with normalizations for imagenet
data_transforms = {
    'train': transforms.Compose([
        transforms.ToTensor(),
        GetRandomCrops(6, size=380),
        TransformPatches(transforms.RandomHorizontalFlip()),
        transforms.Lambda(lambda patches: torch.stack(patches)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
        GetRandomCrops(6, size=380),
        TransformPatches(transforms.RandomHorizontalFlip()),
        transforms.Lambda(lambda patches: torch.stack(patches)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    ])
}
max_samples_per_image = 6
ragged_batches = False

In [None]:
output_dir = Path("/mnt/hdd/leaf-disease-outputs")
output_dir.mkdir(exist_ok=True)
logging_dir = Path("/mnt/hdd/leaf-disease-runs")
logging_dir.mkdir(exist_ok=True)

batch_size = 1
val_batch_size = 2

num_workers = 4

train_dset = LeafDataset("./data/train_images", "./data/train_images/labels.csv", transform=data_transforms["train"], tiny=False)
val_dset = LeafDataset("./data/val_images", "./data/val_images/labels.csv", transform=data_transforms["val"], tiny=False)
# RandomGreen variant
#train_dset = LeafDataset("./data/train_images", "./data/train_images/labels.csv", transform=data_transforms["train"])

train_dataloader = LeafDataLoader(train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers, max_samples_per_image=max_samples_per_image)
val_dataloader = LeafDataLoader(val_dset, batch_size=val_batch_size, shuffle=False, num_workers=num_workers, max_samples_per_image=max_samples_per_image)
#train_dataloader = DataLoader(train_dset, batch_size=batch_size, sampler=RandomSampler(train_dset, replacement=True), num_workers=num_workers, pin_memory=True)

In [None]:
learning_rate = 1e-4
weight_decay = 1e-6
num_epochs = 1
log_steps = 1000

In [None]:
leaf_model = LeafModel("tf_efficientnet_b4", model_prefix="b4_tta", output_dir=output_dir, logging_dir=logging_dir, ragged_batches=ragged_batches)
# optimizer = Adam(leaf_model.model.parameters(), lr=learning_rate, weight_decay=weight_decay)
optimizer = Adam(leaf_model.model.parameters(), lr=learning_rate, weight_decay=weight_decay)
#scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=1, eta_min=min_learning_rate, last_epoch=-1)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate, epochs=num_epochs, steps_per_epoch=len(train_dataloader), )
leaf_model.update_optimizer_scheduler(optimizer, scheduler)

In [None]:
run_name = f"{leaf_model.model_prefix}-{num_epochs}-epochs"
for epoch in range(1, num_epochs+1):
    epoch_name = f"{run_name}-{epoch}"
    train_one_epoch(leaf_model, train_dataloader, log_steps=log_steps, val_data_loader=val_dataloader, save_at_log_steps=True, epoch_name=epoch_name)
    val_loss, val_raw_acc, val_logits_mean_acc, val_probs_mean_acc = validate_one_epoch(leaf_model, val_dataloader)
    print(f"Validation after epoch {epoch}: loss {val_loss}, raw acc {val_raw_acc}, logits mean acc {val_logits_mean_acc}, probs mean acc {val_probs_mean_acc}")
    tb_writer = SummaryWriter(log_dir=leaf_model.logging_model_dir / epoch_name)
    val_step = epoch * len(train_dataloader)
    tb_writer.add_scalar("loss/val", val_loss, val_step)
    tb_writer.add_scalar("raw_acc/val", val_raw_acc, val_step)
    tb_writer.add_scalar("logits_mean_acc/val", val_logits_mean_acc, val_step)
    tb_writer.add_scalar("probs_mean_acc/val", val_probs_mean_acc, val_step)
    tb_writer.close()
    leaf_model.save_checkpoint(f"{epoch_name}", epoch=epoch)

In [None]:
#torch.cuda.empty_cache()

In [None]:
alt_transform = transforms.Compose([
        transforms.ToTensor(),
        GetPatches(800, 600, 380, test_colors=False, include_center=False),
        transforms.Lambda(lambda patches: torch.stack(patches)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
alt_val_dset = LeafDataset("./data/val_images", "./data/val_images/labels.csv", transform=alt_transform, tiny=False)
alt_val_dataloader = LeafDataLoader(alt_val_dset, batch_size=3, shuffle=False, num_workers=num_workers, max_samples_per_image=6)

In [None]:
run_name = f"{leaf_model.model_prefix}-{num_epochs}-epochs"
for epoch in range(1, num_epochs+1):
    epoch_name = f"{run_name}-{epoch}"
    leaf_model.load_checkpoint(f"{epoch_name}")
    val_loss, val_raw_acc, val_logits_mean_acc, val_probs_mean_acc = validate_one_epoch(leaf_model, val_dataloader)
    print(f"Validation after epoch {epoch}: loss {val_loss}, raw acc {val_raw_acc}, logits mean acc {val_logits_mean_acc}, probs mean acc {val_probs_mean_acc}")