<h1><span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Setup-Environment" data-toc-modified-id="Setup-Environment-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Setup Environment</a></span><ul class="toc-item"><li><span><a href="#Install-Libraries" data-toc-modified-id="Install-Libraries-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Install Libraries</a></span></li><li><span><a href="#Import-Libraries" data-toc-modified-id="Import-Libraries-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>Import Libraries</a></span></li><li><span><a href="#Global-Display-Setting" data-toc-modified-id="Global-Display-Setting-1.3"><span class="toc-item-num">1.3&nbsp;&nbsp;</span>Global Display Setting</a></span></li><li><span><a href="#Global-Constants" data-toc-modified-id="Global-Constants-1.4"><span class="toc-item-num">1.4&nbsp;&nbsp;</span>Global Constants</a></span></li><li><span><a href="#Random-Seed" data-toc-modified-id="Random-Seed-1.5"><span class="toc-item-num">1.5&nbsp;&nbsp;</span>Random Seed</a></span></li></ul></li><li><span><a href="#Exploratory-Data-Analysis" data-toc-modified-id="Exploratory-Data-Analysis-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Exploratory Data Analysis</a></span><ul class="toc-item"><li><span><a href="#Loading-Data" data-toc-modified-id="Loading-Data-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Loading Data</a></span></li><li><span><a href="#Features-Analysis" data-toc-modified-id="Features-Analysis-2.2"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>Features Analysis</a></span><ul class="toc-item"><li><span><a href="#Features--Type" data-toc-modified-id="Features--Type-2.2.1"><span class="toc-item-num">2.2.1&nbsp;&nbsp;</span>Features  Type</a></span></li><li><span><a href="#Univariate-and-Multivariate-Visulizations" data-toc-modified-id="Univariate-and-Multivariate-Visulizations-2.2.2"><span class="toc-item-num">2.2.2&nbsp;&nbsp;</span>Univariate and Multivariate Visulizations</a></span></li></ul></li><li><span><a href="#Targets-Analysis-(Multi-Label-Classification)" data-toc-modified-id="Targets-Analysis-(Multi-Label-Classification)-2.3"><span class="toc-item-num">2.3&nbsp;&nbsp;</span>Targets Analysis (Multi Label Classification)</a></span></li></ul></li></ul></div>

In [1]:
! [ ! -L /kaggle ] && ln -s /data/kaggle /kaggle 

[geekysaint/demystifying-mechanism-of-action-eda][1]

[1]: https://www.kaggle.com/geekysaint/demystifying-mechanism-of-action-eda

## Setup Environment

### Install Libraries 

In [2]:
import pkg_resources

LIBS = sorted(pkg_resources.working_set)

needed_stack = [
    'numpy', 'pandas', 'seaborn', 'matplotlib',
    'torch', 'torchvision', 'scikit-learn', 'plotly',
    'iterative-stratification', 'pytorch-lightning'
]

for item in LIBS:
    if item.project_name in needed_stack:
        print(item.project_name, item.version)
        needed_stack.remove(item.project_name)
for apk in needed_stack:
    !pip install $apk

iterative-stratification 0.1.6
torchvision 0.7.0.dev20200609+cu101
pytorch-lightning 0.9.0
seaborn 0.10.1
scikit-learn 0.23.1
pandas 1.0.4
torch 1.6.0.dev20200609+cu101
numpy 1.18.5
matplotlib 3.2.1
plotly 4.10.0


### Import Libraries

In [3]:
import os
import warnings
import random
import pytorch_lightning as pl
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torchvision
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt

from torch import nn
from torch.nn import functional as F
from torch.utils.data import (Dataset, DataLoader)
from torchvision.transforms import (
        Resize,
        Compose,
        ToTensor,
        Normalize,
        RandomOrder,
        ColorJitter,
        RandomRotation,
        RandomGrayscale,
        RandomResizedCrop,
        RandomVerticalFlip,
        RandomHorizontalFlip)

from PIL import Image, ImageDraw, ImageFont
from sklearn.model_selection import train_test_split
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from plotly.subplots import make_subplots

### Global Display Setting

In [87]:
%matplotlib inline

warnings.filterwarnings("ignore")

pd.set_option('display.float_format', lambda x: '{:.3f}'.format(x))
pd.set_option('display.max_columns', 30)
pd.set_option('display.max_rows', 50)

# sns.set(style='white', font_scale=1.2)

