In [None]:
import warnings
warnings.simplefilter("ignore", FutureWarning)

In [None]:
SEED = 42
NUM_CLASSES = 10
IMAGE_SIZE = 32
INPUT_CHANNELS = 3
VAL_RATIO = 0.2
BATCH_SIZE = 64
NUM_WORKERS = 0
NUM_EPOCHS = 10
LEARNING_RATE = 0.01
DATA_DIR = '../data/cifar10'
LOG_DIR = '../runs'

In [None]:
from mdlw.augment import Augmenter
from mdlw.dataset import ImageDataset
from mdlw.utils.data import get_image_paths, make_class_map, train_val_split

img_paths, class_map = get_image_paths(DATA_DIR), make_class_map(DATA_DIR)
train_img_paths, val_img_paths = train_val_split(img_paths, val_ratio=VAL_RATIO, seed=SEED)

train_dataset = ImageDataset(
    image_paths=train_img_paths, 
    class_map=class_map, 
    transform=Augmenter(train=True, image_size=IMAGE_SIZE), 
)
val_dataset = ImageDataset(
    image_paths=val_img_paths,
    class_map=class_map,
    transform=Augmenter(train=False, image_size=IMAGE_SIZE), 
)

In [None]:
from mdlw.utils.visualize import show_image_grid

show_image_grid(train_dataset, num_images=9)
show_image_grid(val_dataset, num_images=9)

In [None]:
import torch
from mdlw.utils.misc import get_device

MODEL_PATH = "../runs/run_1/best_model.pt"

device = get_device()
model = torch.load(MODEL_PATH, map_location=device, weights_only=False)

print(f'Total parameters of the model: {sum(p.numel() for p in model.parameters())}')
model.eval()

In [None]:
import random
import torch
from mdlw.utils.visualize import show_image
from mdlw.utils.data import reverse_class_map


def show_prediction(model, dataset, device, correct=True):
    reversed_map = reverse_class_map(dataset.class_map)
    found = False
    while not found:
        img, label = random.choice(dataset)
        logits = model(img.unsqueeze(0).to(device))
        pred = torch.argmax(logits[0]).item()
        
        if (label == pred) == correct:
            found = True
            show_image(img, title=f'label: {reversed_map[label]}, prediction: {reversed_map[pred]}')

In [None]:
show_prediction(model, val_dataset, device, correct=True)

In [None]:
show_prediction(model, val_dataset, device, correct=False)

In [None]:
from mdlw.utils.visualize import visualize_fmap

img, label = random.choice(val_dataset)
visualize_fmap(model, img, layer_name='bn2', device=device, use_act=True)