# Train MobileNet classifier using Weight Imprinting

*This notebook isn't going to work because I don't have the image data to train the model. - PC*

## Params

In [None]:
data_folder = '/home/pi/dataset/ttt-boxes-smalldb/'
test_ratio = 0.1

output_basename = 'ttt-boxes'
output_model = f'{output_basename}.tflite'
output_labelmap = f'{output_basename}.txt'

In [None]:
pretrained_model_path = '/home/pi/models/mobilenet_v1_1.0_224_l2norm_quant_edgetpu.tflite'
keep_classes = False
input_shape = (224, 224)

## Load pre-trained model

In [None]:
from edgetpu.learn.imprinting.engine import ImprintingEngine

train_engine = ImprintingEngine(pretrained_model_path, keep_classes=keep_classes)

## Load train/test data

In [None]:
import os

train_set, test_set = {}, {}
labels_map = {}

ci = 0

for category in os.listdir(data_folder):
    category_dir = os.path.join(data_folder, category)
    if not os.path.isdir(category_dir):
        continue

    images = [
        os.path.join(category_dir, f) 
        for f in os.listdir(category_dir)
        if os.path.isfile(os.path.join(category_dir, f))
    ]

    k = max(int(test_ratio * len(images)), 1)

    test_set[category] = images[:k]
    train_set[category] = images[k:]
    
    labels_map[ci] = category
    ci += 1

for c in train_set.keys():
    print(f'Label {c}: train imgs {len(train_set[c])} - test imgs {len(test_set[c])}')

In [None]:
import numpy as np

from PIL import Image


def prepare_image(image_list, input_shape):
    ret = []

    for filename in image_list:
        with Image.open(filename) as img:
            img = img.convert('RGB')
            img = img.resize(input_shape, Image.NEAREST)
            ret.append(np.asarray(img).flatten())
    return np.array(ret)

print('Processing train images...')
train_data = [prepare_image(imgs, input_shape) for imgs in train_set.values()]
print('Done!')

## Train model

In [None]:
print('Start training...')
train_engine.train_all(train_data)
print('Done!')

## Save trained model

In [None]:
train_engine.save_model(output_model)

with open(output_labelmap, 'w') as f:
    for label_id, label in labels_map.items():
        f.write(f'{label_id} {label}\n')

## Evaluate our model

In [None]:
from edgetpu.classification.engine import ClassificationEngine

test_engine = ClassificationEngine(output_model)

total = 0
nb_images = 0


for category, image_list in test_set.items():   
    correct = 0
    
    for img_name in image_list:
        img = Image.open(os.path.join(data_folder, category, img_name))
        result = test_engine.classify_with_image(img, threshold=0.1, top_k=1)[0]
        
        if labels_map[result[0]] == category:
            correct += 1

    print(f'Evaluating category "{category}": {correct}/{len(image_list)}')
    
    total += correct
    nb_images += len(image_list)
            
print(f'Total {total}/{nb_images}')