In [None]:
# Part1 Import all the library 

# Built In Imports
import os
from collections import defaultdict
from tqdm import tqdm_notebook as tqdm
import sys; 
sys.path.insert(0,'../input/timm-nfnet')
import timm
import warnings
warnings.filterwarnings("ignore")

# Visualization Imports
import matplotlib.pyplot as plt 
import seaborn as sns
import cv2 
import torch
from PIL import Image

# Machine Learning and Data Science Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
import torchvision.transforms as tfms
import numpy as np
import pandas as pd 

In [None]:
# Define the path to the root data directory & competition data directory

TRAIN_DF_PATH = "../input/hpa-single-cell-image-classification/train.csv"
TRAIN_IMG_PATH = "../input/hpa-single-cell-image-classification/train"
TEST_IMG_PATH = "../input/hpa-single-cell-image-classification/test"
SAMPLE_SUB = "../input/hpa-single-cell-image-classification/sample_submission.csv"
CELL_LABEL = {
0:  "Nucleoplasm", 
1:  "Nuclear membrane",   
2:  "Nucleoli",   
3:  "Nucleoli fibrillar center" ,  
4:  "Nuclear speckles",
5:  "Nuclear bodies",
6:  "Endoplasmic reticulum",   
7:  "Golgi apparatus",
8:  "Intermediate filaments",
9:  "Actin filaments", 
10: "Microtubules",
11:  "Mitotic spindle",
12:  "Centrosome",   
13:  "Plasma membrane",
14:  "Mitochondria",   
15:  "Aggresome",
16:  "Cytosol",   
17:  "Vesicles and punctate cytosolic patterns",   
18:  "Negative"
}

print("\n\n... IMPORTS COMPLETE ...\n")

In [None]:
# Create the relevant dataframe objects
train_df = pd.read_csv(TRAIN_DF_PATH)
train_df['label_count'] = train_df['Label'].apply(lambda x: len(x.split("|")))
train_df.head()

In [None]:
# Part2 Visualization of dataset
# class count for multi classes

plt.title("Class count")
sns.countplot(train_df['label_count'],palette="Set3")
plt.show()

#This shows that most of the samples are single classes compared to multi classes

In [None]:
#Compare between single vs Multi label distribution

single_class = train_df[train_df['label_count'] == 1]['label_count'].count()
multi_class = train_df[train_df['label_count'] > 1]['label_count'].count()

# Label - in the training data, this represents the labels assigned to each sample
plt.figure(figsize=(10, 8))
plt.title("Single VS Mutli distribution")
sns.barplot(x=['Single', 'Multi'], y=[single_class, multi_class],palette='Set3')

#palette : flare to Set3
plt.show()


In [None]:
#Compare between single vs multi label distribution
labels = train_df["Label"].apply(lambda x: x.split("|"))
labels_count = defaultdict(int)

# Update the counter 
for label in labels:
    if len(label) > 1:
        for l in label:
            labels_count[CELL_LABEL[int(l)]]+=1
    else:
        labels_count[CELL_LABEL[int(label[0])]]+=1
        
plt.figure(figsize=(10, 8))
plt.xticks(rotation=45)
plt.title("Target counts")
sns.barplot(list(labels_count.keys()),list(labels_count.values()), palette='Set3')
plt.show()

In [None]:
# Part3 show the four channel images

