In [None]:
from fastai.vision.all import *

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
path = Path("../data/train_small")
path.ls()

In [None]:
bee_wing_stats =([0.7641, 0.7641, 0.7641], [0.1771, 0.1771, 0.1771]) # dataset mean and std to normalizeadsa
def label_func(f): return f.name[:2]

def create_dataloader(size, bs, resize_mode):
    return DataBlock(blocks = (ImageBlock, CategoryBlock),
                    get_items = get_image_files,
                    get_y     = label_func,
                    splitter  = RandomSplitter(),
                    item_tfms = Resize(size, method=resize_mode),
                    batch_tfms = Normalize.from_stats(*bee_wing_stats)
           ).dataloaders(path, bs=bs, num_workers=num_cpus(), pin_memory=True).to('mps')

def create_learner(dls, model_path, model_architecture):

    cbfs = [
            ShowGraphCallback,
            ReduceLROnPlateau(monitor='valid_loss', min_delta=0.01, patience=2),
            ]
    learn = vision_learner(dls, model_architecture, pretrained=True, cbs=cbfs, metrics=accuracy)
    learn.model_dir = '.'

    if os.path.exists(str(model_path) + '.pth'):
        learn.load(model_path, with_opt=True)
        print(f"Loaded pre-trained weights from {model_path}")
    return learn


model_path = Path("../models/prog_resnet152")

dls = create_dataloader(448, 32, 'squish')
learn = create_learner(dls, model_path, resnet152)

learn.fit_one_cycle(1, 1e-3)

learn.save(str(model_path) + '_new', with_opt=True)

In [None]:
interp = ClassificationInterpretation.from_learner(learn)

In [None]:
interp.plot_confusion_matrix(figsize=(12,12), dpi=60)

In [None]:
interp.most_confused (min_val=1)