In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from fastai.vision.all import *
from fastai.data.all import *

In [None]:
import torch

In [None]:
print(torch.cuda.device(0))
print(torch.cuda.get_device_name(0))

In [None]:
from fastai4kaggle.data import *
import json
import random

In [None]:
from pathlib import Path
input_path = Path('/storage/iwildcam-2020-fgvc7')
my_data_path = Path('/storage/my-iwildcam2020-data')
models_path = Path(my_data_path/'resnet50_checkpoints/test/models')
use_previous_model = True
model = resnet50
n_epochs = 1
kaggle_msg = f"resnet50, {n_epochs}"

In [None]:
def get_verified_image_files(path):
    
    if Path(my_data_path/'image_paths.pkl').exists():
        with open(my_data_path/'image_paths.pkl', 'rb') as f:
            files = pickle.load(f)
    else:
        files = get_image_files(path)
        with open(my_data_path/'image_paths.pkl', 'wb') as f:
            pickle.dump(files, f)
        
    blacklist = []
    
    if (path/'failed_imgs_lst.pkl').exists():
        with open(path/'failed_imgs_lst.pkl', 'rb') as f:
            blacklist = pickle.load(f)
    else:
        blacklist = verify_images(files)
        with open(path/'failed_imgs_lst.pkl', 'wb') as f:
            pickle.dump(blacklist, f)


    return list(set(files).difference(blacklist))


In [None]:
def get_annotations_iwildcam(fname, prefix=None):
    "Open a COCO style json in `fname` and returns the lists of filenames (with maybe `prefix`) and labelled bboxes."
    
    annot_dict = json.load(open(fname))
    id2images, id2cats = {}, {}
    classes = {o['id']:o['name'] for o in annot_dict['categories']}
    id2cats = {o['image_id']:o['category_id'] for o in annot_dict['annotations']}        
    id2images = {o['id']:o['file_name'] for o in annot_dict['images']}
                                     
    ids = list(id2images.keys())
    return [id2images[k] for k in ids], [id2cats[k] for k in ids]


In [None]:
# images, lbls = get_annotations_iwildcam(input_path/'iwildcam2020_train_annotations.json')
# img2lbls = dict(zip(images, lbls))

In [None]:
def get_label_counts(img_labels):
    label_counts = {}
    
    lbls_list = img_labels
    classes = list(set(lbls_list))

    for clss in classes:
        imgs_class = [i for i, v in enumerate(lbls_list) if v == clss]
        label_counts[clss] = len(imgs_class)
        
    return label_counts

In [None]:
def get_freq_label_images(path, min_count=2):
    image_files = get_verified_image_files(path)
    freq_image_files = []
    images, lbls = get_annotations_iwildcam(input_path/'iwildcam2020_train_annotations.json')
    img2lbls = dict(zip(images, lbls))
    
    label_counts = get_label_counts(lbls)
    
    for img in image_files:
        lbl = img2lbls[img.name][0]
        if label_counts[lbl] >= min_count:
            freq_image_files.append(img)
    
    return freq_image_files


In [None]:
def make_train_df(annotation_file, outputs_path, min_cat_count=2):
    images, lbls = get_annotations_iwildcam(annotation_file)
    img2lbls = dict(zip(images, lbls))
    
    if not outputs_path.exists():
        outputs_path.mkdir()
        
    label_counts = get_label_counts(lbls)
    if Path(outputs_path/'image_paths.pkl').exists():
        with open(outputs_path/'image_paths.pkl', 'rb') as f:
            image_files = pickle.load(f)
    else:
        image_files = get_image_files(input_path/'train')
        with open(outputs_path/'image_paths.pkl', 'wb') as f:
            pickle.dump(image_files, f)

    if (outputs_path/'failed_imgs_lst.pkl').exists():
        with open(outputs_path/'failed_imgs_lst.pkl', 'rb') as f:
            blacklist = pickle.load(f)
    else:
        blacklist = verify_images(image_files)
        with open(outputs_path/'failed_imgs_lst.pkl', 'wb') as f:
            pickle.dump(blacklist, f)

    image_files = set(image_files).difference(blacklist)
    
    rare_label_images = []
    for img in image_files:
        lbl = img2lbls[img.name]
        if label_counts[lbl] < min_cat_count:
            rare_label_images.append(img)
            
    image_files = image_files.difference(rare_label_images)
    
    train_dict = {'file': [], 'category': []}
    for img in image_files:
        train_dict['file'].append(img.name)
        train_dict['category'].append(img2lbls[img.name])
        
    return pd.DataFrame(train_dict)
    

    

In [None]:
train_df = make_train_df(input_path/'iwildcam2020_train_annotations.json',
                        my_data_path, min_cat_count=2)

