# Cassava Leaf Disease Classification — Plan

Goal: Train a robust image classifier to produce a high-accuracy submission (≥ bronze) fast, then iterate to medal.

Data available:
- train_images/ (21,367 total; 18,717 listed here implies a subset visible; train.csv maps image_id -> label [0..4])
- test_images/ (~3,000 images)
- train_tfrecords/ and test_tfrecords/ (prepared shards)
- label_num_to_disease_map.json, sample_submission.csv

Metric: Accuracy

Hardware: 1x NVIDIA A10-24Q (~24GB). Use PyTorch + timm with AMP.

Validation protocol:
- 5-fold StratifiedKFold on train.csv (deterministic seed).
- Track OOF accuracy per fold and overall.
- Save OOF logits and test logits for ensembling.

Baseline modeling (fast):
1) Model: timm tf_efficientnet_b3_ns (strong for Cassava), img size 512, pretrained=True.
2) Augmentations:
   - Light baseline: RandomResizedCrop(512), Horizontal/Vertical Flip, ColorJitter, Normalize.
   - Mixup/CutMix (alpha=0.2, p=0.5) and Label Smoothing (0.05) for stability.
3) Optimizer/Scheduler: AdamW, CosineAnnealingLR with warmup; weight_decay ~1e-4.
4) Loss: CrossEntropyLoss with label smoothing (or timm SoftTarget for mixup).
5) Training: 5 folds x 3-5 epochs for smoke baseline to verify pipeline and CV; early stop patience 1.
6) Inference: TTA (e.g., 4-8 flips/resizes) after baseline verified.

Scaling up to medal:
- Train longer (10-15 epochs) once baseline CV is solid.
- Larger/backbones for diversity:
  • tf_efficientnet_b4_ns (size 600), tf_efficientnetv2_s (size 384-448), convnext_base (size 512), seresnext50_32x4d (size 512).
- Use stronger augs: RandAugment/AutoAugment, RandomErasing, Color/Contrast jitter tuned, CutMix/Mixup balanced.
- Multiple seeds; average logits across seeds and models.
- Calibrate/ensemble via weighted average (opt weights on OOF).
- TTA at test (8-10 aug views).

Data pipeline choices:
- Prefer direct JPEG dataloader (torchvision/timm) for flexibility and on-the-fly augs. TFRecords optional; may skip.
- Ensure transforms are fit within folds; no leakage.

Efficiency:
- Use AMP, pin_memory, num_workers, persistent_workers.
- Log elapsed time per epoch/fold; save checkpoints/logits per fold.
- Start with a 10% subsample + 2 folds to sanity-check speed, then scale up.

Milestones:
M1: GPU check + EDA + fold split (OOF plan).
M2: Smoke baseline (EffNet-B3, 512, 5-fold, 3 epochs) + OOF/test logits + submission v1.
M3: Full baseline (10-12 epochs) + TTA + submission v2.
M4: Second backbone (ConvNeXt-B or EffNetV2-S) + blend + submission v3.
M5: Further diversity + seeds + optimized blend + submission v4.

Questions for Experts:
1) Backbones that consistently medal on Cassava now: still tf_efficientnet_b3/b4_ns vs effnetv2_s vs convnext_base vs seresnext50?
2) Image size sweet spot for A10-24GB vs accuracy: 512 vs 600?
3) Best augmentation recipe: RandAugment + Mixup/CutMix + LS? Recommended probabilities/strengths?
4) CV pitfalls/leakage: any particular stratification nuances for Cassava? Any grouping needed (e.g., device/source) or plain StratifiedKFold is fine?
5) TTA: how many and which transforms are most reliable here?
6) Ensembling: typical LB gain from blending 2-3 diverse models on this comp? Weighting tips?

Next steps:
- Confirm GPU works; quick EDA (class counts, image dims).
- Implement deterministic 5-fold split, dataloaders, baseline model/training loop with AMP.
- Train smoke baseline (few epochs) and generate submission v1.
- Request expert review before scaling up.

In [1]:
# Environment/GPU check + quick EDA
import sys, subprocess, os, json, random, time, glob
from pathlib import Path

def pip_install(pkg):
    try:
        __import__(pkg.split('==')[0].replace('-', '_'))
    except Exception:
        print(f"Installing {pkg}...")
        subprocess.run([sys.executable, '-m', 'pip', 'install', pkg, '--quiet'])

# Ensure core deps
pip_install('torch')
pip_install('torchvision')
pip_install('timm')
pip_install('albumentations>=1.4.0')
pip_install('opencv-python-headless')
pip_install('pandas')
pip_install('numpy')
pip_install('Pillow')

import torch, torchvision
import timm, albumentations as A
import cv2, pandas as pd, numpy as np
from PIL import Image, ImageOps

print('GPU Available:', torch.cuda.is_available())
print('GPU Count:', torch.cuda.device_count())
if torch.cuda.is_available():
    print('GPU Name:', torch.cuda.get_device_name(0))
    props = torch.cuda.get_device_properties(0)
    print('GPU Memory (GB):', round(props.total_memory/1024**3, 1))

SEED = 42
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed);
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed()

DATA_DIR = Path('.')
train_csv = DATA_DIR/'train.csv'
sample_csv = DATA_DIR/'sample_submission.csv'
train_dir = DATA_DIR/'train_images'
test_dir = DATA_DIR/'test_images'

df = pd.read_csv(train_csv)
print('train.csv shape:', df.shape)
print(df.head())
print('Label value_counts:')
print(df['label'].value_counts().sort_index())

train_files = sorted([p.name for p in train_dir.glob('*.jpg')])
test_files = sorted([p.name for p in test_dir.glob('*.jpg')])
print('Train images on disk:', len(train_files))
print('Test images on disk:', len(test_files))

# Check that all train image_ids exist on disk
missing = [x for x in df['image_id'].tolist() if not (train_dir/x).exists()]
print('Missing train files:', len(missing))
if len(missing) > 0:
    print('First few missing:', missing[:5])

# Sanity check sample submission columns/order
ss = pd.read_csv(sample_csv)
print('sample_submission columns:', list(ss.columns))
print('sample_submission head:')
print(ss.head())

# Probe a few image shapes and EXIF orientation handling
probe = [df['image_id'].iloc[i] for i in np.random.choice(len(df), size=min(5, len(df)), replace=False)]
for fname in probe:
    path = train_dir/fname
    try:
        with Image.open(path) as im:
            im = ImageOps.exif_transpose(im).convert('RGB')
            w, h = im.size
        arr = cv2.cvtColor(np.array(im), cv2.COLOR_RGB2BGR)
        print(f'Probe {fname}: size=({w}x{h}), dtype={arr.dtype}, shape={arr.shape}')
    except Exception as e:
        print(f'Error reading {fname}:', e)

print('ENV/EDA DONE')

Installing torch...


Installing torchvision...




Installing timm...




Installing albumentations>=1.4.0...




Installing opencv-python-headless...




Installing Pillow...




  from .autonotebook import tqdm as notebook_tqdm


GPU Available: True
GPU Count: 1
GPU Name: NVIDIA A10-24Q
GPU Memory (GB): 23.7
train.csv shape: (18721, 2)
         image_id  label
0  1000015157.jpg      0
1  1000201771.jpg      3
2   100042118.jpg      1
3  1000723321.jpg      1
4  1000812911.jpg      3
Label value_counts:
label
0      939
1     1901
2     2091
3    11523
4     2267
Name: count, dtype: int64
Train images on disk: 18721
Test images on disk: 2676
Missing train files: 0
sample_submission columns: ['image_id', 'label']
sample_submission head:
         image_id  label
0  1234294272.jpg      4
1  1234332763.jpg      4
2  1234375577.jpg      4
3  1234555380.jpg      4
4  1234571117.jpg      4
Probe 3939876859.jpg: size=(800x600), dtype=uint8, shape=(600, 800, 3)
Probe 964482896.jpg: size=(800x600), dtype=uint8, shape=(600, 800, 3)
Probe 584853244.jpg: size=(800x600), dtype=uint8, shape=(600, 800, 3)
Probe 306133807.jpg: size=(800x600), dtype=uint8, shape=(600, 800, 3)
Probe 851527428.jpg: size=(800x600), dtype=uint8, shap

In [4]:
# Tiny overfit sanity check (no mixup/cutmix) to validate pipeline
import math
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageOps

NUM_CLASSES = 5
IMG_SIZE = 512
BATCH_SIZE = 16
NUM_WORKERS = min(8, os.cpu_count() or 2)

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

# Use robust, unambiguous resize+crop to avoid RandomResizedCrop schema issues
train_tfms = A.Compose([
    A.SmallestMaxSize(max_size=IMG_SIZE, interpolation=cv2.INTER_LINEAR),
    A.RandomCrop(height=IMG_SIZE, width=IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.ColorJitter(0.2,0.2,0.2,0.1,p=0.5),
    A.RandomBrightnessContrast(0.2,0.2,p=0.5),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

valid_tfms = A.Compose([
    A.Resize(height=IMG_SIZE, width=IMG_SIZE, interpolation=cv2.INTER_LINEAR),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

class CassavaDS(Dataset):
    def __init__(self, df, img_dir, transforms=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = Path(img_dir)
        self.transforms = transforms
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = self.img_dir / row['image_id']
        with Image.open(path) as im:
            im = ImageOps.exif_transpose(im).convert('RGB')
            img = np.array(im)
        if self.transforms is not None:
            img = self.transforms(image=img)['image']
        # HWC->CHW
        img = torch.from_numpy(img.transpose(2,0,1)).float()
        label = int(row['label'])
        return img, torch.tensor(label, dtype=torch.long)

# Stratified tiny subset
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=SEED)
tiny_size = 200  # small to overfit quickly
df_tiny = (df.groupby('label', group_keys=False)
            .apply(lambda x: x.sample(max(1, math.ceil(tiny_size*len(x)/len(df))), random_state=SEED))
            .reset_index(drop=True))
train_idx, val_idx = next(sss.split(df_tiny['image_id'], df_tiny['label']))
df_train_tiny = df_tiny.iloc[train_idx].reset_index(drop=True)
df_valid_tiny = df_tiny.iloc[val_idx].reset_index(drop=True)
print('Tiny train/val sizes:', len(df_train_tiny), len(df_valid_tiny))

ds_tr = CassavaDS(df_train_tiny, train_dir, transforms=train_tfms)
ds_va = CassavaDS(df_valid_tiny, train_dir, transforms=valid_tfms)
dl_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True, persistent_workers=True)
dl_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = timm.create_model('tf_efficientnet_b3_ns', pretrained=True, num_classes=NUM_CLASSES)
model.to(device)
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

# Cosine schedule with warmup (3 epochs warmup, total 8)
EPOCHS = 8
warmup_epochs = max(1, int(0.2 * EPOCHS))
total_steps = EPOCHS * max(1, len(dl_tr))
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

def cosine_warmup(step, total, warmup):
    if step < warmup:
        return step / max(1, warmup)
    progress = (step - warmup) / max(1, total - warmup)
    return 0.5 * (1 + math.cos(math.pi * progress))

global_step = 0
best_acc = 0.0
start_time = time.time()
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    t0 = time.time()
    for it, (imgs, labels) in enumerate(dl_tr):
        lr_scale = cosine_warmup(global_step, total_steps, warmup_epochs*len(dl_tr))
        for pg in optimizer.param_groups:
            pg['lr'] = 1e-3 * lr_scale
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            logits = model(imgs)
            loss = criterion(logits, labels)
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
        global_step += 1
        if (it+1) % 10 == 0:
            elapsed = time.time()-t0
            print(f'Epoch {epoch+1}/{EPOCHS} | it {it+1}/{len(dl_tr)} | loss {(running_loss/total):.4f} | acc {(correct/total):.4f} | {elapsed:.1f}s', flush=True)
    train_loss = running_loss/ max(1,total)
    train_acc = correct/ max(1,total)
    # Validate
    model.eval()
    v_correct = 0
    v_total = 0
    with torch.no_grad():
        for imgs, labels in dl_va:
            imgs = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            logits = model(imgs)
            preds = logits.argmax(1)
            v_correct += (preds == labels).sum().item()
            v_total += imgs.size(0)
    v_acc = v_correct / max(1, v_total)
    epoch_time = time.time()-t0
    print(f'Epoch {epoch+1} done | train_loss {train_loss:.4f} | train_acc {train_acc:.4f} | val_acc {v_acc:.4f} | epoch_time {epoch_time:.1f}s | elapsed {(time.time()-start_time)/60:.1f}m', flush=True)
    if v_acc > best_acc:
        best_acc = v_acc
        torch.save({'model': model.state_dict(), 'acc': best_acc}, 'tiny_overfit_b3_best.pth')
print('Tiny overfit best val_acc:', best_acc)
print('Sanity check complete.')

  .apply(lambda x: x.sample(max(1, math.ceil(tiny_size*len(x)/len(df))), random_state=SEED))
  model = create_fn(


Tiny train/val sizes: 153 51


  scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Epoch 1 done | train_loss 2.2376 | train_acc 0.3264 | val_acc 0.2941 | epoch_time 12.6s | elapsed 0.2m


Epoch 2 done | train_loss 1.4088 | train_acc 0.7083 | val_acc 0.6863 | epoch_time 2.1s | elapsed 0.2m


Epoch 3 done | train_loss 1.0875 | train_acc 0.7708 | val_acc 0.7059 | epoch_time 2.0s | elapsed 0.3m


Epoch 4 done | train_loss 0.8370 | train_acc 0.8264 | val_acc 0.5490 | epoch_time 2.1s | elapsed 0.3m


Epoch 5 done | train_loss 0.6201 | train_acc 0.9514 | val_acc 0.6471 | epoch_time 2.1s | elapsed 0.4m


Epoch 6 done | train_loss 0.5628 | train_acc 0.9444 | val_acc 0.6471 | epoch_time 2.1s | elapsed 0.4m


Epoch 7 done | train_loss 0.5163 | train_acc 0.9931 | val_acc 0.6863 | epoch_time 2.1s | elapsed 0.4m


Epoch 8 done | train_loss 0.4849 | train_acc 0.9931 | val_acc 0.6863 | epoch_time 2.1s | elapsed 0.5m


Tiny overfit best val_acc: 0.7058823529411765
Sanity check complete.


In [17]:
# 1-fold sanity with timm-native transforms (Option A) + CE(LS) + EMA, explicit ImageNet mean/std
import gc, math, time, os, warnings
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader, Dataset
from PIL import Image, ImageOps
import numpy as np
import pandas as pd
import torch, timm
from timm.utils import ModelEmaV3

warnings.filterwarnings('ignore', category=UserWarning)
torch.backends.cudnn.benchmark = True
try:
    torch.set_float32_matmul_precision('high')
except Exception:
    pass

FOLDS = 2
MAX_FOLDS = 1
EPOCHS = 3
MODEL_NAME = 'tf_efficientnet_b3_ns'
BATCH_SIZE = 24
NUM_WORKERS = min(8, os.cpu_count() or 2)
BASE_LR = 1e-4
WD = 1e-4
NUM_CLASSES = 5
IMG_SIZE = 512
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class CassavaDS(Dataset):
    def __init__(self, df, img_dir, transforms=None, is_test=False):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transforms = transforms
        self.is_test = is_test
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = os.path.join(self.img_dir, row['image_id'])
        try:
            with Image.open(path) as im:
                im = ImageOps.exif_transpose(im).convert('RGB')
        except Exception:
            ridx = np.random.randint(0, len(self.df))
            rpath = os.path.join(self.img_dir, self.df.iloc[ridx]['image_id'])
            with Image.open(rpath) as im:
                im = ImageOps.exif_transpose(im).convert('RGB')
        img = im
        if self.transforms is not None:
            img = self.transforms(img)
        if self.is_test:
            return img, row['image_id']
        else:
            label = int(row['label'])
            return img, torch.tensor(label, dtype=torch.long)

def get_model_and_transforms():
    model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=NUM_CLASSES).to(device)
    cfg = timm.data.resolve_data_config({}, model=model)
    cfg['input_size'] = (3, IMG_SIZE, IMG_SIZE)
    # Critical override
    cfg['mean'] = IMAGENET_MEAN
    cfg['std'] = IMAGENET_STD
    crop_pct = float(cfg.get('crop_pct', 0.875))
    print('cfg:', {'input_size': cfg['input_size'], 'crop_pct': crop_pct, 'interpolation': cfg.get('interpolation','bicubic'), 'mean': cfg['mean'], 'std': cfg['std']})
    train_tfms = timm.data.create_transform(
        is_training=True,
        **cfg,
        auto_augment='rand-m9-mstd0.5-inc1',
        re_prob=0.25,
        re_mode='pixel'
    )
    valid_tfms = timm.data.create_transform(is_training=False, **cfg)
    return model, train_tfms, valid_tfms

def train_one_epoch(model, ema, loader, optimizer, criterion, scaler, epoch, total_steps, global_step):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    t0 = time.time()
    warmup_steps = max(1, int(0.05 * total_steps))
    for it, (imgs, labels) in enumerate(loader):
        step = global_step + it
        if step < warmup_steps:
            lr_scale = step / warmup_steps
        else:
            progress = (step - warmup_steps) / max(1, (total_steps - warmup_steps))
            lr_scale = 0.5 * (1 + math.cos(math.pi * progress))
        for pg in optimizer.param_groups:
            pg['lr'] = BASE_LR * lr_scale

        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
            logits = model(imgs)
            loss = criterion(logits, labels)
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        if ema is not None:
            ema.update(model)

        running_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)

        if (it+1) % 50 == 0:
            print(f'  it {it+1}/{len(loader)} | loss {running_loss/max(1,total):.4f} | acc {correct/max(1,total):.4f} | elapsed {time.time()-t0:.1f}s', flush=True)
    tr_loss = running_loss/max(1,total)
    tr_acc = (correct/max(1,total))
    return tr_loss, tr_acc, global_step + len(loader)

def validate(model_eval, loader):
    model_eval.eval()
    correct = 0
    total = 0
    logits_all = []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
                logits = model_eval(imgs)
            logits_all.append(logits.float().cpu().numpy())
            preds = logits.argmax(1)
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
    return correct/max(1,total), np.concatenate(logits_all, axis=0)

def infer_test(model_eval, loader):
    model_eval.eval()
    logits_all = []
    with torch.no_grad():
        for imgs, _ids in loader:
            imgs = imgs.to(device, non_blocking=True)
            with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
                logits = model_eval(imgs)
            logits_all.append(logits.float().cpu().numpy())
    return np.concatenate(logits_all, axis=0)

# Prepare dataframes
test_files = pd.DataFrame({'image_id': sorted([p.name for p in test_dir.glob('*.jpg')])})
skf = StratifiedKFold(n_splits=FOLDS, shuffle=True, random_state=SEED)
df_shuf = df.sample(frac=1.0, random_state=SEED).reset_index(drop=True)

oof_preds = np.zeros((len(df_shuf), NUM_CLASSES), dtype=np.float32)
test_logits_accum = np.zeros((len(test_files), NUM_CLASSES), dtype=np.float32)

fold_idx = 0
t_start_all = time.time()
for tr_idx, va_idx in skf.split(df_shuf['image_id'], df_shuf['label']):
    t_fold = time.time()
    fold_idx += 1
    print(f'===== Fold {fold_idx}/{FOLDS} | train {len(tr_idx)} | valid {len(va_idx)} =====', flush=True)
    df_tr = df_shuf.iloc[tr_idx].reset_index(drop=True)
    df_va = df_shuf.iloc[va_idx].reset_index(drop=True)

    model, train_tfms, valid_tfms = get_model_and_transforms()

    ds_tr = CassavaDS(df_tr, str(train_dir), transforms=train_tfms)
    ds_va = CassavaDS(df_va, str(train_dir), transforms=valid_tfms)
    dl_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True, persistent_workers=False)
    dl_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=False)

    test_ds = CassavaDS(test_files, str(test_dir), transforms=valid_tfms, is_test=True)
    test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=False)

    maj_class = int(df_shuf['label'].mode()[0])
    maj_acc = float((df_va['label'] == maj_class).mean())
    print(f'Majority class={maj_class} | majority baseline acc on val={maj_acc:.4f}', flush=True)

    ema = ModelEmaV3(model, decay=0.999)
    criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.05)
    optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=WD)
    scaler = torch.amp.GradScaler('cuda', enabled=torch.cuda.is_available())

    # Diagnostics: check val batch channel stats post-transform
    with torch.no_grad():
        imgs_chk, labels_chk = next(iter(dl_va))
        m = imgs_chk.mean(dim=[0,2,3])
        s = imgs_chk.std(dim=[0,2,3])
        print('val batch mean:', [round(float(x),4) for x in m])
        print('val batch std :', [round(float(x),4) for x in s])

    best_acc = 0.0
    best_state = None
    total_steps = EPOCHS * len(dl_tr)
    global_step = 0

    for epoch in range(1, EPOCHS+1):
        tr_loss, tr_acc, global_step = train_one_epoch(model, ema, dl_tr, optimizer, criterion, scaler, epoch, total_steps, global_step)
        eval_model = ema.module if ema is not None else model
        va_acc, _ = validate(eval_model, dl_va)
        print(f'Fold {fold_idx} | Epoch {epoch}/{EPOCHS} | tr_loss {tr_loss:.4f} tr_acc {tr_acc:.4f} | va_acc {va_acc:.4f} | elapsed_fold {(time.time()-t_fold)/60:.1f}m', flush=True)
        if va_acc > best_acc:
            best_acc = va_acc
            best_state = {k: v.cpu() for k, v in eval_model.state_dict().items()}

    eval_model = ema.module if ema is not None else model
    eval_model.eval()
    with torch.no_grad():
        for imgs, labels in dl_va:
            imgs = imgs.to(device)
            logits = eval_model(imgs)
            preds = logits.argmax(1).cpu().numpy().tolist()
            labs = labels.numpy().tolist()
            print('val preds[:16]:', preds[:16])
            print('val labs [:16]:', labs[:16])
            break

    if best_state is not None:
        eval_model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES).to(device)
        eval_model.load_state_dict(best_state, strict=True)
    else:
        eval_model = ema.module if ema is not None else model
    va_acc, va_logits = validate(eval_model, dl_va)
    oof_preds[va_idx] = va_logits

    t_logits = infer_test(eval_model, test_dl)
    test_logits_accum += t_logits

    del model, ema, optimizer, scaler, ds_tr, ds_va, dl_tr, dl_va, eval_model, test_ds, test_dl
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print(f'Fold {fold_idx} done | best_va_acc {best_acc:.4f} | fold_time {(time.time()-t_fold)/60:.1f}m', flush=True)

    if fold_idx >= MAX_FOLDS:
        break

print(f'All folds done in {(time.time()-t_start_all)/60:.1f}m')

np.save('oof_logits.npy', oof_preds)
np.save('test_logits.npy', test_logits_accum)

test_probs = test_logits_accum / max(1, MAX_FOLDS)
test_pred = test_probs.argmax(1)
sub = pd.DataFrame({'image_id': test_files['image_id'], 'label': test_pred.astype(int)})
sub.to_csv('submission.csv', index=False)
print('Wrote submission.csv')
sub.head()

===== Fold 1/2 | train 9360 | valid 9361 =====


cfg: {'input_size': (3, 512, 512), 'crop_pct': 0.904, 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225)}
Majority class=3 | majority baseline acc on val=0.6154


val batch mean: [-0.1807, 0.3099, -0.4598]
val batch std : [1.05, 1.1089, 0.9553]


  it 50/390 | loss 1.7315 | acc 0.4433 | elapsed 13.0s


  it 100/390 | loss 1.4343 | acc 0.5437 | elapsed 25.2s


  it 150/390 | loss 1.3034 | acc 0.5872 | elapsed 37.8s


  it 200/390 | loss 1.2173 | acc 0.6179 | elapsed 50.1s


  it 250/390 | loss 1.1572 | acc 0.6362 | elapsed 62.4s


  it 300/390 | loss 1.1096 | acc 0.6517 | elapsed 74.7s


  it 350/390 | loss 1.0790 | acc 0.6632 | elapsed 87.0s


Fold 1 | Epoch 1/3 | tr_loss 1.0549 tr_acc 0.6731 | va_acc 0.3352 | elapsed_fold 2.1m


  it 50/390 | loss 0.8344 | acc 0.7392 | elapsed 13.2s


  it 100/390 | loss 0.8135 | acc 0.7483 | elapsed 25.5s


  it 150/390 | loss 0.7870 | acc 0.7619 | elapsed 37.9s


  it 200/390 | loss 0.7708 | acc 0.7677 | elapsed 50.3s


  it 250/390 | loss 0.7657 | acc 0.7702 | elapsed 62.7s


  it 300/390 | loss 0.7626 | acc 0.7725 | elapsed 75.1s


  it 350/390 | loss 0.7605 | acc 0.7732 | elapsed 87.9s


Fold 1 | Epoch 2/3 | tr_loss 0.7598 tr_acc 0.7738 | va_acc 0.6091 | elapsed_fold 4.2m


  it 50/390 | loss 0.7062 | acc 0.7975 | elapsed 13.0s


  it 100/390 | loss 0.7093 | acc 0.8004 | elapsed 25.5s


  it 150/390 | loss 0.7106 | acc 0.7992 | elapsed 37.9s


  it 200/390 | loss 0.7010 | acc 0.8021 | elapsed 50.4s


  it 250/390 | loss 0.6934 | acc 0.8060 | elapsed 62.7s


  it 300/390 | loss 0.6940 | acc 0.8040 | elapsed 75.1s


  it 350/390 | loss 0.6902 | acc 0.8067 | elapsed 87.6s


Fold 1 | Epoch 3/3 | tr_loss 0.6917 tr_acc 0.8062 | va_acc 0.7744 | elapsed_fold 6.2m


val preds[:16]: [0, 3, 1, 2, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 2]
val labs [:16]: [0, 3, 1, 2, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 4]


Fold 1 done | best_va_acc 0.7744 | fold_time 6.9m


All folds done in 6.9m
Wrote submission.csv


Unnamed: 0,image_id,label
0,1234294272.jpg,3
1,1234332763.jpg,3
2,1234375577.jpg,2
3,1234555380.jpg,2
4,1234571117.jpg,3


In [18]:
# Full 5-fold training: tf_efficientnet_b3_ns @512 with Mixup/CutMix + EMA + cosine schedule (12 epochs)
import os, gc, math, time, warnings, numpy as np, pandas as pd, torch, timm
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageOps
from timm.utils import ModelEmaV3
from timm.data import Mixup
from timm.loss import SoftTargetCrossEntropy

warnings.filterwarnings('ignore', category=UserWarning)
torch.backends.cudnn.benchmark = True
try:
    torch.set_float32_matmul_precision('high')
except Exception:
    pass

