In [None]:
import warnings
warnings.simplefilter("ignore", FutureWarning)

In [None]:
from mdlw.augment import Augmenter
from mdlw.dataset import ImageDataset
from mdlw.utils.data import get_image_paths, make_class_map, train_val_split
from mdlw.utils.misc import load_cfg

cfg = load_cfg(path="../configs/config.yaml")

img_paths, class_map = get_image_paths(cfg.data_dir), make_class_map(cfg.data_dir)
train_img_paths, val_img_paths = train_val_split(img_paths, val_ratio=cfg.val_ratio, seed=cfg.seed)

train_dataset = ImageDataset(
    img_paths=train_img_paths, 
    class_map=class_map, 
    transform=Augmenter(train=True, image_size=cfg.image_size), 
)
val_dataset = ImageDataset(
    img_paths=val_img_paths,
    class_map=class_map,
    transform=Augmenter(train=False, image_size=cfg.image_size), 
)

In [None]:
from mdlw.utils.visualize import show_image_grid

show_image_grid(train_dataset, class_map, title='Train dataset images')
show_image_grid(val_dataset, class_map, title='Validation dataset images')

In [None]:
import torch
from mdlw.utils.misc import get_device

MODEL_PATH = "../runs/run_5/best_model.pt"

device = get_device()
model = torch.load(MODEL_PATH, map_location=device)

print(f'Model param count: {sum(p.numel() for p in model.parameters())}')
model.eval()

In [None]:
import random
from mdlw.utils.visualize import show_image
from mdlw.utils.data import reverse_class_map

img, label = random.choice(val_dataset)
logits = model(img.unsqueeze(0).to(device))
pred = torch.argmax(logits[0]).item()

reversed_map = reverse_class_map(class_map)
show_image(img, title=f'label: {reversed_map[label]}, prediction: {reversed_map[pred]}')

In [None]:
from mdlw.utils.visualize import visualize_fmap

visualize_fmap(model, img, layer_name='conv6', device=device)

In [3]:
from mdlw.model import ImageClassifierV3 as CLS
model = CLS(10)
print(f'Model param count: {sum(p.numel() for p in model.parameters())}')


Model param count: 1972792