### Global Constants

In [5]:
RNG_SEED = 9527
DATA_NAME = 'lish-moa'
DATA_ROOT = f'/kaggle/input/{DATA_NAME}'
WORK_ROOT = f'/kaggle/working/{DATA_NAME}'
CKPT_PATH = f'{WORK_ROOT}/checkpoints/best.ckpt'
SUBMITCSV = '/kaggle/working/submission.csv'
FONT_PATH = '/usr/share/fonts/truetype/dejavu/DejaVuSerif-Bold.ttf'

INPUT_SIZE = 28
BATCH_SIZE = 48
NUM_CLASSES = 10

MAX_EPOCHS = 30

TEST_SPLIT = 0.3

### Random Seed

In [6]:
torch.manual_seed(RNG_SEED)
torch.cuda.manual_seed(RNG_SEED)
np.random.seed(RNG_SEED)
random.seed(RNG_SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [7]:
!ls -l $DATA_ROOT
!mkdir -p $WORK_ROOT

total 210376
-rw-rw-r-- 1 1002 1002   3337770 Sep  2 15:04 sample_submission.csv
-rw-rw-r-- 1 1002 1002  26140860 Sep  2 15:04 test_features.csv
-rw-rw-r-- 1 1002 1002 156340282 Sep  2 15:04 train_features.csv
-rw-rw-r-- 1 1002 1002  19466300 Sep  2 15:04 train_targets_nonscored.csv
-rw-rw-r-- 1 1002 1002  10125786 Sep  2 15:04 train_targets_scored.csv


##  Exploratory Data Analysis

### Loading Data

In [8]:
df_train_features = pd.read_csv(os.path.join(DATA_ROOT, 'train_features.csv'))
df_train_targets  = pd.read_csv(os.path.join(DATA_ROOT, 'train_targets_scored.csv'))
df_test_features  = pd.read_csv(os.path.join(DATA_ROOT, 'test_features.csv'))

df_train_features.shape, df_train_targets.shape, df_test_features.shape

((23814, 876), (23814, 207), (3982, 876))

In [9]:
feature_names = df_train_features.columns[1:] 
target_names  = df_train_targets.columns[1:]

In [10]:
df_train_features.head()

Unnamed: 0,sig_id,cp_type,cp_time,cp_dose,g-0,g-1,g-2,g-3,g-4,g-5,g-6,g-7,g-8,g-9,g-10,...,c-85,c-86,c-87,c-88,c-89,c-90,c-91,c-92,c-93,c-94,c-95,c-96,c-97,c-98,c-99
0,id_000644bb2,trt_cp,24,D1,1.062,0.558,-0.248,-0.621,-0.194,-1.012,-1.022,-0.033,0.555,-0.092,1.183,...,0.18,0.537,-0.111,-1.012,0.668,0.286,0.258,0.808,0.552,-0.191,0.658,-0.398,0.214,0.38,0.418
1,id_000779bfc,trt_cp,72,D1,0.074,0.409,0.299,0.06,1.019,0.521,0.234,0.337,-0.405,0.851,-1.152,...,0.442,0.937,0.819,-0.424,0.319,-0.426,0.754,0.471,0.023,0.296,0.49,0.152,0.124,0.608,0.737
2,id_000a6266a,trt_cp,48,D1,0.628,0.582,1.554,-0.076,-0.032,1.239,0.172,0.215,0.006,1.23,-0.48,...,0.117,0.109,-0.311,0.302,-0.087,-0.725,-0.63,0.61,0.022,-1.324,-0.317,-0.642,-0.219,-1.408,0.693
3,id_0015fd391,trt_cp,48,D1,-0.514,-0.249,-0.266,0.529,4.062,-0.809,-1.959,0.179,-0.132,-1.06,-0.827,...,-1.539,-2.46,-0.942,-1.555,0.243,-2.099,-0.644,-5.63,-1.378,-0.863,-1.288,-1.621,-0.878,-0.388,-0.815
4,id_001626bd3,trt_cp,72,D2,-0.325,-0.401,0.97,0.692,1.418,-0.824,-0.28,-0.15,-0.879,0.863,-0.222,...,0.07,0.813,0.192,0.605,-0.182,0.004,0.005,0.667,1.069,0.552,-0.303,0.109,0.288,-0.379,0.713


In [11]:
mask_smaples = (df_train_targets[target_names[0]] > 0) | (df_train_targets[target_names[2]] > 0)
df_train_targets.loc[mask_smaples].tail()

Unnamed: 0,sig_id,5-alpha_reductase_inhibitor,11-beta-hsd1_inhibitor,acat_inhibitor,acetylcholine_receptor_agonist,acetylcholine_receptor_antagonist,acetylcholinesterase_inhibitor,adenosine_receptor_agonist,adenosine_receptor_antagonist,adenylyl_cyclase_activator,adrenergic_receptor_agonist,adrenergic_receptor_antagonist,akt_inhibitor,aldehyde_dehydrogenase_inhibitor,alk_inhibitor,...,tlr_agonist,tlr_antagonist,tnf_inhibitor,topoisomerase_inhibitor,transient_receptor_potential_channel_antagonist,tropomyosin_receptor_kinase_inhibitor,trpv_agonist,trpv_antagonist,tubulin_inhibitor,tyrosine_kinase_inhibitor,ubiquitin_specific_protease_inhibitor,vegfr_inhibitor,vitamin_b,vitamin_d_receptor_agonist,wnt_inhibitor
19382,id_d00440fe6,0,0,1,0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
19571,id_d218733b4,0,0,1,0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
20100,id_d7bb3adc4,1,0,0,0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
21941,id_eb64d285a,1,0,0,0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
23146,id_f89e49084,1,0,0,0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


### Features Analysis

In [12]:
df_train_features.describe()

Unnamed: 0,cp_time,g-0,g-1,g-2,g-3,g-4,g-5,g-6,g-7,g-8,g-9,g-10,g-11,g-12,g-13,...,c-85,c-86,c-87,c-88,c-89,c-90,c-91,c-92,c-93,c-94,c-95,c-96,c-97,c-98,c-99
count,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,...,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0,23814.0
mean,48.02,0.248,-0.096,0.152,0.082,0.057,-0.139,0.036,-0.203,-0.19,0.12,-0.123,0.182,0.143,0.209,...,-0.409,-0.333,-0.295,-0.328,-0.402,-0.469,-0.461,-0.513,-0.5,-0.507,-0.354,-0.463,-0.378,-0.47,-0.302
std,19.403,1.393,0.812,1.036,0.95,1.032,1.179,0.882,1.125,1.75,1.087,1.292,1.254,1.235,1.273,...,1.884,1.647,1.634,1.663,1.833,2.0,2.042,2.002,2.107,2.16,1.629,2.06,1.704,1.835,1.408
min,24.0,-5.513,-5.737,-9.104,-5.998,-6.369,-10.0,-10.0,-10.0,-10.0,-8.337,-10.0,-5.87,-8.587,-5.018,...,-10.0,-10.0,-10.0,-10.0,-10.0,-10.0,-10.0,-10.0,-10.0,-10.0,-10.0,-10.0,-10.0,-10.0,-10.0
25%,24.0,-0.473,-0.562,-0.438,-0.43,-0.471,-0.602,-0.494,-0.525,-0.512,-0.36,-0.511,-0.49,-0.448,-0.481,...,-0.56,-0.534,-0.505,-0.544,-0.569,-0.566,-0.566,-0.59,-0.569,-0.564,-0.568,-0.553,-0.561,-0.593,-0.563
50%,48.0,-0.009,-0.047,0.075,0.008,-0.027,-0.016,-0.001,-0.018,0.01,0.16,0.039,0.014,0.06,0.01,...,-0.002,0.008,-0.006,-0.021,-0.03,-0.01,0.003,-0.009,-0.014,-0.003,-0.01,-0.001,-0.007,0.014,-0.019
75%,72.0,0.526,0.403,0.664,0.463,0.465,0.51,0.529,0.412,0.549,0.698,0.525,0.575,0.604,0.576,...,0.462,0.466,0.463,0.45,0.431,0.458,0.462,0.446,0.453,0.471,0.445,0.465,0.446,0.461,0.439
max,72.0,10.0,5.039,8.257,10.0,10.0,7.282,7.333,5.473,8.887,6.433,10.0,10.0,10.0,10.0,...,3.738,3.252,5.406,3.11,3.32,4.069,3.96,3.927,3.596,3.747,2.814,3.505,2.924,3.111,3.805


#### Features  Type

In [13]:
df_category_features  = df_train_features.select_dtypes(include=['object'])
df_numerical_features = df_train_features.select_dtypes(exclude=['object'])
set([str(x) for x in df_category_features.dtypes.values]), \
set([str(x) for x in df_numerical_features.dtypes.values])

({'object'}, {'float64', 'int64'})

In [14]:
df_category_features.columns.to_list()

['sig_id', 'cp_type', 'cp_dose']

In [15]:
df_int64_features = df_train_features.select_dtypes(include=['int64'])
for col in df_int64_features.columns:
    if len(df_int64_features[col].value_counts()) < 10: # TODO
        df_category_features[col] = df_int64_features[col]
        df_numerical_features.drop([col], axis=1, inplace=True)

In [214]:
df_category_features.drop(['sig_id'], axis=1, inplace=True)
category_features = df_category_features.columns.to_list()
category_features

['cp_type', 'cp_dose', 'cp_time']

#### Univariate and Multivariate Visulizations 

In [215]:
def bar_category_visualiz(df, features, title):
    fig = make_subplots(
        rows=1, cols=3,
        # subplot_titles=("Type(Compound vs Control)", "Dose(High vs Low)", "Time(in Hours)")
    )

    fig.update_layout(title_text = title)
    fig.layout.template = 'plotly_dark'

    for idx, name in enumerate(features, 1):
        # df[col].value_counts().plot(kind='bar',figsize=[10,3], title=col)
        # plt.show()
        series = df[name].value_counts()
        fig.add_trace(
            go.Bar(
                x=series.index, y=series.values,
                text=series.values,
                textposition="outside",
                name=name),
            row=1, col=idx)
        fig.update_xaxes(title_text=name, row=1, col=idx)
        
    fig.update_yaxes(title_text="Total Observations in the Dataset", row=1, col=1)

    fig.show()
    
bar_category_visualiz(df_train_features, category_features, 'Train Datasets')
bar_category_visualiz(df_test_features,  category_features, 'Test Datasets')

In [200]:
GENES = [col for col in df_numerical_features.columns if col.startswith('g-')]
CELLS = [col for col in df_numerical_features.columns if col.startswith('c-')]

nrow = 6
ncol = 2

def hist_numerical_visualiz(df, features, title, nrow, ncol):
    samples = random.sample(features, nrow*ncol)

    fig = make_subplots(
        rows=nrow, cols=ncol,
        subplot_titles=samples
    )

    fig.layout.template = 'plotly_dark'
    fig.update_layout(showlegend=False, title_text=title)

    for r in range(nrow):
        for c in range(ncol):
            name = samples[r*ncol + c]
            fig.add_trace(
                go.Histogram(
                    x=df[name].values,
                    name=name),
                row=r+1, col=c+1)
    fig.show()
    
    return samples
    
gene_samples = hist_numerical_visualiz(
    df_numerical_features,
    GENES, 'Sample of Genes Distribution',
    nrow, ncol,
)

cell_samples = hist_numerical_visualiz(
    df_numerical_features,
    CELLS, 'Sample of Cells Distribution',
    nrow, ncol,
)

In [211]:
COLORS = ('blue', 'orange', 'green')

def violin_groupcmp_visualiz(feature_xs, feature_ys):
    for r, feat_x in enumerate(feature_xs):
        fig = make_subplots(
            rows=1, cols=len(feature_ys),
            subplot_titles=[f'{x} expression' for x in feature_ys]
        )
        fig.update_layout(
            autosize=True,
            # width=600, height=400,
            showlegend=False,
        )
        fig.layout.template = 'plotly_dark'
        
        for c, feat_y in enumerate(feature_ys):
            feat_xy_grouped = df_train_features.groupby([feat_x])[feat_y]
            for cat_ty in feat_xy_grouped.indices.keys():
                fig.add_trace(
                    go.Violin(
                        y=feat_xy_grouped.get_group(cat_ty).values,
                        line_color=COLORS[c],
                        name=cat_ty,
                        box_visible=True,
                        meanline_visible=True),
                    row=1, col=c+1)
            fig.update_xaxes(title_text=feat_x, row=r+1, col=c+1)
            
        fig.show()
        
violin_groupcmp_visualiz(['cp_type', 'cp_dose', 'cp_time'], gene_samples[:2])
violin_groupcmp_visualiz(['cp_type', 'cp_dose', 'cp_time'], cell_samples[:2])

In [228]:
cc = df_train_features.groupby(['cp_type','cp_dose', 'cp_time'])

In [236]:
dir(cc)

['__annotations__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_accessors',
 '_add_numeric_operations',
 '_agg_examples_doc',
 '_agg_see_also_doc',
 '_aggregate',
 '_aggregate_frame',
 '_aggregate_item_by_item',
 '_aggregate_multiple_funcs',
 '_apply_filter',
 '_apply_to_column_groupbys',
 '_apply_whitelist',
 '_assure_grouper',
 '_bool_agg',
 '_builtin_table',
 '_choose_path',
 '_concat_objects',
 '_constructor',
 '_cumcount_array',
 '_cython_agg_blocks',
 '_cython_agg_general',
 '_cython_table',
 '_cython_transform',
 '_define_paths',
 '_deprecations',
 '_dir_additions',
 '_dir_deletions',
 '_fill',
 '_ge

In [224]:
data = df_train_features.groupby(['cp_type','cp_dose', 'cp_time']).count()
data

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,sig_id,g-0,g-1,g-2,g-3,g-4,g-5,g-6,g-7,g-8,g-9,g-10,g-11,g-12,g-13,...,c-85,c-86,c-87,c-88,c-89,c-90,c-91,c-92,c-93,c-94,c-95,c-96,c-97,c-98,c-99
cp_type,cp_dose,cp_time,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1
ctl_vehicle,D1,24,301,301,301,301,301,301,301,301,301,301,301,301,301,301,301,...,301,301,301,301,301,301,301,301,301,301,301,301,301,301,301
ctl_vehicle,D1,48,343,343,343,343,343,343,343,343,343,343,343,343,343,343,343,...,343,343,343,343,343,343,343,343,343,343,343,343,343,343,343
ctl_vehicle,D1,72,307,307,307,307,307,307,307,307,307,307,307,307,307,307,307,...,307,307,307,307,307,307,307,307,307,307,307,307,307,307,307
ctl_vehicle,D2,24,305,305,305,305,305,305,305,305,305,305,305,305,305,305,305,...,305,305,305,305,305,305,305,305,305,305,305,305,305,305,305
ctl_vehicle,D2,48,305,305,305,305,305,305,305,305,305,305,305,305,305,305,305,...,305,305,305,305,305,305,305,305,305,305,305,305,305,305,305
ctl_vehicle,D2,72,305,305,305,305,305,305,305,305,305,305,305,305,305,305,305,...,305,305,305,305,305,305,305,305,305,305,305,305,305,305,305
trt_cp,D1,24,3585,3585,3585,3585,3585,3585,3585,3585,3585,3585,3585,3585,3585,3585,3585,...,3585,3585,3585,3585,3585,3585,3585,3585,3585,3585,3585,3585,3585,3585,3585
trt_cp,D1,48,4011,4011,4011,4011,4011,4011,4011,4011,4011,4011,4011,4011,4011,4011,4011,...,4011,4011,4011,4011,4011,4011,4011,4011,4011,4011,4011,4011,4011,4011,4011
trt_cp,D1,72,3600,3600,3600,3600,3600,3600,3600,3600,3600,3600,3600,3600,3600,3600,3600,...,3600,3600,3600,3600,3600,3600,3600,3600,3600,3600,3600,3600,3600,3600,3600
trt_cp,D2,24,3581,3581,3581,3581,3581,3581,3581,3581,3581,3581,3581,3581,3581,3581,3581,...,3581,3581,3581,3581,3581,3581,3581,3581,3581,3581,3581,3581,3581,3581,3581


In [227]:
data.iloc[:, 0]

cp_type      cp_dose  cp_time
ctl_vehicle  D1       24          301
                      48          343
                      72          307
             D2       24          305
                      48          305
                      72          305
trt_cp       D1       24         3585
                      48         4011
                      72         3600
             D2       24         3581
                      48         3591
                      72         3580
Name: sig_id, dtype: int64

In [None]:
data = full_data.groupby(['cp_type','cp_time','cp_dose','data'])['sig_id'].count().reset_index()
data.columns = ['cp_type','cp_time','cp_dose','data', 'count']


fig = px.sunburst(data, path=['cp_type', 'cp_time', 'cp_dose'], values='count',
                  color_discrete_sequence = px.colors.qualitative.G10,
                 title ='Train and Test "cp_type","cp_dose" and "cp_dose" distribution')
fig.layout.template = 'plotly_dark'
fig.show()

### Targets Analysis (Multi Label Classification)

In [None]:
df_train_targets.describe()