In [None]:
from fastai.vision.all import *

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
path = Path("../data/train") # path to training data. Load with complete dataset
path.ls()

In [None]:
bee_wing_stats =([0.7641, 0.7641, 0.7641], [0.1771, 0.1771, 0.1771]) # dataset mean and std to normalizeadsa
def label_func(f): return f.name[:2]

def create_dataloader(size, bs, resize_mode):
    return DataBlock(blocks = (ImageBlock, CategoryBlock),
                    get_items = get_image_files,
                    get_y     = label_func,
                    splitter  = RandomSplitter(),
                    item_tfms = Resize(size, method=resize_mode),
                    batch_tfms = Normalize.from_stats(*bee_wing_stats)
           ).dataloaders(path, bs=bs, num_workers=num_cpus(), pin_memory=True).to('mps')

def create_learner(dls, model_path, model_architecture):

    cbfs = [
            #ShowGraphCallback,
            ReduceLROnPlateau(monitor='valid_loss', min_delta=0.01, patience=2),
                ]
    learn = vision_learner(dls, model_architecture, pretrained=True, cbs=cbfs, metrics=accuracy)
    learn.model_dir = '.'

    if os.path.exists(str(model_path) + '.pth'):
        learn.load(model_path, with_opt=True)
        print(f"Loaded pre-trained weights from {model_path}")
    return learn

prog_list = [{'size': 128, 'bs': 256, 'epochs': 5, 'lr': 1e-3, 'unfreeze': False, 'resize_mode': 'squish'},
             {'size': 256, 'bs': 128, 'epochs': 10, 'lr': 1e-4, 'unfreeze': False, 'resize_mode': 'squish'},
             {'size': 312, 'bs': 64, 'epochs': 15, 'lr': 1e-4, 'unfreeze': True, 'resize_mode': 'squish'},
             {'size': 448, 'bs': 32, 'epochs': 1, 'lr': 1e-5, 'unfreeze': True, 'resize_mode': 'squish'},
             ]

model_path = Path("../models/prog_resnet152")

for idx, item in enumerate(prog_list):
    print(f"image size: {item['size']} batch size: {item['bs']} resize type: {item['resize_mode']} unfreeze: {item['unfreeze']}")

    dls = create_dataloader(item['size'], item['bs'], item['resize_mode'])
    learn = create_learner(dls, model_path, resnet152)

    if item['unfreeze']:
        learn.unfreeze()
    else:
        learn.freeze()

    learn.fit_one_cycle(item['epochs'], item['lr'])

    learn.save(str(model_path) + '_new', with_opt=True)

In [None]:
interp = ClassificationInterpretation.from_learner(learn)

In [None]:
interp.plot_confusion_matrix(figsize=(12,12), dpi=60)

In [None]:
interp.most_confused (min_val=1)

## Evaluate model on test data which was not used during trainnig

In [None]:
import pandas as pd
from fastai.vision.all import *
from pathlib import Path

path = Path("../images/test")

# Get list of image paths
image_paths = list(path.glob('*.png')) # Adjust as needed if you have images in other formats


batch_size = 512  # Adjust according to your memory availability
num_batches = len(image_paths) // batch_size + (len(image_paths) % batch_size != 0)

print(len(image_paths), num_batches)

data = {
    'Filename': [],
    'True Country': [],
    'Predicted Country': [],
    'Probability': [],
}

for i in range(num_batches):
    batch_paths = image_paths[i*batch_size:(i+1)*batch_size]
    # Create a test DataLoader
    dls = learn.dls.test_dl(batch_paths)
    # Get predictions
    preds, _ = learn.get_preds(dl=dls)
    # Get class indices
    pred_class_indices = preds.argmax(dim=-1)
    # Get class names
    pred_class_names = [learn.dls.vocab[i] for i in pred_class_indices]
    # Get probabilities
    pred_probs = preds.max(dim=-1)[0]
    # Prepare data for DataFrame
    data['Filename'].extend([p.name for p in batch_paths])
    data['True Country'].extend([p.name[:2] for p in batch_paths])
    data['Predicted Country'].extend(pred_class_names)
    data['Probability'].extend(pred_probs.tolist())

# Convert dictionary to DataFrame and save to CSV
df = pd.DataFrame(data)
df.to_csv('prediction_results.csv', index=False)

In [None]:
# assuming df is your DataFrame
cross_tab = pd.crosstab(df['True Country'], df['Predicted Country'])

print(cross_tab)