# Hyperparams per expert advice
MODEL_NAME = 'tf_efficientnet_b3_ns'
IMG_SIZE = 512
NUM_CLASSES = 5
FOLDS = 5
EPOCHS = 12
BATCH_SIZE = 24
BASE_LR = 2e-4
WD = 1e-4
EMA_DECAY = 0.9998
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)
NUM_WORKERS = min(8, os.cpu_count() or 2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class CassavaDS(Dataset):
    def __init__(self, df, img_dir, transforms=None, is_test=False):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transforms = transforms
        self.is_test = is_test
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = os.path.join(self.img_dir, row['image_id'])
        with Image.open(path) as im:
            im = ImageOps.exif_transpose(im).convert('RGB')
        img = im if self.transforms is None else self.transforms(im)
        if self.is_test:
            return img, row['image_id']
        return img, torch.tensor(int(row['label']), dtype=torch.long)

def get_model_and_transforms():
    model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=NUM_CLASSES).to(device)
    cfg = timm.data.resolve_data_config({}, model=model)
    cfg['input_size'] = (3, IMG_SIZE, IMG_SIZE)
    cfg['mean'] = IMAGENET_MEAN
    cfg['std'] = IMAGENET_STD
    train_tfms = timm.data.create_transform(is_training=True, **cfg, auto_augment='rand-m9-mstd0.5-inc1', re_prob=0.25, re_mode='pixel')
    valid_tfms = timm.data.create_transform(is_training=False, **cfg)
    print('cfg:', {'input_size': cfg['input_size'], 'crop_pct': float(cfg.get('crop_pct', 0.875)), 'interpolation': cfg.get('interpolation','bicubic'), 'mean': cfg['mean'], 'std': cfg['std']})
    return model, train_tfms, valid_tfms

def validate(model_eval, loader):
    model_eval.eval()
    correct, total = 0, 0
    logits_all = []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
                logits = model_eval(imgs)
            logits_all.append(logits.float().cpu().numpy())
            preds = logits.argmax(1)
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
    return correct / max(1, total), np.concatenate(logits_all, axis=0)

def infer_test(model_eval, loader):
    model_eval.eval()
    logits_all = []
    with torch.no_grad():
        for imgs, _ids in loader:
            imgs = imgs.to(device, non_blocking=True)
            with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
                logits = model_eval(imgs)
            logits_all.append(logits.float().cpu().numpy())
    return np.concatenate(logits_all, axis=0)

def run_training():
    skf = StratifiedKFold(n_splits=FOLDS, shuffle=True, random_state=SEED)
    df_shuf = df.sample(frac=1.0, random_state=SEED).reset_index(drop=True)
    test_df = pd.DataFrame({'image_id': sorted([p.name for p in test_dir.glob('*.jpg')])})
    oof_logits = np.zeros((len(df_shuf), NUM_CLASSES), dtype=np.float32)
    test_logits_sum = np.zeros((len(test_df), NUM_CLASSES), dtype=np.float32)

    fold = 0
    start_all = time.time()
    for tr_idx, va_idx in skf.split(df_shuf['image_id'], df_shuf['label']):
        fold += 1
        t_fold = time.time()
        print(f'===== Fold {fold}/{FOLDS} | train {len(tr_idx)} | valid {len(va_idx)} =====', flush=True)
        df_tr = df_shuf.iloc[tr_idx].reset_index(drop=True)
        df_va = df_shuf.iloc[va_idx].reset_index(drop=True)

        model, train_tfms, valid_tfms = get_model_and_transforms()
        ds_tr = CassavaDS(df_tr, str(train_dir), transforms=train_tfms)
        ds_va = CassavaDS(df_va, str(train_dir), transforms=valid_tfms)
        dl_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True, persistent_workers=False)
        dl_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=False)
        test_ds = CassavaDS(test_df, str(test_dir), transforms=valid_tfms, is_test=True)
        test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=False)

        # Mixup/CutMix + SoftTarget CE
        mixup_fn = Mixup(mixup_alpha=0.4, cutmix_alpha=1.0, prob=0.5, switch_prob=0.5, mode='batch', label_smoothing=0.0, num_classes=NUM_CLASSES)
        criterion = SoftTargetCrossEntropy()
        optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=WD)
        scaler = torch.amp.GradScaler('cuda', enabled=torch.cuda.is_available())
        ema = ModelEmaV3(model, decay=EMA_DECAY)

        total_steps = EPOCHS * len(dl_tr)
        warmup_steps = len(dl_tr)  # 1 epoch warmup
        global_step = 0
        best_acc = 0.0
        best_state = None

        for epoch in range(1, EPOCHS+1):
            model.train()
            running_loss = 0.0; correct = 0; total = 0; t0 = time.time()
            # Optionally turn off mixup last 2 epochs to sharpen
            use_mix = epoch <= EPOCHS - 2
            for it, (imgs, labels) in enumerate(dl_tr):
                step = global_step + it
                if step < warmup_steps:
                    lr_scale = step / max(1, warmup_steps)
                else:
                    progress = (step - warmup_steps) / max(1, (total_steps - warmup_steps))
                    lr_scale = 0.5 * (1 + math.cos(math.pi * progress))
                for pg in optimizer.param_groups:
                    pg['lr'] = BASE_LR * lr_scale

                imgs = imgs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                if use_mix and mixup_fn is not None:
                    imgs, targets = mixup_fn(imgs, labels)
                else:
                    # one-hot targets for SoftTarget CE when mixup disabled
                    targets = torch.zeros((labels.size(0), NUM_CLASSES), device=labels.device)
                    targets.scatter_(1, labels.unsqueeze(1), 1.0)

                optimizer.zero_grad(set_to_none=True)
                with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
                    logits = model(imgs)
                    loss = criterion(logits, targets)
                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                if ema is not None:
                    ema.update(model)

                running_loss += loss.item() * imgs.size(0)
                with torch.no_grad():
                    preds = logits.argmax(1)
                    correct += (preds == labels).sum().item()
                    total += imgs.size(0)
                if (it+1) % 50 == 0:
                    print(f'  it {it+1}/{len(dl_tr)} | loss {running_loss/max(1,total):.4f} | acc {correct/max(1,total):.4f} | elapsed {time.time()-t0:.1f}s', flush=True)

            global_step += len(dl_tr)
            eval_model = ema.module if ema is not None else model
            va_acc, _ = validate(eval_model, dl_va)
            print(f'Fold {fold} | Epoch {epoch}/{EPOCHS} | tr_loss {running_loss/max(1,total):.4f} tr_acc {correct/max(1,total):.4f} | va_acc {va_acc:.4f} | elapsed_fold {(time.time()-t_fold)/60:.1f}m', flush=True)
            if va_acc > best_acc:
                best_acc = va_acc
                best_state = {k: v.cpu() for k, v in eval_model.state_dict().items()}

        # Load best EMA state
        if best_state is not None:
            best_model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES).to(device)
            best_model.load_state_dict(best_state, strict=True)
        else:
            best_model = ema.module if ema is not None else model

        # OOF fill
        va_acc, va_logits = validate(best_model, dl_va)
        oof_logits[va_idx] = va_logits

        # Test inference accumulate
        t_logits = infer_test(best_model, test_dl)
        test_logits_sum += t_logits

        # Save best checkpoint
        ckpt_path = f'b3_fold{fold}_best.pth'
        torch.save({'state_dict': best_state, 'best_va_acc': best_acc}, ckpt_path)
        print(f'Fold {fold} done | best_va_acc {best_acc:.4f} | ckpt {ckpt_path} | fold_time {(time.time()-t_fold)/60:.1f}m', flush=True)

        # Cleanup
        del model, ema, optimizer, scaler, ds_tr, ds_va, dl_tr, dl_va, best_model, test_ds, test_dl
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f'All folds done in {(time.time()-start_all)/60:.1f}m')

    # Compute OOF accuracy
    oof_pred = oof_logits.argmax(1)
    oof_acc = (oof_pred == df_shuf['label'].values).mean()
    print(f'OOF accuracy: {oof_acc:.5f}')

    # Save artifacts
    np.save('oof_logits_b3.npy', oof_logits)
    np.save('test_logits_b3.npy', test_logits_sum)

    # Build a quick submission from averaged test logits (no TTA yet, 5-fold avg)
    test_pred = test_logits_sum.argmax(1)
    sub = pd.DataFrame({'image_id': test_df['image_id'], 'label': test_pred.astype(int)})
    sub.to_csv('submission_b3.csv', index=False)
    print('Wrote submission_b3.csv')

run_training()

===== Fold 1/5 | train 14976 | valid 3745 =====


cfg: {'input_size': (3, 512, 512), 'crop_pct': 0.904, 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225)}


  it 50/624 | loss 1.9825 | acc 0.2775 | elapsed 12.9s


  it 100/624 | loss 1.6974 | acc 0.4058 | elapsed 25.1s


  it 150/624 | loss 1.5368 | acc 0.4661 | elapsed 37.4s


  it 200/624 | loss 1.4332 | acc 0.5008 | elapsed 49.6s


  it 250/624 | loss 1.3470 | acc 0.5325 | elapsed 61.9s


  it 300/624 | loss 1.2734 | acc 0.5542 | elapsed 74.2s


  it 350/624 | loss 1.2229 | acc 0.5711 | elapsed 86.5s


  it 400/624 | loss 1.1794 | acc 0.5822 | elapsed 98.9s


  it 450/624 | loss 1.1561 | acc 0.5883 | elapsed 111.2s


  it 500/624 | loss 1.1229 | acc 0.5985 | elapsed 123.8s


  it 550/624 | loss 1.0984 | acc 0.6067 | elapsed 136.2s


  it 600/624 | loss 1.0763 | acc 0.6125 | elapsed 148.5s


Fold 1 | Epoch 1/12 | tr_loss 1.0701 tr_acc 0.6157 | va_acc 0.1207 | elapsed_fold 2.8m


  it 50/624 | loss 0.8218 | acc 0.6800 | elapsed 13.1s


  it 100/624 | loss 0.7682 | acc 0.6950 | elapsed 25.4s


  it 150/624 | loss 0.7584 | acc 0.7081 | elapsed 37.8s


  it 200/624 | loss 0.7671 | acc 0.7013 | elapsed 50.1s


  it 250/624 | loss 0.7683 | acc 0.6987 | elapsed 62.4s


  it 300/624 | loss 0.7626 | acc 0.7032 | elapsed 74.8s


  it 350/624 | loss 0.7634 | acc 0.7031 | elapsed 87.2s


  it 400/624 | loss 0.7560 | acc 0.7083 | elapsed 99.5s


  it 450/624 | loss 0.7443 | acc 0.7137 | elapsed 111.9s


  it 500/624 | loss 0.7431 | acc 0.7129 | elapsed 124.2s


  it 550/624 | loss 0.7421 | acc 0.7137 | elapsed 136.9s


  it 600/624 | loss 0.7379 | acc 0.7156 | elapsed 149.3s


Fold 1 | Epoch 2/12 | tr_loss 0.7356 tr_acc 0.7169 | va_acc 0.1557 | elapsed_fold 5.5m


  it 50/624 | loss 0.6773 | acc 0.7317 | elapsed 13.2s


  it 100/624 | loss 0.6761 | acc 0.7479 | elapsed 25.6s


  it 150/624 | loss 0.7047 | acc 0.7364 | elapsed 38.0s


  it 200/624 | loss 0.6996 | acc 0.7419 | elapsed 50.5s


  it 250/624 | loss 0.7054 | acc 0.7423 | elapsed 62.9s


  it 300/624 | loss 0.7038 | acc 0.7431 | elapsed 75.3s


  it 350/624 | loss 0.6973 | acc 0.7399 | elapsed 87.7s


  it 400/624 | loss 0.6785 | acc 0.7471 | elapsed 100.1s


  it 450/624 | loss 0.6727 | acc 0.7476 | elapsed 112.6s


  it 500/624 | loss 0.6800 | acc 0.7468 | elapsed 125.0s


  it 550/624 | loss 0.6824 | acc 0.7451 | elapsed 137.7s


  it 600/624 | loss 0.6833 | acc 0.7446 | elapsed 150.1s


Fold 1 | Epoch 3/12 | tr_loss 0.6846 tr_acc 0.7448 | va_acc 0.2798 | elapsed_fold 8.3m


  it 50/624 | loss 0.5573 | acc 0.8108 | elapsed 13.3s


  it 100/624 | loss 0.5901 | acc 0.7879 | elapsed 25.8s


  it 150/624 | loss 0.6088 | acc 0.7756 | elapsed 38.3s


  it 200/624 | loss 0.6091 | acc 0.7727 | elapsed 50.7s


  it 250/624 | loss 0.6183 | acc 0.7733 | elapsed 63.2s


  it 300/624 | loss 0.6145 | acc 0.7750 | elapsed 75.7s


  it 350/624 | loss 0.6189 | acc 0.7718 | elapsed 88.2s


  it 400/624 | loss 0.6223 | acc 0.7697 | elapsed 100.7s


  it 450/624 | loss 0.6274 | acc 0.7681 | elapsed 113.2s


  it 500/624 | loss 0.6289 | acc 0.7641 | elapsed 125.7s


  it 550/624 | loss 0.6248 | acc 0.7677 | elapsed 138.5s


  it 600/624 | loss 0.6269 | acc 0.7662 | elapsed 151.0s


Fold 1 | Epoch 4/12 | tr_loss 0.6256 tr_acc 0.7668 | va_acc 0.5223 | elapsed_fold 11.1m


  it 50/624 | loss 0.6078 | acc 0.7967 | elapsed 13.3s


  it 100/624 | loss 0.5876 | acc 0.8037 | elapsed 25.8s


  it 150/624 | loss 0.5923 | acc 0.7894 | elapsed 38.2s


  it 200/624 | loss 0.5796 | acc 0.7950 | elapsed 50.8s


  it 250/624 | loss 0.5844 | acc 0.7875 | elapsed 63.3s


  it 300/624 | loss 0.5783 | acc 0.7899 | elapsed 75.8s


  it 350/624 | loss 0.5771 | acc 0.7860 | elapsed 88.3s


  it 400/624 | loss 0.5850 | acc 0.7861 | elapsed 100.8s


  it 450/624 | loss 0.5867 | acc 0.7844 | elapsed 113.3s


  it 500/624 | loss 0.5905 | acc 0.7816 | elapsed 125.7s


  it 550/624 | loss 0.5900 | acc 0.7790 | elapsed 138.5s


  it 600/624 | loss 0.5947 | acc 0.7773 | elapsed 151.0s


Fold 1 | Epoch 5/12 | tr_loss 0.5938 tr_acc 0.7780 | va_acc 0.7148 | elapsed_fold 13.9m


  it 50/624 | loss 0.6205 | acc 0.7725 | elapsed 13.3s


  it 100/624 | loss 0.5973 | acc 0.7633 | elapsed 25.8s


  it 150/624 | loss 0.5787 | acc 0.7794 | elapsed 38.2s


  it 200/624 | loss 0.5712 | acc 0.7833 | elapsed 50.8s


  it 250/624 | loss 0.5767 | acc 0.7817 | elapsed 63.3s


  it 300/624 | loss 0.5748 | acc 0.7817 | elapsed 75.8s


  it 350/624 | loss 0.5752 | acc 0.7844 | elapsed 88.3s


  it 400/624 | loss 0.5718 | acc 0.7864 | elapsed 100.7s


  it 450/624 | loss 0.5802 | acc 0.7826 | elapsed 113.3s


  it 500/624 | loss 0.5824 | acc 0.7808 | elapsed 125.8s


  it 550/624 | loss 0.5830 | acc 0.7773 | elapsed 138.3s


  it 600/624 | loss 0.5818 | acc 0.7801 | elapsed 151.1s


Fold 1 | Epoch 6/12 | tr_loss 0.5814 tr_acc 0.7817 | va_acc 0.7995 | elapsed_fold 16.8m


  it 50/624 | loss 0.5678 | acc 0.8017 | elapsed 13.3s


  it 100/624 | loss 0.5747 | acc 0.8033 | elapsed 25.7s


  it 150/624 | loss 0.5650 | acc 0.7925 | elapsed 38.2s


  it 200/624 | loss 0.5484 | acc 0.7954 | elapsed 50.6s


  it 250/624 | loss 0.5625 | acc 0.7838 | elapsed 63.1s


  it 300/624 | loss 0.5565 | acc 0.7828 | elapsed 75.5s


  it 350/624 | loss 0.5696 | acc 0.7770 | elapsed 87.9s


  it 400/624 | loss 0.5714 | acc 0.7742 | elapsed 100.4s


  it 450/624 | loss 0.5757 | acc 0.7727 | elapsed 112.8s


  it 500/624 | loss 0.5685 | acc 0.7752 | elapsed 125.3s


  it 550/624 | loss 0.5656 | acc 0.7767 | elapsed 137.7s


  it 600/624 | loss 0.5630 | acc 0.7758 | elapsed 150.4s


Fold 1 | Epoch 7/12 | tr_loss 0.5628 tr_acc 0.7751 | va_acc 0.8387 | elapsed_fold 19.6m


  it 50/624 | loss 0.4903 | acc 0.8150 | elapsed 13.2s


  it 100/624 | loss 0.5012 | acc 0.8179 | elapsed 25.7s


  it 150/624 | loss 0.5234 | acc 0.8106 | elapsed 38.1s


  it 200/624 | loss 0.5170 | acc 0.8056 | elapsed 50.6s


  it 250/624 | loss 0.5142 | acc 0.8040 | elapsed 63.1s


  it 300/624 | loss 0.5230 | acc 0.8033 | elapsed 75.5s


  it 350/624 | loss 0.5279 | acc 0.7985 | elapsed 88.0s


  it 400/624 | loss 0.5287 | acc 0.8015 | elapsed 100.5s


  it 450/624 | loss 0.5313 | acc 0.8005 | elapsed 113.0s


  it 500/624 | loss 0.5364 | acc 0.7927 | elapsed 125.5s


  it 550/624 | loss 0.5403 | acc 0.7933 | elapsed 138.3s


  it 600/624 | loss 0.5427 | acc 0.7923 | elapsed 150.8s


Fold 1 | Epoch 8/12 | tr_loss 0.5423 tr_acc 0.7903 | va_acc 0.8585 | elapsed_fold 22.4m


  it 50/624 | loss 0.5224 | acc 0.7867 | elapsed 13.2s


  it 100/624 | loss 0.5857 | acc 0.7808 | elapsed 25.7s


  it 150/624 | loss 0.5817 | acc 0.7814 | elapsed 38.2s


  it 200/624 | loss 0.5744 | acc 0.7815 | elapsed 50.7s


  it 250/624 | loss 0.5680 | acc 0.7865 | elapsed 63.2s


  it 300/624 | loss 0.5552 | acc 0.7879 | elapsed 75.7s


  it 350/624 | loss 0.5507 | acc 0.7874 | elapsed 88.2s


  it 400/624 | loss 0.5556 | acc 0.7841 | elapsed 100.7s


  it 450/624 | loss 0.5469 | acc 0.7844 | elapsed 113.2s


  it 500/624 | loss 0.5490 | acc 0.7836 | elapsed 125.7s


  it 550/624 | loss 0.5467 | acc 0.7847 | elapsed 138.2s


  it 600/624 | loss 0.5473 | acc 0.7813 | elapsed 151.0s


Fold 1 | Epoch 9/12 | tr_loss 0.5429 tr_acc 0.7820 | va_acc 0.8686 | elapsed_fold 25.2m


  it 50/624 | loss 0.5838 | acc 0.7675 | elapsed 13.3s


  it 100/624 | loss 0.5899 | acc 0.7783 | elapsed 25.8s


  it 150/624 | loss 0.5694 | acc 0.7725 | elapsed 38.3s


  it 200/624 | loss 0.5617 | acc 0.7831 | elapsed 50.8s


  it 250/624 | loss 0.5569 | acc 0.7832 | elapsed 63.3s


  it 300/624 | loss 0.5516 | acc 0.7810 | elapsed 75.8s


  it 350/624 | loss 0.5476 | acc 0.7805 | elapsed 88.3s


  it 400/624 | loss 0.5500 | acc 0.7820 | elapsed 100.8s


  it 450/624 | loss 0.5424 | acc 0.7846 | elapsed 113.3s


  it 500/624 | loss 0.5419 | acc 0.7820 | elapsed 125.8s


  it 550/624 | loss 0.5394 | acc 0.7847 | elapsed 138.2s


  it 600/624 | loss 0.5399 | acc 0.7852 | elapsed 150.9s


Fold 1 | Epoch 10/12 | tr_loss 0.5404 tr_acc 0.7839 | va_acc 0.8772 | elapsed_fold 28.0m


  it 50/624 | loss 0.3537 | acc 0.8800 | elapsed 13.2s


  it 100/624 | loss 0.3605 | acc 0.8738 | elapsed 25.6s


  it 150/624 | loss 0.3614 | acc 0.8736 | elapsed 38.0s


  it 200/624 | loss 0.3635 | acc 0.8758 | elapsed 50.5s


  it 250/624 | loss 0.3607 | acc 0.8758 | elapsed 62.9s


  it 300/624 | loss 0.3613 | acc 0.8742 | elapsed 75.3s


  it 350/624 | loss 0.3586 | acc 0.8752 | elapsed 87.8s


  it 400/624 | loss 0.3582 | acc 0.8758 | elapsed 100.2s


  it 450/624 | loss 0.3586 | acc 0.8769 | elapsed 112.6s


  it 500/624 | loss 0.3579 | acc 0.8773 | elapsed 125.0s


  it 550/624 | loss 0.3569 | acc 0.8776 | elapsed 137.5s


  it 600/624 | loss 0.3557 | acc 0.8776 | elapsed 150.2s


Fold 1 | Epoch 11/12 | tr_loss 0.3554 tr_acc 0.8774 | va_acc 0.8849 | elapsed_fold 30.8m


  it 50/624 | loss 0.3457 | acc 0.8892 | elapsed 13.2s


  it 100/624 | loss 0.3424 | acc 0.8883 | elapsed 25.7s


  it 150/624 | loss 0.3317 | acc 0.8861 | elapsed 38.2s


  it 200/624 | loss 0.3432 | acc 0.8812 | elapsed 50.7s


  it 250/624 | loss 0.3465 | acc 0.8818 | elapsed 63.2s


  it 300/624 | loss 0.3426 | acc 0.8847 | elapsed 75.8s


  it 350/624 | loss 0.3460 | acc 0.8852 | elapsed 88.3s


  it 400/624 | loss 0.3450 | acc 0.8832 | elapsed 100.8s


  it 450/624 | loss 0.3444 | acc 0.8824 | elapsed 113.2s


  it 500/624 | loss 0.3464 | acc 0.8812 | elapsed 125.6s


  it 550/624 | loss 0.3457 | acc 0.8814 | elapsed 138.1s


  it 600/624 | loss 0.3498 | acc 0.8807 | elapsed 150.8s


Fold 1 | Epoch 12/12 | tr_loss 0.3508 tr_acc 0.8801 | va_acc 0.8884 | elapsed_fold 33.6m


Fold 1 done | best_va_acc 0.8884 | ckpt b3_fold1_best.pth | fold_time 33.9m


===== Fold 2/5 | train 14977 | valid 3744 =====


cfg: {'input_size': (3, 512, 512), 'crop_pct': 0.904, 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225)}


  it 50/624 | loss 2.1379 | acc 0.2567 | elapsed 13.2s


  it 100/624 | loss 1.7591 | acc 0.4138 | elapsed 25.6s


  it 150/624 | loss 1.5850 | acc 0.4708 | elapsed 38.0s


  it 200/624 | loss 1.4748 | acc 0.5008 | elapsed 50.4s


  it 250/624 | loss 1.4001 | acc 0.5293 | elapsed 62.9s


  it 300/624 | loss 1.3286 | acc 0.5536 | elapsed 75.3s


  it 350/624 | loss 1.2709 | acc 0.5739 | elapsed 87.8s


  it 400/624 | loss 1.2255 | acc 0.5857 | elapsed 100.3s


  it 450/624 | loss 1.1834 | acc 0.5960 | elapsed 112.7s


  it 500/624 | loss 1.1430 | acc 0.6048 | elapsed 125.2s


  it 550/624 | loss 1.1127 | acc 0.6142 | elapsed 137.9s


  it 600/624 | loss 1.0854 | acc 0.6214 | elapsed 150.4s


Fold 2 | Epoch 1/12 | tr_loss 1.0765 tr_acc 0.6231 | va_acc 0.0897 | elapsed_fold 2.8m


  it 50/624 | loss 0.7989 | acc 0.7308 | elapsed 13.2s


  it 100/624 | loss 0.7683 | acc 0.7262 | elapsed 25.7s


  it 150/624 | loss 0.7720 | acc 0.7269 | elapsed 38.1s


  it 200/624 | loss 0.7619 | acc 0.7273 | elapsed 50.6s


  it 250/624 | loss 0.7486 | acc 0.7310 | elapsed 63.1s


  it 300/624 | loss 0.7493 | acc 0.7344 | elapsed 75.6s


  it 350/624 | loss 0.7361 | acc 0.7393 | elapsed 88.1s


  it 400/624 | loss 0.7389 | acc 0.7379 | elapsed 100.6s


  it 450/624 | loss 0.7302 | acc 0.7370 | elapsed 113.1s


  it 500/624 | loss 0.7272 | acc 0.7376 | elapsed 125.6s


  it 550/624 | loss 0.7229 | acc 0.7366 | elapsed 138.4s


  it 600/624 | loss 0.7173 | acc 0.7378 | elapsed 150.9s


Fold 2 | Epoch 2/12 | tr_loss 0.7171 tr_acc 0.7378 | va_acc 0.1079 | elapsed_fold 5.6m


  it 50/624 | loss 0.6811 | acc 0.7400 | elapsed 13.2s


  it 100/624 | loss 0.6851 | acc 0.7529 | elapsed 25.8s


  it 150/624 | loss 0.6785 | acc 0.7608 | elapsed 38.3s


  it 200/624 | loss 0.6854 | acc 0.7552 | elapsed 50.7s


  it 250/624 | loss 0.6804 | acc 0.7558 | elapsed 63.2s


  it 300/624 | loss 0.6751 | acc 0.7525 | elapsed 75.7s


  it 350/624 | loss 0.6669 | acc 0.7538 | elapsed 88.2s


  it 400/624 | loss 0.6657 | acc 0.7508 | elapsed 100.7s


  it 450/624 | loss 0.6635 | acc 0.7532 | elapsed 113.1s


  it 500/624 | loss 0.6620 | acc 0.7506 | elapsed 125.6s


  it 550/624 | loss 0.6635 | acc 0.7516 | elapsed 138.3s


  it 600/624 | loss 0.6624 | acc 0.7504 | elapsed 150.8s


Fold 2 | Epoch 3/12 | tr_loss 0.6627 tr_acc 0.7501 | va_acc 0.2059 | elapsed_fold 8.4m


  it 50/624 | loss 0.6481 | acc 0.7517 | elapsed 13.2s


  it 100/624 | loss 0.6559 | acc 0.7617 | elapsed 25.7s


  it 150/624 | loss 0.6592 | acc 0.7633 | elapsed 38.2s


  it 200/624 | loss 0.6671 | acc 0.7556 | elapsed 50.6s


  it 250/624 | loss 0.6599 | acc 0.7578 | elapsed 63.0s


  it 300/624 | loss 0.6449 | acc 0.7617 | elapsed 75.5s


  it 350/624 | loss 0.6439 | acc 0.7620 | elapsed 88.0s


  it 400/624 | loss 0.6382 | acc 0.7616 | elapsed 100.4s


  it 450/624 | loss 0.6360 | acc 0.7583 | elapsed 112.9s


  it 500/624 | loss 0.6393 | acc 0.7570 | elapsed 125.4s


  it 550/624 | loss 0.6378 | acc 0.7593 | elapsed 138.1s


  it 600/624 | loss 0.6343 | acc 0.7591 | elapsed 150.6s


