In [None]:
from pathlib import Path
from tqdm import tqdm
import math

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torchvision import transforms

from leaf.dta import LeafDataset, LeafIterableDataset, LeafDataLoader, GetPatches, TransformPatches, RandomGreen, LeafDataLoader
from leaf.model import LeafModel, train_one_epoch, validate_one_epoch, warmup
from torch.utils.data import RandomSampler, DataLoader

from torch.optim.lr_scheduler import CyclicLR

from torch.utils.tensorboard import SummaryWriter


# Transforms with normalizations for imagenet
data_transforms = {
    'train': transforms.Compose([
        transforms.ToTensor(),
        GetPatches(800, 600, 224, include_whole=False),#, min_ratio=None, min_value=None, min_hue=None, max_hue=None),
        # TransformPatches([
        #     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        # ]),
        transforms.Lambda(lambda patches: torch.stack(patches)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
        GetPatches(800, 600, 224, include_whole=False),#, min_ratio=None, min_value=None, min_hue=None, max_hue=None),
        # TransformPatches([
        #     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        # ]),
        transforms.Lambda(lambda patches: torch.stack(patches)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    ])
}

In [None]:
output_dir = Path("/mnt/hdd/leaf-disease-outputs")
output_dir.mkdir(exist_ok=True)
logging_dir = Path("/mnt/hdd/leaf-disease-runs")
logging_dir.mkdir(exist_ok=True)

batch_size = 1
val_batch_size = 2

num_workers = 4

train_dset = LeafDataset("./data/train_images", "./data/train_images/labels.csv", transform=data_transforms["train"], tiny=False)
val_dset = LeafDataset("./data/val_images", "./data/val_images/labels.csv", transform=data_transforms["val"], tiny=False)
# RandomGreen variant
#train_dset = LeafDataset("./data/train_images", "./data/train_images/labels.csv", transform=data_transforms["train"])

train_dataloader = LeafDataLoader(train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers, max_samples_per_image=12)
val_dataloader = LeafDataLoader(val_dset, batch_size=val_batch_size, shuffle=False, num_workers=num_workers, max_samples_per_image=12)
#train_dataloader = DataLoader(train_dset, batch_size=batch_size, sampler=RandomSampler(train_dset, replacement=True), num_workers=num_workers, pin_memory=True)

In [None]:
leaf_model = LeafModel("tf_efficientnet_b0_ns", model_prefix="NEW", output_dir=output_dir, logging_dir=logging_dir, ragged_batches=True)
# optimizer = Adam(leaf_model.model.parameters(), lr=learning_rate, weight_decay=weight_decay)
optimizer = Adam(leaf_model.model.parameters(), lr=1e-4, weight_decay=1e-6)
scheduler = None
#scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=1, eta_min=min_learning_rate, last_epoch=-1)
leaf_model.update_optimizer_scheduler(optimizer, scheduler)

In [None]:
#leaf_model.load_checkpoint("NEW-epochs-5")

In [None]:
leaf_model.model_prefix = "NEW"

In [None]:
# learning_rate = 1e-6
# #min_learning_rate = 1e-6
# weight_decay = 5e-9

# #T_0 = 10

In [None]:
# optimizer = Adam(leaf_model.model.parameters(), lr=learning_rate, weight_decay=weight_decay)
# scheduler = None

In [None]:
# optimizer = SGD(leaf_model.model.parameters(), lr=1e-3, weight_decay=weight_decay)
# scheduler = CyclicLR(optimizer, 1e-6, 1e-3, step_size_up=2*len(train_dataloader), mode="exp_range", gamma=0.9)

In [None]:
# leaf_model.update_optimizer_scheduler(optimizer, scheduler)

In [None]:
run_name = "NEW-5-epochs"
for epoch in range(2):
    epoch_name = f"{run_name}-{epoch}"
    train_one_epoch(leaf_model, train_dataloader, log_steps=2000, val_data_loader=val_dataloader, save_at_log_steps=True, epoch_name=epoch_name)
    val_loss, raw_val_acc, val_acc = validate_one_epoch(leaf_model, val_dataloader, ensemble_patches=True)
    print(f"Validation after epoch {epoch}: loss {val_loss}, raw acc {raw_val_acc}, acc {val_acc}")
    tb_writer = SummaryWriter(log_dir=leaf_model.logging_model_dir / epoch_name)
    val_step = epoch * len(train_dataloader)
    tb_writer.add_scalar("loss/val", val_loss, val_step)
    tb_writer.add_scalar("raw_acc/val", raw_val_acc, val_step)
    tb_writer.add_scalar("acc/val", val_acc, val_step)
    tb_writer.close()
    leaf_model.save_checkpoint(f"{epoch_name}", epoch=epoch)

In [None]:
res = validate_one_epoch(leaf_model, val_dataloader)

In [None]:
foo = timm.create_model(model_name="tf_efficientnet_b4_ns", input_size=(3, 380, 380), num_classes=5, pretrained=True)

In [None]:
dir(leaf_model.model)

In [None]:
foo = torch.randn((8, 3, 244, 244)).to(leaf_model.device)

In [None]:
with torch.no_grad():
    bar = leaf_model.model.forward(foo)

In [None]:
bar.shape

In [None]:
bigfoo = torch.randn((8, 3, 380, 380)).to(leaf_model.device)

In [None]:
with torch.no_grad():
    bigbar = leaf_model.model.forward(foo)

In [None]:

bigbar.shape