In [None]:
import os
import argparse
import matplotlib.pyplot as plt
import torch
from data import StratifiedGroupKFoldDataModule
import cv2 as cv
from pathlib import Path
from models.base import BaseModel
from models.resnet import ResNetModel
from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", help="Batch size", type=int, default=16)
parser.add_argument(
    "--k",
    help="Number of folds in k-fold cross validation",
    type=int,
    default=5,
)
parser.add_argument(
    "--internal_k",
    help="Number of folds in k-fold cross validation",
    type=int,
    default=10,
)
parser.add_argument(
    "--data_path",
    help="Path to images",
    type=str,
    default="data",
)
parser.add_argument(
    "--metadata_path",
    help="Path to COCO metadata file",
    type=str,
    default="data/qt-coyotes-merged.json",
)
parser.add_argument(
    "--num_workers",
    help="Number of workers for dataloader",
    type=int,
    default=os.cpu_count() - 2,
)
parser.add_argument(
    "--persistent_workers",
    help="If True, the data loader will not shutdown the worker processes "
    "after a dataset has been consumed once. This allows to maintain the "
    "workers Dataset instances alive.",
    type=bool,
    default=True,
)
parser.add_argument(
    "--shuffle",
    help="Whether to shuffle each class's samples before splitting into "
    "batches. Note that the samples within each split will not be "
    "shuffled. This implementation can only shuffle groups that have "
    "approximately the same y distribution, no global shuffle will be "
    "performed.",
    type=bool,
    default=True,
)
parser.add_argument(
    "--random_state",
    help="When shuffle is True, random_state affects the ordering of the "
    "indices, which controls the randomness of each fold for each class. "
    "Otherwise, leave random_state as None. Pass an int for reproducible "
    "output across multiple function calls.",
    type=int,
    default=42,
)
parser.add_argument(
    "--internal_group",
    help="Use grouped k-fold cross validation in internal k-fold cross validation",
    action="store_true",
)
parser.add_argument(
    "--no_crop",
    help="Disabl cropping",
    action="store_true",
)
parser.add_argument(
    "--crop_size",
    help="Crop size",
    type=int,
    default=224,
)
parser.add_argument(
    "--no_external_group",
    help="Use grouped k-fold cross validation in external k-fold cross validation",
    action="store_true",
)
parser.add_argument(
    "--path_prefix",
    help="Prefix for image path",
    type=str
)
parser.add_argument(
    "--no_data_augmentation",
    help="Disable data augmentation",
    action="store_true",
)
parser.add_argument(
    "--learning_rate", help="Learning rate", type=float, default=1e-3
)
parser.add_argument(
    "--scheduler_factor",
    help="Factor by which the lr will be decreased",
    type=float,
    default=0.5,
)
parser.add_argument(
    "--scheduler_patience",
    help="Number of checks with no improvement after which lr will decrease",
    type=int,
    default=4,
)
parser.add_argument(
    "--resnet_model",
    help="Yolo pretrained model",
    type=str,
    default="ResNet18",
)
parser.add_argument(
    "--nonpretrained",
    help="Do not use pretrained weights, train from scratch",
    action="store_true",
)
args = parser.parse_args(['--metadata_path', 'data/CHIL/CHIL_uwin_mange_Marit_07242020.json'])

In [None]:
import torchvision.transforms as transforms
augmentations = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.GaussianBlur(3, sigma=(0.1, 2)),
    transforms.ColorJitter(
        brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1
    ),
])

In [None]:
loss = BCEWithLogitsLoss()
model = ResNetModel(loss, args)
print(model.feature_extractor)

In [None]:
datamodule = StratifiedGroupKFoldDataModule(args)
datamodule.prepare_data()
datamodule.setup(None)

DEBUG_PATH = Path("data/debug")

for i, datamodule_i in tqdm(enumerate(datamodule)):
    train = datamodule_i.train_dataloader()
    val = datamodule_i.val_dataloader()
    test = datamodule_i.test_dataloader()
    ltrain = len(train.dataset)
    lval = len(val.dataset)
    ltest = len(test.dataset)
    print(f"Fold {i}: train={ltrain}, val={lval}, test={ltest}, total={ltrain+lval+ltest}")
    folder_path = DEBUG_PATH / f"fold_{i}"
    for dataloader, stage in zip([train, val, test], ["train", "val", "test"]):
        stage_path = folder_path / stage
        for j, (X, Y) in tqdm(enumerate(dataloader)):
            fig, axs = plt.subplots(4, 4, figsize=(16, 16))
            for k, (x, y) in enumerate(zip(X, Y)):
                image_path = stage_path / f"{y}/batch_{j}_{k}.png"
                image_path.parent.mkdir(parents=True, exist_ok=True)
                ax = axs[k // 4, k % 4]
                x = x.permute(1, 2, 0)
                cv.imwrite(str(image_path), x.numpy())
                ax.imshow(x)
                ax.set_title(f"{stage} {y}")
            batch_path = stage_path / f"batch_{j}.png"
            batch_path.parent.mkdir(parents=True, exist_ok=True)
            plt.savefig(batch_path)
            plt.close()

            if stage == "test":
                continue
            Xa = augmentations(X)
            fig, axs = plt.subplots(4, 4, figsize=(16, 16))
            for k, (x, y) in enumerate(zip(Xa, Y)):
                ax = axs[k // 4, k % 4]
                ax.imshow(x.permute(1, 2, 0))
                ax.set_title(f"{stage} {y}")
            batch_path = stage_path / f"batch_{j}_augmented.png"
            batch_path.parent.mkdir(parents=True, exist_ok=True)
            plt.savefig(batch_path)
            plt.close()
    
