In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
import pprint
from functools import partial
from tqdm.auto import tqdm
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms, models
from sklearn.metrics import roc_auc_score

# Label data we want to keep
## Bird labeling criteria:
* A bird is good if it is not bad
* A bird is bad if:
    * It's flying
    * It's swimming
    * It contains complex branches, even in background (don't need the gan to learn anything this facny)
        * Sometimes bird is sitting on branch, that's OK
        * Background is blurred means its a non issue
    * It's camoflauged
    * There are multiple birds
    * There are objects unrelated to the bird (e.g. human hand)
    * Parts of the bird are missing
        * to be good, need to see Most of head, beak, body, tail
    * The bird is oriented in a non-standard way (e.g. upsidedown)
        * Bird should be approximatley head -> body -> tail (top to bottom)
        * Bird shouldn't be mostly vertical
        * Bird feet should be on a horiztonally flatish surface (e.g. not standing on the side of the image)
        * Bird should be facing forward

# 1. Label the data

In [2]:
label_path = '../../data/first_pass/quality_labels.json'
with open(label_path, 'r') as f:
    labels = json.load(f)
possible_labels = list(labels.keys())

In [3]:
# Cell for reviewing labels
# start_idx = 0
# for i in range(start_idx, 1500):
#     img = Image.open('../../data/first_pass/images/%d.jpg'%i)
#     plt.imshow(img)
#     cls = "greg"
#     for k, v in labels.items():
#         if i in v:
#             cls = k
#     plt.title('%d, %s'%(i, cls))
    
#     plt.show()
#     res = input()
#     if res.lower()=='break': break

In [4]:
# Yay stackoerflow! https://stackoverflow.com/questions/2460177/edit-distance-in-python
def levenshteinDistance(s1, s2):
    if len(s1) > len(s2):
        s1, s2 = s2, s1

    distances = range(len(s1) + 1)
    for i2, c2 in enumerate(s2):
        distances_ = [i2+1]
        for i1, c1 in enumerate(s1):
            if c1 == c2:
                distances_.append(distances[i1])
            else:
                distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
        distances = distances_
    return distances[-1]

In [5]:
# cell for labeling the data. Just look @ images in windows explorer - infinitely easeir

# for i in range(0):
#     if any([i in v for _, v in labels.items()]): continue
#     print(i)
#     res = input()
#     if res=='exit': break
#     labels[min(possible_labels, key = partial(levenshteinDistance, res))].append(i)
# with open(label_path, 'w') as f:
#     json.dump(labels, f)

In [6]:
device = torch.device('cuda')

In [7]:
data_path = '../../data/first_pass/'
image_path = data_path + 'images/'

In [8]:
def collate_skip_none(batch):
    def none_check(item):
        if hasattr(item, '__iter__'):
            return all([i is not None for i in item])
        return item is not None
    batch = list(filter(none_check, batch))
    return torch.utils.data.dataloader.default_collate(batch)

In [9]:
def train_for_epoch(model, loader, loss_fn, opt):
    model.train()
    losses = []
    pbar = tqdm(total = len(loader), leave=False)
    for batch_imgs, batch_labels in loader:
        opt.zero_grad()
        batch_imgs, batch_labels = batch_imgs.to(device), batch_labels.to(device)
        pred = model(batch_imgs)
        loss = loss_fn(pred, batch_labels)
        loss.backward()
        opt.step()
        losses.append(loss.item())
        pbar.update(1)
    pbar.close()
    return np.mean(losses)

def test_for_epoch(model, loader, loss_fn, n_classes, thresh=.5):
    model.eval()
    losses = []
    preds = []
    labels = []
    with torch.no_grad():
        pbar = tqdm(total = len(loader), leave=False)
        for batch_imgs, batch_labels in loader:
            batch_imgs, batch_labels = batch_imgs.to(device), batch_labels.to(device)
            batch_preds = model(batch_imgs)
            loss_val = loss_fn(batch_preds, batch_labels)
            losses.append(loss_val.item())
            preds.append(torch.softmax(batch_preds, dim=-1).data.cpu().numpy())
            labels.append(np.eye(n_classes)[batch_labels.data.cpu().numpy()])
            pbar.update(1)
        pbar.close()    
    
    preds = np.vstack(preds)
    labels = np.vstack(labels)
    if n_classes==2:
        pred_class = (preds[:,1] > thresh).astype(np.int)
    else:
        pred_class = np.argmax(preds, axis=1)
    label_class = np.argmax(labels, axis=1)
    
    metrics = {}
    for i in range(n_classes):
        this_auc = roc_auc_score((labels[:,i]==1).astype(np.int), preds[:,i])
        this_acc = np.mean(pred_class[label_class==i]==label_class[label_class==i])
        metrics[i] = {'auc' : this_auc, 'conditional_acc': this_acc}
    
    metric_df = pd.DataFrame.from_dict(metrics, orient='index').reset_index(drop=True)
    metric_df.rename({'index' : 'class_id', 0: "auc", 1: "conditional_acc"}, axis=1, inplace=True)
    return np.mean(losses), metric_df

In [10]:
train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
])
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
])
inverse_normalize_transform = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
   std=[1/0.229, 1/0.224, 1/0.225])

