In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

from fastai.conv_learner import *
from fastai.dataset import *
from fastai.models.resnet import vgg_resnet50
from fastai.models.unet import *
from torch import nn
from PIL import Image as PILImage

import json

torch.backends.cudnn.benchmark=True

torch.cuda.is_available()

True

# Data

## Definition

In [2]:
LABEL_MAP = {
0: "Nucleoplasm" ,
1: "Nuclear membrane"   ,
2: "Nucleoli"   ,
3: "Nucleoli fibrillar center",   
4: "Nuclear speckles"   ,
5: "Nuclear bodies"   ,
6: "Endoplasmic reticulum"   ,
7: "Golgi apparatus"  ,
8: "Peroxisomes"   ,
9:  "Endosomes"   ,
10: "Lysosomes"   ,
11: "Intermediate filaments"  , 
12: "Actin filaments"   ,
13: "Focal adhesion sites"  ,
14: "Microtubules"   ,
15: "Microtubule ends"   ,
16: "Cytokinetic bridge"   ,
17: "Mitotic spindle"  ,
18: "Microtubule organizing center",  
19: "Centrosome",
20: "Lipid droplets"   ,
21: "Plasma membrane"  ,
22: "Cell junctions"   ,
23: "Mitochondria"   ,
24: "Aggresome"   ,
25: "Cytosol" ,
26: "Cytoplasmic bodies",
27: "Rods & rings"}

In [3]:
class MultiBandMultiLabelDataset(Dataset):
    BANDS_NAMES = ['_red.png','_green.png','_blue.png','_yellow.png']
    
    def __len__(self):
        return len(self.images_df)
    
    def __init__(self, images_df, 
                 base_path, 
                 image_transform, 
                 augmentator=None,
                 train_mode=True    
                ):
        if not isinstance(base_path, pathlib.Path):
            base_path = pathlib.Path(base_path)
            
        self.images_df = images_df.copy()
        self.image_transform = image_transform
        self.augmentator = augmentator
        self.images_df.Id = self.images_df.Id.apply(lambda x: base_path / x)
        self.mlb = MultiLabelBinarizer(classes=list(LABEL_MAP.keys()))
        self.train_mode = train_mode

    def __getitem__(self, index):
        y = None
        X = self._load_multiband_image(index)
        if self.train_mode:
            y = self._load_multilabel_target(index)
        
        # augmentator can be for instance imgaug augmentation object
        if self.augmentator is not None:
            X = self.augmentator(X)
        X = self.image_transform(X)
            
        return X, y 
        
    def _load_multiband_image(self, index):
        row = self.images_df.iloc[index]
        image_bands = []
        for band_name in self.BANDS_NAMES:
            p = str(row.Id.absolute()) + band_name
            pil_channel = Image.open(p)
            image_bands.append(pil_channel)
            
        # lets pretend its a RBGA image to support 4 channels
        band4image = Image.merge('RGBA', bands=image_bands)
        return band4image
    
    def _load_multilabel_target(self, index):
        return list(map(int, self.images_df.iloc[index].Target.split(' ')))
    
    def collate_func(self, batch):
        labels = None
        images = [x[0] for x in batch]
        
        if self.train_mode:
            labels = [x[1] for x in batch]
            labels_one_hot  = self.mlb.fit_transform(labels)
            labels = torch.FloatTensor(labels_one_hot)
        
        return torch.stack(images)[:,:4,:,:], labels

## Instances

In [4]:
import pathlib
from sklearn.model_selection import train_test_split
from torchvision.transforms import transforms
from sklearn.preprocessing import MultiLabelBinarizer
from torch.utils.data import Dataset, DataLoader

In [5]:
PATH = pathlib.Path('data')

In [6]:
sz = 128
bs = 8

In [7]:
PATH_TO_IMAGES = 'data/train/'
PATH_TO_TEST_IMAGES = 'data/test/'
PATH_TO_META = 'data/train.csv'
SAMPLE_SUBMIT = 'data/sample_submission.csv'

train_df = pd.read_csv(PATH/'train.csv')
df_train, df_test  = train_test_split(train_df, test_size=0.2, random_state=33)
submit_df = pd.read_csv(PATH/'sample_submission.csv')

image_transform = transforms.Compose([
            transforms.Resize(sz),
            transforms.ToTensor(),])

train_set = MultiBandMultiLabelDataset(df_train, base_path=PATH_TO_IMAGES, image_transform=image_transform)
val_set = MultiBandMultiLabelDataset(df_test, base_path=PATH_TO_IMAGES, image_transform=image_transform)
submit_set = MultiBandMultiLabelDataset(submit_df, base_path=PATH_TO_TEST_IMAGES, train_mode=False, image_transform=image_transform)