Fold 2 | Epoch 4/12 | tr_loss 0.6290 tr_acc 0.7624 | va_acc 0.4004 | elapsed_fold 11.2m


  it 50/624 | loss 0.5537 | acc 0.7958 | elapsed 13.3s


  it 100/624 | loss 0.5351 | acc 0.7987 | elapsed 25.8s


  it 150/624 | loss 0.5464 | acc 0.7869 | elapsed 38.4s


  it 200/624 | loss 0.5688 | acc 0.7769 | elapsed 50.9s


  it 250/624 | loss 0.5792 | acc 0.7750 | elapsed 63.5s


  it 300/624 | loss 0.5790 | acc 0.7715 | elapsed 76.1s


  it 350/624 | loss 0.5810 | acc 0.7690 | elapsed 88.6s


  it 400/624 | loss 0.5860 | acc 0.7609 | elapsed 101.2s


  it 450/624 | loss 0.5885 | acc 0.7645 | elapsed 113.7s


  it 500/624 | loss 0.5863 | acc 0.7695 | elapsed 126.3s


  it 550/624 | loss 0.5870 | acc 0.7694 | elapsed 139.1s


  it 600/624 | loss 0.5836 | acc 0.7707 | elapsed 151.7s


Fold 2 | Epoch 5/12 | tr_loss 0.5847 tr_acc 0.7707 | va_acc 0.6143 | elapsed_fold 14.0m


  it 50/624 | loss 0.6218 | acc 0.7783 | elapsed 13.2s


  it 100/624 | loss 0.6078 | acc 0.7771 | elapsed 25.7s


  it 150/624 | loss 0.5760 | acc 0.7844 | elapsed 38.1s


  it 200/624 | loss 0.5706 | acc 0.7873 | elapsed 50.6s


  it 250/624 | loss 0.5825 | acc 0.7777 | elapsed 63.0s


  it 300/624 | loss 0.5780 | acc 0.7839 | elapsed 75.5s


  it 350/624 | loss 0.5726 | acc 0.7869 | elapsed 87.9s


  it 400/624 | loss 0.5725 | acc 0.7873 | elapsed 100.4s


  it 450/624 | loss 0.5633 | acc 0.7891 | elapsed 112.8s


  it 500/624 | loss 0.5675 | acc 0.7858 | elapsed 125.6s


  it 550/624 | loss 0.5707 | acc 0.7866 | elapsed 138.1s


  it 600/624 | loss 0.5725 | acc 0.7837 | elapsed 150.5s


Fold 2 | Epoch 6/12 | tr_loss 0.5710 tr_acc 0.7843 | va_acc 0.7401 | elapsed_fold 16.8m


  it 50/624 | loss 0.5166 | acc 0.7817 | elapsed 13.3s


  it 100/624 | loss 0.5403 | acc 0.7804 | elapsed 25.8s


  it 150/624 | loss 0.5642 | acc 0.7778 | elapsed 38.3s


  it 200/624 | loss 0.5704 | acc 0.7735 | elapsed 50.7s


  it 250/624 | loss 0.5698 | acc 0.7772 | elapsed 63.2s


  it 300/624 | loss 0.5666 | acc 0.7749 | elapsed 75.7s


  it 350/624 | loss 0.5617 | acc 0.7743 | elapsed 88.3s


  it 400/624 | loss 0.5694 | acc 0.7698 | elapsed 100.8s


  it 450/624 | loss 0.5746 | acc 0.7686 | elapsed 113.3s


  it 500/624 | loss 0.5771 | acc 0.7686 | elapsed 126.2s


  it 550/624 | loss 0.5760 | acc 0.7692 | elapsed 138.7s


  it 600/624 | loss 0.5700 | acc 0.7722 | elapsed 151.3s


Fold 2 | Epoch 7/12 | tr_loss 0.5716 tr_acc 0.7724 | va_acc 0.8074 | elapsed_fold 19.6m


  it 50/624 | loss 0.4962 | acc 0.8033 | elapsed 13.3s


  it 100/624 | loss 0.5220 | acc 0.7983 | elapsed 25.8s


  it 150/624 | loss 0.5193 | acc 0.8025 | elapsed 38.2s


  it 200/624 | loss 0.5212 | acc 0.8017 | elapsed 50.7s


  it 250/624 | loss 0.5204 | acc 0.8048 | elapsed 63.1s


  it 300/624 | loss 0.5169 | acc 0.8007 | elapsed 75.6s


  it 350/624 | loss 0.5186 | acc 0.7939 | elapsed 88.1s


  it 400/624 | loss 0.5213 | acc 0.7925 | elapsed 100.6s


  it 450/624 | loss 0.5203 | acc 0.7956 | elapsed 113.3s


  it 500/624 | loss 0.5235 | acc 0.7938 | elapsed 125.8s


  it 550/624 | loss 0.5257 | acc 0.7927 | elapsed 138.2s


  it 600/624 | loss 0.5307 | acc 0.7898 | elapsed 150.7s


Fold 2 | Epoch 8/12 | tr_loss 0.5311 tr_acc 0.7891 | va_acc 0.8387 | elapsed_fold 22.4m


  it 50/624 | loss 0.5393 | acc 0.7933 | elapsed 13.3s


  it 100/624 | loss 0.5481 | acc 0.7950 | elapsed 25.8s


  it 150/624 | loss 0.5288 | acc 0.8058 | elapsed 38.3s


  it 200/624 | loss 0.5226 | acc 0.8054 | elapsed 50.7s


  it 250/624 | loss 0.5248 | acc 0.8005 | elapsed 63.1s


  it 300/624 | loss 0.5236 | acc 0.7992 | elapsed 75.6s


  it 350/624 | loss 0.5214 | acc 0.8001 | elapsed 88.0s


  it 400/624 | loss 0.5257 | acc 0.7955 | elapsed 100.4s


  it 450/624 | loss 0.5273 | acc 0.7914 | elapsed 112.8s


  it 500/624 | loss 0.5261 | acc 0.7948 | elapsed 125.5s


  it 550/624 | loss 0.5322 | acc 0.7904 | elapsed 137.9s


  it 600/624 | loss 0.5315 | acc 0.7915 | elapsed 150.4s


Fold 2 | Epoch 9/12 | tr_loss 0.5305 tr_acc 0.7907 | va_acc 0.8582 | elapsed_fold 25.2m


  it 50/624 | loss 0.5086 | acc 0.8117 | elapsed 13.2s


  it 100/624 | loss 0.5239 | acc 0.8033 | elapsed 25.7s


  it 150/624 | loss 0.5369 | acc 0.7883 | elapsed 38.2s


  it 200/624 | loss 0.5426 | acc 0.7927 | elapsed 50.7s


  it 250/624 | loss 0.5430 | acc 0.7910 | elapsed 63.2s


  it 300/624 | loss 0.5441 | acc 0.7856 | elapsed 75.7s


  it 350/624 | loss 0.5330 | acc 0.7942 | elapsed 88.2s


  it 400/624 | loss 0.5290 | acc 0.7991 | elapsed 100.7s


  it 450/624 | loss 0.5265 | acc 0.7967 | elapsed 113.5s


  it 500/624 | loss 0.5280 | acc 0.7964 | elapsed 126.0s


  it 550/624 | loss 0.5337 | acc 0.7939 | elapsed 138.4s


  it 600/624 | loss 0.5351 | acc 0.7936 | elapsed 151.0s


Fold 2 | Epoch 10/12 | tr_loss 0.5370 tr_acc 0.7921 | va_acc 0.8686 | elapsed_fold 28.0m


  it 50/624 | loss 0.4017 | acc 0.8550 | elapsed 13.3s


  it 100/624 | loss 0.3746 | acc 0.8667 | elapsed 25.8s


  it 150/624 | loss 0.3601 | acc 0.8758 | elapsed 38.3s


  it 200/624 | loss 0.3589 | acc 0.8767 | elapsed 50.8s


  it 250/624 | loss 0.3573 | acc 0.8773 | elapsed 63.2s


  it 300/624 | loss 0.3576 | acc 0.8774 | elapsed 75.7s


  it 350/624 | loss 0.3549 | acc 0.8781 | elapsed 88.1s


  it 400/624 | loss 0.3523 | acc 0.8792 | elapsed 100.6s


  it 450/624 | loss 0.3526 | acc 0.8782 | elapsed 113.0s


  it 500/624 | loss 0.3512 | acc 0.8791 | elapsed 125.8s


  it 550/624 | loss 0.3490 | acc 0.8800 | elapsed 138.3s


  it 600/624 | loss 0.3476 | acc 0.8805 | elapsed 150.7s


Fold 2 | Epoch 11/12 | tr_loss 0.3473 tr_acc 0.8805 | va_acc 0.8795 | elapsed_fold 30.8m


  it 50/624 | loss 0.3210 | acc 0.8892 | elapsed 13.2s


  it 100/624 | loss 0.3346 | acc 0.8800 | elapsed 25.6s


  it 150/624 | loss 0.3432 | acc 0.8767 | elapsed 38.0s


  it 200/624 | loss 0.3454 | acc 0.8781 | elapsed 50.4s


  it 250/624 | loss 0.3372 | acc 0.8798 | elapsed 62.8s


  it 300/624 | loss 0.3418 | acc 0.8783 | elapsed 75.2s


  it 350/624 | loss 0.3456 | acc 0.8771 | elapsed 87.6s


  it 400/624 | loss 0.3449 | acc 0.8774 | elapsed 100.1s


  it 450/624 | loss 0.3457 | acc 0.8781 | elapsed 112.5s


  it 500/624 | loss 0.3426 | acc 0.8792 | elapsed 124.9s


  it 550/624 | loss 0.3433 | acc 0.8785 | elapsed 137.4s


  it 600/624 | loss 0.3421 | acc 0.8797 | elapsed 150.1s


Fold 2 | Epoch 12/12 | tr_loss 0.3432 tr_acc 0.8797 | va_acc 0.8876 | elapsed_fold 33.6m


Fold 2 done | best_va_acc 0.8876 | ckpt b3_fold2_best.pth | fold_time 34.0m


===== Fold 3/5 | train 14977 | valid 3744 =====


cfg: {'input_size': (3, 512, 512), 'crop_pct': 0.904, 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225)}


  it 50/624 | loss 2.2115 | acc 0.2442 | elapsed 13.3s


  it 100/624 | loss 1.8745 | acc 0.3829 | elapsed 25.8s


  it 150/624 | loss 1.6811 | acc 0.4431 | elapsed 38.2s


  it 200/624 | loss 1.5480 | acc 0.4806 | elapsed 50.6s


  it 250/624 | loss 1.4515 | acc 0.5098 | elapsed 63.1s


  it 300/624 | loss 1.3846 | acc 0.5304 | elapsed 75.5s


  it 350/624 | loss 1.3225 | acc 0.5456 | elapsed 88.0s


  it 400/624 | loss 1.2660 | acc 0.5621 | elapsed 100.4s


  it 450/624 | loss 1.2248 | acc 0.5743 | elapsed 112.9s


  it 500/624 | loss 1.1816 | acc 0.5887 | elapsed 125.6s


  it 550/624 | loss 1.1531 | acc 0.5955 | elapsed 138.1s


  it 600/624 | loss 1.1231 | acc 0.6042 | elapsed 150.5s


Fold 3 | Epoch 1/12 | tr_loss 1.1099 tr_acc 0.6077 | va_acc 0.1330 | elapsed_fold 2.8m


  it 50/624 | loss 0.7961 | acc 0.7067 | elapsed 13.3s


  it 100/624 | loss 0.7754 | acc 0.7129 | elapsed 25.7s


  it 150/624 | loss 0.7654 | acc 0.7128 | elapsed 38.2s


  it 200/624 | loss 0.7625 | acc 0.7200 | elapsed 50.6s


  it 250/624 | loss 0.7591 | acc 0.7180 | elapsed 63.0s


  it 300/624 | loss 0.7648 | acc 0.7169 | elapsed 75.4s


  it 350/624 | loss 0.7570 | acc 0.7236 | elapsed 87.9s


  it 400/624 | loss 0.7592 | acc 0.7259 | elapsed 100.3s


  it 450/624 | loss 0.7570 | acc 0.7246 | elapsed 112.8s


  it 500/624 | loss 0.7551 | acc 0.7275 | elapsed 125.2s


  it 550/624 | loss 0.7503 | acc 0.7295 | elapsed 137.9s


  it 600/624 | loss 0.7490 | acc 0.7286 | elapsed 150.4s


Fold 3 | Epoch 2/12 | tr_loss 0.7433 tr_acc 0.7278 | va_acc 0.2332 | elapsed_fold 5.6m


  it 50/624 | loss 0.6503 | acc 0.7800 | elapsed 13.3s


  it 100/624 | loss 0.6635 | acc 0.7450 | elapsed 25.8s


  it 150/624 | loss 0.6629 | acc 0.7436 | elapsed 38.3s


  it 200/624 | loss 0.6610 | acc 0.7508 | elapsed 50.8s


  it 250/624 | loss 0.6618 | acc 0.7470 | elapsed 63.3s


  it 300/624 | loss 0.6644 | acc 0.7514 | elapsed 75.8s


  it 350/624 | loss 0.6800 | acc 0.7462 | elapsed 88.4s


  it 400/624 | loss 0.6775 | acc 0.7477 | elapsed 100.9s


  it 450/624 | loss 0.6748 | acc 0.7464 | elapsed 113.4s


  it 500/624 | loss 0.6732 | acc 0.7476 | elapsed 126.0s


  it 550/624 | loss 0.6752 | acc 0.7460 | elapsed 138.8s


  it 600/624 | loss 0.6725 | acc 0.7479 | elapsed 151.4s


Fold 3 | Epoch 3/12 | tr_loss 0.6730 tr_acc 0.7488 | va_acc 0.4693 | elapsed_fold 8.4m


  it 50/624 | loss 0.5760 | acc 0.7525 | elapsed 13.3s


  it 100/624 | loss 0.6464 | acc 0.7462 | elapsed 25.8s


  it 150/624 | loss 0.6156 | acc 0.7558 | elapsed 38.3s


  it 200/624 | loss 0.6183 | acc 0.7581 | elapsed 50.7s


  it 250/624 | loss 0.6231 | acc 0.7525 | elapsed 63.2s


  it 300/624 | loss 0.6153 | acc 0.7556 | elapsed 75.6s


  it 350/624 | loss 0.6254 | acc 0.7531 | elapsed 88.1s


  it 400/624 | loss 0.6341 | acc 0.7498 | elapsed 100.5s


  it 450/624 | loss 0.6290 | acc 0.7532 | elapsed 113.0s


  it 500/624 | loss 0.6268 | acc 0.7541 | elapsed 125.4s


  it 550/624 | loss 0.6239 | acc 0.7560 | elapsed 137.9s


  it 600/624 | loss 0.6248 | acc 0.7562 | elapsed 150.6s


Fold 3 | Epoch 4/12 | tr_loss 0.6237 tr_acc 0.7548 | va_acc 0.6856 | elapsed_fold 11.2m


  it 50/624 | loss 0.5706 | acc 0.8058 | elapsed 13.3s


  it 100/624 | loss 0.5882 | acc 0.7812 | elapsed 25.7s


  it 150/624 | loss 0.5849 | acc 0.7842 | elapsed 38.2s


  it 200/624 | loss 0.5898 | acc 0.7858 | elapsed 50.7s


  it 250/624 | loss 0.5956 | acc 0.7802 | elapsed 63.2s


  it 300/624 | loss 0.5969 | acc 0.7743 | elapsed 75.6s


  it 350/624 | loss 0.5983 | acc 0.7795 | elapsed 88.1s


  it 400/624 | loss 0.5906 | acc 0.7826 | elapsed 100.6s


  it 450/624 | loss 0.5759 | acc 0.7864 | elapsed 113.1s


  it 500/624 | loss 0.5802 | acc 0.7811 | elapsed 125.6s


  it 550/624 | loss 0.5801 | acc 0.7800 | elapsed 138.4s


  it 600/624 | loss 0.5819 | acc 0.7789 | elapsed 151.0s


Fold 3 | Epoch 5/12 | tr_loss 0.5849 tr_acc 0.7784 | va_acc 0.7874 | elapsed_fold 14.0m


  it 50/624 | loss 0.5644 | acc 0.7917 | elapsed 13.2s


  it 100/624 | loss 0.5658 | acc 0.7879 | elapsed 25.7s


  it 150/624 | loss 0.5753 | acc 0.7814 | elapsed 38.2s


  it 200/624 | loss 0.5640 | acc 0.7715 | elapsed 50.8s


  it 250/624 | loss 0.5585 | acc 0.7770 | elapsed 63.3s


  it 300/624 | loss 0.5720 | acc 0.7738 | elapsed 75.8s


  it 350/624 | loss 0.5830 | acc 0.7730 | elapsed 88.3s


  it 400/624 | loss 0.5937 | acc 0.7649 | elapsed 100.8s


  it 450/624 | loss 0.5900 | acc 0.7696 | elapsed 113.3s


  it 500/624 | loss 0.5823 | acc 0.7713 | elapsed 125.8s


  it 550/624 | loss 0.5847 | acc 0.7711 | elapsed 138.7s


  it 600/624 | loss 0.5806 | acc 0.7732 | elapsed 151.1s


Fold 3 | Epoch 6/12 | tr_loss 0.5826 tr_acc 0.7717 | va_acc 0.8283 | elapsed_fold 16.8m


  it 50/624 | loss 0.5147 | acc 0.7942 | elapsed 13.1s


  it 100/624 | loss 0.5358 | acc 0.7896 | elapsed 25.6s


  it 150/624 | loss 0.5486 | acc 0.7953 | elapsed 38.0s


  it 200/624 | loss 0.5618 | acc 0.7921 | elapsed 50.5s


  it 250/624 | loss 0.5669 | acc 0.7883 | elapsed 63.0s


  it 300/624 | loss 0.5642 | acc 0.7861 | elapsed 75.5s


  it 350/624 | loss 0.5750 | acc 0.7820 | elapsed 88.0s


  it 400/624 | loss 0.5702 | acc 0.7839 | elapsed 100.5s


  it 450/624 | loss 0.5693 | acc 0.7824 | elapsed 113.0s


  it 500/624 | loss 0.5645 | acc 0.7811 | elapsed 125.5s


  it 550/624 | loss 0.5624 | acc 0.7832 | elapsed 138.0s


  it 600/624 | loss 0.5606 | acc 0.7826 | elapsed 150.9s


Fold 3 | Epoch 7/12 | tr_loss 0.5604 tr_acc 0.7806 | va_acc 0.8499 | elapsed_fold 19.6m


  it 50/624 | loss 0.5951 | acc 0.7608 | elapsed 13.2s


  it 100/624 | loss 0.5636 | acc 0.7775 | elapsed 25.7s


  it 150/624 | loss 0.5522 | acc 0.7783 | elapsed 38.2s


  it 200/624 | loss 0.5546 | acc 0.7673 | elapsed 50.7s


  it 250/624 | loss 0.5572 | acc 0.7642 | elapsed 63.1s


  it 300/624 | loss 0.5530 | acc 0.7714 | elapsed 75.6s


  it 350/624 | loss 0.5450 | acc 0.7737 | elapsed 88.1s


  it 400/624 | loss 0.5453 | acc 0.7768 | elapsed 100.5s


  it 450/624 | loss 0.5395 | acc 0.7785 | elapsed 113.0s


  it 500/624 | loss 0.5426 | acc 0.7784 | elapsed 125.6s


  it 550/624 | loss 0.5430 | acc 0.7814 | elapsed 138.4s


  it 600/624 | loss 0.5382 | acc 0.7833 | elapsed 150.9s


Fold 3 | Epoch 8/12 | tr_loss 0.5378 tr_acc 0.7845 | va_acc 0.8627 | elapsed_fold 22.4m


  it 50/624 | loss 0.5020 | acc 0.8233 | elapsed 13.3s


  it 100/624 | loss 0.5111 | acc 0.8154 | elapsed 25.8s


  it 150/624 | loss 0.5374 | acc 0.8058 | elapsed 38.3s


  it 200/624 | loss 0.5475 | acc 0.8023 | elapsed 50.8s


  it 250/624 | loss 0.5462 | acc 0.8005 | elapsed 63.4s


  it 300/624 | loss 0.5397 | acc 0.8025 | elapsed 75.9s


  it 350/624 | loss 0.5343 | acc 0.8045 | elapsed 88.4s


  it 400/624 | loss 0.5391 | acc 0.8017 | elapsed 101.0s


  it 450/624 | loss 0.5358 | acc 0.8027 | elapsed 113.5s


  it 500/624 | loss 0.5345 | acc 0.7995 | elapsed 126.0s


  it 550/624 | loss 0.5315 | acc 0.7972 | elapsed 138.8s


  it 600/624 | loss 0.5307 | acc 0.7978 | elapsed 151.3s


Fold 3 | Epoch 9/12 | tr_loss 0.5303 tr_acc 0.7959 | va_acc 0.8705 | elapsed_fold 25.2m


  it 50/624 | loss 0.5060 | acc 0.8033 | elapsed 13.3s


  it 100/624 | loss 0.4974 | acc 0.8167 | elapsed 25.7s


  it 150/624 | loss 0.5157 | acc 0.8067 | elapsed 38.2s


  it 200/624 | loss 0.5189 | acc 0.8044 | elapsed 50.6s


  it 250/624 | loss 0.5157 | acc 0.8058 | elapsed 63.0s


  it 300/624 | loss 0.5160 | acc 0.7999 | elapsed 75.4s


  it 350/624 | loss 0.5147 | acc 0.7971 | elapsed 87.9s


  it 400/624 | loss 0.5190 | acc 0.7966 | elapsed 100.3s


  it 450/624 | loss 0.5163 | acc 0.7970 | elapsed 112.8s


  it 500/624 | loss 0.5153 | acc 0.7950 | elapsed 125.3s


  it 550/624 | loss 0.5143 | acc 0.7949 | elapsed 138.0s


  it 600/624 | loss 0.5161 | acc 0.7942 | elapsed 150.5s


Fold 3 | Epoch 10/12 | tr_loss 0.5150 tr_acc 0.7956 | va_acc 0.8763 | elapsed_fold 28.0m


  it 50/624 | loss 0.3359 | acc 0.8783 | elapsed 13.1s


  it 100/624 | loss 0.3199 | acc 0.8829 | elapsed 25.6s


  it 150/624 | loss 0.3370 | acc 0.8800 | elapsed 38.0s


  it 200/624 | loss 0.3420 | acc 0.8790 | elapsed 50.4s


  it 250/624 | loss 0.3387 | acc 0.8812 | elapsed 62.9s


  it 300/624 | loss 0.3343 | acc 0.8839 | elapsed 75.4s


  it 350/624 | loss 0.3331 | acc 0.8842 | elapsed 87.8s


  it 400/624 | loss 0.3326 | acc 0.8847 | elapsed 100.3s


  it 450/624 | loss 0.3390 | acc 0.8819 | elapsed 112.8s


  it 500/624 | loss 0.3404 | acc 0.8811 | elapsed 125.5s


  it 550/624 | loss 0.3414 | acc 0.8805 | elapsed 138.0s


  it 600/624 | loss 0.3406 | acc 0.8818 | elapsed 150.5s


Fold 3 | Epoch 11/12 | tr_loss 0.3385 tr_acc 0.8827 | va_acc 0.8809 | elapsed_fold 30.8m


  it 50/624 | loss 0.3053 | acc 0.8900 | elapsed 13.2s


  it 100/624 | loss 0.3276 | acc 0.8842 | elapsed 25.7s


  it 150/624 | loss 0.3312 | acc 0.8847 | elapsed 38.1s


  it 200/624 | loss 0.3354 | acc 0.8850 | elapsed 50.6s


  it 250/624 | loss 0.3334 | acc 0.8827 | elapsed 63.0s


  it 300/624 | loss 0.3368 | acc 0.8803 | elapsed 75.5s


  it 350/624 | loss 0.3321 | acc 0.8820 | elapsed 87.9s


  it 400/624 | loss 0.3356 | acc 0.8805 | elapsed 100.4s


  it 450/624 | loss 0.3339 | acc 0.8813 | elapsed 112.8s


  it 500/624 | loss 0.3345 | acc 0.8818 | elapsed 125.5s


  it 550/624 | loss 0.3381 | acc 0.8805 | elapsed 138.0s


  it 600/624 | loss 0.3393 | acc 0.8802 | elapsed 150.4s


Fold 3 | Epoch 12/12 | tr_loss 0.3389 tr_acc 0.8800 | va_acc 0.8835 | elapsed_fold 33.6m


Fold 3 done | best_va_acc 0.8835 | ckpt b3_fold3_best.pth | fold_time 34.0m


===== Fold 4/5 | train 14977 | valid 3744 =====


cfg: {'input_size': (3, 512, 512), 'crop_pct': 0.904, 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225)}


  it 50/624 | loss 2.6478 | acc 0.1700 | elapsed 13.3s


  it 100/624 | loss 2.0624 | acc 0.3525 | elapsed 25.7s


  it 150/624 | loss 1.7996 | acc 0.4244 | elapsed 38.1s


  it 200/624 | loss 1.6417 | acc 0.4683 | elapsed 50.5s


  it 250/624 | loss 1.5214 | acc 0.4998 | elapsed 62.9s


  it 300/624 | loss 1.4415 | acc 0.5196 | elapsed 75.4s


  it 350/624 | loss 1.3611 | acc 0.5423 | elapsed 87.8s


  it 400/624 | loss 1.3087 | acc 0.5564 | elapsed 100.3s


  it 450/624 | loss 1.2625 | acc 0.5706 | elapsed 112.7s


  it 500/624 | loss 1.2255 | acc 0.5813 | elapsed 125.2s


  it 550/624 | loss 1.1882 | acc 0.5924 | elapsed 138.0s


  it 600/624 | loss 1.1565 | acc 0.6038 | elapsed 150.5s


Fold 4 | Epoch 1/12 | tr_loss 1.1441 tr_acc 0.6070 | va_acc 0.1368 | elapsed_fold 2.8m


  it 50/624 | loss 0.7757 | acc 0.7100 | elapsed 13.2s


  it 100/624 | loss 0.7771 | acc 0.7017 | elapsed 25.7s


  it 150/624 | loss 0.7864 | acc 0.7067 | elapsed 38.2s


  it 200/624 | loss 0.7802 | acc 0.7108 | elapsed 50.8s


  it 250/624 | loss 0.7703 | acc 0.7145 | elapsed 63.3s


  it 300/624 | loss 0.7564 | acc 0.7142 | elapsed 75.8s


  it 350/624 | loss 0.7538 | acc 0.7120 | elapsed 88.3s


  it 400/624 | loss 0.7591 | acc 0.7110 | elapsed 100.8s


  it 450/624 | loss 0.7502 | acc 0.7145 | elapsed 113.3s


  it 500/624 | loss 0.7505 | acc 0.7157 | elapsed 125.8s


  it 550/624 | loss 0.7480 | acc 0.7180 | elapsed 138.5s


  it 600/624 | loss 0.7406 | acc 0.7197 | elapsed 150.9s