In [None]:
len(train_df)

In [None]:
train_df.head()

In [None]:
train_df.category.value_counts()

In [None]:
db = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_x=ColReader(0, pref=input_path/"train"),
    get_y=ColReader(1),
    item_tfms=Resize(800),
    batch_tfms=aug_transforms(size=460),
    splitter=RandomSplitter(valid_pct=0.025, seed=42))

In [None]:
dls = db.dataloaders(train_df, bs=32)

In [None]:
def get_valid_test_images(path, outputs_path):
    image_files = get_image_files(path)
    
    if (outputs_path/'failed_test_imgs_lst.pkl').exists():
        with open(outputs_path/'failed_test_imgs_lst.pkl', 'rb') as f:
            blacklist = pickle.load(f)
    else:
        blacklist = verify_images(image_files)
        with open(outputs_path/'failed_test_imgs_lst.pkl', 'wb') as f:
            pickle.dump(blacklist, f)

    image_files = set(image_files).difference(blacklist)
    
    return list(image_files)

In [None]:
class ForwardTrainEvalCallback(TrainEvalCallback):
    "`Callback` that skips to a given epoch the number of iterations done and properly sets training/eval mode"
    run_valid = False
    def after_create(self): self.learn.n_epoch = 1

    def before_fit(self):
        "Set the iter and epoch counters to 0, put the model and the right device"
        self.learn.epoch,self.learn.loss = 0,tensor(0.)
        self.learn.train_iter,self.learn.pct_train = 0,0.
        if hasattr(self.dls, 'device'): self.model.to(self.dls.device)
        if hasattr(self.model, 'reset'): self.model.reset()

    def after_batch(self):
        "Update the iter counter (in training mode)"
        self.learn.pct_train += 1./(self.n_iter*self.n_epoch)
        self.learn.train_iter += 1

    def before_train(self):
        "Set the model in training mode"
        dd
        self.model.train()
        self.learn.training=True

    def before_validate(self):
        "Set the model in validation mode"
        self.model.eval()
        self.learn.training=False
        
class SkipToEpoch(Callback):
    def __init__(self, s_epoch, n_epoch): 
        self.s_epoch = s_epoch
        
    def before_fit(self):
        "Set the iter and epoch counters, put the model and the right device"

        self.learn.train_iter, self.learn.pct_train = self.skip_to_epoch, self.skip_to_epoch / self.n_epoch


In [None]:
def restart_fine_tune(dls, model, metrics, path, path_to_model, n_epochs):
    learn = cnn_learner(dls, model, metrics=metrics, path=path)    
    learn.load(path_to_model/'model')
    
    lr_min, lr_steep = learn.lr_find()
    learn.fine_tune(n_epochs, lr_steep, cbs=SaveModelCallback(with_opt=True))
    
    return learn

In [None]:
def predict_batch(learn, test_images):
    test_dl = learn.dls.test_dl(test_images)
    preds_batch, _, dec_preds = learn.get_preds(dl=test_dl, with_decoded=True)
    dec_cats = [learn.dls.vocab[dec_pred] for dec_pred in dec_preds]
    
    img_names = [img.stem for img in test_images]
    test_results = pd.DataFrame({'Id': img_names, 'Category': dec_cats})
    
    return test_results

In [None]:
if use_previous_model:
    learn = restart_fine_tune(dls, resnet50, accuracy, my_data_path/'resnet50_checkpoints/test', models_path, n_epochs)
else:
    learn = cnn_learner(dls, resnet50, metrics=accuracy, path=models_path)
    lr_min,lr_steep = learn.lr_find()
    learn.fine_tune(10, 1e-3)

In [None]:
learn.validate()

In [None]:
test_images = get_valid_test_images(input_path/'test', my_data_path)

In [None]:
test_results = predict_batch(learn, test_images)
test_results.to_csv("submission.csv", index=False)

In [None]:
!kaggle competitions submit -c iwildcam-2020-fgvc7 -f "submission.csv" -m {kaggle_msg}

In [None]:
interp = ClassificationInterpretation.from_learner(learn)

In [None]:
interp.plot_top_losses(9, nrows=3, figsize=(20, 20))

In [None]:
learn.show_results()

In [None]:
# path = Path()
# learn = load_learner(path/'export.pkl')

In [None]:
img_idx = 2
pred,pred_idx,probs = learn.predict(test_images[img_idx])
img = PILImage.create(test_images[img_idx])

img.show()
print(f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}')

In [None]:
test_results.head()

In [None]:
annot_dict = json.load(open('/storage/iwildcam-2020-fgvc7/iwildcam2020_train_annotations.json'))
annot_dict['categories']