train_load = DataLoader(train_set, collate_fn=train_set.collate_func, batch_size=bs, num_workers=6)
test_load = DataLoader(val_set, collate_fn=val_set.collate_func, batch_size=bs, num_workers=6)
submission_load = DataLoader(submit_set, collate_fn=submit_set.collate_func, batch_size=bs, num_workers=6)

In [8]:
def change_size(sz=128, bs=64):
    train_df = pd.read_csv(PATH/'train.csv')
    df_train, df_test  = train_test_split(train_df, test_size=0.2, random_state=33)
    submit_df = pd.read_csv(PATH/'sample_submission.csv')

    image_transform = transforms.Compose([
            transforms.Resize(sz),
            transforms.ToTensor(),])

    train_set = MultiBandMultiLabelDataset(df_train, base_path=PATH_TO_IMAGES, image_transform=image_transform)
    val_set = MultiBandMultiLabelDataset(df_test, base_path=PATH_TO_IMAGES, image_transform=image_transform)
    submit_set = MultiBandMultiLabelDataset(submit_df, base_path=PATH_TO_TEST_IMAGES, train_mode=False, image_transform=image_transform)

    train_load = DataLoader(train_set, collate_fn=train_set.collate_func, batch_size=bs, num_workers=6)
    test_load = DataLoader(val_set, collate_fn=val_set.collate_func, batch_size=bs, num_workers=6)
    submission_load = DataLoader(submit_set, collate_fn=submit_set.collate_func, batch_size=bs, num_workers=6)
    
    md = ModelData.from_dls(PATH, train_load, test_load, submission_load)
    return md

# Model

In [9]:
def get_model(n_classes, image_channels=4):
    model = resnet50(pretrained=True)
    for p in model.parameters():
        p.requires_grad = True
    inft = model.fc.in_features
    model.fc = nn.Linear(in_features=inft, out_features=n_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    model.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
    return model 

In [10]:
class CustomModel():
    def __init__(self,model,name='res50_4'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute=False):
        return [children(self.model)]

In [11]:
class CustomLearner(Learner):
    def get_layer_groups(self, precompute=False): 
        return self.models.get_layer_groups()
    
    def get_layer_opt(self, lrs, wds):
        return LayerOptimizer(self.opt_fn, self.get_layer_groups(), lrs, wds)

In [12]:
def get_learner(md):
    md = ModelData.from_dls(PATH, train_load, test_load, submission_load)
    res50_4 = to_gpu(get_model(28,4))
    models = CustomModel(res50_4)
    learn = CustomLearner(md, models)
    
    learn.crit = nn.BCEWithLogitsLoss()
    learn.opt_fn = optim.Adam
    learn.metrics=[my_f1_score]
    return learn

In [13]:
from sklearn.metrics import f1_score

def my_f1_score(y_pred, y_true):
    threshold = 0.2
    return f1_score(y_pred>0.2, y_true, average='micro')

# Test Model

In [14]:
md = ModelData.from_dls(PATH, train_load, test_load, submission_load)
res50_4 = to_gpu(get_model(28,4))

In [15]:
models = CustomModel(res50_4)

In [16]:
learn = CustomLearner(md, models)

In [17]:
learn.crit = nn.BCEWithLogitsLoss()
learn.opt_fn = optim.Adam
learn.metrics=[my_f1_score]

In [18]:
# %%time
# learn.lr_find(1e-10, 1e-3)
# learn.sched.plot()

# Train Model

## Small set

In [35]:
lr = 5e-6
md = change_size(64,128)
learn = get_learner(md)

In [36]:
learn.fit(lr,1,wds=1e-7,cycle_len=10,
          use_clr_beta=(10,10, 0.85, 0.9), 
          use_wd_sched=True, best_save_name='res50_FA64')

HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))

                                                               

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', avera

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', avera

epoch      trn_loss   val_loss   my_f1_score 
    0      0.160154   0.163706   0.171548  
                                                               

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', avera

    1      0.153417   0.158709   0.232408  
                                                               

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


    2      0.151063   0.155864   0.262241  
                                                               

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


    3      0.148282   0.152978   0.281016  


  'recall', 'true', average, warn_for)


                                                               

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


    4      0.144326   0.149865   0.318132  
                                                               

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


    5      0.139171   0.147948   0.351024  
                                                               

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


    6      0.133927   0.147902   0.379546  
    7      0.128878   0.148149   0.385471                      
                                                               

  'recall', 'true', average, warn_for)


    8      0.125555   0.146715   0.365573  
    9      0.124584   0.145492   0.367277                      


