In [1]:
import os
import os.path
import random
import gc
import itertools
import numpy as np
import pandas as pd
import scipy.sparse
from tqdm import tqdm

In [2]:
import warnings 
warnings.filterwarnings('ignore')

In [3]:
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
from IPython.core.pylabtools import figsize

In [4]:
sns.set()

In [5]:
import pickle

def dump_pickle(file, filename):
    outfile = open(filename, 'wb')
    pickle.dump(file, outfile)
    outfile.close()

def load_pickle(filename):
    infile = open(filename, 'rb')
    file = pickle.load(infile)
    infile.close()
    return file

In [6]:
DATA_DIR = '../input/open-problems-multimodal'
%ls $DATA_DIR -lh

total 27G
-rw-r--r-- 1 nobody nogroup 2.3G Sep  7 20:28 evaluation_ids.csv
-rw-r--r-- 1 nobody nogroup 9.4M Sep  7 20:28 metadata.csv
-rw-r--r-- 1 nobody nogroup 230K Sep  7 20:28 metadata_cite_day_2_donor_27678.csv
-rw-r--r-- 1 nobody nogroup 805M Sep  7 20:28 sample_submission.csv
-rw-r--r-- 1 nobody nogroup 1.6G Sep  7 20:28 test_cite_inputs.h5
-rw-r--r-- 1 nobody nogroup 294M Sep  7 20:28 test_cite_inputs_day_2_donor_27678.h5
-rw-r--r-- 1 nobody nogroup 6.1G Sep  7 20:29 test_multi_inputs.h5
-rw-r--r-- 1 nobody nogroup 2.4G Sep  7 20:29 train_cite_inputs.h5
-rw-r--r-- 1 nobody nogroup  37M Sep  7 20:28 train_cite_targets.h5
-rw-r--r-- 1 nobody nogroup  11G Sep  7 20:30 train_multi_inputs.h5
-rw-r--r-- 1 nobody nogroup 3.0G Sep  7 20:29 train_multi_targets.h5


## Read Data

In [7]:
%%time
train_inp = pd.read_hdf(f'{DATA_DIR}/train_cite_inputs.h5')
train_inp_cols = train_inp.columns

CPU times: user 16.1 s, sys: 2.42 s, total: 18.5 s
Wall time: 39.4 s


In [8]:
%%time
test_inp = pd.read_hdf(f'{DATA_DIR}/test_cite_inputs.h5')

CPU times: user 11.3 s, sys: 1.49 s, total: 12.8 s
Wall time: 27.1 s


In [9]:
%%time
train_tar = pd.read_hdf(f'{DATA_DIR}/train_cite_targets.h5')
train_tar_cols = train_tar.columns

CPU times: user 136 ms, sys: 24.8 ms, total: 160 ms
Wall time: 541 ms


## Data Preprocessing

Find columns with all zeroes

In [10]:
%%time
zero_cols = []
for idx, col in enumerate(train_inp_cols, 0):
    if idx % 1000 == 0:
        print(idx)
    if len(train_inp[col].unique()) == 1 or len(test_inp[col].unique()) == 1:
        zero_cols.append(col)
print('Number of columns with zero values only (Train or Test):', 
      len(zero_cols))

0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
Number of columns with zero values only (Train or Test): 1194
CPU times: user 30.6 s, sys: 104 ms, total: 30.7 s
Wall time: 30.7 s


In [11]:
%%time
train_inp = train_inp.drop(zero_cols, axis=1)
train_inp_cols = train_inp.columns
test_inp = test_inp.drop(zero_cols, axis=1)
train_inp.shape, test_inp.shape

CPU times: user 585 ms, sys: 1.1 s, total: 1.69 s
Wall time: 1.69 s


((70988, 20856), (48663, 20856))

In [12]:
# %%time
# np.min(train_inp.min().unique()), np.max(train_inp.max().unique())

In [13]:
from sklearn.preprocessing import StandardScaler, Normalizer

In [14]:
sc = StandardScaler()
train_inp = sc.fit_transform(train_inp)
test_inp = sc.transform(test_inp)

In [15]:
del sc
gc.collect()

42

In [16]:
%mkdir ../tmp

In [17]:
dump_pickle(test_inp, '../tmp/test_inp')
del test_inp
gc.collect()

42

