In [None]:
from fastai.vision.all import *
from fastai.data.all import *

In [None]:
import torch

In [None]:
print(torch.cuda.device(0))
print(torch.cuda.get_device_name(0))

In [None]:
from pathlib import Path
input_path = Path('/storage/iwildcam-2020-fgvc7')
my_data_path = Path('/storage/my-iwildcam2020-data')


In [None]:
def get_verified_image_files(path):
    
    if Path(my_data_path/'image_paths.pkl').exists():
        with open(my_data_path/'image_paths.pkl', 'rb') as f:
            files = pickle.load(f)
    else:
        files = get_image_files(path)
        with open(my_data_path/'image_paths.pkl', 'wb') as f:
            pickle.dump(files, f)
        
    blacklist = []
    
    if (path/'failed_imgs_lst.pkl').exists():
        with open(path/'failed_imgs_lst.pkl', 'rb') as f:
            blacklist = pickle.load(f)
    else:
        blacklist = verify_images(files)
        with open(path/'failed_imgs_lst.pkl', 'wb') as f:
            pickle.dump(blacklist, f)


    return list(set(files).difference(blacklist))


In [None]:
def get_annotations_iwildcam(fname, prefix=None):
    "Open a COCO style json in `fname` and returns the lists of filenames (with maybe `prefix`) and labelled bboxes."
    
    annot_dict = json.load(open(fname))
    id2images, id2cats = {}, collections.defaultdict(list)
    classes = {o['id']:o['name'] for o in annot_dict['categories']}
    for o in annot_dict['annotations']:
        id2cats[o['image_id']].append(classes[o['category_id']])
        
    id2images = {o['id']:o['file_name'] for o in annot_dict['images']}
    ids = list(id2images.keys())
    return [id2images[k] for k in ids], [id2cats[k] for k in ids]


In [None]:
images, lbls = get_annotations_iwildcam(input_path/'iwildcam2020_train_annotations.json')
img2lbls = dict(zip(images, lbls))

In [None]:
def get_label_counts(img_labels):
    label_counts = {}
    
    lbls_list = [l[0] for l in img_labels]
    classes = list(set(lbls_list))

    for clss in classes:
        imgs_class = [i for i, v in enumerate(lbls_list) if v == clss]
        label_counts[clss] = len(imgs_class)
        
    return label_counts

In [None]:
def get_freq_label_images(path, min_count=2):
    image_files = get_verified_image_files(path)
    freq_image_files = []
    images, lbls = get_annotations_iwildcam(input_path/'iwildcam2020_train_annotations.json')
    img2lbls = dict(zip(images, lbls))
    
    label_counts = get_label_counts(lbls)
    
    for img in image_files:
        lbl = img2lbls[img.name][0]
        if label_counts[lbl] >= min_count:
            freq_image_files.append(img)
    
    return freq_image_files


In [None]:
def make_train_df(annotation_file, outputs_path, min_cat_count=2):
    images, lbls = get_annotations_iwildcam(annotation_file)
    img2lbls = dict(zip(images, lbls))

    label_counts = get_label_counts(lbls)
    if Path(outputs_path/'image_paths.pkl').exists():
        with open(outputs_path/'image_paths.pkl', 'rb') as f:
            image_files = pickle.load(f)
    else:
        image_files = get_image_files(path)
        with open(outputs_path/'image_paths.pkl', 'wb') as f:
            pickle.dump(files, f)

    if (outputs_path/'failed_imgs_lst.pkl').exists():
        with open(outputs_path/'failed_imgs_lst.pkl', 'rb') as f:
            blacklist = pickle.load(f)
    else:
        blacklist = verify_images(image_files)
        with open(outputs_path/'failed_imgs_lst.pkl', 'wb') as f:
            pickle.dump(blacklist, f)

    image_files = set(image_files).difference(blacklist)
    
    rare_label_images = []
    for img in image_files:
        lbl = img2lbls[img.name][0]
        if label_counts[lbl] < min_cat_count:
            rare_label_images.append(img)
            
    image_files = image_files.difference(rare_label_images)
    
    train_dict = {'file': [], 'category': []}
    for img in image_files:
        train_dict['file'].append(img.name)
        train_dict['category'].append(img2lbls[img.name][0])
        
    return pd.DataFrame(train_dict)
    

    

In [None]:
train_df = make_train_df(input_path/'iwildcam2020_train_annotations.json',
                        my_data_path, min_cat_count=5)

In [None]:
train_df.head()

In [None]:
wildlife = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_x=ColReader(0, pref=input_path/"train"),
    get_y=ColReader(1),
    item_tfms=Resize(128), 
    splitter=RandomSplitter())

In [None]:
dls = wildlife.dataloaders(train_df)

In [None]:
dls.show_batch(max_n=20)

In [None]:
learn = cnn_learner(dls, resnet18, metrics=error_rate)

In [None]:
learn.fine_tune(1)

In [None]:
learn.export()

In [None]:
path = Path()
learn = load_learner(path/'export.pkl')

In [None]:
interp = ClassificationInterpretation.from_learner(learn)

In [None]:
interp.plot_confusion_matrix()

In [None]:
interp.plot_top_losses(5, nrows=1)

In [None]:
cleaner = ImageClassifierCleaner(learn)
cleaner