# 2. Look at the data

In [11]:
labels.keys()

dict_keys(['good', 'flying', 'swimming', 'branches', 'camoflauge', 'multiple', 'objects', 'orientation', 'missing'])

In [12]:
label_key_mapper = {k : 1 if k=="multiple" else 0 for i, k in enumerate(labels.keys())}
num_classes = len(set(list(label_key_mapper.values())))
n_labeled = sum([len(v) for _, v in labels.items()])
img_id_to_class = {i : label_key_mapper[k] for k, v in labels.items() for i in v}

In [13]:
pd.DataFrame.from_dict(img_id_to_class, orient='index')[0].value_counts()

0    1025
1     123
Name: 0, dtype: int64

In [14]:
train_dict = {}
test_dict = {}
train_counter = 0
test_counter = 0
for i in img_id_to_class.keys():
    if np.random.rand() < .8:
        train_dict[train_counter] = {'img' : i, 'class': img_id_to_class[i]}
        train_counter+=1
    else:
        test_dict[test_counter] = {'img' : i, 'class': img_id_to_class[i]}
        test_counter+=1

In [15]:
pd.Series({k: v['class'] for k, v in train_dict.items()}).value_counts(normalize=True)

0    0.894019
1    0.105981
dtype: float64

In [16]:
pd.Series({k: v['class'] for k, v in test_dict.items()}).value_counts(normalize=True)

0    0.887179
1    0.112821
dtype: float64

# Setup the daataloaders

In [17]:
class LabelQualityDataset(Dataset):
    def __init__(self, data_dictionary, load_path, transforms=None):
        self.data = data_dictionary
        self.load_path = load_path
        self.transforms = transforms
        
    def __getitem__(self, idx):
        try:
            img = Image.open(image_path + '%d.jpg'%self.data[idx]['img'])
        except FileNotFoundError:
            return None, None
        label = self.data[idx]['class']
        if self.transforms is not None:
            img = self.transforms(img)
        return img, label
        
    def __len__(self):
        return len(self.data)

In [18]:
train_dataset = LabelQualityDataset(train_dict, image_path, train_transforms)
test_dataset = LabelQualityDataset(test_dict, image_path, test_transforms)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False,)

In [None]:
n_classes = 2

In [None]:
class_count_show = 10
shown_dict = {i : 0 for i in range(n_classes)}
for batch_imgs, batch_labels in train_dataloader:
    if not np.sum(batch_labels.data.cpu().numpy()): continue
    for img, l in zip(batch_imgs, batch_labels):
        if shown_dict[l.item()]>=class_count_show: continue
        img = inverse_normalize_transform(img).permute(1,2,0)
        plt.imshow(img.data.cpu().numpy())
        plt.title(l.item())
        plt.show()
        shown_dict[l.item()]+=1
    if all([v>=class_count_show for _, v in shown_dict.items()]): break

# 4. Build a model

In [None]:
resnet = models.resnet18(pretrained=True)
resnet.fc = nn.Linear(resnet.fc.in_features, n_classes)
resnet.to(device)
loss_fn = nn.CrossEntropyLoss(weight = torch.tensor([1.,2.]).to(device))
opt = optim.Adam(resnet.fc.parameters())
opt_full = optim.Adam(resnet.parameters(), lr=10**-5)

* Different models used different just_head_steps vs not that

In [None]:
just_head_steps = 2

In [None]:
nb_epochs = 4
for i in range(nb_epochs):
    if i < just_head_steps:
        train_loss = train_for_epoch(resnet, train_dataloader, loss_fn, opt)
    else:
        train_loss = train_for_epoch(resnet, train_dataloader, loss_fn, opt_full)
    print('mean train loss: %.4f'%train_loss)
    test_loss, metric_df = test_for_epoch(resnet, test_dataloader, loss_fn, n_classes, thresh=.5)
    test_auc = metric_df.iloc[1]['auc']
    display(metric_df)
    print('mean test loss: %.4f'%test_loss)
    print('='*100)
#     if test_auc > .9:
#         break


In [None]:
# resnet.load_state_dict(torch.load('missing_predictor.pt'))
# display(test_for_epoch(resnet, test_dataloader, loss_fn, 2)[1])