Fold 4 | Epoch 2/12 | tr_loss 0.7424 tr_acc 0.7208 | va_acc 0.2182 | elapsed_fold 5.6m


  it 50/624 | loss 0.6829 | acc 0.7458 | elapsed 13.3s


  it 100/624 | loss 0.6491 | acc 0.7446 | elapsed 25.8s


  it 150/624 | loss 0.6559 | acc 0.7361 | elapsed 38.2s


  it 200/624 | loss 0.6583 | acc 0.7427 | elapsed 50.7s


  it 250/624 | loss 0.6594 | acc 0.7452 | elapsed 63.1s


  it 300/624 | loss 0.6696 | acc 0.7411 | elapsed 75.6s


  it 350/624 | loss 0.6709 | acc 0.7411 | elapsed 88.0s


  it 400/624 | loss 0.6771 | acc 0.7384 | elapsed 100.5s


  it 450/624 | loss 0.6736 | acc 0.7394 | elapsed 112.9s


  it 500/624 | loss 0.6741 | acc 0.7426 | elapsed 125.4s


  it 550/624 | loss 0.6717 | acc 0.7408 | elapsed 138.2s


  it 600/624 | loss 0.6704 | acc 0.7415 | elapsed 150.6s


Fold 4 | Epoch 3/12 | tr_loss 0.6691 tr_acc 0.7426 | va_acc 0.4663 | elapsed_fold 8.4m


  it 50/624 | loss 0.6169 | acc 0.7692 | elapsed 13.3s


  it 100/624 | loss 0.6354 | acc 0.7583 | elapsed 25.7s


  it 150/624 | loss 0.6538 | acc 0.7564 | elapsed 38.2s


  it 200/624 | loss 0.6641 | acc 0.7504 | elapsed 50.7s


  it 250/624 | loss 0.6666 | acc 0.7555 | elapsed 63.2s


  it 300/624 | loss 0.6570 | acc 0.7478 | elapsed 75.6s


  it 350/624 | loss 0.6533 | acc 0.7449 | elapsed 88.1s


  it 400/624 | loss 0.6530 | acc 0.7442 | elapsed 100.6s


  it 450/624 | loss 0.6487 | acc 0.7480 | elapsed 113.1s


  it 500/624 | loss 0.6451 | acc 0.7512 | elapsed 125.6s


  it 550/624 | loss 0.6366 | acc 0.7554 | elapsed 138.4s


  it 600/624 | loss 0.6313 | acc 0.7574 | elapsed 150.9s


Fold 4 | Epoch 4/12 | tr_loss 0.6309 tr_acc 0.7585 | va_acc 0.6838 | elapsed_fold 11.2m


  it 50/624 | loss 0.5745 | acc 0.7933 | elapsed 13.2s


  it 100/624 | loss 0.5710 | acc 0.7913 | elapsed 25.7s


  it 150/624 | loss 0.5638 | acc 0.7756 | elapsed 38.1s


  it 200/624 | loss 0.5582 | acc 0.7748 | elapsed 50.6s


  it 250/624 | loss 0.5682 | acc 0.7718 | elapsed 63.1s


  it 300/624 | loss 0.5778 | acc 0.7699 | elapsed 75.6s


  it 350/624 | loss 0.5801 | acc 0.7686 | elapsed 88.0s


  it 400/624 | loss 0.5834 | acc 0.7698 | elapsed 100.5s


  it 450/624 | loss 0.5827 | acc 0.7719 | elapsed 112.9s


  it 500/624 | loss 0.5805 | acc 0.7745 | elapsed 125.3s


  it 550/624 | loss 0.5830 | acc 0.7755 | elapsed 138.0s


  it 600/624 | loss 0.5898 | acc 0.7729 | elapsed 150.5s


Fold 4 | Epoch 5/12 | tr_loss 0.5899 tr_acc 0.7738 | va_acc 0.7887 | elapsed_fold 14.0m


  it 50/624 | loss 0.5317 | acc 0.8092 | elapsed 13.2s


  it 100/624 | loss 0.5389 | acc 0.8008 | elapsed 25.7s


  it 150/624 | loss 0.5638 | acc 0.7903 | elapsed 38.1s


  it 200/624 | loss 0.5744 | acc 0.7842 | elapsed 50.6s


  it 250/624 | loss 0.5794 | acc 0.7837 | elapsed 63.1s


  it 300/624 | loss 0.5822 | acc 0.7817 | elapsed 75.6s


  it 350/624 | loss 0.5799 | acc 0.7813 | elapsed 88.1s


  it 400/624 | loss 0.5800 | acc 0.7823 | elapsed 100.5s


  it 450/624 | loss 0.5830 | acc 0.7831 | elapsed 113.0s


  it 500/624 | loss 0.5824 | acc 0.7800 | elapsed 125.5s


  it 550/624 | loss 0.5810 | acc 0.7792 | elapsed 138.4s


  it 600/624 | loss 0.5852 | acc 0.7771 | elapsed 150.9s


Fold 4 | Epoch 6/12 | tr_loss 0.5851 tr_acc 0.7781 | va_acc 0.8296 | elapsed_fold 16.8m


  it 50/624 | loss 0.6333 | acc 0.7550 | elapsed 13.3s


  it 100/624 | loss 0.5916 | acc 0.7775 | elapsed 25.8s


  it 150/624 | loss 0.5765 | acc 0.7906 | elapsed 38.4s


  it 200/624 | loss 0.5610 | acc 0.7898 | elapsed 51.0s


  it 250/624 | loss 0.5649 | acc 0.7902 | elapsed 63.5s


  it 300/624 | loss 0.5675 | acc 0.7887 | elapsed 76.1s


  it 350/624 | loss 0.5603 | acc 0.7918 | elapsed 88.6s


  it 400/624 | loss 0.5543 | acc 0.7929 | elapsed 101.1s


  it 450/624 | loss 0.5555 | acc 0.7926 | elapsed 113.6s


  it 500/624 | loss 0.5580 | acc 0.7899 | elapsed 126.1s


  it 550/624 | loss 0.5577 | acc 0.7904 | elapsed 139.0s


  it 600/624 | loss 0.5543 | acc 0.7908 | elapsed 151.5s


Fold 4 | Epoch 7/12 | tr_loss 0.5566 tr_acc 0.7889 | va_acc 0.8504 | elapsed_fold 19.6m


  it 50/624 | loss 0.5942 | acc 0.7642 | elapsed 13.3s


  it 100/624 | loss 0.5753 | acc 0.7792 | elapsed 25.7s


  it 150/624 | loss 0.5732 | acc 0.7842 | elapsed 38.1s


  it 200/624 | loss 0.5673 | acc 0.7815 | elapsed 50.6s


  it 250/624 | loss 0.5543 | acc 0.7910 | elapsed 63.1s


  it 300/624 | loss 0.5547 | acc 0.7944 | elapsed 75.5s


  it 350/624 | loss 0.5523 | acc 0.7848 | elapsed 88.0s


  it 400/624 | loss 0.5526 | acc 0.7860 | elapsed 100.4s


  it 450/624 | loss 0.5520 | acc 0.7859 | elapsed 112.9s


  it 500/624 | loss 0.5572 | acc 0.7821 | elapsed 125.4s


  it 550/624 | loss 0.5531 | acc 0.7823 | elapsed 137.8s


  it 600/624 | loss 0.5557 | acc 0.7816 | elapsed 150.6s


Fold 4 | Epoch 8/12 | tr_loss 0.5537 tr_acc 0.7818 | va_acc 0.8595 | elapsed_fold 22.4m


  it 50/624 | loss 0.5088 | acc 0.8000 | elapsed 13.3s


  it 100/624 | loss 0.5650 | acc 0.7913 | elapsed 25.8s


  it 150/624 | loss 0.5507 | acc 0.7872 | elapsed 38.3s


  it 200/624 | loss 0.5489 | acc 0.7852 | elapsed 50.8s


  it 250/624 | loss 0.5556 | acc 0.7782 | elapsed 63.3s


  it 300/624 | loss 0.5513 | acc 0.7840 | elapsed 75.8s


  it 350/624 | loss 0.5487 | acc 0.7848 | elapsed 88.3s


  it 400/624 | loss 0.5420 | acc 0.7885 | elapsed 100.8s


  it 450/624 | loss 0.5451 | acc 0.7911 | elapsed 113.3s


  it 500/624 | loss 0.5513 | acc 0.7896 | elapsed 125.8s


  it 550/624 | loss 0.5501 | acc 0.7889 | elapsed 138.6s


  it 600/624 | loss 0.5490 | acc 0.7886 | elapsed 151.1s


Fold 4 | Epoch 9/12 | tr_loss 0.5473 tr_acc 0.7895 | va_acc 0.8699 | elapsed_fold 25.2m


  it 50/624 | loss 0.5147 | acc 0.7733 | elapsed 13.2s


  it 100/624 | loss 0.5293 | acc 0.7675 | elapsed 25.7s


  it 150/624 | loss 0.5180 | acc 0.7808 | elapsed 38.2s


  it 200/624 | loss 0.5249 | acc 0.7840 | elapsed 50.7s


  it 250/624 | loss 0.5136 | acc 0.7913 | elapsed 63.2s


  it 300/624 | loss 0.5061 | acc 0.7907 | elapsed 75.6s


  it 350/624 | loss 0.5145 | acc 0.7860 | elapsed 88.1s


  it 400/624 | loss 0.5112 | acc 0.7855 | elapsed 100.6s


  it 450/624 | loss 0.5169 | acc 0.7870 | elapsed 113.1s


  it 500/624 | loss 0.5167 | acc 0.7893 | elapsed 125.5s


  it 550/624 | loss 0.5160 | acc 0.7875 | elapsed 138.3s


  it 600/624 | loss 0.5163 | acc 0.7873 | elapsed 150.7s


Fold 4 | Epoch 10/12 | tr_loss 0.5170 tr_acc 0.7881 | va_acc 0.8747 | elapsed_fold 28.0m


  it 50/624 | loss 0.3626 | acc 0.8767 | elapsed 13.2s


  it 100/624 | loss 0.3641 | acc 0.8738 | elapsed 25.6s


  it 150/624 | loss 0.3705 | acc 0.8758 | elapsed 38.0s


  it 200/624 | loss 0.3632 | acc 0.8767 | elapsed 50.4s


  it 250/624 | loss 0.3674 | acc 0.8772 | elapsed 62.8s


  it 300/624 | loss 0.3640 | acc 0.8768 | elapsed 75.2s


  it 350/624 | loss 0.3635 | acc 0.8789 | elapsed 87.6s


  it 400/624 | loss 0.3577 | acc 0.8810 | elapsed 99.9s


  it 450/624 | loss 0.3547 | acc 0.8814 | elapsed 112.3s


  it 500/624 | loss 0.3510 | acc 0.8825 | elapsed 124.7s


  it 550/624 | loss 0.3539 | acc 0.8808 | elapsed 137.5s


  it 600/624 | loss 0.3533 | acc 0.8805 | elapsed 149.9s


Fold 4 | Epoch 11/12 | tr_loss 0.3572 tr_acc 0.8787 | va_acc 0.8801 | elapsed_fold 30.8m


  it 50/624 | loss 0.3678 | acc 0.8825 | elapsed 13.3s


  it 100/624 | loss 0.3572 | acc 0.8842 | elapsed 25.8s


  it 150/624 | loss 0.3601 | acc 0.8833 | elapsed 38.4s


  it 200/624 | loss 0.3551 | acc 0.8821 | elapsed 50.9s


  it 250/624 | loss 0.3570 | acc 0.8813 | elapsed 63.4s


  it 300/624 | loss 0.3512 | acc 0.8821 | elapsed 75.8s


  it 350/624 | loss 0.3527 | acc 0.8811 | elapsed 88.4s


  it 400/624 | loss 0.3481 | acc 0.8812 | elapsed 100.9s


  it 450/624 | loss 0.3492 | acc 0.8806 | elapsed 113.4s


  it 500/624 | loss 0.3486 | acc 0.8808 | elapsed 125.9s


  it 550/624 | loss 0.3485 | acc 0.8809 | elapsed 138.7s


  it 600/624 | loss 0.3503 | acc 0.8794 | elapsed 151.1s


Fold 4 | Epoch 12/12 | tr_loss 0.3479 tr_acc 0.8802 | va_acc 0.8849 | elapsed_fold 33.6m


Fold 4 done | best_va_acc 0.8849 | ckpt b3_fold4_best.pth | fold_time 34.0m


===== Fold 5/5 | train 14977 | valid 3744 =====


cfg: {'input_size': (3, 512, 512), 'crop_pct': 0.904, 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225)}


  it 50/624 | loss 2.0106 | acc 0.2950 | elapsed 13.2s


  it 100/624 | loss 1.7256 | acc 0.4092 | elapsed 25.7s


  it 150/624 | loss 1.5639 | acc 0.4647 | elapsed 38.1s


  it 200/624 | loss 1.4584 | acc 0.5004 | elapsed 50.6s


  it 250/624 | loss 1.3683 | acc 0.5280 | elapsed 63.0s


  it 300/624 | loss 1.3108 | acc 0.5479 | elapsed 75.5s


  it 350/624 | loss 1.2648 | acc 0.5615 | elapsed 87.9s


  it 400/624 | loss 1.2106 | acc 0.5795 | elapsed 100.4s


  it 450/624 | loss 1.1804 | acc 0.5887 | elapsed 112.8s


  it 500/624 | loss 1.1504 | acc 0.5978 | elapsed 125.3s


  it 550/624 | loss 1.1228 | acc 0.6060 | elapsed 137.7s


  it 600/624 | loss 1.0975 | acc 0.6131 | elapsed 150.5s


Fold 5 | Epoch 1/12 | tr_loss 1.0898 tr_acc 0.6159 | va_acc 0.1293 | elapsed_fold 2.8m


  it 50/624 | loss 0.7925 | acc 0.7158 | elapsed 13.1s


  it 100/624 | loss 0.7742 | acc 0.7175 | elapsed 25.5s


  it 150/624 | loss 0.8056 | acc 0.7064 | elapsed 38.0s


  it 200/624 | loss 0.7903 | acc 0.7067 | elapsed 50.6s


  it 250/624 | loss 0.7836 | acc 0.7123 | elapsed 63.1s


  it 300/624 | loss 0.7808 | acc 0.7071 | elapsed 75.7s


  it 350/624 | loss 0.7808 | acc 0.7023 | elapsed 88.3s


  it 400/624 | loss 0.7653 | acc 0.7084 | elapsed 100.8s


  it 450/624 | loss 0.7568 | acc 0.7122 | elapsed 113.4s


  it 500/624 | loss 0.7578 | acc 0.7120 | elapsed 126.0s


  it 550/624 | loss 0.7561 | acc 0.7123 | elapsed 138.5s


  it 600/624 | loss 0.7524 | acc 0.7154 | elapsed 151.4s


Fold 5 | Epoch 2/12 | tr_loss 0.7508 tr_acc 0.7133 | va_acc 0.1795 | elapsed_fold 5.6m


  it 50/624 | loss 0.7008 | acc 0.7508 | elapsed 13.3s


  it 100/624 | loss 0.7057 | acc 0.7338 | elapsed 25.9s


  it 150/624 | loss 0.6959 | acc 0.7408 | elapsed 38.4s


  it 200/624 | loss 0.6752 | acc 0.7529 | elapsed 50.9s


  it 250/624 | loss 0.6703 | acc 0.7540 | elapsed 63.4s


  it 300/624 | loss 0.6706 | acc 0.7514 | elapsed 75.9s


  it 350/624 | loss 0.6759 | acc 0.7469 | elapsed 88.4s


  it 400/624 | loss 0.6835 | acc 0.7455 | elapsed 100.9s


  it 450/624 | loss 0.6846 | acc 0.7444 | elapsed 113.4s


  it 500/624 | loss 0.6867 | acc 0.7424 | elapsed 125.9s


  it 550/624 | loss 0.6883 | acc 0.7433 | elapsed 138.4s


  it 600/624 | loss 0.6844 | acc 0.7449 | elapsed 151.2s


Fold 5 | Epoch 3/12 | tr_loss 0.6821 tr_acc 0.7442 | va_acc 0.2861 | elapsed_fold 8.4m


  it 50/624 | loss 0.6143 | acc 0.7642 | elapsed 13.1s


  it 100/624 | loss 0.6192 | acc 0.7562 | elapsed 25.5s


  it 150/624 | loss 0.6483 | acc 0.7408 | elapsed 38.0s


  it 200/624 | loss 0.6359 | acc 0.7510 | elapsed 50.5s


  it 250/624 | loss 0.6435 | acc 0.7540 | elapsed 63.0s


  it 300/624 | loss 0.6378 | acc 0.7536 | elapsed 75.4s


  it 350/624 | loss 0.6309 | acc 0.7546 | elapsed 87.9s


  it 400/624 | loss 0.6193 | acc 0.7599 | elapsed 100.4s


  it 450/624 | loss 0.6258 | acc 0.7590 | elapsed 112.9s


  it 500/624 | loss 0.6251 | acc 0.7612 | elapsed 125.3s


  it 550/624 | loss 0.6246 | acc 0.7626 | elapsed 137.8s


  it 600/624 | loss 0.6302 | acc 0.7567 | elapsed 150.6s


Fold 5 | Epoch 4/12 | tr_loss 0.6308 tr_acc 0.7553 | va_acc 0.4330 | elapsed_fold 11.2m


  it 50/624 | loss 0.6537 | acc 0.7150 | elapsed 13.2s


  it 100/624 | loss 0.6489 | acc 0.7292 | elapsed 25.7s


  it 150/624 | loss 0.6477 | acc 0.7253 | elapsed 38.2s


  it 200/624 | loss 0.6302 | acc 0.7383 | elapsed 50.7s


  it 250/624 | loss 0.6137 | acc 0.7442 | elapsed 63.1s


  it 300/624 | loss 0.6002 | acc 0.7526 | elapsed 75.6s


  it 350/624 | loss 0.6033 | acc 0.7531 | elapsed 88.0s


  it 400/624 | loss 0.6060 | acc 0.7540 | elapsed 100.4s


  it 450/624 | loss 0.5980 | acc 0.7553 | elapsed 112.9s


  it 500/624 | loss 0.6077 | acc 0.7528 | elapsed 125.3s


  it 550/624 | loss 0.6032 | acc 0.7542 | elapsed 137.7s


  it 600/624 | loss 0.6044 | acc 0.7556 | elapsed 150.2s


Fold 5 | Epoch 5/12 | tr_loss 0.6063 tr_acc 0.7566 | va_acc 0.6138 | elapsed_fold 14.0m


  it 50/624 | loss 0.6366 | acc 0.7475 | elapsed 13.2s


  it 100/624 | loss 0.6072 | acc 0.7600 | elapsed 25.7s


  it 150/624 | loss 0.6009 | acc 0.7583 | elapsed 38.2s


  it 200/624 | loss 0.5872 | acc 0.7704 | elapsed 50.8s


  it 250/624 | loss 0.5883 | acc 0.7683 | elapsed 63.3s


  it 300/624 | loss 0.5864 | acc 0.7690 | elapsed 75.7s


  it 350/624 | loss 0.5850 | acc 0.7700 | elapsed 88.2s


  it 400/624 | loss 0.5852 | acc 0.7731 | elapsed 100.7s


  it 450/624 | loss 0.5827 | acc 0.7797 | elapsed 113.2s


  it 500/624 | loss 0.5810 | acc 0.7795 | elapsed 125.7s


  it 550/624 | loss 0.5813 | acc 0.7773 | elapsed 138.2s


  it 600/624 | loss 0.5844 | acc 0.7760 | elapsed 150.7s


Fold 5 | Epoch 6/12 | tr_loss 0.5838 tr_acc 0.7751 | va_acc 0.7433 | elapsed_fold 16.8m


  it 50/624 | loss 0.5294 | acc 0.7742 | elapsed 13.5s


  it 100/624 | loss 0.5357 | acc 0.7875 | elapsed 25.9s


  it 150/624 | loss 0.5407 | acc 0.7847 | elapsed 38.4s


  it 200/624 | loss 0.5543 | acc 0.7840 | elapsed 50.9s


  it 250/624 | loss 0.5563 | acc 0.7860 | elapsed 63.4s


  it 300/624 | loss 0.5612 | acc 0.7822 | elapsed 75.9s


  it 350/624 | loss 0.5659 | acc 0.7804 | elapsed 88.5s


  it 400/624 | loss 0.5664 | acc 0.7831 | elapsed 101.0s


  it 450/624 | loss 0.5648 | acc 0.7807 | elapsed 113.6s


  it 500/624 | loss 0.5689 | acc 0.7778 | elapsed 126.1s


  it 550/624 | loss 0.5658 | acc 0.7772 | elapsed 138.6s


  it 600/624 | loss 0.5660 | acc 0.7775 | elapsed 151.1s


Fold 5 | Epoch 7/12 | tr_loss 0.5657 tr_acc 0.7743 | va_acc 0.8106 | elapsed_fold 19.6m


  it 50/624 | loss 0.5771 | acc 0.7575 | elapsed 13.4s


  it 100/624 | loss 0.5759 | acc 0.7754 | elapsed 25.9s


  it 150/624 | loss 0.5604 | acc 0.7850 | elapsed 38.4s


  it 200/624 | loss 0.5563 | acc 0.7848 | elapsed 50.9s


  it 250/624 | loss 0.5514 | acc 0.7922 | elapsed 63.4s


  it 300/624 | loss 0.5574 | acc 0.7899 | elapsed 75.9s


  it 350/624 | loss 0.5582 | acc 0.7902 | elapsed 88.4s


  it 400/624 | loss 0.5572 | acc 0.7941 | elapsed 100.8s


  it 450/624 | loss 0.5568 | acc 0.7930 | elapsed 113.3s


  it 500/624 | loss 0.5545 | acc 0.7919 | elapsed 125.7s


  it 550/624 | loss 0.5569 | acc 0.7914 | elapsed 138.2s


  it 600/624 | loss 0.5572 | acc 0.7910 | elapsed 150.7s


Fold 5 | Epoch 8/12 | tr_loss 0.5548 tr_acc 0.7919 | va_acc 0.8438 | elapsed_fold 22.4m


  it 50/624 | loss 0.5292 | acc 0.7742 | elapsed 13.3s


  it 100/624 | loss 0.5345 | acc 0.7688 | elapsed 26.0s


  it 150/624 | loss 0.5322 | acc 0.7789 | elapsed 38.5s


  it 200/624 | loss 0.5354 | acc 0.7783 | elapsed 51.0s


  it 250/624 | loss 0.5470 | acc 0.7757 | elapsed 63.4s


  it 300/624 | loss 0.5412 | acc 0.7804 | elapsed 75.8s


  it 350/624 | loss 0.5529 | acc 0.7812 | elapsed 88.3s


  it 400/624 | loss 0.5472 | acc 0.7850 | elapsed 100.7s


  it 450/624 | loss 0.5442 | acc 0.7874 | elapsed 113.1s


  it 500/624 | loss 0.5425 | acc 0.7877 | elapsed 125.6s


  it 550/624 | loss 0.5470 | acc 0.7837 | elapsed 138.0s


  it 600/624 | loss 0.5459 | acc 0.7848 | elapsed 150.4s


Fold 5 | Epoch 9/12 | tr_loss 0.5458 tr_acc 0.7845 | va_acc 0.8640 | elapsed_fold 25.2m


  it 50/624 | loss 0.5422 | acc 0.7958 | elapsed 13.3s


  it 100/624 | loss 0.5398 | acc 0.7887 | elapsed 25.7s


  it 150/624 | loss 0.5222 | acc 0.7861 | elapsed 38.5s


  it 200/624 | loss 0.5153 | acc 0.7867 | elapsed 51.0s


  it 250/624 | loss 0.5169 | acc 0.7800 | elapsed 63.5s


  it 300/624 | loss 0.5063 | acc 0.7814 | elapsed 75.9s


  it 350/624 | loss 0.5092 | acc 0.7804 | elapsed 88.4s


  it 400/624 | loss 0.5065 | acc 0.7809 | elapsed 101.0s


  it 450/624 | loss 0.5047 | acc 0.7800 | elapsed 113.5s


  it 500/624 | loss 0.5109 | acc 0.7760 | elapsed 126.1s


  it 550/624 | loss 0.5143 | acc 0.7772 | elapsed 138.6s


  it 600/624 | loss 0.5214 | acc 0.7763 | elapsed 151.2s


Fold 5 | Epoch 10/12 | tr_loss 0.5245 tr_acc 0.7758 | va_acc 0.8761 | elapsed_fold 28.0m


  it 50/624 | loss 0.3224 | acc 0.8900 | elapsed 13.2s


  it 100/624 | loss 0.3365 | acc 0.8812 | elapsed 25.7s


  it 150/624 | loss 0.3368 | acc 0.8808 | elapsed 38.5s


  it 200/624 | loss 0.3434 | acc 0.8794 | elapsed 51.0s


  it 250/624 | loss 0.3534 | acc 0.8760 | elapsed 63.4s


  it 300/624 | loss 0.3539 | acc 0.8765 | elapsed 75.9s


  it 350/624 | loss 0.3538 | acc 0.8767 | elapsed 88.3s


  it 400/624 | loss 0.3518 | acc 0.8774 | elapsed 100.7s


  it 450/624 | loss 0.3544 | acc 0.8763 | elapsed 113.1s


  it 500/624 | loss 0.3527 | acc 0.8772 | elapsed 125.5s


  it 550/624 | loss 0.3492 | acc 0.8789 | elapsed 137.9s


  it 600/624 | loss 0.3475 | acc 0.8792 | elapsed 150.3s


Fold 5 | Epoch 11/12 | tr_loss 0.3484 tr_acc 0.8787 | va_acc 0.8830 | elapsed_fold 30.8m


  it 50/624 | loss 0.3678 | acc 0.8708 | elapsed 13.1s


  it 100/624 | loss 0.3756 | acc 0.8704 | elapsed 25.6s


  it 150/624 | loss 0.3775 | acc 0.8733 | elapsed 38.3s


  it 200/624 | loss 0.3705 | acc 0.8756 | elapsed 50.7s


  it 250/624 | loss 0.3639 | acc 0.8763 | elapsed 63.2s


  it 300/624 | loss 0.3655 | acc 0.8746 | elapsed 75.6s


  it 350/624 | loss 0.3626 | acc 0.8762 | elapsed 88.0s


  it 400/624 | loss 0.3555 | acc 0.8780 | elapsed 100.5s


  it 450/624 | loss 0.3572 | acc 0.8764 | elapsed 112.9s


  it 500/624 | loss 0.3601 | acc 0.8752 | elapsed 125.3s


  it 550/624 | loss 0.3588 | acc 0.8758 | elapsed 137.8s


  it 600/624 | loss 0.3569 | acc 0.8757 | elapsed 150.2s


Fold 5 | Epoch 12/12 | tr_loss 0.3561 tr_acc 0.8757 | va_acc 0.8865 | elapsed_fold 33.6m


Fold 5 done | best_va_acc 0.8865 | ckpt b3_fold5_best.pth | fold_time 34.0m


All folds done in 169.8m
OOF accuracy: 0.88617
Wrote submission_b3.csv


In [23]:
# Full 5-fold training: convnext_base.fb_in22k_ft_in1k @512 with Mixup/CutMix + EMA + cosine schedule (12 epochs)
import os, gc, math, time, warnings, numpy as np, pandas as pd, torch, timm
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageOps
from timm.utils import ModelEmaV3
from timm.data import Mixup
from timm.loss import SoftTargetCrossEntropy