# Images are given in the form of red green blue and yellow
def show_image(img_path):
    
    sns.reset_orig()

    #get image id
    im_id = train_df.loc[1, "ID"]

    cdict1 = {'red':   ((0.0,  0.0, 0.0),
                       (1.0,  0.0, 0.0)),

             'green': ((0.0,  0.0, 0.0),
                       (0.75, 1.0, 1.0),
                       (1.0,  1.0, 1.0)),

             'blue':  ((0.0,  0.0, 0.0),
                       (1.0,  0.0, 0.0))}

    cdict2 = {'red':   ((0.0,  0.0, 0.0),
                       (0.75, 1.0, 1.0),
                       (1.0,  1.0, 1.0)),

             'green': ((0.0,  0.0, 0.0),
                       (1.0,  0.0, 0.0)),

             'blue':  ((0.0,  0.0, 0.0),
                       (1.0,  0.0, 0.0))}

    cdict3 = {'red':   ((0.0,  0.0, 0.0),
                       (1.0,  0.0, 0.0)),

             'green': ((0.0,  0.0, 0.0),
                       (1.0,  0.0, 0.0)),

             'blue':  ((0.0,  0.0, 0.0),
                       (0.75, 1.0, 1.0),
                       (1.0,  1.0, 1.0))}

    cdict4 = {'red': ((0.0,  0.0, 0.0),
                       (0.75, 1.0, 1.0),
                       (1.0,  1.0, 1.0)),

             'green': ((0.0,  0.0, 0.0),
                       (0.75, 1.0, 1.0),
                       (1.0,  1.0, 1.0)),

             'blue':  ((0.0,  0.0, 0.0),
                       (1.0,  0.0, 0.0))}

    plt.register_cmap(name='greens', data=cdict1)
    plt.register_cmap(name='reds', data=cdict2)
    plt.register_cmap(name='blues', data=cdict3)
    plt.register_cmap(name='yellows', data=cdict4)

    #get each image channel as a greyscale image (second argument 0 in imread)
    green = cv2.imread('../input/hpa-single-cell-image-classification/train/{}_green.png'.format(img_path), 0)
    red = cv2.imread('../input/hpa-single-cell-image-classification/train/{}_red.png'.format(img_path), 0)
    blue = cv2.imread('../input/hpa-single-cell-image-classification/train/{}_blue.png'.format(img_path), 0)
    yellow = cv2.imread('../input/hpa-single-cell-image-classification/train/{}_yellow.png'.format(img_path), 0)


    #display each channel separately
    fig, ax = plt.subplots(nrows = 2, ncols=2, figsize=(15, 15))
    ax[0, 0].imshow(green, cmap="greens")
    ax[0, 0].set_title("Protein of interest", fontsize=18)
    ax[0, 1].imshow(red, cmap="reds")
    ax[0, 1].set_title("Microtubules", fontsize=18)
    ax[1, 0].imshow(blue, cmap="blues")
    ax[1, 0].set_title("Nucleus", fontsize=18)
    ax[1, 1].imshow(yellow, cmap="yellows")
    ax[1, 1].set_title("Endoplasmic reticulum", fontsize=18)
    for i in range(2):
        for j in range(2):
            ax[i, j].set_xticklabels([])
            ax[i, j].set_yticklabels([])
            ax[i, j].tick_params(left=False, bottom=False)
    plt.show()
    

    
#All image samples are represented by four filters (stored as individual files), 
#the protein of interest (green) plus three cellular landmarks: nucleus (blue), microtubules (red), 
#endoplasmic reticulum (yellow). 
#The green filter should be used to predict the label and the other filters are used as references. 
show_image(train_df.iloc[1,0])

In [None]:
#hyperparameters
CLASS = 19
BATCH_SIZE = 64
EPOCHS = 5
LR = 1e-4
RESIZE = 256
DEVICE = torch.device('cuda') if torch.cuda.is_available() \
         else torch.device('cpu')
PATH = '../input/hpa-single-cell-image-classification/'
TRAIN_DIR = PATH + 'train/'
TEST_DIR = PATH + 'test/'