In [None]:
torch.save(resnet.state_dict(), 'multiple_predictor.pt')

# 5. Use the models!

In [18]:
model_path = '../../models/final_filter_models/'

In [19]:
def load_model(model_name):
    resnet = models.resnet18()
    resnet.fc = nn.Linear(resnet.fc.in_features, 2)
    resnet.to(device)
    resnet.load_state_dict(torch.load(model_path + model_name))
    return resnet

In [20]:
class UnlabeledDataset(Dataset):
    def __init__(self, max_val, load_path, transforms):
        self.max_val = max_val
        self.load_path = load_path
        self.transforms = transforms
        
    def __getitem__(self, idx):
        try:
            img = Image.open(image_path + '%d.jpg'%idx)
        except FileNotFoundError:
            return None, None
        
        img = self.transforms(img)
        return img, idx
    
    def __len__(self):
        return self.max_val

In [21]:
max_val = max([int(img_id[:-4]) for img_id in os.listdir(image_path)])
unlabeled_dataset = UnlabeledDataset(max_val, image_path, test_transforms)
unlabeled_dataloader = DataLoader(unlabeled_dataset, batch_size=32, collate_fn = collate_skip_none, shuffle=True)

In [22]:
# Cell for evaluating picking evaluating threshholds

# with torch.no_grad():
#     pbar = tqdm(total = len(unlabeled_dataloader), leave=False)
#     for batch, _ in unlabeled_dataloader:
#         batch = batch.to(device)
#         pred = torch.softmax(resnet(batch), dim=-1).data.cpu().numpy()
#         show_indexer = pred[:,1]>pred_thresh
#         for p, img in zip(pred[show_indexer,1], batch[show_indexer]):
# #             if p > .4: continue
#             print(p)
#             img = inverse_normalize_transform(img).permute(1,2,0).data.cpu().numpy()
#             plt.imshow(img)
#             plt.show()
#         pbar.update(1)
#     pbar.close()

In [23]:
pred_threshes = {
    "branches_predictor.pt" : .6,
    "camo_predictor.pt" : .5,
    "flying_predictor.pt" : .4,
    "missing_predictor.pt" : .4,
    "multiple_predictor.pt" : .65,
    "object_predictor.pt" : .5,
    "orientation_predictor.pt" : .3, 
    "swimming_predictor.pt" : .35,
}

In [28]:
should_skip = {}
outer_pbar = tqdm(total = len(pred_threshes), leave=False)
for model_name, thresh in pred_threshes.items():
    model = load_model(model_name)
    with torch.no_grad():
        inner_pbar = tqdm(total = len(unlabeled_dataloader), leave=False)
        predicted_imgs = []
        for batch, idx in unlabeled_dataloader:
            batch = batch.to(device)
            pred = torch.softmax(model(batch), dim=-1).data.cpu().numpy()[:,1]
            predicted_indicies = idx[pred > thresh].data.cpu().numpy().ravel()
            predicted_imgs.append(predicted_indicies)
            inner_pbar.update(1)
        should_skip[model_name] = np.hstack(predicted_imgs)
        inner_pbar.close()
    outer_pbar.update(1)
outer_pbar.close()

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5340), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5340), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5340), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5340), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5340), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5340), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5340), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5340), HTML(value='')))

In [29]:
[len(v) for _, v in should_skip.items()]

[23792, 2248, 12216, 6425, 3991, 1741, 14290, 8900]

In [39]:
def convert(o):
    if isinstance(o, np.generic): return o.item() 
    elif isinstance(o, np.ndarray): return list(o)
    raise TypeError

In [40]:
with open('should_skip.json', 'w') as f:
    json.dump(should_skip, f, default=convert)

# 6. Move the data around! 

In [19]:
from shutil import copyfile
import itertools

In [20]:
save_dir = '../../data/modeling_256/'
save_img_dir = save_dir + 'images/'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
if not os.path.exists(save_img_dir):
    os.mkdir(save_img_dir)

In [21]:
with open('should_skip.json', 'r') as f:
    should_skip = json.load(f)

In [22]:
dont_copy = set(list(itertools.chain(*[v for _,v in should_skip.items()])))

In [23]:
len(dont_copy)

66938

In [31]:
# Copy the results!
id_counter = 0
pbar = tqdm(total = 170000 - len(dont_copy))
for f in os.listdir(image_path):
    index = int(f[:-4])
    if index in dont_copy:
        continue
    save_img_path = save_img_dir + '%d.jpg'%id_counter
    if os.path.exists(save_img_path):
        id_counter +=1
        continue
    copyfile(image_path + f, save_img_path)
    id_counter +=1
    
    pbar.update(1)

HBox(children=(IntProgress(value=0, max=103062), HTML(value='')))