warnings.filterwarnings('ignore', category=UserWarning)
torch.backends.cudnn.benchmark = True
try:
    torch.set_float32_matmul_precision('high')
except Exception:
    pass

# Hyperparams per expert advice
MODEL_NAME_CNX = 'convnext_base.fb_in22k_ft_in1k'
IMG_SIZE = 512
NUM_CLASSES = 5
FOLDS = 5
EPOCHS = 12
BATCH_SIZE = 20  # convnext is heavier; adjust if OOM
BASE_LR = 8e-4
WD = 0.05
EMA_DECAY = 0.9998
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)
NUM_WORKERS = min(8, os.cpu_count() or 2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class CassavaDSConv(Dataset):
    def __init__(self, df, img_dir, transforms=None, is_test=False):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transforms = transforms
        self.is_test = is_test
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = os.path.join(self.img_dir, row['image_id'])
        with Image.open(path) as im:
            im = ImageOps.exif_transpose(im).convert('RGB')
        img = im if self.transforms is None else self.transforms(im)
        if self.is_test:
            return img, row['image_id']
        return img, torch.tensor(int(row['label']), dtype=torch.long)

def get_model_and_transforms_convnext():
    model = timm.create_model(MODEL_NAME_CNX, pretrained=True, num_classes=NUM_CLASSES).to(device)
    cfg = timm.data.resolve_data_config({}, model=model)
    cfg['input_size'] = (3, IMG_SIZE, IMG_SIZE)
    cfg['mean'] = IMAGENET_MEAN
    cfg['std'] = IMAGENET_STD
    train_tfms = timm.data.create_transform(is_training=True, **cfg, auto_augment='rand-m9-mstd0.5-inc1', re_prob=0.25, re_mode='pixel')
    valid_tfms = timm.data.create_transform(is_training=False, **cfg)
    print('CNX cfg:', {'input_size': cfg['input_size'], 'crop_pct': float(cfg.get('crop_pct', 0.875)), 'interpolation': cfg.get('interpolation','bicubic'), 'mean': cfg['mean'], 'std': cfg['std']})
    return model, train_tfms, valid_tfms

def validate(model_eval, loader):
    model_eval.eval()
    correct, total = 0, 0
    logits_all = []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
                logits = model_eval(imgs)
            logits_all.append(logits.float().cpu().numpy())
            preds = logits.argmax(1)
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
    return correct / max(1, total), np.concatenate(logits_all, axis=0)

def infer_test(model_eval, loader):
    model_eval.eval()
    logits_all = []
    with torch.no_grad():
        for imgs, _ids in loader:
            imgs = imgs.to(device, non_blocking=True)
            with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
                logits = model_eval(imgs)
            logits_all.append(logits.float().cpu().numpy())
    return np.concatenate(logits_all, axis=0)

def run_training_convnext():
    skf = StratifiedKFold(n_splits=FOLDS, shuffle=True, random_state=SEED)
    df_shuf = df.sample(frac=1.0, random_state=SEED).reset_index(drop=True)
    test_df = pd.DataFrame({'image_id': sorted([p.name for p in test_dir.glob('*.jpg')])})
    oof_logits = np.zeros((len(df_shuf), NUM_CLASSES), dtype=np.float32)
    test_logits_sum = np.zeros((len(test_df), NUM_CLASSES), dtype=np.float32)

    fold = 0
    start_all = time.time()
    for tr_idx, va_idx in skf.split(df_shuf['image_id'], df_shuf['label']):
        fold += 1
        t_fold = time.time()
        print(f'===== ConvNeXt Fold {fold}/{FOLDS} | train {len(tr_idx)} | valid {len(va_idx)} =====', flush=True)
        df_tr = df_shuf.iloc[tr_idx].reset_index(drop=True)
        df_va = df_shuf.iloc[va_idx].reset_index(drop=True)

        model, train_tfms, valid_tfms = get_model_and_transforms_convnext()
        ds_tr = CassavaDSConv(df_tr, str(train_dir), transforms=train_tfms)
        ds_va = CassavaDSConv(df_va, str(train_dir), transforms=valid_tfms)
        dl_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True, persistent_workers=False)
        dl_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=False)
        test_ds = CassavaDSConv(test_df, str(test_dir), transforms=valid_tfms, is_test=True)
        test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=False)

        mixup_fn = Mixup(mixup_alpha=0.4, cutmix_alpha=1.0, prob=0.5, switch_prob=0.5, mode='batch', label_smoothing=0.0, num_classes=NUM_CLASSES)
        criterion = SoftTargetCrossEntropy()
        optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=WD)
        scaler = torch.amp.GradScaler('cuda', enabled=torch.cuda.is_available())
        ema = ModelEmaV3(model, decay=EMA_DECAY)

        total_steps = EPOCHS * len(dl_tr)
        warmup_steps = len(dl_tr)  # 1 epoch warmup
        global_step = 0
        best_acc = 0.0
        best_state = None

        for epoch in range(1, EPOCHS+1):
            model.train()
            running_loss = 0.0; correct = 0; total = 0; t0 = time.time()
            use_mix = epoch <= EPOCHS - 2
            for it, (imgs, labels) in enumerate(dl_tr):
                step = global_step + it
                if step < warmup_steps:
                    lr_scale = step / max(1, warmup_steps)
                else:
                    progress = (step - warmup_steps) / max(1, (total_steps - warmup_steps))
                    lr_scale = 0.5 * (1 + math.cos(math.pi * progress))
                for pg in optimizer.param_groups:
                    pg['lr'] = BASE_LR * lr_scale

                imgs = imgs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                if use_mix and mixup_fn is not None:
                    imgs, targets = mixup_fn(imgs, labels)
                else:
                    targets = torch.zeros((labels.size(0), NUM_CLASSES), device=labels.device)
                    targets.scatter_(1, labels.unsqueeze(1), 1.0)

                optimizer.zero_grad(set_to_none=True)
                with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
                    logits = model(imgs)
                    loss = criterion(logits, targets)
                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                if ema is not None:
                    ema.update(model)

                running_loss += loss.item() * imgs.size(0)
                with torch.no_grad():
                    preds = logits.argmax(1)
                    correct += (preds == labels).sum().item()
                    total += imgs.size(0)
                if (it+1) % 50 == 0:
                    print(f'  it {it+1}/{len(dl_tr)} | loss {running_loss/max(1,total):.4f} | acc {correct/max(1,total):.4f} | elapsed {time.time()-t0:.1f}s', flush=True)

            global_step += len(dl_tr)
            eval_model = ema.module if ema is not None else model
            va_acc, _ = validate(eval_model, dl_va)
            print(f'ConvNeXt Fold {fold} | Epoch {epoch}/{EPOCHS} | tr_loss {running_loss/max(1,total):.4f} tr_acc {correct/max(1,total):.4f} | va_acc {va_acc:.4f} | elapsed_fold {(time.time()-t_fold)/60:.1f}m', flush=True)
            if va_acc > best_acc:
                best_acc = va_acc
                best_state = {k: v.cpu() for k, v in eval_model.state_dict().items()}

        if best_state is not None:
            best_model = timm.create_model(MODEL_NAME_CNX, pretrained=False, num_classes=NUM_CLASSES).to(device)
            best_model.load_state_dict(best_state, strict=True)
        else:
            best_model = ema.module if ema is not None else model

        va_acc, va_logits = validate(best_model, dl_va)
        oof_logits[va_idx] = va_logits
        t_logits = infer_test(best_model, test_dl)
        test_logits_sum += t_logits

        ckpt_path = f'convnext_fold{fold}_best.pth'
        torch.save({'state_dict': best_state, 'best_va_acc': best_acc}, ckpt_path)
        print(f'ConvNeXt Fold {fold} done | best_va_acc {best_acc:.4f} | ckpt {ckpt_path} | fold_time {(time.time()-t_fold)/60:.1f}m', flush=True)

        del model, ema, optimizer, scaler, ds_tr, ds_va, dl_tr, dl_va, best_model, test_ds, test_dl
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f'ConvNeXt all folds done in {(time.time()-start_all)/60:.1f}m')
    oof_pred = oof_logits.argmax(1)
    oof_acc = (oof_pred == df_shuf['label'].values).mean()
    print(f'ConvNeXt OOF accuracy: {oof_acc:.5f}')
    np.save('oof_logits_convnext.npy', oof_logits)
    np.save('test_logits_convnext.npy', test_logits_sum)
    test_pred = test_logits_sum.argmax(1)
    sub = pd.DataFrame({'image_id': test_df['image_id'], 'label': test_pred.astype(int)})
    sub.to_csv('submission_convnext.csv', index=False)
    print('Wrote submission_convnext.csv')

# To run after B3 completes: run_training_convnext()

In [27]:
# TTA inference and ensembling helpers (B3 4x TTA; optional blend later)
import os, numpy as np, pandas as pd, torch, timm, time
from pathlib import Path
from PIL import Image, ImageOps
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

class TestTTADS(Dataset):
    def __init__(self, img_dir, files, tfms):
        self.img_dir = Path(img_dir)
        self.files = files
        self.tfms = tfms
    def __len__(self): return len(self.files)
    def __getitem__(self, idx):
        fname = self.files[idx]
        with Image.open(self.img_dir / fname) as im:
            im = ImageOps.exif_transpose(im).convert('RGB')
        return im, fname

def collate_pil(batch):
    # Collate a batch of (PIL.Image, filename) into lists to avoid default_collate errors
    imgs = [b[0] for b in batch]
    names = [b[1] for b in batch]
    return imgs, names

def build_eval_tfms_for(model_name='tf_efficientnet_b3_ns', img_size=512):
    model = timm.create_model(model_name, pretrained=False, num_classes=5).to(device)
    cfg = timm.data.resolve_data_config({}, model=model)
    cfg['input_size'] = (3, img_size, img_size)
    cfg['mean'] = IMAGENET_MEAN
    cfg['std']  = IMAGENET_STD
    tfms = timm.data.create_transform(is_training=False, **cfg)
    return tfms, cfg

@torch.no_grad()
def infer_tta_b3(checkpoints, img_dir='test_images', batch_size=32, img_size=512):
    files = sorted([p.name for p in Path(img_dir).glob('*.jpg')])
    tfms, cfg = build_eval_tfms_for('tf_efficientnet_b3_ns', img_size=img_size)
    ds = TestTTADS(img_dir, files, tfms)
    dl = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,  # avoid MP issues with PIL + custom collate
        pin_memory=True,
        collate_fn=collate_pil
    )

    def apply_flip(img, flip_id):
        # 0: none, 1: hflip, 2: vflip, 3: hvflip
        if flip_id == 1: img = F.hflip(img)
        elif flip_id == 2: img = F.vflip(img)
        elif flip_id == 3: img = F.vflip(F.hflip(img))
        return img

    all_logits = np.zeros((len(files), 5), dtype=np.float32)
    for ckpt_path in checkpoints:
        if not Path(ckpt_path).exists():
            print(f'Skip missing checkpoint {ckpt_path}')
            continue
        model = timm.create_model('tf_efficientnet_b3_ns', pretrained=False, num_classes=5).to(device)
        sd = torch.load(ckpt_path, map_location='cpu')['state_dict']
        model.load_state_dict(sd, strict=True)
        model.eval()
        fold_logits = np.zeros_like(all_logits)
        ptr = 0
        t0 = time.time()
        for imgs_pil, names in dl:
            bs = len(imgs_pil)
            logits_sum = torch.zeros((bs, 5), device=device)
            for f in range(4):
                batch = [apply_flip(img, f) for img in imgs_pil]
                batch = torch.stack([ds.tfms(b) for b in batch]).to(device, non_blocking=True)
                with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
                    logits = model(batch)
                logits_sum += logits.float()
            logits_avg = (logits_sum / 4.0).cpu().numpy()
            fold_logits[ptr:ptr+bs] = logits_avg
            ptr += bs
        all_logits += fold_logits
        print(f'Inferred {ckpt_path} in {time.time()-t0:.1f}s')
        del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    all_logits /= max(1, len(checkpoints))
    return files, all_logits

# Generic TTA for any timm model (e.g., convnext_base)
@torch.no_grad()
def infer_tta_model(model_name, checkpoints, img_dir='test_images', batch_size=32, img_size=512):
    files = sorted([p.name for p in Path(img_dir).glob('*.jpg')])
    tfms, cfg = build_eval_tfms_for(model_name, img_size=img_size)
    ds = TestTTADS(img_dir, files, tfms)
    dl = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        collate_fn=collate_pil
    )

    def apply_flip(img, flip_id):
        if flip_id == 1: img = F.hflip(img)
        elif flip_id == 2: img = F.vflip(img)
        elif flip_id == 3: img = F.vflip(F.hflip(img))
        return img

    all_logits = np.zeros((len(files), 5), dtype=np.float32)
    for ckpt_path in checkpoints:
        if not Path(ckpt_path).exists():
            print(f'Skip missing checkpoint {ckpt_path}')
            continue
        model = timm.create_model(model_name, pretrained=False, num_classes=5).to(device)
        sd = torch.load(ckpt_path, map_location='cpu')['state_dict']
        model.load_state_dict(sd, strict=True)
        model.eval()
        fold_logits = np.zeros_like(all_logits)
        ptr = 0
        t0 = time.time()
        for imgs_pil, names in dl:
            bs = len(imgs_pil)
            logits_sum = torch.zeros((bs, 5), device=device)
            for f in range(4):
                batch = [apply_flip(img, f) for img in imgs_pil]
                batch = torch.stack([ds.tfms(b) for b in batch]).to(device, non_blocking=True)
                with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
                    logits = model(batch)
                logits_sum += logits.float()
            logits_avg = (logits_sum / 4.0).cpu().numpy()
            fold_logits[ptr:ptr+bs] = logits_avg
            ptr += bs
        all_logits += fold_logits
        print(f'Inferred {ckpt_path} in {time.time()-t0:.1f}s')
        del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    all_logits /= max(1, len(checkpoints))
    return files, all_logits

def write_submission_from_logits(files, logits, out_csv):
    pred = logits.argmax(1)
    sub = pd.DataFrame({'image_id': files, 'label': pred.astype(int)})
    sub.to_csv(out_csv, index=False)
    print(f'Wrote {out_csv}')

# If B3 folds finished, you can run:
# files, logits_b3_tta = infer_tta_b3([f'b3_fold{i}_best.pth' for i in range(1,6)], img_size=512, batch_size=24)
# np.save('test_logits_b3_tta.npy', logits_b3_tta)
# write_submission_from_logits(files, logits_b3_tta, 'submission_b3_tta.csv')

# Optional simple blend if both model logits exist:
def try_blend_and_write():
    test_df = pd.DataFrame({'image_id': sorted([p.name for p in Path('test_images').glob('*.jpg')])})
    paths = {'b3': 'test_logits_b3.npy', 'cnx': 'test_logits_convnext.npy', 'b3tta': 'test_logits_b3_tta.npy'}
    logits_list = []
    names = []
    for name, p in paths.items():
        if Path(p).exists():
            logits = np.load(p)
            logits_list.append(logits)
            names.append(name)
    if len(logits_list) == 0:
        print('No logits found to blend yet.')
        return
    # Equal-weight blend
    blend = np.mean(np.stack(logits_list, axis=0), axis=0)
    pred = blend.argmax(1)
    sub = pd.DataFrame({'image_id': test_df['image_id'], 'label': pred.astype(int)})
    sub.to_csv('submission_ensemble.csv', index=False)
    print(f'Wrote submission_ensemble.csv from components: {names}')

# After both runs, call try_blend_and_write() to produce ensemble submission.

In [22]:
# Run 4-flip TTA for B3 folds and write submission
checkpoints = [f'b3_fold{i}_best.pth' for i in range(1,6)]
print('Running B3 4x TTA on checkpoints:', checkpoints, flush=True)
files, logits_b3_tta = infer_tta_b3(checkpoints, img_dir='test_images', batch_size=24, img_size=512)
np.save('test_logits_b3_tta.npy', logits_b3_tta)
write_submission_from_logits(files, logits_b3_tta, 'submission_b3_tta.csv')
print('B3 TTA done. logits shape:', logits_b3_tta.shape)

Running B3 4x TTA on checkpoints: ['b3_fold1_best.pth', 'b3_fold2_best.pth', 'b3_fold3_best.pth', 'b3_fold4_best.pth', 'b3_fold5_best.pth']


Inferred b3_fold1_best.pth in 111.1s


Inferred b3_fold2_best.pth in 106.1s


Inferred b3_fold3_best.pth in 106.3s


Inferred b3_fold4_best.pth in 105.6s


Inferred b3_fold5_best.pth in 106.6s
Wrote submission_b3_tta.csv
B3 TTA done. logits shape: (2676, 5)


In [24]:
# Safety baseline submit + start ConvNeXt training
import os, shutil
if os.path.exists('submission_b3_tta.csv'):
    shutil.copyfile('submission_b3_tta.csv', 'submission.csv')
    print('Copied submission_b3_tta.csv -> submission.csv')
else:
    print('WARNING: submission_b3_tta.csv not found; baseline not copied')

print('Starting ConvNeXt-Base 5-fold training...')
run_training_convnext()

Copied submission_b3_tta.csv -> submission.csv
Starting ConvNeXt-Base 5-fold training...
===== ConvNeXt Fold 1/5 | train 14976 | valid 3745 =====


CNX cfg: {'input_size': (3, 512, 512), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225)}


  it 50/748 | loss 1.1595 | acc 0.5960 | elapsed 47.7s


  it 100/748 | loss 1.0482 | acc 0.6245 | elapsed 68.7s


  it 150/748 | loss 0.9415 | acc 0.6573 | elapsed 89.7s


  it 200/748 | loss 0.8997 | acc 0.6720 | elapsed 110.8s


  it 250/748 | loss 0.8665 | acc 0.6816 | elapsed 131.9s


  it 300/748 | loss 0.8500 | acc 0.6880 | elapsed 153.0s


  it 350/748 | loss 0.8440 | acc 0.6900 | elapsed 174.6s


  it 400/748 | loss 0.8296 | acc 0.6954 | elapsed 195.9s


  it 450/748 | loss 0.8085 | acc 0.7003 | elapsed 217.3s


  it 500/748 | loss 0.8106 | acc 0.7005 | elapsed 238.7s


  it 550/748 | loss 0.8000 | acc 0.7058 | elapsed 260.2s


  it 600/748 | loss 0.7944 | acc 0.7079 | elapsed 281.7s


  it 650/748 | loss 0.7959 | acc 0.7078 | elapsed 303.2s


  it 700/748 | loss 0.7902 | acc 0.7104 | elapsed 324.8s


ConvNeXt Fold 1 | Epoch 1/12 | tr_loss 0.7892 tr_acc 0.7109 | va_acc 0.6665 | elapsed_fold 6.3m


  it 50/748 | loss 0.7710 | acc 0.6560 | elapsed 22.0s


  it 100/748 | loss 0.7677 | acc 0.6700 | elapsed 43.3s


  it 150/748 | loss 0.7708 | acc 0.6767 | elapsed 64.6s


  it 200/748 | loss 0.7553 | acc 0.6973 | elapsed 86.0s


  it 250/748 | loss 0.7585 | acc 0.7020 | elapsed 107.3s


  it 300/748 | loss 0.7563 | acc 0.7063 | elapsed 128.6s


  it 350/748 | loss 0.7501 | acc 0.7101 | elapsed 150.3s


  it 400/748 | loss 0.7589 | acc 0.7076 | elapsed 171.6s


  it 450/748 | loss 0.7604 | acc 0.7073 | elapsed 193.0s


  it 500/748 | loss 0.7526 | acc 0.7134 | elapsed 214.4s


  it 550/748 | loss 0.7523 | acc 0.7145 | elapsed 235.9s


  it 600/748 | loss 0.7490 | acc 0.7117 | elapsed 257.4s


  it 650/748 | loss 0.7511 | acc 0.7132 | elapsed 278.9s


  it 700/748 | loss 0.7472 | acc 0.7146 | elapsed 300.4s


ConvNeXt Fold 1 | Epoch 2/12 | tr_loss 0.7485 tr_acc 0.7161 | va_acc 0.7773 | elapsed_fold 12.1m


  it 50/748 | loss 0.6844 | acc 0.7560 | elapsed 22.3s


  it 100/748 | loss 0.6986 | acc 0.7255 | elapsed 43.8s


  it 150/748 | loss 0.6845 | acc 0.7390 | elapsed 65.2s


  it 200/748 | loss 0.7073 | acc 0.7312 | elapsed 86.5s


  it 250/748 | loss 0.7079 | acc 0.7358 | elapsed 107.9s


  it 300/748 | loss 0.7115 | acc 0.7378 | elapsed 129.3s


  it 350/748 | loss 0.7007 | acc 0.7436 | elapsed 150.7s


  it 400/748 | loss 0.6998 | acc 0.7399 | elapsed 172.5s


  it 450/748 | loss 0.7028 | acc 0.7396 | elapsed 193.9s


  it 500/748 | loss 0.7060 | acc 0.7383 | elapsed 215.4s


  it 550/748 | loss 0.7103 | acc 0.7327 | elapsed 236.9s


  it 600/748 | loss 0.7102 | acc 0.7325 | elapsed 258.4s


  it 650/748 | loss 0.7095 | acc 0.7334 | elapsed 279.9s


  it 700/748 | loss 0.7061 | acc 0.7354 | elapsed 301.3s


ConvNeXt Fold 1 | Epoch 3/12 | tr_loss 0.6975 tr_acc 0.7374 | va_acc 0.8339 | elapsed_fold 17.9m


  it 50/748 | loss 0.6037 | acc 0.7430 | elapsed 22.3s


  it 100/748 | loss 0.6029 | acc 0.7625 | elapsed 43.8s


  it 150/748 | loss 0.5783 | acc 0.7723 | elapsed 65.4s


  it 200/748 | loss 0.5782 | acc 0.7758 | elapsed 86.9s


  it 250/748 | loss 0.5914 | acc 0.7736 | elapsed 108.4s


  it 300/748 | loss 0.6016 | acc 0.7690 | elapsed 129.9s


  it 350/748 | loss 0.5960 | acc 0.7713 | elapsed 151.4s


  it 400/748 | loss 0.5986 | acc 0.7678 | elapsed 173.1s


  it 450/748 | loss 0.5988 | acc 0.7682 | elapsed 194.6s


  it 500/748 | loss 0.6019 | acc 0.7670 | elapsed 216.0s


  it 550/748 | loss 0.6061 | acc 0.7666 | elapsed 237.5s


  it 600/748 | loss 0.6035 | acc 0.7689 | elapsed 258.9s


  it 650/748 | loss 0.6035 | acc 0.7682 | elapsed 280.2s


  it 700/748 | loss 0.6035 | acc 0.7673 | elapsed 301.6s


ConvNeXt Fold 1 | Epoch 4/12 | tr_loss 0.6048 tr_acc 0.7647 | va_acc 0.8569 | elapsed_fold 23.7m


  it 50/748 | loss 0.5633 | acc 0.7530 | elapsed 22.2s


  it 100/748 | loss 0.5892 | acc 0.7420 | elapsed 43.7s


  it 150/748 | loss 0.5840 | acc 0.7583 | elapsed 65.1s


  it 200/748 | loss 0.5914 | acc 0.7625 | elapsed 86.6s


  it 250/748 | loss 0.5901 | acc 0.7632 | elapsed 108.2s


  it 300/748 | loss 0.5889 | acc 0.7627 | elapsed 129.7s


  it 350/748 | loss 0.5880 | acc 0.7607 | elapsed 151.2s


  it 400/748 | loss 0.5839 | acc 0.7652 | elapsed 173.0s


  it 450/748 | loss 0.5868 | acc 0.7603 | elapsed 194.5s


  it 500/748 | loss 0.5838 | acc 0.7614 | elapsed 215.9s


  it 550/748 | loss 0.5892 | acc 0.7605 | elapsed 237.3s


  it 600/748 | loss 0.5908 | acc 0.7606 | elapsed 258.7s


  it 650/748 | loss 0.5837 | acc 0.7652 | elapsed 280.1s


  it 700/748 | loss 0.5836 | acc 0.7662 | elapsed 301.5s


ConvNeXt Fold 1 | Epoch 5/12 | tr_loss 0.5802 tr_acc 0.7665 | va_acc 0.8732 | elapsed_fold 29.6m


  it 50/748 | loss 0.6005 | acc 0.7390 | elapsed 22.1s


  it 100/748 | loss 0.6097 | acc 0.7535 | elapsed 43.5s


  it 150/748 | loss 0.6029 | acc 0.7533 | elapsed 64.8s


  it 200/748 | loss 0.5975 | acc 0.7552 | elapsed 86.2s


  it 250/748 | loss 0.5979 | acc 0.7570 | elapsed 107.6s


  it 300/748 | loss 0.5952 | acc 0.7610 | elapsed 129.0s


  it 350/748 | loss 0.5880 | acc 0.7640 | elapsed 150.5s


  it 400/748 | loss 0.5902 | acc 0.7640 | elapsed 172.2s


  it 450/748 | loss 0.5877 | acc 0.7650 | elapsed 193.7s


  it 500/748 | loss 0.5766 | acc 0.7692 | elapsed 215.3s


  it 550/748 | loss 0.5760 | acc 0.7677 | elapsed 236.8s


  it 600/748 | loss 0.5752 | acc 0.7685 | elapsed 258.3s


  it 650/748 | loss 0.5719 | acc 0.7735 | elapsed 279.8s


  it 700/748 | loss 0.5753 | acc 0.7725 | elapsed 301.3s


ConvNeXt Fold 1 | Epoch 6/12 | tr_loss 0.5719 tr_acc 0.7745 | va_acc 0.8820 | elapsed_fold 35.4m


  it 50/748 | loss 0.5199 | acc 0.7670 | elapsed 22.2s


  it 100/748 | loss 0.5114 | acc 0.7865 | elapsed 43.6s


  it 150/748 | loss 0.5103 | acc 0.8030 | elapsed 64.9s


  it 200/748 | loss 0.5220 | acc 0.7965 | elapsed 86.4s


  it 250/748 | loss 0.5129 | acc 0.7994 | elapsed 107.8s


  it 300/748 | loss 0.5255 | acc 0.7987 | elapsed 129.3s


  it 350/748 | loss 0.5303 | acc 0.7911 | elapsed 150.7s


  it 400/748 | loss 0.5404 | acc 0.7864 | elapsed 172.4s


  it 450/748 | loss 0.5449 | acc 0.7873 | elapsed 193.8s


  it 500/748 | loss 0.5501 | acc 0.7848 | elapsed 215.3s


  it 550/748 | loss 0.5487 | acc 0.7850 | elapsed 236.8s


  it 600/748 | loss 0.5502 | acc 0.7819 | elapsed 258.4s


  it 650/748 | loss 0.5463 | acc 0.7832 | elapsed 280.1s


  it 700/748 | loss 0.5426 | acc 0.7859 | elapsed 301.7s