#imagenet transform
img_tfms = tfms.Compose(
    [tfms.ToTensor(),
     tfms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
DEVICE

In [None]:
class HPADataset(Dataset):
    def __init__(self,csv_path,ids,label,resize=None,transforms=None):
        self.csv_path = csv_path
        self.ids = ids
        self.label = label
        self.resize = resize
        self.transforms = transforms
        
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, item):
        _ids = self.ids[item]
        image = cv2.imread(os.path.join(self.csv_path,_ids +'_green.png'))
        if self.resize:
            image = cv2.resize(image, (self.resize, self.resize))
            image = image / 255.0
        
        #setting the target to one hot encoded form
        if "train" in self.csv_path:
            y = self.label[item]
            y = y.split('|')
            y = list(map(int, y))
            y = np.eye(CLASS, dtype='float')[y]
            y = y.sum(axis=0)
            return self.transforms(image), y
        elif "test" in self.csv_path:
            return self.transforms(image), _ids

In [None]:
#model
class NFNet(nn.Module):
    def __init__(self,output_features, model_name = 'nfnet_f1', pertrained=True):
        super(NFNet, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pertrained)
        self.model.head.fc = nn.Sequential(nn.Linear(self.model.head.fc.in_features, 512),
                                 nn.ReLU(),
                                 nn.Linear(512, output_features))
    
        
        
    def forward(self, x):
        x = self.model(x)
        return x

class CNNet(nn.Module):
    def __init__(self, input_features, output_features):
        super(CNNet, self).__init__()
        self.model = torchvision.models.resnet34(pretrained=True)
        self.model.fc = nn.Sequential(nn.Linear(input_features, 100),
                                 nn.ReLU(),
                                 nn.Linear(100, output_features))

    def forward(self, x):
        out = self.model(x)
        return out
    
# class PNASNet5Large(nn.Module):
#     def __init__(self, num_classes, odel_name = 'pnasnet5large', pretrained = True):
#         super(PNASNet5Large, self).__init__()
#         model = pretrainedmodels.pnasnet5large(num_classes=1000,pretrained="imagenet")
#         model.last_linear = nn.Linear(model.last_linear.in_features,
#                                       num_classes)
        
        
#     def forward(self, x):
#         x = self.model(x)
#         return x

In [None]:
model = NFNet(CLASS)
#model = PNASNet5Large(CLASS)
model = model.to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

Since we are experimenting we will only be taking a small batch of 5000 sample

In [None]:
train_df = pd.read_csv(PATH + "train.csv")
train_df =train_df.sample(frac=1).reset_index(drop=True)
train_df = train_df.iloc[:5000,:]
X_train, y_train = train_df.loc[:,'ID'].values,\
                    train_df['Label'].values
X_ds = HPADataset(TRAIN_DIR, X_train, y_train, RESIZE, img_tfms)
train_ds, valid_ds = random_split(X_ds,[4000,1000])  
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE,shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE,shuffle=True)

In [None]:
#train loop
loss_hist = []
for epoch in tqdm(range(EPOCHS)):
    losses = []
    model = model.train()
    for batch_idx, (image, label) in enumerate(train_dl):
        image = image.to(DEVICE)
        label = label.to(DEVICE)
        output = model(image.float())
        loss = loss_fn(output, label)
        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_hist.append(sum(losses)/len(losses))
    print(f"epoch: {epoch} loss:{sum(losses)/len(losses)}")

plt.figure(figsize=(15, 8))
plt.title('Train Loss')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.plot(loss_hist)
plt.show()

In [None]:
def check_accuracy(loader, model):
    correct = 0.
    total = 0.
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(DEVICE)
            outputs = model(images.float())
            outputs = torch.sigmoid(outputs).cpu() 
            predicted = np.round(outputs)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            break
    accuracy = 100 * correct / total
    print("Accuracy: {}%".format(accuracy))

In [None]:
X_Test = [name.rstrip('green.png').rstrip('_') for name in (os.listdir(TEST_DIR)) if '_green.png' in name]

test_ds = HPADataset(TEST_DIR, X_Test, None, RESIZE, img_tfms)
test_dl = DataLoader(test_ds, batch_size=1, shuffle=False)

submission_lst = []

with torch.no_grad():
    model.eval()
    for image, file in test_dl:     
        image = image.to(DEVICE)        
        output = model(image.float())                          
        prob = torch.softmax(output, dim=1)
        p, top_class = prob.topk(1, dim=1)
        sp = ' '.join(str(e) for e in [top_class[0][0].item(), p[0][0].item()])               
        img = cv2.imread(TEST_DIR + file[0] + '_green.png')
        
        if img.shape[0] == 2048:
            sp = sp + ' eNoLCAgIMAEABJkBdQ=='
        elif img.shape[0] == 1728:
            sp = sp + ' eNoLCAjJNgIABNkBkg=='
        else:
            sp = sp + ' eNoLCAgIsAQABJ4Beg=='
        
        submission_lst.append([file[0], img.shape[1], img.shape[0], sp])
        
sub = pd.DataFrame.from_records(submission_lst, columns=['ID', 'ImageWidth', 'ImageHeight', 'PredictionString'])
sub.head()

In [None]:
sub.to_csv("submission.csv", index=False)

Thanks to @ateplyuk for the dataset and inference pipelines

**if you found this notebook helpful, please leave an upvote!**