[0.14549167727810283, 0.36727740453050894]

In [37]:
learn.load('res50_FA64')

In [38]:
sz = 128
bs = 32
md = change_size(sz,bs)
learn = get_learner(md)
learn.set_data(md)

In [39]:
learn.fit(lr,1,wds=1e-7,cycle_len=20,
          use_clr_beta=(10,10, 0.85, 0.9), 
          use_wd_sched=True, best_save_name='res50_FA128')

HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))

                                                             

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


epoch      trn_loss   val_loss   my_f1_score 
    0      0.168632   0.168653   0.090942  
    1      0.162363   0.162927   0.274622                    
    2      0.157433   0.159459   0.304355                    
    3      0.154824   0.157833   0.320182                    
    4      0.153108   0.156926   0.329084                    
    5      0.151761   0.156779   0.333768                    
    6      0.150349   0.15624    0.333294                    
    7      0.148696   0.155443   0.333323                    
    8      0.146982   0.154088   0.327398                    
    9      0.144177   0.152413   0.323695                    
    10     0.141312   0.151328   0.329507                    
    11     0.138354   0.150346   0.339932                    
    12     0.135493   0.149937   0.34903                     
    13     0.132715   0.149923   0.354959                    
    14     0.130168   0.150026   0.359691                    
    15     0.127896   0.149833   0.358111 

[0.14967848159139038, 0.3773001058012267]

In [40]:
learn.load('res50_FA128')

In [41]:
sz = 256
bs = 16
md = change_size(sz,bs)
learn = get_learner(md)
learn.set_data(md)

In [42]:
learn.fit(lr,1,wds=1e-7,cycle_len=20,
          use_clr_beta=(10,10, 0.85, 0.9), 
          use_wd_sched=True, best_save_name='res50_FA256')

HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))

                                                               

  'recall', 'true', average, warn_for)


epoch      trn_loss   val_loss   my_f1_score 
    0      0.160957   0.162878   0.266066  
    1      0.153847   0.156229   0.299472                      
    2      0.150704   0.153812   0.327173                      
    3      0.148404   0.152319   0.350566                      
    4      0.14599    0.150846   0.375971                      
    5      0.142992   0.14819    0.386972                      
    6      0.139281   0.144314   0.403665                      
    7      0.134949   0.140658   0.42176                       
    8      0.130386   0.137816   0.444275                      
    9      0.12458    0.135552   0.462642                      
    10     0.118952   0.133772   0.478241                      
    11     0.113836   0.132825   0.491969                      
    12     0.109278   0.13171    0.498723                      
    13     0.105098   0.130728   0.507117                      
    14     0.101415   0.129872   0.506144                      
    15     0.0

[0.12621428572354734, 0.5221180184332047]

In [43]:
learn.load('res50_FA256')

In [125]:
sz = 512
bs = 8
md = change_size(sz,bs)
learn = get_learner(md)
learn.set_data(md)

In [45]:
learn.fit(lr,1,wds=1e-7,cycle_len=60,
          use_clr_beta=(10,10, 0.85, 0.9), 
          use_wd_sched=True, best_save_name='res50_FA512')

HBox(children=(IntProgress(value=0, description='Epoch', max=60), HTML(value='')))

                                                               

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', avera

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


epoch      trn_loss   val_loss   my_f1_score 
    0      0.162408   0.165299   0.196964  
                                                               

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


    1      0.153293   0.158582   0.263599  
                                                               

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


    2      0.148921   0.154021   0.28057   
                                                               

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


    3      0.145427   0.150492   0.287302  
                                                               

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


    4      0.141756   0.146725   0.299962  
                                                               

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


    5      0.137739   0.142723   0.331596  
                                                               

  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


    6      0.132941   0.138491   0.364157  
                                                               

  'recall', 'true', average, warn_for)


    7      0.127117   0.133928   0.395577  
                                                               

  'recall', 'true', average, warn_for)


    8      0.121186   0.128818   0.43375   
                                                               

  'recall', 'true', average, warn_for)


    9      0.115672   0.123991   0.466444  
    10     0.11091    0.120305   0.49589                       
    11     0.106779   0.117502   0.514246                      
    12     0.103008   0.115354   0.528858                      
    13     0.099546   0.113569   0.537859                       
    14     0.096228   0.112206   0.547333                       
    15     0.093133   0.111356   0.551448                       
                                                                

  'recall', 'true', average, warn_for)


    16     0.090136   0.110646   0.557276  
                                                                

  'recall', 'true', average, warn_for)


    17     0.087148   0.110224   0.559137  
                                                                

  'recall', 'true', average, warn_for)


    18     0.084108   0.110027   0.564039  
                                                                

  'recall', 'true', average, warn_for)


    19     0.081028   0.110091   0.565447  
    20     0.077855   0.110205   0.567279                       
    21     0.074686   0.110661   0.56828                        
    22     0.071333   0.111022   0.569296                       
    23     0.068041   0.111585   0.569885                       
    24     0.064577   0.112451   0.571933                       
    25     0.061219   0.113053   0.575155                       
    26     0.057749   0.114088   0.576379                       
    27     0.054077   0.114567   0.582635                       
    28     0.050188   0.117518   0.580725                       
    29     0.047754   0.120573   0.576452                       
    30     0.044061   0.123642   0.591334                       
    31     0.040444   0.125317   0.606845                       
    32     0.037092   0.128816   0.62468                        
    33     0.03458    0.14032    0.612583                       
    34     0.032469   0.141568   0.610503     