ConvNeXt Fold 1 | Epoch 7/12 | tr_loss 0.5430 tr_acc 0.7860 | va_acc 0.8870 | elapsed_fold 41.2m


  it 50/748 | loss 0.5259 | acc 0.7800 | elapsed 22.3s


  it 100/748 | loss 0.5272 | acc 0.7880 | elapsed 43.8s


  it 150/748 | loss 0.5328 | acc 0.7867 | elapsed 65.2s


  it 200/748 | loss 0.5435 | acc 0.7933 | elapsed 86.7s


  it 250/748 | loss 0.5304 | acc 0.8004 | elapsed 108.1s


  it 300/748 | loss 0.5312 | acc 0.8012 | elapsed 129.5s


  it 350/748 | loss 0.5215 | acc 0.8029 | elapsed 150.9s


  it 400/748 | loss 0.5239 | acc 0.8050 | elapsed 172.6s


  it 450/748 | loss 0.5205 | acc 0.8063 | elapsed 193.9s


  it 500/748 | loss 0.5195 | acc 0.8058 | elapsed 215.3s


  it 550/748 | loss 0.5208 | acc 0.8048 | elapsed 236.7s


  it 600/748 | loss 0.5229 | acc 0.8007 | elapsed 258.2s


  it 650/748 | loss 0.5259 | acc 0.7978 | elapsed 279.6s


  it 700/748 | loss 0.5244 | acc 0.7974 | elapsed 301.1s


ConvNeXt Fold 1 | Epoch 8/12 | tr_loss 0.5240 tr_acc 0.7985 | va_acc 0.8900 | elapsed_fold 47.0m


  it 50/748 | loss 0.4731 | acc 0.7930 | elapsed 22.2s


  it 100/748 | loss 0.4944 | acc 0.7985 | elapsed 43.9s


  it 150/748 | loss 0.5039 | acc 0.8003 | elapsed 65.6s


  it 200/748 | loss 0.4956 | acc 0.8093 | elapsed 87.4s


  it 250/748 | loss 0.4873 | acc 0.8092 | elapsed 109.2s


  it 300/748 | loss 0.4867 | acc 0.8087 | elapsed 130.8s


  it 350/748 | loss 0.4911 | acc 0.8063 | elapsed 152.4s


  it 400/748 | loss 0.5002 | acc 0.8024 | elapsed 174.2s


  it 450/748 | loss 0.5025 | acc 0.7988 | elapsed 195.6s


  it 500/748 | loss 0.5030 | acc 0.7996 | elapsed 217.1s


  it 550/748 | loss 0.4985 | acc 0.8019 | elapsed 238.6s


  it 600/748 | loss 0.4958 | acc 0.8008 | elapsed 260.1s


  it 650/748 | loss 0.4934 | acc 0.8014 | elapsed 281.6s


  it 700/748 | loss 0.4923 | acc 0.8018 | elapsed 303.0s


ConvNeXt Fold 1 | Epoch 9/12 | tr_loss 0.4938 tr_acc 0.8037 | va_acc 0.8921 | elapsed_fold 52.9m


  it 50/748 | loss 0.4726 | acc 0.7850 | elapsed 22.2s


  it 100/748 | loss 0.5011 | acc 0.7895 | elapsed 43.7s


  it 150/748 | loss 0.5071 | acc 0.7953 | elapsed 65.1s


  it 200/748 | loss 0.4901 | acc 0.8047 | elapsed 86.6s


  it 250/748 | loss 0.4802 | acc 0.8086 | elapsed 108.2s


  it 300/748 | loss 0.4889 | acc 0.8063 | elapsed 129.7s


  it 350/748 | loss 0.4844 | acc 0.8070 | elapsed 151.3s


  it 400/748 | loss 0.4905 | acc 0.8084 | elapsed 173.1s


  it 450/748 | loss 0.4895 | acc 0.8069 | elapsed 194.6s


  it 500/748 | loss 0.4874 | acc 0.8061 | elapsed 216.1s


  it 550/748 | loss 0.4817 | acc 0.8081 | elapsed 237.5s


  it 600/748 | loss 0.4794 | acc 0.8095 | elapsed 259.0s


  it 650/748 | loss 0.4793 | acc 0.8103 | elapsed 280.4s


  it 700/748 | loss 0.4794 | acc 0.8090 | elapsed 301.8s


ConvNeXt Fold 1 | Epoch 10/12 | tr_loss 0.4767 tr_acc 0.8106 | va_acc 0.8953 | elapsed_fold 58.7m


  it 50/748 | loss 0.3197 | acc 0.8940 | elapsed 22.1s


  it 100/748 | loss 0.3021 | acc 0.8970 | elapsed 43.5s


  it 150/748 | loss 0.3081 | acc 0.8947 | elapsed 65.0s


  it 200/748 | loss 0.3069 | acc 0.8965 | elapsed 86.4s


  it 250/748 | loss 0.3037 | acc 0.8966 | elapsed 107.9s


  it 300/748 | loss 0.3132 | acc 0.8945 | elapsed 129.4s


  it 350/748 | loss 0.3151 | acc 0.8946 | elapsed 150.9s


  it 400/748 | loss 0.3144 | acc 0.8942 | elapsed 172.7s


  it 450/748 | loss 0.3225 | acc 0.8918 | elapsed 194.3s


  it 500/748 | loss 0.3219 | acc 0.8917 | elapsed 215.8s


  it 550/748 | loss 0.3221 | acc 0.8920 | elapsed 237.3s


  it 600/748 | loss 0.3207 | acc 0.8924 | elapsed 258.9s


  it 650/748 | loss 0.3216 | acc 0.8928 | elapsed 280.4s


  it 700/748 | loss 0.3201 | acc 0.8936 | elapsed 301.9s


ConvNeXt Fold 1 | Epoch 11/12 | tr_loss 0.3234 tr_acc 0.8926 | va_acc 0.8975 | elapsed_fold 64.5m


  it 50/748 | loss 0.3397 | acc 0.8870 | elapsed 22.3s


  it 100/748 | loss 0.3401 | acc 0.8860 | elapsed 43.6s


  it 150/748 | loss 0.3311 | acc 0.8897 | elapsed 65.0s


  it 200/748 | loss 0.3228 | acc 0.8915 | elapsed 86.3s


  it 250/748 | loss 0.3158 | acc 0.8928 | elapsed 107.7s


  it 300/748 | loss 0.3119 | acc 0.8940 | elapsed 129.0s


  it 350/748 | loss 0.3128 | acc 0.8934 | elapsed 150.4s


  it 400/748 | loss 0.3128 | acc 0.8936 | elapsed 171.8s


  it 450/748 | loss 0.3127 | acc 0.8930 | elapsed 193.5s


  it 500/748 | loss 0.3102 | acc 0.8936 | elapsed 214.9s


  it 550/748 | loss 0.3147 | acc 0.8930 | elapsed 236.4s


  it 600/748 | loss 0.3150 | acc 0.8932 | elapsed 257.9s


  it 650/748 | loss 0.3134 | acc 0.8940 | elapsed 279.4s


  it 700/748 | loss 0.3118 | acc 0.8939 | elapsed 300.9s


ConvNeXt Fold 1 | Epoch 12/12 | tr_loss 0.3128 tr_acc 0.8931 | va_acc 0.8977 | elapsed_fold 70.4m


ConvNeXt Fold 1 done | best_va_acc 0.8977 | ckpt convnext_fold1_best.pth | fold_time 71.3m


===== ConvNeXt Fold 2/5 | train 14977 | valid 3744 =====


CNX cfg: {'input_size': (3, 512, 512), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225)}


  it 50/748 | loss 1.2792 | acc 0.5170 | elapsed 22.1s


  it 100/748 | loss 1.1132 | acc 0.5830 | elapsed 43.4s


  it 150/748 | loss 1.0282 | acc 0.6150 | elapsed 64.8s


  it 200/748 | loss 0.9641 | acc 0.6415 | elapsed 86.1s


  it 250/748 | loss 0.9243 | acc 0.6530 | elapsed 107.5s


  it 300/748 | loss 0.8925 | acc 0.6677 | elapsed 128.9s


  it 350/748 | loss 0.8798 | acc 0.6720 | elapsed 150.3s


  it 400/748 | loss 0.8677 | acc 0.6775 | elapsed 171.7s


  it 450/748 | loss 0.8606 | acc 0.6804 | elapsed 193.1s


  it 500/748 | loss 0.8467 | acc 0.6823 | elapsed 214.5s


  it 550/748 | loss 0.8408 | acc 0.6817 | elapsed 235.9s


  it 600/748 | loss 0.8331 | acc 0.6842 | elapsed 257.3s


  it 650/748 | loss 0.8234 | acc 0.6885 | elapsed 279.1s


  it 700/748 | loss 0.8188 | acc 0.6899 | elapsed 300.6s


ConvNeXt Fold 2 | Epoch 1/12 | tr_loss 0.8127 tr_acc 0.6936 | va_acc 0.4207 | elapsed_fold 5.9m


  it 50/748 | loss 0.7901 | acc 0.7220 | elapsed 22.1s


  it 100/748 | loss 0.7630 | acc 0.7220 | elapsed 43.6s


  it 150/748 | loss 0.7560 | acc 0.7213 | elapsed 65.1s


  it 200/748 | loss 0.7530 | acc 0.7198 | elapsed 86.6s


  it 250/748 | loss 0.7539 | acc 0.7206 | elapsed 108.1s


  it 300/748 | loss 0.7548 | acc 0.7195 | elapsed 129.5s


  it 350/748 | loss 0.7566 | acc 0.7140 | elapsed 150.9s


  it 400/748 | loss 0.7517 | acc 0.7161 | elapsed 172.3s


  it 450/748 | loss 0.7502 | acc 0.7180 | elapsed 193.7s


  it 500/748 | loss 0.7475 | acc 0.7182 | elapsed 215.1s


  it 550/748 | loss 0.7395 | acc 0.7217 | elapsed 236.5s


  it 600/748 | loss 0.7421 | acc 0.7208 | elapsed 257.8s


  it 650/748 | loss 0.7402 | acc 0.7202 | elapsed 279.1s


  it 700/748 | loss 0.7402 | acc 0.7197 | elapsed 300.7s


ConvNeXt Fold 2 | Epoch 2/12 | tr_loss 0.7456 tr_acc 0.7189 | va_acc 0.7465 | elapsed_fold 11.7m


  it 50/748 | loss 0.7089 | acc 0.7500 | elapsed 22.1s


  it 100/748 | loss 0.6962 | acc 0.7530 | elapsed 43.6s


  it 150/748 | loss 0.6964 | acc 0.7423 | elapsed 65.1s


  it 200/748 | loss 0.6956 | acc 0.7402 | elapsed 86.7s


  it 250/748 | loss 0.7058 | acc 0.7342 | elapsed 108.2s


  it 300/748 | loss 0.6973 | acc 0.7292 | elapsed 129.8s


  it 350/748 | loss 0.6990 | acc 0.7306 | elapsed 151.3s


  it 400/748 | loss 0.7062 | acc 0.7296 | elapsed 172.8s


  it 450/748 | loss 0.7057 | acc 0.7302 | elapsed 194.2s


  it 500/748 | loss 0.7031 | acc 0.7320 | elapsed 215.7s


  it 550/748 | loss 0.7028 | acc 0.7326 | elapsed 237.1s


  it 600/748 | loss 0.7017 | acc 0.7308 | elapsed 258.5s


  it 650/748 | loss 0.6981 | acc 0.7320 | elapsed 279.9s


  it 700/748 | loss 0.6965 | acc 0.7310 | elapsed 301.5s


ConvNeXt Fold 2 | Epoch 3/12 | tr_loss 0.6903 tr_acc 0.7309 | va_acc 0.8291 | elapsed_fold 17.5m


  it 50/748 | loss 0.6151 | acc 0.7770 | elapsed 22.2s


  it 100/748 | loss 0.6063 | acc 0.7740 | elapsed 43.6s


  it 150/748 | loss 0.6208 | acc 0.7690 | elapsed 65.0s


  it 200/748 | loss 0.6044 | acc 0.7718 | elapsed 86.5s


  it 250/748 | loss 0.5966 | acc 0.7778 | elapsed 107.9s


  it 300/748 | loss 0.6017 | acc 0.7693 | elapsed 129.5s


  it 350/748 | loss 0.5930 | acc 0.7686 | elapsed 151.0s


  it 400/748 | loss 0.5935 | acc 0.7674 | elapsed 172.5s


  it 450/748 | loss 0.5925 | acc 0.7696 | elapsed 194.1s


  it 500/748 | loss 0.5949 | acc 0.7698 | elapsed 215.6s


  it 550/748 | loss 0.5984 | acc 0.7695 | elapsed 237.2s


  it 600/748 | loss 0.6045 | acc 0.7662 | elapsed 258.7s


  it 650/748 | loss 0.6050 | acc 0.7658 | elapsed 280.2s


  it 700/748 | loss 0.6040 | acc 0.7656 | elapsed 301.9s


ConvNeXt Fold 2 | Epoch 4/12 | tr_loss 0.6047 tr_acc 0.7659 | va_acc 0.8571 | elapsed_fold 23.3m


  it 50/748 | loss 0.5742 | acc 0.7730 | elapsed 22.2s


  it 100/748 | loss 0.5737 | acc 0.7725 | elapsed 43.6s


  it 150/748 | loss 0.5680 | acc 0.7707 | elapsed 65.2s


  it 200/748 | loss 0.5841 | acc 0.7695 | elapsed 86.7s


  it 250/748 | loss 0.5879 | acc 0.7734 | elapsed 108.2s


  it 300/748 | loss 0.5893 | acc 0.7722 | elapsed 129.7s


  it 350/748 | loss 0.5820 | acc 0.7757 | elapsed 151.2s


  it 400/748 | loss 0.5799 | acc 0.7762 | elapsed 172.7s


  it 450/748 | loss 0.5805 | acc 0.7750 | elapsed 194.2s


  it 500/748 | loss 0.5872 | acc 0.7742 | elapsed 215.7s


  it 550/748 | loss 0.5820 | acc 0.7771 | elapsed 237.2s


  it 600/748 | loss 0.5814 | acc 0.7773 | elapsed 258.9s


  it 650/748 | loss 0.5805 | acc 0.7770 | elapsed 280.6s


  it 700/748 | loss 0.5841 | acc 0.7743 | elapsed 302.7s


ConvNeXt Fold 2 | Epoch 5/12 | tr_loss 0.5876 tr_acc 0.7735 | va_acc 0.8793 | elapsed_fold 29.2m


  it 50/748 | loss 0.6663 | acc 0.7300 | elapsed 22.3s


  it 100/748 | loss 0.6003 | acc 0.7535 | elapsed 43.7s


  it 150/748 | loss 0.5837 | acc 0.7590 | elapsed 65.1s


  it 200/748 | loss 0.5652 | acc 0.7695 | elapsed 86.5s


  it 250/748 | loss 0.5667 | acc 0.7730 | elapsed 107.8s


  it 300/748 | loss 0.5708 | acc 0.7712 | elapsed 129.2s


  it 350/748 | loss 0.5770 | acc 0.7729 | elapsed 150.5s


  it 400/748 | loss 0.5794 | acc 0.7666 | elapsed 171.9s


  it 450/748 | loss 0.5803 | acc 0.7644 | elapsed 193.4s


  it 500/748 | loss 0.5743 | acc 0.7675 | elapsed 214.8s


  it 550/748 | loss 0.5776 | acc 0.7664 | elapsed 236.3s


  it 600/748 | loss 0.5734 | acc 0.7690 | elapsed 257.8s


  it 650/748 | loss 0.5632 | acc 0.7729 | elapsed 279.3s


  it 700/748 | loss 0.5666 | acc 0.7732 | elapsed 301.1s


ConvNeXt Fold 2 | Epoch 6/12 | tr_loss 0.5643 tr_acc 0.7767 | va_acc 0.8878 | elapsed_fold 35.0m


  it 50/748 | loss 0.4725 | acc 0.8530 | elapsed 22.3s


  it 100/748 | loss 0.4890 | acc 0.8320 | elapsed 43.8s


  it 150/748 | loss 0.4828 | acc 0.8373 | elapsed 65.2s


  it 200/748 | loss 0.4995 | acc 0.8295 | elapsed 86.7s


  it 250/748 | loss 0.4954 | acc 0.8218 | elapsed 108.1s


  it 300/748 | loss 0.4955 | acc 0.8208 | elapsed 129.6s


  it 350/748 | loss 0.4976 | acc 0.8133 | elapsed 151.0s


  it 400/748 | loss 0.5041 | acc 0.8127 | elapsed 172.4s


  it 450/748 | loss 0.5076 | acc 0.8108 | elapsed 193.8s


  it 500/748 | loss 0.5065 | acc 0.8092 | elapsed 215.2s


  it 550/748 | loss 0.5116 | acc 0.8069 | elapsed 236.7s


  it 600/748 | loss 0.5099 | acc 0.8038 | elapsed 258.1s


  it 650/748 | loss 0.5131 | acc 0.8015 | elapsed 279.6s


  it 700/748 | loss 0.5134 | acc 0.8012 | elapsed 301.2s


ConvNeXt Fold 2 | Epoch 7/12 | tr_loss 0.5154 tr_acc 0.8017 | va_acc 0.8926 | elapsed_fold 40.8m


  it 50/748 | loss 0.5021 | acc 0.7810 | elapsed 22.2s


  it 100/748 | loss 0.4675 | acc 0.8070 | elapsed 43.8s


  it 150/748 | loss 0.4698 | acc 0.8137 | elapsed 65.4s


  it 200/748 | loss 0.4773 | acc 0.8160 | elapsed 87.0s


  it 250/748 | loss 0.4893 | acc 0.8130 | elapsed 108.6s


  it 300/748 | loss 0.4960 | acc 0.8102 | elapsed 130.2s


  it 350/748 | loss 0.4996 | acc 0.8104 | elapsed 151.8s


  it 400/748 | loss 0.5074 | acc 0.8074 | elapsed 173.3s


  it 450/748 | loss 0.5058 | acc 0.8104 | elapsed 194.8s


  it 500/748 | loss 0.5071 | acc 0.8086 | elapsed 216.2s


  it 550/748 | loss 0.5063 | acc 0.8080 | elapsed 237.7s


  it 600/748 | loss 0.5046 | acc 0.8091 | elapsed 259.2s


  it 650/748 | loss 0.5055 | acc 0.8119 | elapsed 280.7s


  it 700/748 | loss 0.5057 | acc 0.8099 | elapsed 302.4s


ConvNeXt Fold 2 | Epoch 8/12 | tr_loss 0.5059 tr_acc 0.8096 | va_acc 0.8966 | elapsed_fold 46.7m


  it 50/748 | loss 0.4838 | acc 0.8120 | elapsed 22.2s


  it 100/748 | loss 0.4886 | acc 0.8035 | elapsed 43.6s


  it 150/748 | loss 0.4919 | acc 0.7983 | elapsed 65.1s


  it 200/748 | loss 0.4952 | acc 0.8013 | elapsed 86.6s


  it 250/748 | loss 0.4960 | acc 0.7916 | elapsed 108.1s


  it 300/748 | loss 0.4975 | acc 0.7920 | elapsed 129.7s


  it 350/748 | loss 0.4926 | acc 0.7944 | elapsed 151.3s


  it 400/748 | loss 0.4898 | acc 0.7945 | elapsed 172.9s


  it 450/748 | loss 0.4899 | acc 0.7976 | elapsed 194.5s


  it 500/748 | loss 0.4887 | acc 0.7993 | elapsed 216.1s


  it 550/748 | loss 0.4876 | acc 0.7991 | elapsed 237.6s


  it 600/748 | loss 0.4895 | acc 0.7974 | elapsed 259.2s


  it 650/748 | loss 0.4849 | acc 0.7969 | elapsed 280.7s


  it 700/748 | loss 0.4827 | acc 0.7979 | elapsed 302.5s


ConvNeXt Fold 2 | Epoch 9/12 | tr_loss 0.4823 tr_acc 0.7994 | va_acc 0.9004 | elapsed_fold 52.5m


  it 50/748 | loss 0.4869 | acc 0.8110 | elapsed 22.2s


  it 100/748 | loss 0.4847 | acc 0.8085 | elapsed 43.6s


  it 150/748 | loss 0.4921 | acc 0.8067 | elapsed 65.1s


  it 200/748 | loss 0.4985 | acc 0.7985 | elapsed 86.5s


  it 250/748 | loss 0.4829 | acc 0.7954 | elapsed 108.0s


  it 300/748 | loss 0.4945 | acc 0.7915 | elapsed 129.4s


  it 350/748 | loss 0.4826 | acc 0.7946 | elapsed 150.9s


  it 400/748 | loss 0.4802 | acc 0.7984 | elapsed 172.4s


  it 450/748 | loss 0.4828 | acc 0.8008 | elapsed 193.9s


  it 500/748 | loss 0.4822 | acc 0.7999 | elapsed 215.5s


  it 550/748 | loss 0.4849 | acc 0.7998 | elapsed 237.1s


  it 600/748 | loss 0.4836 | acc 0.8033 | elapsed 258.7s


  it 650/748 | loss 0.4841 | acc 0.8029 | elapsed 280.3s


  it 700/748 | loss 0.4785 | acc 0.8046 | elapsed 302.1s


ConvNeXt Fold 2 | Epoch 10/12 | tr_loss 0.4778 tr_acc 0.8040 | va_acc 0.9004 | elapsed_fold 58.3m


  it 50/748 | loss 0.2843 | acc 0.9010 | elapsed 22.4s


  it 100/748 | loss 0.2851 | acc 0.9025 | elapsed 43.8s


  it 150/748 | loss 0.2914 | acc 0.9010 | elapsed 65.3s


  it 200/748 | loss 0.2965 | acc 0.8970 | elapsed 86.8s


  it 250/748 | loss 0.3085 | acc 0.8944 | elapsed 108.4s


  it 300/748 | loss 0.3076 | acc 0.8938 | elapsed 130.0s


  it 350/748 | loss 0.3067 | acc 0.8951 | elapsed 151.5s


  it 400/748 | loss 0.3076 | acc 0.8939 | elapsed 173.0s


  it 450/748 | loss 0.3085 | acc 0.8932 | elapsed 194.4s


  it 500/748 | loss 0.3110 | acc 0.8939 | elapsed 215.9s


  it 550/748 | loss 0.3125 | acc 0.8940 | elapsed 237.4s


  it 600/748 | loss 0.3141 | acc 0.8945 | elapsed 259.0s


  it 650/748 | loss 0.3158 | acc 0.8935 | elapsed 280.6s


  it 700/748 | loss 0.3208 | acc 0.8921 | elapsed 302.1s


ConvNeXt Fold 2 | Epoch 11/12 | tr_loss 0.3209 tr_acc 0.8923 | va_acc 0.9009 | elapsed_fold 64.2m


  it 50/748 | loss 0.2854 | acc 0.9000 | elapsed 22.3s


  it 100/748 | loss 0.2936 | acc 0.8980 | elapsed 43.7s


  it 150/748 | loss 0.3017 | acc 0.8980 | elapsed 65.2s


  it 200/748 | loss 0.3115 | acc 0.8950 | elapsed 86.7s


  it 250/748 | loss 0.3226 | acc 0.8894 | elapsed 108.1s


  it 300/748 | loss 0.3209 | acc 0.8905 | elapsed 129.6s


  it 350/748 | loss 0.3214 | acc 0.8906 | elapsed 151.0s


  it 400/748 | loss 0.3209 | acc 0.8908 | elapsed 172.4s


  it 450/748 | loss 0.3175 | acc 0.8924 | elapsed 193.8s


  it 500/748 | loss 0.3126 | acc 0.8936 | elapsed 215.2s


  it 550/748 | loss 0.3110 | acc 0.8945 | elapsed 236.5s


  it 600/748 | loss 0.3124 | acc 0.8944 | elapsed 257.9s


  it 650/748 | loss 0.3130 | acc 0.8947 | elapsed 279.2s


  it 700/748 | loss 0.3117 | acc 0.8950 | elapsed 300.6s


ConvNeXt Fold 2 | Epoch 12/12 | tr_loss 0.3105 tr_acc 0.8953 | va_acc 0.9014 | elapsed_fold 70.0m


ConvNeXt Fold 2 done | best_va_acc 0.9014 | ckpt convnext_fold2_best.pth | fold_time 70.8m


===== ConvNeXt Fold 3/5 | train 14977 | valid 3744 =====


CNX cfg: {'input_size': (3, 512, 512), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225)}


  it 50/748 | loss 1.2153 | acc 0.6200 | elapsed 22.1s


  it 100/748 | loss 1.0915 | acc 0.6345 | elapsed 43.7s


  it 150/748 | loss 1.0127 | acc 0.6517 | elapsed 65.2s


  it 200/748 | loss 0.9461 | acc 0.6730 | elapsed 86.8s


  it 250/748 | loss 0.9114 | acc 0.6816 | elapsed 108.3s


  it 300/748 | loss 0.8981 | acc 0.6865 | elapsed 129.9s


  it 350/748 | loss 0.8700 | acc 0.6934 | elapsed 151.4s


  it 400/748 | loss 0.8476 | acc 0.7000 | elapsed 172.9s


  it 450/748 | loss 0.8292 | acc 0.7038 | elapsed 194.3s


  it 500/748 | loss 0.8137 | acc 0.7071 | elapsed 215.7s


  it 550/748 | loss 0.8055 | acc 0.7078 | elapsed 237.2s


  it 600/748 | loss 0.8044 | acc 0.7072 | elapsed 258.5s


  it 650/748 | loss 0.7978 | acc 0.7090 | elapsed 279.9s


  it 700/748 | loss 0.7933 | acc 0.7077 | elapsed 301.5s


ConvNeXt Fold 3 | Epoch 1/12 | tr_loss 0.7896 tr_acc 0.7074 | va_acc 0.6544 | elapsed_fold 5.8m


  it 50/748 | loss 0.6986 | acc 0.7360 | elapsed 22.1s


  it 100/748 | loss 0.7679 | acc 0.7275 | elapsed 43.6s


  it 150/748 | loss 0.7863 | acc 0.7257 | elapsed 65.1s


  it 200/748 | loss 0.7712 | acc 0.7240 | elapsed 86.7s


  it 250/748 | loss 0.7655 | acc 0.7174 | elapsed 108.3s


  it 300/748 | loss 0.7592 | acc 0.7165 | elapsed 130.0s


  it 350/748 | loss 0.7657 | acc 0.7120 | elapsed 151.7s


  it 400/748 | loss 0.7489 | acc 0.7166 | elapsed 173.4s


  it 450/748 | loss 0.7440 | acc 0.7177 | elapsed 195.1s


  it 500/748 | loss 0.7450 | acc 0.7172 | elapsed 216.8s


  it 550/748 | loss 0.7415 | acc 0.7196 | elapsed 238.4s


  it 600/748 | loss 0.7409 | acc 0.7193 | elapsed 260.0s


  it 650/748 | loss 0.7373 | acc 0.7188 | elapsed 281.5s


  it 700/748 | loss 0.7411 | acc 0.7178 | elapsed 303.2s