In [18]:
train_tar = train_tar.values

## Modeling

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset, random_split

In [20]:
# torch.manual_seed(42)
# torch.backends.cudnn.deterministic = True

In [21]:
%%time
train_inp = torch.from_numpy(train_inp)
train_tar = torch.from_numpy(train_tar)

CPU times: user 26 µs, sys: 5 µs, total: 31 µs
Wall time: 34.3 µs


In [22]:
full_ds = TensorDataset(train_inp, train_tar)
train_sz = 56832 # 111*512
val_sz = len(full_ds) - train_sz
train_ds, val_ds = random_split(full_ds, 
                                [train_sz, val_sz],
                                generator=torch.Generator().manual_seed(42))

In [23]:
batch_size = 512
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

In [24]:
class Net(nn.Module):
    def __init__(self):
        """In the constructor we instantiate two nn.Linear modules and assign them as
        member variables (self).
        """
        super(Net, self).__init__()
        self.linear1 = nn.Linear(20856, 128)
        self.bn1 = nn.BatchNorm1d(128)
        self.linear2 = nn.Linear(128, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.linear3 = nn.Linear(128, 128)
        self.bn3 = nn.BatchNorm1d(128)
        self.linear4 = nn.Linear(128, 140)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        x = self.linear1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.linear3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.linear4(x)
        return x

In [25]:
# preds = net(x)
# vpreds = preds - torch.mean(preds)
# vy = y - torch.mean(y)
# loss = torch.sum(vpreds * vy) / \
#        (torch.sqrt(torch.sum(vpreds ** 2)) *
#         torch.sqrt(torch.sum(vy ** 2)))
# loss

In [26]:
def train_model(train_loader, model, optimizer, scheduler):
    
    model.train()
    sum_corr = 0.0 # sum_loss = 0.0
    total = 0
    lrs = []
    
    for i, (x, y) in enumerate(train_loader):
        batch = x.shape[0]
        
        preds = model(x)
        vpreds = preds - torch.mean(preds)
        vy = y - torch.mean(y)
        corr = torch.sum(vpreds * vy) / \
               (torch.sqrt(torch.sum(vpreds ** 2)) *
                torch.sqrt(torch.sum(vy ** 2)))
        loss = -corr
        
        optimizer.zero_grad()
        loss.backward()
        
        optimizer.step()
        lrs.append(optimizer.param_groups[0]['lr'])
        scheduler.step()
        
        total += batch
        sum_corr += batch * corr
    
    train_corr = sum_corr/total 
    return train_corr, lrs

In [27]:
def model_eval(model, val_loader):
    model.eval()
    sum_corr = 0.0 # sum_loss = 0.0
    total = 0
    
    for i, (x, y) in enumerate(val_loader):
        batch = x.shape[0]
        
        preds = model(x)
        vpreds = preds - torch.mean(preds)
        vy = y - torch.mean(y)
        corr = torch.sum(vpreds * vy) / \
               (torch.sqrt(torch.sum(vpreds ** 2)) *
                torch.sqrt(torch.sum(vy ** 2)))
        loss = -corr
        
        total += batch
        sum_corr += batch * corr
        
    val_corr = sum_corr/total
    return val_corr

In [28]:
net = Net()
learning_rate = 0.01
weight_decay = 0.0001 
optimizer = optim.Adam(net.parameters(),
                       lr=learning_rate,
                       weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 
                                                max_lr=0.01,
                                                steps_per_epoch=len(train_loader),
                                                epochs=70)

In [29]:
%%time
epochs = 70
train_corrs = []
val_corrs = []
best_val_corr = 0.0
best_epoch = 0
num_epochs_run = 0
lrs = []

for epoch in range(epochs):
    num_epochs_run += 1
    
    # train
    print('epoch', epoch)
    train_corr, epoch_lrs = train_model(train_loader, net, optimizer, scheduler)
    train_corr = float(train_corr.detach().numpy())
    train_corrs.append(train_corr)
    lrs.append(epoch_lrs)
    print(f'train corr: {train_corr:.4f}')
    
    # val
    val_corr = model_eval(net, val_loader)
    val_corr = float(val_corr.detach().numpy())
    val_corrs.append(val_corr)
    print(f'val corr: {val_corr:.4f}')
    print()
    
    # early stoppage
    if val_corr >= 0.8880:
        print(f'**Early Stoppage** We are stopping at epoch {epoch}...')
#         print(f'Best epoch: epoch {best_epoch}')
#         print(f'Best val correlation: {best_val_corr: .4f}')
        print()
        break
            
#     # update
#     if val_corr > best_val_corr:
#         best_val_corr = max([best_val_corr, val_corr])
#         best_epoch = epoch
#         print(epoch, best_epoch, best_val_corr)
#         print()

epoch 0
train corr: 0.7291
val corr: 0.8589

epoch 1
train corr: 0.8768
val corr: 0.8747

epoch 2
train corr: 0.8949
val corr: 0.8761

epoch 3
train corr: 0.9074
val corr: 0.8755

epoch 4
train corr: 0.9126
val corr: 0.8762

epoch 5
train corr: 0.9144
val corr: 0.8763

epoch 6
train corr: 0.9150
val corr: 0.8769

epoch 7
train corr: 0.9125
val corr: 0.8792

epoch 8
train corr: 0.9100
val corr: 0.8793

epoch 9
train corr: 0.9069
val corr: 0.8821

epoch 10
train corr: 0.9037
val corr: 0.8818

epoch 11
train corr: 0.9006
val corr: 0.8828

epoch 12
train corr: 0.8979
val corr: 0.8845

epoch 13
train corr: 0.8958
val corr: 0.8854

epoch 14
train corr: 0.8937
val corr: 0.8849

epoch 15
train corr: 0.8925
val corr: 0.8840

epoch 16
train corr: 0.8917
val corr: 0.8845

epoch 17
train corr: 0.8907
val corr: 0.8845

epoch 18
train corr: 0.8901
val corr: 0.8820

epoch 19
train corr: 0.8894
val corr: 0.8825

epoch 20
train corr: 0.8892
val corr: 0.8858

epoch 21
train corr: 0.8891
val corr: 0.8820

In [30]:
# figsize(8, 8)
# plt.plot(list(itertools.chain(*lrs)))

In [31]:
# figsize(8, 8)
# plt.scatter(list(range(num_epochs_run)), train_corrs, label='training corr')
# plt.plot(list(range(num_epochs_run)), train_corrs)
# plt.scatter(list(range(num_epochs_run)), val_corrs, label='val corr')
# plt.plot(list(range(num_epochs_run)), val_corrs)
# plt.legend(loc='lower right')
# # plt.title(f'lr = {learning_rate} | wd = {weight_decay}')
# plt.ylim(0, 1)

In [32]:
# train_inp[:1]

In [33]:
# with torch.no_grad():
#     net.eval()
#     train_tar_preds = net(train_inp[:5]).detach().numpy().flatten()

In [34]:
# np.corrcoef(train_tar_preds, train_tar[:5].detach().numpy().flatten())

In [35]:
del train_inp, train_tar, full_ds
gc.collect()

126

## Prediction

In [36]:
%%time
test_inp = load_pickle('../tmp/test_inp')
test_inp = torch.from_numpy(test_inp)

CPU times: user 395 ms, sys: 6.49 s, total: 6.88 s
Wall time: 6.87 s


In [37]:
with torch.no_grad():
    net.eval()
    test_tar_preds = net(test_inp).detach().numpy()

In [38]:
del test_inp
gc.collect()

63

## Creating Submission

In [39]:
DATA_DIR = '../input/msci-h5-sparse-transform'
%ls $DATA_DIR -lh

total 7.1G
-rw-r--r-- 1 nobody nogroup  25K Nov  2 12:50 __notebook__.ipynb
-rw-r--r-- 1 nobody nogroup  25K Nov  2 12:50 __output__.json
-rw-r--r-- 1 nobody nogroup 293K Nov  2 12:50 __results__.html
-rw-r--r-- 1 nobody nogroup    0 Nov  2 12:50 custom.css
-rw-r--r-- 1 nobody nogroup 359M Nov  2 12:50 evaluation_ids.parquet
-rw-r--r-- 1 nobody nogroup 3.8M Nov  2 12:50 metadata.parquet
-rw-r--r-- 1 nobody nogroup 108K Nov  2 12:50 metadata_cite_day_2_donor_27678.parquet
-rw-r--r-- 1 nobody nogroup 252M Nov  2 12:50 sample_submission.parquet
-rw-r--r-- 1 nobody nogroup 856K Nov  2 12:50 test_cite_inputs_day_2_donor_27678_idx.npz
-rw-r--r-- 1 nobody nogroup  78M Nov  2 12:50 test_cite_inputs_day_2_donor_27678_val.sparse.npz
-rw-r--r-- 1 nobody nogroup 1.8M Nov  2 12:50 test_cite_inputs_idx.npz
-rw-r--r-- 1 nobody nogroup 488M Nov  2 12:51 test_cite_inputs_val.sparse.npz
-rw-r--r-- 1 nobody nogroup 8.4M Nov  2 12:50 test_multi_inputs_idx.npz
-rw-r--r-- 1 nobody nogroup 1.7G

In [40]:
test_tar_cols = np.load(f'{DATA_DIR}/train_cite_targets_idx.npz',
                        allow_pickle=True)['columns']
test_tar_idx = np.load(f'{DATA_DIR}/test_cite_inputs_idx.npz',
                       allow_pickle=True)['index']
test_tar_cols.shape, test_tar_idx.shape, test_tar_preds.shape

((140,), (48663,), (48663, 140))

In [41]:
%%time
print('Start Eval...')
eval_ids = pd.read_parquet(f'{DATA_DIR}/evaluation_ids.parquet')
eval_ids.cell_id = eval_ids.cell_id.astype(pd.CategoricalDtype())
eval_ids.gene_id = eval_ids.gene_id.astype(pd.CategoricalDtype())

Start Eval...
CPU times: user 31.4 s, sys: 16.4 s, total: 47.9 s
Wall time: 38.9 s


In [42]:
%%time
sub = pd.Series(name='target',
                index=pd.MultiIndex.from_frame(eval_ids), 
                dtype=np.float32)
sub

CPU times: user 12.2 s, sys: 4.49 s, total: 16.7 s
Wall time: 16.7 s


row_id    cell_id       gene_id        
0         c2150f55becb  CD86              NaN
1         c2150f55becb  CD274             NaN
2         c2150f55becb  CD270             NaN
3         c2150f55becb  CD155             NaN
4         c2150f55becb  CD112             NaN
                                           ..
65744175  2c53aa67933d  ENSG00000134419   NaN
65744176  2c53aa67933d  ENSG00000186862   NaN
65744177  2c53aa67933d  ENSG00000170959   NaN
65744178  2c53aa67933d  ENSG00000107874   NaN
65744179  2c53aa67933d  ENSG00000166012   NaN
Name: target, Length: 65744180, dtype: float32

In [43]:
cell_id_dict = {cell_id: idx 
                for idx, cell_id in enumerate(test_tar_idx, 0)}
gene_id_dict = {gene_id: idx 
                for idx, gene_id in enumerate(test_tar_cols, 0)}

In [44]:
eid_cid_idx = eval_ids['cell_id']\
              .apply(lambda x: cell_id_dict.get(x, -1))
eid_gid_idx = eval_ids['gene_id']\
              .apply(lambda x: gene_id_dict.get(x, -1))
valid_cite_rows = (eid_cid_idx != -1) & (eid_gid_idx != -1)

In [45]:
%%time
sub.iloc[valid_cite_rows] = test_tar_preds\
                             [eid_cid_idx[valid_cite_rows].to_numpy(),
                              eid_gid_idx[valid_cite_rows].to_numpy()]

CPU times: user 195 ms, sys: 60 ms, total: 255 ms
Wall time: 254 ms


In [46]:
del eval_ids, test_tar_idx, test_tar_cols
del eid_cid_idx, eid_gid_idx, valid_cite_rows
gc.collect() 

97

In [47]:
sub = pd.DataFrame(sub).fillna(0).reset_index()
sub.drop(['cell_id', 'gene_id'], axis=1)\
   .to_csv('cite_sub.csv', index=False)

In [48]:
sub.head()

Unnamed: 0,row_id,cell_id,gene_id,target
0,0,c2150f55becb,CD86,-0.072532
1,1,c2150f55becb,CD274,-0.063299
2,2,c2150f55becb,CD270,-0.057439
3,3,c2150f55becb,CD155,0.157071
4,4,c2150f55becb,CD112,0.174948
