In [None]:
!pip install timm
!pip install torch-summary
!pip uninstall -y pillow
!pip install pillow-simd

In [None]:
import json 
import collections
from tqdm import tqdm
import pickle

from PIL import Image
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import timm
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as F

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

## Dataset Preprocessing

In [None]:
# Read files having for training annotation data
with open('../input/iwildcam2021-fgvc8/metadata/iwildcam2021_megadetector_results.json', encoding='utf-8') as json_file:
    detections = json.load(json_file)

with open('../input/iwildcam2021-fgvc8/metadata/iwildcam2021_test_information.json', encoding='utf-8') as json_file:
    test_anno = json.load(json_file)

conf_thresh = 0.7
quick_lookup_detections = collections.defaultdict(lambda: [])

for detection in tqdm(detections['images']):
    all_dets = []
    for det in detection['detections']:
        if det['conf'] < conf_thresh or det['category'] != '1':
            continue        
        all_dets.append(det['bbox'])
    quick_lookup_detections[detection['id']] = all_dets    

test_detections = collections.defaultdict(lambda: {})

for t_anno in tqdm(test_anno['images']):
    test_detections[t_anno['seq_id']][t_anno['file_name']] = quick_lookup_detections[t_anno['file_name'][:-4]]

## Model

In [None]:
class Cnn_model(nn.Module):
    def __init__(self, backbone, out_dim, pretrained=False):
        super(Cnn_model, self).__init__()
        self.bnet = timm.create_model(backbone, pretrained=pretrained)
        in_ch = self.bnet.classifier.in_features
        self.myfc = nn.Linear(in_ch, out_dim)
        self.bnet.classifier = nn.Identity()

    def forward(self, x):
        x = self.bnet(x)
        x = self.myfc(x)
        return x

## Dataloader pipelines

In [None]:
size = 456

# Custom transform
class SquarePad:
    def __call__(self, image):
        w, h = image.size
        max_wh = np.max([w, h])
        hp = int((max_wh - w) / 2)
        vp = int((max_wh - h) / 2)
        return F.pad(image,(hp, vp, hp, vp),0,'symmetric')

transform = transforms.Compose([
        SquarePad(),
        transforms.Resize((size,size)),
#         transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.35318278, 0.35319862, 0.35318562],
                                 std = [0.21479782, 0.21479183, 0.21479116])
    ])

def get_crop_area(bbox, image_size):
    x1, y1,w_box, h_box = bbox
    ymin,xmin,ymax, xmax = y1, x1, y1 + h_box, x1 + w_box
    area = (xmin * image_size[0], ymin * image_size[1], 
            xmax * image_size[0], ymax * image_size[1])
    return area

def preprocess_transform(im_name, detections):
    im_path = '../input/iwildcam2021-fgvc8/test/' + im_name
    img = Image.open(im_path)
    res_dets = []
    for detection in detections:
        cropped_img = img.crop(get_crop_area(detection,img.size))
        res_dets.append(transform(cropped_img))
    return res_dets

In [None]:
# Model
# model = torch.load("../input/iwildcam2021-weighted-loss/play_weigh_effnet_b2_ns_8_0.4882.pth").to(device)
# model = torch.load("../input/iwildcam2021-weighted-loss/play_weigh_effnet_b2_ns_8_0.4882.pth",map_location=torch.device('cpu'))
model = Cnn_model(backbone="tf_efficientnet_b5_ns", out_dim=200).to(device) 
checkpoint = torch.load("../input/iwildcam2021-final-weighted-loss/EffnetB5Ns_34_0.7826.tar",map_location=torch.device(device))
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

# Label encoder
file = open('../input/iwildcam2021-final-weighted-loss/target_encoder.pkl','rb')
le = pickle.load(file)

In [None]:
# Visualize the preprocessing

from matplotlib import pyplot as plt

# Test data processing
invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [1/0.21487217, 1/0.21486713, 1/0.21486738]),
                                transforms.Normalize(mean = [-0.36073838, -0.36075481, -0.36074335],
                                                     std = [ 1., 1., 1. ]),
                               ])

def show_img(img):
    plt.figure(figsize=(18,15))
    # unnormalize
    img = invTrans(img)
    npimg = img.cpu().numpy()
    npimg = np.clip(npimg, 0., 1.)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# for idx, (sequence, images) in tqdm(enumerate(test_detections.items())):
#     grid_images = []
#     for image_name, detections in images.items():
#         ips = preprocess_transform(image_name, detections) # get all preprocessed detections for an image
#         if len(ips) == 0:
#             continue

#         grid_images.extend(ips)
#     # show images
#     if len(grid_images) == 0:
#         continue
#     grid_images = torch.stack(grid_images).to(device)
#     show_img(torchvision.utils.make_grid(grid_images))
#     if idx>10:
#         break

In [None]:
samp_submission = pd.read_csv("../input/iwildcam2021-fgvc8/sample_submission.csv")
submission = pd.DataFrame(columns=samp_submission.columns)

In [None]:
def update_freq(output_labels,frq_pred):
    # maintain frequencies
    if len(output_labels) != 0:
        is_present = 0
        for pred in frq_pred:
            if pred[0] == output_labels:
                pred[1] += 1
                is_present = 1
                break
        if not is_present:
            frq_pred.append([output_labels,1])
    return frq_pred

for sequence, images in tqdm(test_detections.items()):
    preds = []
    frq_pred = []
    new_row = {key:0 for key in samp_submission.columns}
    
    for image_name, detections in images.items():
        ips = preprocess_transform(image_name, detections) # get all preprocessed detections for an image
        if len(ips) == 0:
            continue
        
        # perform model predictions
        with torch.no_grad(): 
            ips = torch.stack(ips).to(device)
            outputs = model(ips)
            pred_labels = torch.argmax(outputs, dim=1).tolist()
            output_labels = le.inverse_transform(pred_labels)
        
        frq_pred = update_freq(output_labels.tolist(),frq_pred)
    
    # get final output on the basis of max count and max frequency
    max_pred_freq = max(frq_pred, key=lambda x:x[1])[0] if len(frq_pred) > 0 else []
    max_pred_len = max(frq_pred, key=lambda x:len(x[0]))[0] if len(frq_pred) > 0 else []
    preds = max_pred_freq if len(max_pred_freq) == len(max_pred_len) else max_pred_len
            
    # write to submission file
    new_row['Id'] = sequence
    for pred in preds:
        if pred == 0:
            continue
        new_row[f'Predicted{pred}']+=1
    submission = submission.append(new_row, ignore_index = True)

In [None]:
print(len(submission))
print(len(samp_submission))

In [None]:
submission.to_csv("submission.csv",index=False)