ConvNeXt Fold 3 | Epoch 2/12 | tr_loss 0.7381 tr_acc 0.7193 | va_acc 0.7668 | elapsed_fold 11.7m


  it 50/748 | loss 0.7636 | acc 0.7000 | elapsed 22.1s


  it 100/748 | loss 0.7166 | acc 0.7205 | elapsed 43.5s


  it 150/748 | loss 0.7195 | acc 0.7230 | elapsed 64.8s


  it 200/748 | loss 0.7207 | acc 0.7280 | elapsed 86.2s


  it 250/748 | loss 0.7248 | acc 0.7234 | elapsed 107.6s


  it 300/748 | loss 0.7185 | acc 0.7227 | elapsed 128.9s


  it 350/748 | loss 0.7170 | acc 0.7226 | elapsed 150.3s


  it 400/748 | loss 0.7187 | acc 0.7225 | elapsed 171.8s


  it 450/748 | loss 0.7319 | acc 0.7207 | elapsed 193.3s


  it 500/748 | loss 0.7273 | acc 0.7254 | elapsed 214.8s


  it 550/748 | loss 0.7241 | acc 0.7258 | elapsed 236.3s


  it 600/748 | loss 0.7226 | acc 0.7248 | elapsed 257.8s


  it 650/748 | loss 0.7171 | acc 0.7280 | elapsed 279.3s


  it 700/748 | loss 0.7116 | acc 0.7314 | elapsed 301.2s


ConvNeXt Fold 3 | Epoch 3/12 | tr_loss 0.7043 tr_acc 0.7340 | va_acc 0.8202 | elapsed_fold 17.5m


  it 50/748 | loss 0.5787 | acc 0.7830 | elapsed 22.1s


  it 100/748 | loss 0.5815 | acc 0.7660 | elapsed 43.6s


  it 150/748 | loss 0.5721 | acc 0.7707 | elapsed 65.0s


  it 200/748 | loss 0.5852 | acc 0.7720 | elapsed 86.4s


  it 250/748 | loss 0.5848 | acc 0.7632 | elapsed 107.9s


  it 300/748 | loss 0.5869 | acc 0.7680 | elapsed 129.3s


  it 350/748 | loss 0.5973 | acc 0.7647 | elapsed 150.8s


  it 400/748 | loss 0.6052 | acc 0.7656 | elapsed 172.2s


  it 450/748 | loss 0.6051 | acc 0.7644 | elapsed 193.6s


  it 500/748 | loss 0.5997 | acc 0.7670 | elapsed 215.0s


  it 550/748 | loss 0.5977 | acc 0.7705 | elapsed 236.5s


  it 600/748 | loss 0.5962 | acc 0.7701 | elapsed 257.9s


  it 650/748 | loss 0.5992 | acc 0.7690 | elapsed 279.4s


  it 700/748 | loss 0.6008 | acc 0.7694 | elapsed 301.2s


ConvNeXt Fold 3 | Epoch 4/12 | tr_loss 0.5987 tr_acc 0.7703 | va_acc 0.8475 | elapsed_fold 23.3m


  it 50/748 | loss 0.6156 | acc 0.7530 | elapsed 22.3s


  it 100/748 | loss 0.5950 | acc 0.7670 | elapsed 43.8s


  it 150/748 | loss 0.5968 | acc 0.7743 | elapsed 65.4s


  it 200/748 | loss 0.6159 | acc 0.7662 | elapsed 87.0s


  it 250/748 | loss 0.6074 | acc 0.7716 | elapsed 108.5s


  it 300/748 | loss 0.6014 | acc 0.7723 | elapsed 130.0s


  it 350/748 | loss 0.6074 | acc 0.7693 | elapsed 151.4s


  it 400/748 | loss 0.6174 | acc 0.7691 | elapsed 172.9s


  it 450/748 | loss 0.6164 | acc 0.7648 | elapsed 194.4s


  it 500/748 | loss 0.6104 | acc 0.7647 | elapsed 215.8s


  it 550/748 | loss 0.6089 | acc 0.7678 | elapsed 237.3s


  it 600/748 | loss 0.6036 | acc 0.7709 | elapsed 258.7s


  it 650/748 | loss 0.6070 | acc 0.7665 | elapsed 280.1s


  it 700/748 | loss 0.6039 | acc 0.7681 | elapsed 301.6s


ConvNeXt Fold 3 | Epoch 5/12 | tr_loss 0.6005 tr_acc 0.7692 | va_acc 0.8654 | elapsed_fold 29.2m


  it 50/748 | loss 0.5789 | acc 0.7640 | elapsed 22.2s


  it 100/748 | loss 0.5617 | acc 0.7900 | elapsed 43.8s


  it 150/748 | loss 0.5516 | acc 0.7900 | elapsed 65.3s


  it 200/748 | loss 0.5636 | acc 0.7782 | elapsed 86.9s


  it 250/748 | loss 0.5654 | acc 0.7810 | elapsed 108.4s


  it 300/748 | loss 0.5587 | acc 0.7847 | elapsed 130.0s


  it 350/748 | loss 0.5707 | acc 0.7830 | elapsed 151.5s


  it 400/748 | loss 0.5818 | acc 0.7734 | elapsed 173.0s


  it 450/748 | loss 0.5707 | acc 0.7782 | elapsed 194.5s


  it 500/748 | loss 0.5715 | acc 0.7769 | elapsed 215.9s


  it 550/748 | loss 0.5741 | acc 0.7729 | elapsed 237.4s


  it 600/748 | loss 0.5685 | acc 0.7743 | elapsed 258.8s


  it 650/748 | loss 0.5711 | acc 0.7736 | elapsed 280.2s


  it 700/748 | loss 0.5703 | acc 0.7749 | elapsed 301.6s


ConvNeXt Fold 3 | Epoch 6/12 | tr_loss 0.5666 tr_acc 0.7773 | va_acc 0.8737 | elapsed_fold 35.0m


  it 50/748 | loss 0.5706 | acc 0.8010 | elapsed 22.3s


  it 100/748 | loss 0.5417 | acc 0.7995 | elapsed 43.7s


  it 150/748 | loss 0.5461 | acc 0.7910 | elapsed 65.2s


  it 200/748 | loss 0.5416 | acc 0.7857 | elapsed 86.9s


  it 250/748 | loss 0.5400 | acc 0.7830 | elapsed 108.5s


  it 300/748 | loss 0.5440 | acc 0.7835 | elapsed 130.2s


  it 350/748 | loss 0.5404 | acc 0.7820 | elapsed 151.9s


  it 400/748 | loss 0.5346 | acc 0.7880 | elapsed 173.6s


  it 450/748 | loss 0.5370 | acc 0.7902 | elapsed 195.4s


  it 500/748 | loss 0.5344 | acc 0.7921 | elapsed 217.1s


  it 550/748 | loss 0.5407 | acc 0.7893 | elapsed 238.8s


  it 600/748 | loss 0.5366 | acc 0.7898 | elapsed 260.5s


  it 650/748 | loss 0.5320 | acc 0.7935 | elapsed 282.0s


  it 700/748 | loss 0.5335 | acc 0.7923 | elapsed 303.6s


ConvNeXt Fold 3 | Epoch 7/12 | tr_loss 0.5348 tr_acc 0.7906 | va_acc 0.8817 | elapsed_fold 40.9m


  it 50/748 | loss 0.5100 | acc 0.7660 | elapsed 22.2s


  it 100/748 | loss 0.5055 | acc 0.7675 | elapsed 43.5s


  it 150/748 | loss 0.4912 | acc 0.7870 | elapsed 64.8s


  it 200/748 | loss 0.4886 | acc 0.7880 | elapsed 86.2s


  it 250/748 | loss 0.4932 | acc 0.7914 | elapsed 107.6s


  it 300/748 | loss 0.4940 | acc 0.7930 | elapsed 129.0s


  it 350/748 | loss 0.4906 | acc 0.7957 | elapsed 150.4s


  it 400/748 | loss 0.5013 | acc 0.7916 | elapsed 171.8s


  it 450/748 | loss 0.5074 | acc 0.7900 | elapsed 193.3s


  it 500/748 | loss 0.5072 | acc 0.7897 | elapsed 214.8s


  it 550/748 | loss 0.5086 | acc 0.7923 | elapsed 236.4s


  it 600/748 | loss 0.5081 | acc 0.7926 | elapsed 257.9s


  it 650/748 | loss 0.5087 | acc 0.7949 | elapsed 279.5s


  it 700/748 | loss 0.5057 | acc 0.7954 | elapsed 301.1s


ConvNeXt Fold 3 | Epoch 8/12 | tr_loss 0.5023 tr_acc 0.7982 | va_acc 0.8841 | elapsed_fold 46.7m


  it 50/748 | loss 0.5161 | acc 0.7570 | elapsed 22.4s


  it 100/748 | loss 0.5103 | acc 0.7775 | elapsed 44.0s


  it 150/748 | loss 0.5140 | acc 0.7827 | elapsed 65.6s


  it 200/748 | loss 0.5024 | acc 0.7890 | elapsed 87.2s


  it 250/748 | loss 0.4991 | acc 0.7920 | elapsed 108.7s


  it 300/748 | loss 0.5038 | acc 0.7897 | elapsed 130.3s


  it 350/748 | loss 0.4986 | acc 0.7939 | elapsed 151.9s


  it 400/748 | loss 0.4982 | acc 0.7921 | elapsed 173.4s


  it 450/748 | loss 0.4925 | acc 0.7912 | elapsed 194.9s


  it 500/748 | loss 0.4939 | acc 0.7890 | elapsed 216.4s


  it 550/748 | loss 0.4984 | acc 0.7873 | elapsed 237.9s


  it 600/748 | loss 0.4949 | acc 0.7869 | elapsed 259.5s


  it 650/748 | loss 0.4937 | acc 0.7887 | elapsed 281.1s


  it 700/748 | loss 0.4916 | acc 0.7887 | elapsed 302.8s


ConvNeXt Fold 3 | Epoch 9/12 | tr_loss 0.4948 tr_acc 0.7894 | va_acc 0.8892 | elapsed_fold 52.5m


  it 50/748 | loss 0.4426 | acc 0.8500 | elapsed 22.7s


  it 100/748 | loss 0.4576 | acc 0.8285 | elapsed 44.3s


  it 150/748 | loss 0.4777 | acc 0.8150 | elapsed 65.8s


  it 200/748 | loss 0.4727 | acc 0.8103 | elapsed 87.4s


  it 250/748 | loss 0.4718 | acc 0.8120 | elapsed 109.0s


  it 300/748 | loss 0.4611 | acc 0.8175 | elapsed 130.6s


  it 350/748 | loss 0.4593 | acc 0.8144 | elapsed 152.2s


  it 400/748 | loss 0.4642 | acc 0.8043 | elapsed 173.7s


  it 450/748 | loss 0.4635 | acc 0.8048 | elapsed 195.2s


  it 500/748 | loss 0.4583 | acc 0.8071 | elapsed 216.6s


  it 550/748 | loss 0.4578 | acc 0.8093 | elapsed 238.0s


  it 600/748 | loss 0.4605 | acc 0.8089 | elapsed 259.5s


  it 650/748 | loss 0.4652 | acc 0.8085 | elapsed 280.9s


  it 700/748 | loss 0.4636 | acc 0.8076 | elapsed 302.4s


ConvNeXt Fold 3 | Epoch 10/12 | tr_loss 0.4638 tr_acc 0.8060 | va_acc 0.8902 | elapsed_fold 58.4m


  it 50/748 | loss 0.3235 | acc 0.8880 | elapsed 22.5s


  it 100/748 | loss 0.3528 | acc 0.8810 | elapsed 44.0s


  it 150/748 | loss 0.3407 | acc 0.8877 | elapsed 65.6s


  it 200/748 | loss 0.3285 | acc 0.8915 | elapsed 87.2s


  it 250/748 | loss 0.3266 | acc 0.8918 | elapsed 109.0s


  it 300/748 | loss 0.3232 | acc 0.8927 | elapsed 130.7s


  it 350/748 | loss 0.3174 | acc 0.8950 | elapsed 152.4s


  it 400/748 | loss 0.3192 | acc 0.8936 | elapsed 174.1s


  it 450/748 | loss 0.3155 | acc 0.8946 | elapsed 195.7s


  it 500/748 | loss 0.3129 | acc 0.8945 | elapsed 217.3s


  it 550/748 | loss 0.3169 | acc 0.8922 | elapsed 238.9s


  it 600/748 | loss 0.3161 | acc 0.8922 | elapsed 260.4s


  it 650/748 | loss 0.3167 | acc 0.8923 | elapsed 281.9s


  it 700/748 | loss 0.3204 | acc 0.8917 | elapsed 303.4s


ConvNeXt Fold 3 | Epoch 11/12 | tr_loss 0.3194 tr_acc 0.8918 | va_acc 0.8894 | elapsed_fold 64.2m


  it 50/748 | loss 0.3318 | acc 0.8950 | elapsed 22.5s


  it 100/748 | loss 0.3085 | acc 0.9010 | elapsed 43.9s


  it 150/748 | loss 0.3113 | acc 0.8967 | elapsed 65.3s


  it 200/748 | loss 0.3077 | acc 0.8985 | elapsed 86.7s


  it 250/748 | loss 0.3070 | acc 0.8978 | elapsed 108.2s


  it 300/748 | loss 0.3093 | acc 0.8980 | elapsed 129.7s


  it 350/748 | loss 0.3062 | acc 0.8979 | elapsed 151.3s


  it 400/748 | loss 0.3044 | acc 0.8988 | elapsed 172.9s


  it 450/748 | loss 0.3009 | acc 0.8992 | elapsed 194.6s


  it 500/748 | loss 0.3083 | acc 0.8973 | elapsed 216.2s


  it 550/748 | loss 0.3070 | acc 0.8979 | elapsed 237.8s


  it 600/748 | loss 0.3072 | acc 0.8982 | elapsed 259.4s


  it 650/748 | loss 0.3128 | acc 0.8966 | elapsed 281.0s


  it 700/748 | loss 0.3130 | acc 0.8958 | elapsed 302.6s


ConvNeXt Fold 3 | Epoch 12/12 | tr_loss 0.3144 tr_acc 0.8950 | va_acc 0.8884 | elapsed_fold 70.1m


ConvNeXt Fold 3 done | best_va_acc 0.8902 | ckpt convnext_fold3_best.pth | fold_time 70.9m


===== ConvNeXt Fold 4/5 | train 14977 | valid 3744 =====


CNX cfg: {'input_size': (3, 512, 512), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225)}


  it 50/748 | loss 1.2234 | acc 0.5910 | elapsed 22.1s


  it 100/748 | loss 1.1227 | acc 0.6120 | elapsed 43.5s


  it 150/748 | loss 1.0373 | acc 0.6300 | elapsed 64.9s


  it 200/748 | loss 0.9752 | acc 0.6488 | elapsed 86.4s


  it 250/748 | loss 0.9350 | acc 0.6598 | elapsed 107.9s


  it 300/748 | loss 0.9068 | acc 0.6690 | elapsed 129.4s


  it 350/748 | loss 0.8821 | acc 0.6787 | elapsed 150.8s


  it 400/748 | loss 0.8598 | acc 0.6823 | elapsed 172.3s


  it 450/748 | loss 0.8582 | acc 0.6842 | elapsed 193.8s


  it 500/748 | loss 0.8530 | acc 0.6865 | elapsed 215.4s


  it 550/748 | loss 0.8461 | acc 0.6898 | elapsed 237.0s


  it 600/748 | loss 0.8390 | acc 0.6922 | elapsed 258.6s


  it 650/748 | loss 0.8342 | acc 0.6928 | elapsed 280.1s


  it 700/748 | loss 0.8287 | acc 0.6935 | elapsed 301.9s


ConvNeXt Fold 4 | Epoch 1/12 | tr_loss 0.8278 tr_acc 0.6942 | va_acc 0.6039 | elapsed_fold 5.8m


  it 50/748 | loss 0.7988 | acc 0.7030 | elapsed 22.4s


  it 100/748 | loss 0.7981 | acc 0.6985 | elapsed 43.8s


  it 150/748 | loss 0.7847 | acc 0.6997 | elapsed 65.1s


  it 200/748 | loss 0.7861 | acc 0.7037 | elapsed 86.5s


  it 250/748 | loss 0.7801 | acc 0.7090 | elapsed 107.9s


  it 300/748 | loss 0.7766 | acc 0.7083 | elapsed 129.2s


  it 350/748 | loss 0.7655 | acc 0.7144 | elapsed 150.5s


  it 400/748 | loss 0.7631 | acc 0.7134 | elapsed 171.9s


  it 450/748 | loss 0.7514 | acc 0.7161 | elapsed 193.2s


  it 500/748 | loss 0.7469 | acc 0.7164 | elapsed 214.6s


  it 550/748 | loss 0.7523 | acc 0.7131 | elapsed 236.0s


  it 600/748 | loss 0.7524 | acc 0.7137 | elapsed 257.3s


  it 650/748 | loss 0.7467 | acc 0.7159 | elapsed 278.8s


  it 700/748 | loss 0.7447 | acc 0.7163 | elapsed 300.5s


ConvNeXt Fold 4 | Epoch 2/12 | tr_loss 0.7439 tr_acc 0.7179 | va_acc 0.7569 | elapsed_fold 11.7m


  it 50/748 | loss 0.7745 | acc 0.7170 | elapsed 22.4s


  it 100/748 | loss 0.7433 | acc 0.7230 | elapsed 43.9s


  it 150/748 | loss 0.7428 | acc 0.7223 | elapsed 65.4s


  it 200/748 | loss 0.7433 | acc 0.7248 | elapsed 86.8s


  it 250/748 | loss 0.7327 | acc 0.7240 | elapsed 108.3s


  it 300/748 | loss 0.7208 | acc 0.7293 | elapsed 129.7s


  it 350/748 | loss 0.7166 | acc 0.7260 | elapsed 151.1s


  it 400/748 | loss 0.7115 | acc 0.7256 | elapsed 172.5s


  it 450/748 | loss 0.7141 | acc 0.7266 | elapsed 193.9s


  it 500/748 | loss 0.7092 | acc 0.7270 | elapsed 215.2s


  it 550/748 | loss 0.7060 | acc 0.7253 | elapsed 236.6s


  it 600/748 | loss 0.7040 | acc 0.7249 | elapsed 257.9s


  it 650/748 | loss 0.6971 | acc 0.7279 | elapsed 279.4s


  it 700/748 | loss 0.6929 | acc 0.7306 | elapsed 301.1s


ConvNeXt Fold 4 | Epoch 3/12 | tr_loss 0.6890 tr_acc 0.7336 | va_acc 0.8181 | elapsed_fold 17.5m


  it 50/748 | loss 0.6427 | acc 0.7380 | elapsed 22.3s


  it 100/748 | loss 0.5939 | acc 0.7655 | elapsed 43.8s


  it 150/748 | loss 0.6043 | acc 0.7510 | elapsed 65.3s


  it 200/748 | loss 0.6217 | acc 0.7500 | elapsed 86.9s


  it 250/748 | loss 0.6256 | acc 0.7450 | elapsed 108.4s


  it 300/748 | loss 0.6210 | acc 0.7455 | elapsed 129.9s


  it 350/748 | loss 0.6267 | acc 0.7394 | elapsed 151.4s


  it 400/748 | loss 0.6129 | acc 0.7448 | elapsed 172.8s


  it 450/748 | loss 0.6140 | acc 0.7438 | elapsed 194.2s


  it 500/748 | loss 0.6144 | acc 0.7469 | elapsed 215.6s


  it 550/748 | loss 0.6092 | acc 0.7478 | elapsed 237.1s


  it 600/748 | loss 0.6101 | acc 0.7492 | elapsed 258.5s


  it 650/748 | loss 0.6116 | acc 0.7493 | elapsed 279.8s


  it 700/748 | loss 0.6073 | acc 0.7511 | elapsed 301.2s


ConvNeXt Fold 4 | Epoch 4/12 | tr_loss 0.6102 tr_acc 0.7525 | va_acc 0.8528 | elapsed_fold 23.3m


  it 50/748 | loss 0.6188 | acc 0.7510 | elapsed 22.3s


  it 100/748 | loss 0.6139 | acc 0.7390 | elapsed 43.7s


  it 150/748 | loss 0.6118 | acc 0.7523 | elapsed 65.2s


  it 200/748 | loss 0.6263 | acc 0.7588 | elapsed 86.8s


  it 250/748 | loss 0.6241 | acc 0.7626 | elapsed 108.2s


  it 300/748 | loss 0.6204 | acc 0.7633 | elapsed 129.7s


  it 350/748 | loss 0.6082 | acc 0.7713 | elapsed 151.2s


  it 400/748 | loss 0.6073 | acc 0.7748 | elapsed 172.8s


  it 450/748 | loss 0.6115 | acc 0.7719 | elapsed 194.4s


  it 500/748 | loss 0.6108 | acc 0.7722 | elapsed 215.9s


  it 550/748 | loss 0.6114 | acc 0.7699 | elapsed 237.4s


  it 600/748 | loss 0.6049 | acc 0.7726 | elapsed 258.9s


  it 650/748 | loss 0.6078 | acc 0.7709 | elapsed 280.4s


  it 700/748 | loss 0.6052 | acc 0.7680 | elapsed 301.9s


ConvNeXt Fold 4 | Epoch 5/12 | tr_loss 0.6003 tr_acc 0.7711 | va_acc 0.8723 | elapsed_fold 29.1m


  it 50/748 | loss 0.5272 | acc 0.8020 | elapsed 22.3s


  it 100/748 | loss 0.5564 | acc 0.7930 | elapsed 43.7s


  it 150/748 | loss 0.5540 | acc 0.7997 | elapsed 65.2s


  it 200/748 | loss 0.5651 | acc 0.7927 | elapsed 86.7s


  it 250/748 | loss 0.5675 | acc 0.7842 | elapsed 108.1s


  it 300/748 | loss 0.5701 | acc 0.7845 | elapsed 129.6s


  it 350/748 | loss 0.5675 | acc 0.7820 | elapsed 151.0s


  it 400/748 | loss 0.5652 | acc 0.7782 | elapsed 172.5s


  it 450/748 | loss 0.5664 | acc 0.7781 | elapsed 194.1s


  it 500/748 | loss 0.5598 | acc 0.7826 | elapsed 215.7s


  it 550/748 | loss 0.5532 | acc 0.7835 | elapsed 237.3s


  it 600/748 | loss 0.5553 | acc 0.7803 | elapsed 259.0s


  it 650/748 | loss 0.5456 | acc 0.7843 | elapsed 280.8s


  it 700/748 | loss 0.5478 | acc 0.7829 | elapsed 302.5s


ConvNeXt Fold 4 | Epoch 6/12 | tr_loss 0.5459 tr_acc 0.7848 | va_acc 0.8841 | elapsed_fold 35.0m


  it 50/748 | loss 0.5613 | acc 0.7810 | elapsed 22.3s


  it 100/748 | loss 0.5396 | acc 0.7630 | elapsed 43.8s


  it 150/748 | loss 0.5309 | acc 0.7653 | elapsed 65.2s


  it 200/748 | loss 0.5215 | acc 0.7708 | elapsed 86.8s


  it 250/748 | loss 0.5283 | acc 0.7742 | elapsed 108.4s


  it 300/748 | loss 0.5183 | acc 0.7865 | elapsed 130.0s


  it 350/748 | loss 0.5192 | acc 0.7860 | elapsed 151.5s


  it 400/748 | loss 0.5194 | acc 0.7861 | elapsed 173.0s


  it 450/748 | loss 0.5260 | acc 0.7873 | elapsed 194.4s


  it 500/748 | loss 0.5249 | acc 0.7861 | elapsed 215.9s


  it 550/748 | loss 0.5212 | acc 0.7867 | elapsed 237.4s


  it 600/748 | loss 0.5210 | acc 0.7883 | elapsed 258.9s


  it 650/748 | loss 0.5239 | acc 0.7878 | elapsed 280.5s


  it 700/748 | loss 0.5288 | acc 0.7874 | elapsed 302.1s


ConvNeXt Fold 4 | Epoch 7/12 | tr_loss 0.5253 tr_acc 0.7876 | va_acc 0.8892 | elapsed_fold 40.8m


  it 50/748 | loss 0.5378 | acc 0.7890 | elapsed 22.4s


  it 100/748 | loss 0.5155 | acc 0.7970 | elapsed 44.0s


  it 150/748 | loss 0.5130 | acc 0.8037 | elapsed 65.7s


  it 200/748 | loss 0.5210 | acc 0.8017 | elapsed 87.3s


  it 250/748 | loss 0.5174 | acc 0.7976 | elapsed 108.8s


  it 300/748 | loss 0.5255 | acc 0.8003 | elapsed 130.2s


  it 350/748 | loss 0.5211 | acc 0.7997 | elapsed 151.7s


  it 400/748 | loss 0.5239 | acc 0.7976 | elapsed 173.2s


  it 450/748 | loss 0.5225 | acc 0.7993 | elapsed 194.6s


  it 500/748 | loss 0.5182 | acc 0.8004 | elapsed 215.9s


  it 550/748 | loss 0.5121 | acc 0.8009 | elapsed 237.3s


  it 600/748 | loss 0.5135 | acc 0.8014 | elapsed 258.7s


  it 650/748 | loss 0.5110 | acc 0.8015 | elapsed 280.0s


  it 700/748 | loss 0.5132 | acc 0.8017 | elapsed 301.4s


ConvNeXt Fold 4 | Epoch 8/12 | tr_loss 0.5177 tr_acc 0.8018 | va_acc 0.8926 | elapsed_fold 46.6m


  it 50/748 | loss 0.5856 | acc 0.7720 | elapsed 22.3s


  it 100/748 | loss 0.5610 | acc 0.7785 | elapsed 43.8s


  it 150/748 | loss 0.5357 | acc 0.7787 | elapsed 65.3s


  it 200/748 | loss 0.5370 | acc 0.7815 | elapsed 86.9s


  it 250/748 | loss 0.5248 | acc 0.7890 | elapsed 108.6s


  it 300/748 | loss 0.5352 | acc 0.7857 | elapsed 130.2s


  it 350/748 | loss 0.5274 | acc 0.7897 | elapsed 151.8s


  it 400/748 | loss 0.5234 | acc 0.7944 | elapsed 173.4s


  it 450/748 | loss 0.5166 | acc 0.7964 | elapsed 195.0s


  it 500/748 | loss 0.5085 | acc 0.7955 | elapsed 216.6s


  it 550/748 | loss 0.5020 | acc 0.7959 | elapsed 238.1s


  it 600/748 | loss 0.4993 | acc 0.7962 | elapsed 259.6s


  it 650/748 | loss 0.5004 | acc 0.7942 | elapsed 281.0s


  it 700/748 | loss 0.4964 | acc 0.7929 | elapsed 302.5s


ConvNeXt Fold 4 | Epoch 9/12 | tr_loss 0.5015 tr_acc 0.7918 | va_acc 0.8924 | elapsed_fold 52.5m


  it 50/748 | loss 0.5496 | acc 0.7600 | elapsed 22.2s


  it 100/748 | loss 0.4972 | acc 0.7945 | elapsed 43.6s


  it 150/748 | loss 0.4977 | acc 0.7953 | elapsed 65.0s


  it 200/748 | loss 0.4954 | acc 0.8025 | elapsed 86.5s


  it 250/748 | loss 0.4868 | acc 0.8088 | elapsed 108.0s


  it 300/748 | loss 0.4841 | acc 0.8093 | elapsed 129.5s


  it 350/748 | loss 0.4855 | acc 0.8093 | elapsed 151.0s


  it 400/748 | loss 0.4878 | acc 0.8081 | elapsed 172.6s


  it 450/748 | loss 0.4900 | acc 0.8072 | elapsed 194.2s


  it 500/748 | loss 0.4929 | acc 0.8046 | elapsed 215.8s


  it 550/748 | loss 0.4887 | acc 0.8063 | elapsed 237.4s


  it 600/748 | loss 0.4876 | acc 0.8045 | elapsed 259.0s


  it 650/748 | loss 0.4891 | acc 0.8039 | elapsed 280.5s


  it 700/748 | loss 0.4874 | acc 0.8064 | elapsed 302.1s