[0.24857315290434553, 0.6642273853029884]

# Prepare Submission

## TTA

In [18]:
image_vflip = transforms.Compose([
            transforms.Resize(sz),
            transforms.RandomHorizontalFlip(p=1),
            transforms.ToTensor()])

image_hflip = transforms.Compose([
            transforms.Resize(sz),
            transforms.RandomVerticalFlip(p=1),
            transforms.ToTensor()])

class TestTimeDataset(MultiBandMultiLabelDataset):
    def check_flipped(self, flip):
        if flip==1:
            self.image_transform = image_vflip
        elif flip==2:
            self.image_transform = image_hflip

In [19]:
submit_set = TestTimeDataset(submit_df, base_path=PATH_TO_TEST_IMAGES, train_mode=False, image_transform=image_transform)

In [20]:
submission_load = DataLoader(submit_set, collate_fn=submit_set.collate_func, batch_size=bs, num_workers=6)

In [21]:
pred_array = np.zeros((11702, 28),dtype=float)

In [22]:
for o in [1,2,3]:
    print(str(o))
    submit_set.check_flipped(o)
    
    submission_load = DataLoader(submit_set, 
                                 collate_fn=submit_set.collate_func, 
                                 batch_size=16, num_workers=6)
    learn.load('res50_FA512')
    learn.data.test_dl = submission_load
    
    submission_predictions = learn.predict(is_test=True)
    pred_array = (pred_array + submission_predictions)/2

1
2
3


## Prepare Submission

In [23]:
def make_submission_file(sample_submission_df, predictions):
    submissions = []
    for row in predictions:
        subrow = ' '.join(list([str(i) for i in np.nonzero(row)[0]]))
        submissions.append(subrow)
    
    sample_submission_df['Predicted'] = submissions
    sample_submission_df.to_csv(f'sub_res50.csv', index=None)
    
    return sample_submission_df

In [24]:
THRESHOLD = 0.2
pred_array = pred_array>THRESHOLD

submission_file = make_submission_file(sample_submission_df=submit_df,
                     predictions=pred_array)

In [25]:
submission_file.head()

Unnamed: 0,Id,Predicted
0,00008af0-bad0-11e8-b2b8-ac1f6b6435d0,5
1,0000a892-bacf-11e8-b2b8-ac1f6b6435d0,3 7
2,0006faa6-bac7-11e8-b2b7-ac1f6b6435d0,0 2 25
3,0008baca-bad7-11e8-b2b9-ac1f6b6435d0,3 21
4,000cce7e-bad4-11e8-b2b8-ac1f6b6435d0,18


## No TTA

In [46]:
def make_submission_file(sample_submission_df, predictions):
    submissions = []
    for row in predictions:
        subrow = ' '.join(list([str(i) for i in np.nonzero(row)[0]]))
        submissions.append(subrow)
    
    sample_submission_df['Predicted'] = submissions
    sample_submission_df.to_csv('submission_FA.csv', index=None)
    
    return sample_submission_df

In [48]:
learn.load('res50_FA512')

In [49]:
pred = learn.predict(is_test=True)

In [50]:
pred.shape

(11702, 28)

In [51]:
pred>0.2

array([[False, False,  True, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [ True, False, False, ...,  True, False, False],
       ...,
       [False, False, False, ...,  True, False, False],
       [ True,  True, False, ..., False, False, False],
       [ True, False, False, ...,  True, False, False]])

In [52]:
THRESHOLD = 0.2
p = pred>THRESHOLD

submission_file = make_submission_file(sample_submission_df=submit_df,
                     predictions=p)