ConvNeXt Fold 4 | Epoch 10/12 | tr_loss 0.4837 tr_acc 0.8070 | va_acc 0.8937 | elapsed_fold 58.3m


  it 50/748 | loss 0.3189 | acc 0.8860 | elapsed 22.3s


  it 100/748 | loss 0.2948 | acc 0.9010 | elapsed 43.8s


  it 150/748 | loss 0.3121 | acc 0.8973 | elapsed 65.3s


  it 200/748 | loss 0.3109 | acc 0.8945 | elapsed 86.9s


  it 250/748 | loss 0.3149 | acc 0.8938 | elapsed 108.4s


  it 300/748 | loss 0.3177 | acc 0.8928 | elapsed 129.8s


  it 350/748 | loss 0.3184 | acc 0.8926 | elapsed 151.4s


  it 400/748 | loss 0.3247 | acc 0.8896 | elapsed 173.0s


  it 450/748 | loss 0.3246 | acc 0.8904 | elapsed 194.6s


  it 500/748 | loss 0.3216 | acc 0.8916 | elapsed 216.2s


  it 550/748 | loss 0.3239 | acc 0.8907 | elapsed 237.9s


  it 600/748 | loss 0.3258 | acc 0.8901 | elapsed 259.6s


  it 650/748 | loss 0.3269 | acc 0.8898 | elapsed 281.3s


  it 700/748 | loss 0.3238 | acc 0.8909 | elapsed 303.0s


ConvNeXt Fold 4 | Epoch 11/12 | tr_loss 0.3217 tr_acc 0.8916 | va_acc 0.8942 | elapsed_fold 64.2m


  it 50/748 | loss 0.3222 | acc 0.8850 | elapsed 22.3s


  it 100/748 | loss 0.3050 | acc 0.8895 | elapsed 43.8s


  it 150/748 | loss 0.3109 | acc 0.8937 | elapsed 65.2s


  it 200/748 | loss 0.3175 | acc 0.8932 | elapsed 86.6s


  it 250/748 | loss 0.3166 | acc 0.8936 | elapsed 108.0s


  it 300/748 | loss 0.3124 | acc 0.8947 | elapsed 129.4s


  it 350/748 | loss 0.3156 | acc 0.8941 | elapsed 150.8s


  it 400/748 | loss 0.3102 | acc 0.8965 | elapsed 172.1s


  it 450/748 | loss 0.3072 | acc 0.8988 | elapsed 193.5s


  it 500/748 | loss 0.3062 | acc 0.8981 | elapsed 214.9s


  it 550/748 | loss 0.3079 | acc 0.8975 | elapsed 236.4s


  it 600/748 | loss 0.3092 | acc 0.8964 | elapsed 257.8s


  it 650/748 | loss 0.3120 | acc 0.8947 | elapsed 279.3s


  it 700/748 | loss 0.3108 | acc 0.8952 | elapsed 300.7s


ConvNeXt Fold 4 | Epoch 12/12 | tr_loss 0.3109 tr_acc 0.8953 | va_acc 0.8972 | elapsed_fold 70.0m


ConvNeXt Fold 4 done | best_va_acc 0.8972 | ckpt convnext_fold4_best.pth | fold_time 70.8m


===== ConvNeXt Fold 5/5 | train 14977 | valid 3744 =====


CNX cfg: {'input_size': (3, 512, 512), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225)}


  it 50/748 | loss 1.2715 | acc 0.5190 | elapsed 22.2s


  it 100/748 | loss 1.1035 | acc 0.5805 | elapsed 43.8s


  it 150/748 | loss 1.0286 | acc 0.6027 | elapsed 65.3s


  it 200/748 | loss 0.9798 | acc 0.6198 | elapsed 86.8s


  it 250/748 | loss 0.9363 | acc 0.6406 | elapsed 108.3s


  it 300/748 | loss 0.9028 | acc 0.6583 | elapsed 129.8s


  it 350/748 | loss 0.8873 | acc 0.6590 | elapsed 151.2s


  it 400/748 | loss 0.8666 | acc 0.6679 | elapsed 172.5s


  it 450/748 | loss 0.8428 | acc 0.6782 | elapsed 193.9s


  it 500/748 | loss 0.8364 | acc 0.6802 | elapsed 215.3s


  it 550/748 | loss 0.8291 | acc 0.6825 | elapsed 236.7s


  it 600/748 | loss 0.8196 | acc 0.6867 | elapsed 258.2s


  it 650/748 | loss 0.8136 | acc 0.6887 | elapsed 279.6s


  it 700/748 | loss 0.8098 | acc 0.6899 | elapsed 301.3s


ConvNeXt Fold 5 | Epoch 1/12 | tr_loss 0.8028 tr_acc 0.6931 | va_acc 0.4671 | elapsed_fold 5.8m


  it 50/748 | loss 0.8118 | acc 0.6870 | elapsed 22.4s


  it 100/748 | loss 0.7710 | acc 0.6995 | elapsed 44.0s


  it 150/748 | loss 0.7670 | acc 0.7093 | elapsed 65.8s


  it 200/748 | loss 0.7618 | acc 0.7000 | elapsed 87.5s


  it 250/748 | loss 0.7651 | acc 0.7044 | elapsed 109.2s


  it 300/748 | loss 0.7552 | acc 0.7090 | elapsed 130.8s


  it 350/748 | loss 0.7610 | acc 0.7064 | elapsed 152.4s


  it 400/748 | loss 0.7582 | acc 0.7093 | elapsed 174.1s


  it 450/748 | loss 0.7625 | acc 0.7092 | elapsed 195.7s


  it 500/748 | loss 0.7609 | acc 0.7066 | elapsed 217.3s


  it 550/748 | loss 0.7593 | acc 0.7078 | elapsed 238.9s


  it 600/748 | loss 0.7536 | acc 0.7103 | elapsed 260.3s


  it 650/748 | loss 0.7468 | acc 0.7114 | elapsed 281.8s


  it 700/748 | loss 0.7488 | acc 0.7124 | elapsed 303.7s


ConvNeXt Fold 5 | Epoch 2/12 | tr_loss 0.7460 tr_acc 0.7134 | va_acc 0.7193 | elapsed_fold 11.7m


  it 50/748 | loss 0.7049 | acc 0.7460 | elapsed 22.2s


  it 100/748 | loss 0.7109 | acc 0.7400 | elapsed 43.7s


  it 150/748 | loss 0.7084 | acc 0.7393 | elapsed 65.1s


  it 200/748 | loss 0.7170 | acc 0.7370 | elapsed 86.6s


  it 250/748 | loss 0.7129 | acc 0.7370 | elapsed 108.2s


  it 300/748 | loss 0.7072 | acc 0.7377 | elapsed 129.9s


  it 350/748 | loss 0.7156 | acc 0.7337 | elapsed 151.6s


  it 400/748 | loss 0.7106 | acc 0.7320 | elapsed 173.3s


  it 450/748 | loss 0.7088 | acc 0.7331 | elapsed 194.9s


  it 500/748 | loss 0.7074 | acc 0.7362 | elapsed 216.6s


  it 550/748 | loss 0.7038 | acc 0.7373 | elapsed 238.1s


  it 600/748 | loss 0.6967 | acc 0.7408 | elapsed 259.7s


  it 650/748 | loss 0.6966 | acc 0.7400 | elapsed 281.3s


  it 700/748 | loss 0.6907 | acc 0.7402 | elapsed 302.9s


ConvNeXt Fold 5 | Epoch 3/12 | tr_loss 0.6838 tr_acc 0.7437 | va_acc 0.8088 | elapsed_fold 17.5m


  it 50/748 | loss 0.6114 | acc 0.7680 | elapsed 22.3s


  it 100/748 | loss 0.6015 | acc 0.7760 | elapsed 43.7s


  it 150/748 | loss 0.6238 | acc 0.7663 | elapsed 65.1s


  it 200/748 | loss 0.6131 | acc 0.7702 | elapsed 86.5s


  it 250/748 | loss 0.6144 | acc 0.7672 | elapsed 108.0s


  it 300/748 | loss 0.6083 | acc 0.7705 | elapsed 129.5s


  it 350/748 | loss 0.6093 | acc 0.7696 | elapsed 151.0s


  it 400/748 | loss 0.6136 | acc 0.7670 | elapsed 172.6s


  it 450/748 | loss 0.6128 | acc 0.7681 | elapsed 194.3s


  it 500/748 | loss 0.6146 | acc 0.7706 | elapsed 215.9s


  it 550/748 | loss 0.6118 | acc 0.7701 | elapsed 237.6s


  it 600/748 | loss 0.6120 | acc 0.7681 | elapsed 259.4s


  it 650/748 | loss 0.6137 | acc 0.7692 | elapsed 281.1s


  it 700/748 | loss 0.6135 | acc 0.7669 | elapsed 302.8s


ConvNeXt Fold 5 | Epoch 4/12 | tr_loss 0.6159 tr_acc 0.7646 | va_acc 0.8552 | elapsed_fold 23.4m


  it 50/748 | loss 0.6148 | acc 0.7760 | elapsed 22.3s


  it 100/748 | loss 0.5955 | acc 0.7645 | elapsed 43.7s


  it 150/748 | loss 0.5775 | acc 0.7723 | elapsed 65.2s


  it 200/748 | loss 0.5820 | acc 0.7725 | elapsed 86.7s


  it 250/748 | loss 0.5845 | acc 0.7750 | elapsed 108.1s


  it 300/748 | loss 0.5887 | acc 0.7755 | elapsed 129.5s


  it 350/748 | loss 0.5879 | acc 0.7703 | elapsed 151.0s


  it 400/748 | loss 0.5869 | acc 0.7694 | elapsed 172.5s


  it 450/748 | loss 0.5929 | acc 0.7673 | elapsed 194.0s


  it 500/748 | loss 0.5941 | acc 0.7665 | elapsed 215.6s


  it 550/748 | loss 0.5921 | acc 0.7673 | elapsed 237.1s


  it 600/748 | loss 0.5931 | acc 0.7649 | elapsed 258.7s


  it 650/748 | loss 0.5917 | acc 0.7659 | elapsed 280.3s


  it 700/748 | loss 0.5914 | acc 0.7660 | elapsed 301.9s


ConvNeXt Fold 5 | Epoch 5/12 | tr_loss 0.5933 tr_acc 0.7631 | va_acc 0.8785 | elapsed_fold 29.2m


  it 50/748 | loss 0.6255 | acc 0.7770 | elapsed 22.4s


  it 100/748 | loss 0.6327 | acc 0.7570 | elapsed 44.0s


  it 150/748 | loss 0.5953 | acc 0.7690 | elapsed 65.5s


  it 200/748 | loss 0.5925 | acc 0.7792 | elapsed 87.1s


  it 250/748 | loss 0.5826 | acc 0.7796 | elapsed 108.5s


  it 300/748 | loss 0.5799 | acc 0.7800 | elapsed 130.0s


  it 350/748 | loss 0.5721 | acc 0.7837 | elapsed 151.4s


  it 400/748 | loss 0.5758 | acc 0.7805 | elapsed 172.9s


  it 450/748 | loss 0.5700 | acc 0.7823 | elapsed 194.4s


  it 500/748 | loss 0.5639 | acc 0.7819 | elapsed 215.8s


  it 550/748 | loss 0.5643 | acc 0.7802 | elapsed 237.3s


  it 600/748 | loss 0.5641 | acc 0.7788 | elapsed 258.8s


  it 650/748 | loss 0.5651 | acc 0.7802 | elapsed 280.3s


  it 700/748 | loss 0.5673 | acc 0.7759 | elapsed 301.9s


ConvNeXt Fold 5 | Epoch 6/12 | tr_loss 0.5649 tr_acc 0.7787 | va_acc 0.8913 | elapsed_fold 35.1m


  it 50/748 | loss 0.4976 | acc 0.8130 | elapsed 22.4s


  it 100/748 | loss 0.5496 | acc 0.7900 | elapsed 44.0s


  it 150/748 | loss 0.5346 | acc 0.7930 | elapsed 65.7s


  it 200/748 | loss 0.5228 | acc 0.8035 | elapsed 87.4s


  it 250/748 | loss 0.5132 | acc 0.8066 | elapsed 109.0s


  it 300/748 | loss 0.5308 | acc 0.7955 | elapsed 130.7s


  it 350/748 | loss 0.5342 | acc 0.7954 | elapsed 152.3s


  it 400/748 | loss 0.5316 | acc 0.7967 | elapsed 173.9s


  it 450/748 | loss 0.5332 | acc 0.7957 | elapsed 195.4s


  it 500/748 | loss 0.5312 | acc 0.7952 | elapsed 217.0s


  it 550/748 | loss 0.5330 | acc 0.7913 | elapsed 238.5s


  it 600/748 | loss 0.5286 | acc 0.7893 | elapsed 260.0s


  it 650/748 | loss 0.5370 | acc 0.7847 | elapsed 281.5s


  it 700/748 | loss 0.5360 | acc 0.7866 | elapsed 303.0s


ConvNeXt Fold 5 | Epoch 7/12 | tr_loss 0.5344 tr_acc 0.7862 | va_acc 0.8964 | elapsed_fold 40.9m


  it 50/748 | loss 0.4455 | acc 0.8380 | elapsed 22.2s


  it 100/748 | loss 0.4804 | acc 0.8190 | elapsed 43.7s


  it 150/748 | loss 0.4985 | acc 0.8217 | elapsed 65.2s


  it 200/748 | loss 0.5052 | acc 0.8130 | elapsed 86.8s


  it 250/748 | loss 0.4957 | acc 0.8146 | elapsed 108.4s


  it 300/748 | loss 0.5012 | acc 0.8142 | elapsed 130.1s


  it 350/748 | loss 0.5088 | acc 0.8083 | elapsed 151.7s


  it 400/748 | loss 0.5127 | acc 0.8049 | elapsed 173.3s


  it 450/748 | loss 0.5200 | acc 0.8014 | elapsed 195.0s


  it 500/748 | loss 0.5203 | acc 0.7986 | elapsed 216.7s


  it 550/748 | loss 0.5191 | acc 0.8015 | elapsed 238.4s


  it 600/748 | loss 0.5191 | acc 0.8008 | elapsed 260.0s


  it 650/748 | loss 0.5177 | acc 0.8009 | elapsed 281.5s


  it 700/748 | loss 0.5168 | acc 0.8004 | elapsed 303.1s


ConvNeXt Fold 5 | Epoch 8/12 | tr_loss 0.5170 tr_acc 0.8015 | va_acc 0.9006 | elapsed_fold 46.8m


  it 50/748 | loss 0.5738 | acc 0.7850 | elapsed 22.2s


  it 100/748 | loss 0.5356 | acc 0.8025 | elapsed 43.6s


  it 150/748 | loss 0.5195 | acc 0.7913 | elapsed 65.0s


  it 200/748 | loss 0.5069 | acc 0.7997 | elapsed 86.4s


  it 250/748 | loss 0.5039 | acc 0.7976 | elapsed 107.8s


  it 300/748 | loss 0.5022 | acc 0.8003 | elapsed 129.2s


  it 350/748 | loss 0.4923 | acc 0.8019 | elapsed 150.6s


  it 400/748 | loss 0.4900 | acc 0.8006 | elapsed 172.1s


  it 450/748 | loss 0.4952 | acc 0.8006 | elapsed 193.5s


  it 500/748 | loss 0.4962 | acc 0.8001 | elapsed 215.0s


  it 550/748 | loss 0.4849 | acc 0.8030 | elapsed 236.6s


  it 600/748 | loss 0.4829 | acc 0.8049 | elapsed 258.1s


  it 650/748 | loss 0.4885 | acc 0.8012 | elapsed 279.7s


  it 700/748 | loss 0.4909 | acc 0.7998 | elapsed 301.2s


ConvNeXt Fold 5 | Epoch 9/12 | tr_loss 0.4909 tr_acc 0.8023 | va_acc 0.9022 | elapsed_fold 52.6m


  it 50/748 | loss 0.4616 | acc 0.8310 | elapsed 22.4s


  it 100/748 | loss 0.4681 | acc 0.8075 | elapsed 43.9s


  it 150/748 | loss 0.4861 | acc 0.8103 | elapsed 65.4s


  it 200/748 | loss 0.4771 | acc 0.8095 | elapsed 86.9s


  it 250/748 | loss 0.4846 | acc 0.8022 | elapsed 108.4s


  it 300/748 | loss 0.4776 | acc 0.8072 | elapsed 130.0s


  it 350/748 | loss 0.4673 | acc 0.8081 | elapsed 151.5s


  it 400/748 | loss 0.4650 | acc 0.8019 | elapsed 173.0s


  it 450/748 | loss 0.4661 | acc 0.8011 | elapsed 194.4s


  it 500/748 | loss 0.4637 | acc 0.8045 | elapsed 215.9s


  it 550/748 | loss 0.4650 | acc 0.8065 | elapsed 237.4s


  it 600/748 | loss 0.4609 | acc 0.8073 | elapsed 259.1s


  it 650/748 | loss 0.4656 | acc 0.8042 | elapsed 280.8s


  it 700/748 | loss 0.4696 | acc 0.8049 | elapsed 302.4s


ConvNeXt Fold 5 | Epoch 10/12 | tr_loss 0.4726 tr_acc 0.8045 | va_acc 0.9022 | elapsed_fold 58.5m


  it 50/748 | loss 0.3417 | acc 0.8840 | elapsed 22.4s


  it 100/748 | loss 0.3222 | acc 0.8930 | elapsed 43.9s


  it 150/748 | loss 0.3213 | acc 0.8903 | elapsed 65.4s


  it 200/748 | loss 0.3207 | acc 0.8902 | elapsed 86.8s


  it 250/748 | loss 0.3279 | acc 0.8898 | elapsed 108.3s


  it 300/748 | loss 0.3247 | acc 0.8892 | elapsed 129.7s


  it 350/748 | loss 0.3266 | acc 0.8883 | elapsed 151.1s


  it 400/748 | loss 0.3233 | acc 0.8885 | elapsed 172.5s


  it 450/748 | loss 0.3195 | acc 0.8899 | elapsed 193.9s


  it 500/748 | loss 0.3200 | acc 0.8900 | elapsed 215.3s


  it 550/748 | loss 0.3183 | acc 0.8905 | elapsed 236.6s


  it 600/748 | loss 0.3162 | acc 0.8912 | elapsed 258.0s


  it 650/748 | loss 0.3168 | acc 0.8910 | elapsed 279.4s


  it 700/748 | loss 0.3158 | acc 0.8915 | elapsed 300.8s


ConvNeXt Fold 5 | Epoch 11/12 | tr_loss 0.3165 tr_acc 0.8914 | va_acc 0.9046 | elapsed_fold 64.3m


  it 50/748 | loss 0.3322 | acc 0.8910 | elapsed 22.5s


  it 100/748 | loss 0.3065 | acc 0.8965 | elapsed 44.1s


  it 150/748 | loss 0.3050 | acc 0.8977 | elapsed 65.6s


  it 200/748 | loss 0.3111 | acc 0.8930 | elapsed 87.2s


  it 250/748 | loss 0.3059 | acc 0.8948 | elapsed 108.7s


  it 300/748 | loss 0.3078 | acc 0.8940 | elapsed 130.2s


  it 350/748 | loss 0.3099 | acc 0.8921 | elapsed 151.8s


  it 400/748 | loss 0.3094 | acc 0.8931 | elapsed 173.3s


  it 450/748 | loss 0.3156 | acc 0.8914 | elapsed 194.8s


  it 500/748 | loss 0.3159 | acc 0.8916 | elapsed 216.2s


  it 550/748 | loss 0.3148 | acc 0.8918 | elapsed 237.6s


  it 600/748 | loss 0.3129 | acc 0.8932 | elapsed 259.0s


  it 650/748 | loss 0.3196 | acc 0.8913 | elapsed 280.4s


  it 700/748 | loss 0.3195 | acc 0.8906 | elapsed 301.8s


ConvNeXt Fold 5 | Epoch 12/12 | tr_loss 0.3176 tr_acc 0.8918 | va_acc 0.9060 | elapsed_fold 70.1m


ConvNeXt Fold 5 done | best_va_acc 0.9060 | ckpt convnext_fold5_best.pth | fold_time 70.9m


ConvNeXt all folds done in 354.6m
ConvNeXt OOF accuracy: 0.89851
Wrote submission_convnext.csv


In [25]:
# After ConvNeXt training finishes: run 4-flip TTA for ConvNeXt and blend with B3 TTA
from pathlib import Path
import numpy as np, pandas as pd

def run_convnext_tta_and_blend():
    cnx_ckpts = [f'convnext_fold{i}_best.pth' for i in range(1,6)]
    missing = [p for p in cnx_ckpts if not Path(p).exists()]
    if missing:
        print('ConvNeXt checkpoints not all present yet, missing:', missing)
        return
    print('Running ConvNeXt 4x TTA on checkpoints:', cnx_ckpts, flush=True)
    files, logits_cnx_tta = infer_tta_model('convnext_base.fb_in22k_ft_in1k', cnx_ckpts, img_dir='test_images', batch_size=20, img_size=512)
    np.save('test_logits_convnext_tta.npy', logits_cnx_tta)
    write_submission_from_logits(files, logits_cnx_tta, 'submission_convnext_tta.csv')
    print('ConvNeXt TTA done. logits shape:', logits_cnx_tta.shape)

    # Equal-weight blend with B3 TTA if available
    b3_path = Path('test_logits_b3_tta.npy')
    cnx_path = Path('test_logits_convnext_tta.npy')
    if b3_path.exists() and cnx_path.exists():
        logits_b3 = np.load(b3_path)
        logits_cnx = np.load(cnx_path)
        if logits_b3.shape != logits_cnx.shape:
            print('Shape mismatch in logits, cannot blend:', logits_b3.shape, logits_cnx.shape)
            return
        blend = 0.5 * logits_b3 + 0.5 * logits_cnx
        pred = blend.argmax(1)
        sub = pd.DataFrame({'image_id': files, 'label': pred.astype(int)})
        sub.to_csv('submission_ensemble.csv', index=False)
        print('Wrote submission_ensemble.csv (equal-weight B3_TTA + ConvNeXt_TTA)')
    else:
        print('Missing logits for blending: b3_tta or convnext_tta not found.')

# Call run_convnext_tta_and_blend() after ConvNeXt training completes.

In [28]:
# Run ConvNeXt TTA and blend; set final submission.csv
import os, shutil
print('Starting ConvNeXt TTA + blend...', flush=True)
run_convnext_tta_and_blend()

chosen = None
if os.path.exists('submission_ensemble.csv'):
    shutil.copyfile('submission_ensemble.csv', 'submission.csv')
    chosen = 'submission_ensemble.csv'
elif os.path.exists('submission_convnext_tta.csv'):
    shutil.copyfile('submission_convnext_tta.csv', 'submission.csv')
    chosen = 'submission_convnext_tta.csv'
elif os.path.exists('submission_b3_tta.csv'):
    shutil.copyfile('submission_b3_tta.csv', 'submission.csv')
    chosen = 'submission_b3_tta.csv'
print('submission.csv set from:', chosen)

Starting ConvNeXt TTA + blend...


Running ConvNeXt 4x TTA on checkpoints: ['convnext_fold1_best.pth', 'convnext_fold2_best.pth', 'convnext_fold3_best.pth', 'convnext_fold4_best.pth', 'convnext_fold5_best.pth']


Inferred convnext_fold1_best.pth in 111.8s


Inferred convnext_fold2_best.pth in 113.4s


Inferred convnext_fold3_best.pth in 110.7s


Inferred convnext_fold4_best.pth in 110.9s


Inferred convnext_fold5_best.pth in 110.9s
Wrote submission_convnext_tta.csv
ConvNeXt TTA done. logits shape: (2676, 5)
Wrote submission_ensemble.csv (equal-weight B3_TTA + ConvNeXt_TTA)
submission.csv set from: submission_ensemble.csv


In [29]:
# Optimize ensemble weight on OOF logits and apply to TTA test logits
import numpy as np, pandas as pd, os, shutil
from pathlib import Path

def optimize_and_build_weighted_submission():
    # Load OOF logits
    b3_oof = np.load('oof_logits_b3.npy') if Path('oof_logits_b3.npy').exists() else None
    cnx_oof = np.load('oof_logits_convnext.npy') if Path('oof_logits_convnext.npy').exists() else None
    if b3_oof is None or cnx_oof is None:
        print('Missing OOF logits for optimization.'); return
    # Recreate df_shuf order to get y_true in the same order
    df_shuf = df.sample(frac=1.0, random_state=SEED).reset_index(drop=True)
    y_true = df_shuf['label'].values.astype(int)
    assert b3_oof.shape == cnx_oof.shape == (len(y_true), 5), f'OOF shape mismatch: b3 {b3_oof.shape}, cnx {cnx_oof.shape}, y {len(y_true)}'

    # Grid search weight w for ConvNeXt in [0.4, 0.6]
    best_w, best_acc = 0.5, 0.0
    for w in np.linspace(0.4, 0.6, 21):
        blend = w * cnx_oof + (1.0 - w) * b3_oof
        pred = blend.argmax(1)
        acc = (pred == y_true).mean()
        if acc > best_acc:
            best_acc, best_w = acc, float(w)
    print(f'Best OOF blend: w={best_w:.2f} (ConvNeXt weight), OOF acc={best_acc:.5f}')

    # Apply to test TTA logits
    b3_test = np.load('test_logits_b3_tta.npy') if Path('test_logits_b3_tta.npy').exists() else None
    cnx_test = np.load('test_logits_convnext_tta.npy') if Path('test_logits_convnext_tta.npy').exists() else None
    if b3_test is None or cnx_test is None:
        print('Missing test TTA logits for blending.'); return
    assert b3_test.shape == cnx_test.shape, f'Test shape mismatch: {b3_test.shape} vs {cnx_test.shape}'
    blend_test = best_w * cnx_test + (1.0 - best_w) * b3_test
    test_files = sorted([p.name for p in Path('test_images').glob('*.jpg')])
    sub = pd.DataFrame({'image_id': test_files, 'label': blend_test.argmax(1).astype(int)})
    out_csv = 'submission_weighted_ensemble.csv'
    sub.to_csv(out_csv, index=False)
    print(f'Wrote {out_csv} with w={best_w:.2f}')
    shutil.copyfile(out_csv, 'submission.csv')
    print('submission.csv updated ->', out_csv)

# Run optimization and build weighted submission
optimize_and_build_weighted_submission()

Best OOF blend: w=0.58 (ConvNeXt weight), OOF acc=0.89883
Wrote submission_weighted_ensemble.csv with w=0.58
submission.csv updated -> submission_weighted_ensemble.csv


In [30]:
# Set submission to ConvNeXt TTA-only per expert advice
import shutil, os
src = 'submission_convnext_tta.csv'
dst = 'submission.csv'
if os.path.exists(src):
    shutil.copyfile(src, dst)
    print(f'Copied {src} -> {dst}')
else:
    print(f'MISSING: {src}')

Copied submission_convnext_tta.csv -> submission.csv
