In [1]:
%%capture output
!pip install --upgrade pip
# !pip install --upgrade pandas
!pip install tables   
# necessary for pd.read_hdf()

!pip install ipywidgets
!pip install --upgrade jupyter
!pip install IProgress
!pip install catboost
!pip install shap

In [2]:
print(output.stderr) # prints potential installation errors from cell above




In [3]:
import os
import random
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold, GroupKFold
import anndata as ad

In [4]:
%matplotlib inline
from tqdm.notebook import tqdm
import gc
import pickle

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda")
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)

## data load

In [29]:
lrz_path = '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/open-problems-multimodal-3rd-solution/'

model_path_for_now = '/dss/dsshome1/02/di93zoj/valentina/open-problems-multimodal-3rd-solution/'

raw_path =  lrz_path + 'input/raw/'  # '../../../input/raw/'

cite_target_path = lrz_path + 'input/target/cite/'   # '../../../input/target/cite/'
cite_feature_path = lrz_path + 'input/features/cite/'   # '../../../input/features/cite/'
cite_mlp_path = lrz_path + 'model/cite/mlp/'   # '../../../model/cite/mlp/'   # '../../../model/cite/mlp/'
cite_cb_path = lrz_path + 'model/cite/cb/'   # '../../../model/cite/cb/'

multi_target_path = lrz_path + 'input/target/multi/'   # '../../../input/target/multi/'
multi_feature_path = lrz_path + 'input/features/multi/'   # '../../../input/features/multi/'
multi_mlp_path = lrz_path + 'model/multi/mlp/'   # '../../../model/multi/mlp/'
multi_cb_path = lrz_path + 'model/multi/cb/'   # '../../../model/multi/cb/'

output_path = lrz_path + 'output/'   # '../../../output/'

### train datasets from model/

In [6]:
feature_dict = {}
#                               training sets           test sets
feature_dict['add_con_imp'] = ['X_add_con_imp.pickle', 'X_test_add_con_imp.pickle']
feature_dict['last_v3'] = ['X_last_v3.pickle', 'X_test_last_v3.pickle']
feature_dict['c_add_w2v_v1_mish'] = ['X_c_add_w2v_v1.pickle', 'X_test_c_add_w2v_v1.pickle']
feature_dict['c_add_w2v_v1'] = ['X_c_add_w2v_v1.pickle', 'X_test_c_add_w2v_v1.pickle']
feature_dict['c_add_84_v1'] = ['X_c_add_84_v1.pickle', 'X_test_c_add_84_v1.pickle']
feature_dict['c_add_120_v1'] = ['X_c_add_v1.pickle', 'X_test_c_add_v1.pickle']

feature_dict['w2v_cell'] = ['X_feature_w2v_cell.pickle', 'X_test_feature_w2v_cell.pickle']
feature_dict['best_cell_120'] = ['X_best_cell_128_120.pickle', 'X_test_best_cell_128_120.pickle']
feature_dict['cluster_cell'] = ['X_cluster_cell_128.pickle', 'X_test_cluster_cell_128.pickle']

feature_dict['w2v_128'] = ['X_feature_w2v.pickle', 'X_test_feature_w2v.pickle']
feature_dict['imp_w2v_128'] = ['X_feature_imp_w2v.pickle', 'X_test_feature_imp_w2v.pickle']
feature_dict['snorm'] = ['X_feature_snorm.pickle', 'X_test_feature_snorm.pickle']
feature_dict['best_128'] = ['X_best_128.pickle', 'X_test_best_128.pickle']
feature_dict['best_64'] = ['X_best_64.pickle', 'X_test_best_64.pickle']
feature_dict['cluster_128'] = ['X_cluster_128.pickle', 'X_test_cluster_128.pickle']
feature_dict['cluster_64'] = ['X_cluster_64.pickle', 'X_test_cluster_64.pickle']
feature_dict['svd_128'] = ['X_svd_128.pickle', 'X_test_svd_128.pickle']   # model #16
feature_dict['svd_64'] = ['X_svd_64.pickle', 'X_test_svd_64.pickle']

## Cite

In [7]:
# get model name
mlp_model_path = os.listdir(cite_mlp_path)
mlp_model_path

['cite_mlp_corr_add_con_imp_flg_donor_val_50',
 'cite_mlp_corr_c_add_84_v1_flg_donor_val_47',
 'cite_mlp_corr_c_add_120_v1_flg_donor_val_63',
 '.ipynb_checkpoints',
 'cite_mlp_corr_snorm_flg_donor_val_39',
 'cite_mlp_corr_c_add_w2v_v1_mish_flg_donor_val_66',
 'cite_mlp_corr_cluster_128_flg_donor_val_51',
 'cite_mlp_corr_svd_128_flg_donor_val_30',
 'cite_mlp_corr_w2v_cell_flg_donor_val_51',
 'cite_mlp_corr_cluster_64_flg_donor_val_57',
 'cite_mlp_corr_w2v_128_flg_donor_val_42',
 'cite_mlp_corr_cluster_cell_flg_donor_val_64',
 'cite_mlp_corr_imp_w2v_128_flg_donor_val_38',
 'cite_mlp_corr_best_cell_120_flg_donor_val_51',
 'cite_mlp_corr_best_128_flg_donor_val_45',
 'cite_mlp_corr_svd_64_flg_donor_val_38',
 'cite_mlp_corr_c_add_w2v_v1_flg_donor_val_66',
 'cite_mlp_corr_best_64_flg_donor_val_50',
 '.gitkeep',
 'cite_mlp_corr_last_v3_flg_donor_val_55']

In [8]:
mlp_model_name = [
    'corr_add_con_imp',
    'corr_last_v3', 
    'corr_c_add_w2v_v1_mish_flg',
    'corr_c_add_w2v_v1_flg',
    'corr_c_add_84_v1',
    'corr_c_add_120_v1',
    'corr_w2v_cell_flg',
    'corr_best_cell_120',
    'corr_cluster_cell',
    'corr_w2v_128',
    'corr_imp_w2v_128',
    'corr_snorm',
    'corr_best_128',
    'corr_best_64',
    'corr_cluster_128',
    'corr_cluster_64',
    'corr_svd_128',
    'corr_svd_64',
             ]

In [9]:
model_name_list = []

for i in mlp_model_name:
    for num, j in enumerate(os.listdir(cite_mlp_path)):
        if i in j:
            model_name_list.append(j)

len(model_name_list)
model_name_list

['cite_mlp_corr_add_con_imp_flg_donor_val_50',
 'cite_mlp_corr_last_v3_flg_donor_val_55',
 'cite_mlp_corr_c_add_w2v_v1_mish_flg_donor_val_66',
 'cite_mlp_corr_c_add_w2v_v1_flg_donor_val_66',
 'cite_mlp_corr_c_add_84_v1_flg_donor_val_47',
 'cite_mlp_corr_c_add_120_v1_flg_donor_val_63',
 'cite_mlp_corr_w2v_cell_flg_donor_val_51',
 'cite_mlp_corr_best_cell_120_flg_donor_val_51',
 'cite_mlp_corr_cluster_cell_flg_donor_val_64',
 'cite_mlp_corr_w2v_128_flg_donor_val_42',
 'cite_mlp_corr_imp_w2v_128_flg_donor_val_38',
 'cite_mlp_corr_snorm_flg_donor_val_39',
 'cite_mlp_corr_best_128_flg_donor_val_45',
 'cite_mlp_corr_best_64_flg_donor_val_50',
 'cite_mlp_corr_cluster_128_flg_donor_val_51',
 'cite_mlp_corr_cluster_64_flg_donor_val_57',
 'cite_mlp_corr_svd_128_flg_donor_val_30',
 'cite_mlp_corr_svd_64_flg_donor_val_38']

In [10]:
weight = [1, 0.3, 1, 1, 1, 1, 1, 1, 1, 0.8, 0.8, 0.8, 0.8, 0.5, 0.5, 0.5, 1, 1, 2, 2]
weight_sum = np.array(weight).sum()
weight_sum

# dict:            model name          input-feature-cite file:    weight
model_feat_dict = {model_name_list[0]:['X_test_add_con_imp.pickle', 1],
                   model_name_list[1]:['X_test_last_v3.pickle', 0.3],
                   model_name_list[2]:['X_test_c_add_w2v_v1.pickle', 1],
                   model_name_list[3]:['X_test_c_add_w2v_v1.pickle', 1],
                   model_name_list[4]:['X_test_c_add_84_v1.pickle', 1],
                   model_name_list[5]:['X_test_c_add_v1.pickle', 1],
                   
                   model_name_list[6]:['X_test_feature_w2v_cell.pickle', 1],
                   model_name_list[7]:['X_test_best_cell_128_120.pickle', 1],
                   model_name_list[8]:['X_test_cluster_cell_128.pickle', 1],
                   
                   model_name_list[9]:['X_test_feature_w2v.pickle', 0.8],
                   model_name_list[10]:['X_test_feature_imp_w2v.pickle',0.8],
                   model_name_list[11]:['X_test_feature_snorm.pickle', 0.8],
                   model_name_list[12]:['X_test_best_128.pickle', 0.8],
                   model_name_list[13]:['X_test_best_64.pickle', 0.5],
                   model_name_list[14]:['X_test_cluster_128.pickle', 0.5],
                   model_name_list[15]:['X_test_cluster_64.pickle', 0.5],
                   model_name_list[16]:['X_test_svd_128.pickle', 1],
                   model_name_list[17]:['X_test_svd_64.pickle', 1],
                   
                   'best_128':['X_test_best_128.pickle', 2],
                   'best_64':['X_test_best_64.pickle', 2],
                  }

In [36]:
# new
for i in model_name_list:
    #i = 'cite_mlp_corr_snorm_flg_donor_val_39'
    try:
        test_file = model_feat_dict[i][0]
        X_test = pd.read_pickle(cite_feature_path  + test_file)   
        print(cite_feature_path  + test_file)
        print(X_test.shape)
    except Exception as e:
        print(e)
        print('UnpicklingError: ', i)

/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/open-problems-multimodal-3rd-solution/input/features/cite/X_test_add_con_imp.pickle
(48203, 925)
/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/open-problems-multimodal-3rd-solution/input/features/cite/X_test_last_v3.pickle
(48203, 843)
/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/open-problems-multimodal-3rd-solution/input/features/cite/X_test_c_add_w2v_v1.pickle
(48203, 843)
/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/open-problems-multimodal-3rd-solution/input/features/cite/X_test_c_add_w2v_v1.pickle
(48203, 843)
/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/open-problems-multimodal-3rd-solution/input/features/cite/X_test_c_add_84_v1.pickle
(48203, 667)
/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/open-problems-multimodal-3rd-solution/input/features/cite/X_test_c_add_v1.pickle
(48203, 811)
/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/open-problems-multimodal

### cite model

In [12]:
def std(x):
    x = np.array(x)
    return (x - x.mean(1).reshape(-1, 1)) / x.std(1).reshape(-1, 1)

In [13]:
class CiteDataset(Dataset):
    
    def __init__(self, feature, target):
        
        self.feature = feature
        self.target = target
        
    def __len__(self):
        return len(self.feature)
    
    def __getitem__(self, index):
                
        d = {
            "X": self.feature[index],
            "y" : self.target[index],
        }
        return d

In [14]:
class CiteDataset_test(Dataset):
    
    def __init__(self, feature):
        self.feature = feature
        
    def __len__(self):
        return len(self.feature)
    
    def __getitem__(self, index):
                
        d = {
            "X": self.feature[index]
        }
        return d

In [15]:
def partial_correlation_score_torch_faster(y_true, y_pred):
    """Compute the correlation between each rows of the y_true and y_pred tensors.
    Compatible with backpropagation.
    """
    y_true_centered = y_true - torch.mean(y_true, dim=1)[:,None]
    y_pred_centered = y_pred - torch.mean(y_pred, dim=1)[:,None]
    cov_tp = torch.sum(y_true_centered*y_pred_centered, dim=1)/(y_true.shape[1]-1)
    var_t = torch.sum(y_true_centered**2, dim=1)/(y_true.shape[1]-1)
    var_p = torch.sum(y_pred_centered**2, dim=1)/(y_true.shape[1]-1)
    return cov_tp/torch.sqrt(var_t*var_p)

def correl_loss(pred, tgt):
    """Loss for directly optimizing the correlation.
    """
    return -torch.mean(partial_correlation_score_torch_faster(tgt, pred))

In [37]:
class CiteModel(nn.Module):
    
    def __init__(self, feature_num):
        super(CiteModel, self).__init__()
        
        self.layer_seq_256 = nn.Sequential(nn.Linear(feature_num, 256),
                                           nn.Linear(256, 128),
                                       nn.LayerNorm(128),
                                       nn.ReLU(),
                                      )
        self.layer_seq_64 = nn.Sequential(nn.Linear(128, 64),
                                       nn.Linear(64, 32),
                                       nn.LayerNorm(32),
                                       nn.ReLU(),
                                      )
        self.layer_seq_8 = nn.Sequential(nn.Linear(32, 16),
                                         nn.Linear(16, 8),
                                       nn.LayerNorm(8),
                                       nn.ReLU(),
                                      )
        
        self.head = nn.Linear(128 + 32 + 8, 140)
                   
    def forward(self, X, y=None):
        
        ##
#         if isinstance(X, np.ndarray):
#             X = torch.from_numpy(X)
#         X = X.to(device)  # Move the input to the appropriate device if necessary
        ##
    
        X_256 = self.layer_seq_256(X)
        X_64 = self.layer_seq_64(X_256)
        X_8 = self.layer_seq_8(X_64)
        
        X = torch.cat([X_256, X_64, X_8], axis = 1)
        out = self.head(X)
        
        return out

In [17]:
class CiteModel_mish(nn.Module):
    
    def __init__(self, feature_num):
        super(CiteModel_mish, self).__init__()
        
        self.layer_seq_256 = nn.Sequential(nn.Linear(feature_num, 256),
                                           nn.Linear(256, 128),
                                       nn.LayerNorm(128),
                                       nn.Mish(),
                                      )
        self.layer_seq_64 = nn.Sequential(nn.Linear(128, 64),
                                       nn.Linear(64, 32),
                                       nn.LayerNorm(32),
                                       nn.Mish(),
                                      )
        self.layer_seq_8 = nn.Sequential(nn.Linear(32, 16),
                                         nn.Linear(16, 8),
                                       nn.LayerNorm(8),
                                       nn.Mish(),
                                      )
        
        self.head = nn.Linear(128 + 32 + 8, 140)
                   
    def forward(self, X, y=None):
    
        X_256 = self.layer_seq_256(X)
        X_64 = self.layer_seq_64(X_256)
        X_8 = self.layer_seq_8(X_64)
        
        X = torch.cat([X_256, X_64, X_8], axis = 1)
        out = self.head(X)
        
        return out

In [18]:
def train_loop(model, optimizer, loader, epoch):
    
    losses, lrs = [], []
    model.train()
    optimizer.zero_grad()
    #loss_fn = nn.MSELoss()
    
    with tqdm(total=len(loader),unit="batch") as pbar:
        pbar.set_description(f"Epoch{epoch}")
        
        for d in loader:
            X = d['X'].to(device)
            y = d['y'].to(device)
            
            logits = model(X)
            loss = correl_loss(logits, y)
            #loss = torch.sqrt(loss_fn(logits, y))
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            pbar.set_postfix({"loss":loss.item()})
            pbar.update(1)

    return model

In [19]:
def valid_loop(model, loader, y_val):
    
    model.eval()
    partial_correlation_scores = []
    oof_pred = []
    
    for d in loader:
        with torch.no_grad():
            val_X = d['X'].to(device).float()
            val_y = d['y'].to(device)
            logits = model(val_X)
            oof_pred.append(logits)
    
    #print(torch.cat(oof_pred).shape, torch.cat(oof_pred).detach().cpu().numpy().shape)
    cor = partial_correlation_score_torch_faster(torch.tensor(y_val).to(device), torch.cat(oof_pred))
    cor = cor.mean().item()
    logits = torch.cat(oof_pred).detach().cpu().numpy()
    
    return logits, cor

In [20]:
def test_loop(model, loader):
    
    model.eval()
    predicts=[]

    for d in tqdm(loader):
        with torch.no_grad():
            X = d['X'].to(device)
            logits = model(X)
            predicts.append(logits.detach().cpu().numpy())
            
    return np.concatenate(predicts)

### pred

In [38]:
pred = np.zeros([48203, 140])

for num, i in enumerate(model_feat_dict.keys()):
    
    #print(i)
    
    if 'mlp' in i:

        try:
            test_file = model_feat_dict[i][0]
            test_weight = model_feat_dict[i][1]
            X_test = pd.read_pickle(cite_feature_path  + test_file)   
            # print(cite_feature_path  + test_file)
            X_test = np.array(X_test)
            feature_dims = X_test.shape[1]

            test_ds = CiteDataset_test(X_test)
            test_dataloader = DataLoader(test_ds, batch_size=128, pin_memory=True, 
                                         shuffle=False, drop_last=False, num_workers=4)

            if 'mish' in i:
                model = CiteModel_mish(feature_dims)
            else:
                model = CiteModel(feature_dims)

            model = model.to(device)
            model.load_state_dict(torch.load(f'{cite_mlp_path}/{i}'))

            result = test_loop(model, test_dataloader).astype(np.float32)
            result = std(result) * test_weight / weight_sum
            pred += result

            torch.cuda.empty_cache()

        except Exception as e: 
            print(i)
            print(e)             # TODOOOOOOOOOOOOOO
        
    else:
        test_file = model_feat_dict[i][0]
        test_weight = model_feat_dict[i][1]
        X_test = pd.read_pickle(cite_feature_path  + test_file)
        
        cb_pred = np.zeros([48203, 140])
        
        for t in tqdm(range(140)): 
            cb_model_path = [j for j in os.listdir(cite_cb_path) if f'cb_{t}_{i}' in j][0]
            cb = pickle.load(open(cite_cb_path + cb_model_path, 'rb'))
            cb_pred[:,t] = cb.predict(X_test)
            
        cb_pred = cb_pred.astype(np.float32)
        pred += std(cb_pred) * test_weight / weight_sum
        
        #del cb_pred

cite_mlp_corr_add_con_imp_flg_donor_val_50
Attempting to deserialize object on CUDA device 0 but torch.cuda.device_count() is 0. Please use torch.load with map_location to map your storages to an existing device.
cite_mlp_corr_last_v3_flg_donor_val_55
Attempting to deserialize object on CUDA device 0 but torch.cuda.device_count() is 0. Please use torch.load with map_location to map your storages to an existing device.
cite_mlp_corr_c_add_w2v_v1_mish_flg_donor_val_66
Attempting to deserialize object on CUDA device 0 but torch.cuda.device_count() is 0. Please use torch.load with map_location to map your storages to an existing device.
cite_mlp_corr_c_add_w2v_v1_flg_donor_val_66
Attempting to deserialize object on CUDA device 0 but torch.cuda.device_count() is 0. Please use torch.load with map_location to map your storages to an existing device.
cite_mlp_corr_c_add_84_v1_flg_donor_val_47
Attempting to deserialize object on CUDA device 0 but torch.cuda.device_count() is 0. Please use torch

  0%|          | 0/140 [00:00<?, ?it/s]

  0%|          | 0/140 [00:00<?, ?it/s]

In [22]:
cite_sub = pd.DataFrame(pred.round(6))
cite_sub

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139
0,-0.076877,-0.107536,-0.081412,0.235340,0.212383,0.508778,0.714871,-0.111080,-0.094659,-0.073315,-0.116281,-0.108862,-0.097573,-0.104481,0.483716,-0.079324,0.426531,0.302832,0.053903,-0.107653,-0.082257,0.090119,-0.112168,-0.095676,0.687361,-0.113379,-0.115403,-0.108167,-0.117696,-0.108534,-0.111692,-0.104854,-0.105364,-0.105601,-0.103499,-0.111801,-0.088663,0.869081,-0.095526,-0.093547,-0.110812,-0.105767,-0.112340,-0.092821,-0.116772,-0.105450,-0.081068,-0.082683,0.610800,-0.109826,-0.104239,-0.105984,-0.094111,-0.115045,-0.020432,0.026682,-0.105935,0.272084,-0.105307,-0.106435,-0.111940,-0.097159,-0.089999,-0.103777,-0.112325,-0.095288,-0.079516,-0.093184,0.072477,-0.109951,-0.110628,-0.118938,-0.107578,0.769669,-0.110701,0.448382,-0.105106,0.076690,-0.108419,-0.109099,0.042020,-0.113356,-0.060216,-0.099415,-0.112301,-0.113646,-0.097630,-0.107329,-0.081830,-0.101124,-0.079099,-0.110585,-0.103025,-0.086625,0.053080,-0.096740,-0.099395,0.420336,-0.103349,0.363882,0.214918,-0.115309,0.262475,-0.095930,0.022537,-0.092371,0.312263,-0.070511,0.484739,-0.072997,-0.068368,0.107584,-0.087482,-0.107399,-0.112273,0.073439,-0.038099,-0.095309,-0.102715,0.572418,-0.099416,0.070323,-0.106917,-0.111302,-0.094307,-0.092461,-0.110055,-0.072320,-0.101630,-0.110362,-0.113396,0.400050,-0.090653,-0.115911,-0.114397,-0.078690,0.011340,-0.056059,0.012302,0.078415
1,-0.070980,-0.105483,-0.078728,0.238921,0.227432,0.494300,0.868887,-0.106807,-0.090079,-0.077481,-0.116207,-0.105719,-0.089185,-0.101150,0.476003,-0.084783,0.395964,0.322830,0.022176,-0.101084,-0.087909,0.063508,-0.109683,-0.085107,0.715892,-0.106425,-0.111527,-0.107168,-0.115465,-0.103104,-0.106192,-0.099530,-0.101643,-0.100825,-0.101097,-0.109965,-0.090401,0.831216,-0.089985,-0.093495,-0.107225,-0.103154,-0.108755,-0.103866,-0.113471,-0.102514,-0.077068,-0.083564,0.589613,-0.106611,-0.099559,-0.103316,-0.106078,-0.110636,-0.025642,0.042071,-0.102003,0.180827,-0.103470,-0.103296,-0.104360,-0.095484,-0.088204,-0.098695,-0.108114,-0.093327,-0.078601,-0.091153,0.080918,-0.108536,-0.103518,-0.112984,-0.104845,0.734899,-0.107256,0.403173,-0.101109,0.039439,-0.103329,-0.103254,0.018683,-0.110621,-0.056592,-0.095719,-0.110038,-0.108561,-0.095397,-0.102081,-0.078001,-0.104646,-0.079995,-0.105553,-0.101321,-0.083420,0.043079,-0.092405,-0.097549,0.416196,-0.102180,0.356702,0.219733,-0.110612,0.246460,-0.093005,-0.011302,-0.092891,0.270515,-0.075828,0.516748,-0.078225,-0.073941,0.042847,-0.080481,-0.097804,-0.106804,0.024413,-0.055215,-0.090355,-0.100601,0.594136,-0.099683,0.170995,-0.104191,-0.107329,-0.089034,-0.090868,-0.105538,-0.057181,-0.096199,-0.107437,-0.111155,0.405383,-0.086804,-0.110228,-0.110396,-0.084210,-0.006720,-0.076285,0.012299,0.078346
2,-0.089150,-0.101540,-0.076435,0.338878,0.290691,0.323657,1.014777,-0.105268,-0.083472,-0.092418,-0.112782,-0.102729,-0.092453,-0.097368,0.790603,-0.063658,0.255258,0.259364,-0.007930,-0.099144,-0.079277,0.090837,-0.107481,-0.086070,0.681157,-0.103144,-0.109113,-0.100578,-0.122811,-0.096156,-0.102624,-0.092398,-0.097464,-0.097798,-0.095754,-0.104699,-0.085765,0.723438,-0.087787,-0.088923,-0.104203,-0.101058,-0.102143,-0.045539,-0.107397,-0.097045,-0.074226,-0.085142,0.200093,-0.104822,-0.092768,-0.099995,-0.080043,-0.107574,-0.019395,0.032191,-0.101605,0.177870,-0.100262,-0.098974,-0.051951,-0.086125,-0.083883,-0.091763,-0.102663,-0.089455,-0.072468,-0.088659,0.086627,-0.104616,-0.094370,-0.107391,-0.101939,0.314863,-0.100397,0.283052,-0.099730,0.002265,-0.099223,-0.102801,-0.019462,-0.103728,-0.036343,-0.091792,-0.109485,-0.104571,-0.091584,-0.099345,-0.069711,-0.100305,-0.074391,-0.102350,-0.090838,-0.080520,0.039465,-0.090182,-0.092018,0.526895,-0.098591,0.248171,0.198832,-0.108625,0.113670,-0.087849,0.206987,-0.087614,0.187265,-0.094517,0.585659,-0.078287,-0.062720,-0.034883,-0.078428,-0.089400,-0.101648,-0.047284,-0.077511,-0.092484,-0.060043,0.816152,-0.095282,0.000230,-0.100147,-0.103879,-0.089214,-0.085673,-0.102612,-0.048358,-0.092067,-0.100856,-0.105592,0.657454,-0.085529,-0.106070,-0.105622,-0.088992,0.080794,-0.068035,0.034081,0.154974
3,-0.128470,-0.113081,-0.063264,0.226137,0.264680,0.215720,-0.138097,-0.083314,-0.076588,-0.114464,-0.140753,-0.099502,-0.109232,-0.089628,0.474046,-0.116054,0.243485,-0.126989,-0.007528,-0.024485,-0.085576,0.149265,-0.118302,-0.085031,0.411208,-0.119124,-0.130412,-0.142772,-0.121572,-0.107517,-0.112074,-0.087340,-0.113672,-0.113528,-0.107757,-0.124679,-0.049125,0.723409,-0.068823,-0.082963,-0.110676,-0.088543,-0.105773,0.603149,-0.146289,-0.110784,-0.024087,-0.115214,-0.053530,-0.118277,-0.074823,-0.102637,-0.066636,-0.129014,0.016024,-0.050958,-0.101564,0.207655,-0.116863,-0.108798,-0.131065,-0.084988,-0.066779,-0.093403,-0.129859,-0.053791,0.004140,-0.051470,0.141465,-0.130556,-0.115734,-0.124357,-0.109641,0.041874,-0.134306,0.284571,-0.099015,-0.003217,-0.113088,-0.116721,-0.081502,-0.142544,-0.060783,-0.077973,-0.138675,-0.134785,-0.087301,-0.119694,0.048126,0.725502,-0.036445,-0.120802,-0.100612,-0.033217,-0.031502,-0.067413,-0.078095,0.485416,-0.113209,0.193753,0.126609,-0.125333,-0.011117,-0.053230,0.063069,-0.085200,0.248008,-0.113879,0.969142,0.020769,0.068446,0.239143,-0.088375,0.680371,-0.107285,-0.086793,0.050469,-0.114316,-0.122811,0.733466,-0.088834,0.273690,-0.114222,-0.116357,-0.060485,-0.080096,-0.114231,-0.073862,-0.098380,-0.112004,-0.122159,0.455277,-0.047073,-0.132610,-0.127592,-0.069297,0.386599,-0.114666,0.312522,0.101726
4,-0.098286,-0.099954,-0.073080,0.293693,0.327439,0.424508,0.117842,-0.105193,-0.083176,-0.068864,-0.113596,-0.103659,-0.095560,-0.098380,0.289770,-0.070396,0.368481,0.002567,0.077985,-0.103401,-0.072822,0.125669,-0.108211,-0.097858,0.873365,-0.102820,-0.112430,-0.101254,-0.118073,-0.091733,-0.099887,-0.093579,-0.095260,-0.098847,-0.094704,-0.104830,-0.088485,1.035236,-0.090230,-0.083635,-0.105453,-0.095568,-0.109029,-0.090898,-0.112438,-0.095990,-0.077846,-0.080340,0.385987,-0.105133,-0.090610,-0.102524,-0.080226,-0.111710,-0.022006,-0.024104,-0.102407,0.310267,-0.103706,-0.097704,-0.109111,-0.073828,-0.084230,-0.091237,-0.097104,-0.090127,-0.077063,-0.077250,0.059001,-0.101473,-0.088599,-0.093482,-0.102430,0.404472,-0.103217,0.879171,-0.100458,-0.044119,-0.103167,-0.109070,0.020951,-0.111249,-0.067570,-0.090442,-0.109529,-0.108943,-0.086852,-0.103018,-0.058942,-0.104964,-0.078793,-0.101075,-0.082314,-0.084838,0.104553,-0.101992,-0.090762,0.475220,-0.095363,0.439697,0.183189,-0.110480,0.095128,-0.084486,0.038840,-0.091652,0.352793,-0.079545,0.764845,-0.074370,-0.059496,0.128804,-0.092744,-0.111269,-0.103887,0.028374,-0.010592,-0.092135,-0.105753,0.483382,-0.094944,0.286158,-0.101178,-0.104258,-0.088571,-0.088850,-0.100128,-0.078800,-0.092968,-0.103294,-0.102136,0.284629,-0.084118,-0.102261,-0.112081,-0.079703,0.077863,-0.093023,0.055629,0.045722
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
48198,-0.016057,-0.107369,-0.078000,0.223535,0.230541,0.471154,0.740946,-0.114454,-0.088883,-0.078458,-0.118445,-0.109615,-0.099668,-0.103614,0.533335,-0.093516,0.431017,0.300968,0.044993,-0.107427,-0.092964,0.076197,-0.114851,-0.082166,0.715034,-0.115262,-0.116341,-0.106765,-0.118788,-0.100762,-0.107477,-0.102802,-0.103562,-0.101791,-0.103151,-0.109781,-0.090268,0.855044,-0.091621,-0.101769,-0.113936,-0.106684,-0.115684,-0.091435,-0.115706,-0.101562,-0.082364,-0.095820,0.549308,-0.112390,-0.098033,-0.107053,-0.089752,-0.117442,-0.039271,0.055210,-0.108676,0.262131,-0.110421,-0.104865,-0.102085,-0.091681,-0.088784,-0.098899,-0.115326,-0.098828,-0.082589,-0.082951,0.076279,-0.111164,-0.099849,-0.107379,-0.109046,0.833310,-0.108794,0.509993,-0.103918,0.117955,-0.107636,-0.108558,0.039579,-0.116049,-0.053659,-0.097434,-0.113071,-0.109205,-0.100199,-0.107072,-0.084110,-0.101162,-0.082161,-0.105851,-0.094397,-0.088448,0.050215,-0.102431,-0.099713,0.454839,-0.104848,0.262089,0.226668,-0.117845,0.250594,-0.096402,-0.029743,-0.098394,0.337215,-0.099089,0.314727,-0.078626,-0.063640,0.073811,-0.086328,-0.085208,-0.112829,0.107655,-0.084881,-0.100405,-0.109496,0.536659,-0.103923,0.127736,-0.109245,-0.112429,-0.094747,-0.092203,-0.108622,-0.075437,-0.098587,-0.096890,-0.110126,0.421573,-0.093679,-0.110951,-0.117332,-0.088202,0.029138,-0.020140,-0.005769,0.105734
48199,-0.061787,-0.099345,-0.079985,0.165232,0.196600,0.482482,0.507181,-0.118028,-0.091818,-0.062635,-0.120118,-0.105835,-0.095377,-0.101149,0.428801,-0.083683,0.413566,0.389086,0.066568,-0.108848,-0.103644,0.079706,-0.113725,-0.058829,0.653562,-0.110858,-0.116593,-0.107356,-0.121701,-0.100035,-0.102499,-0.102922,-0.102696,-0.104251,-0.099643,-0.110396,-0.093629,0.913962,-0.090032,-0.102661,-0.114622,-0.101328,-0.109310,-0.072710,-0.115118,-0.103532,-0.079873,-0.088293,0.512211,-0.112338,-0.098494,-0.110462,-0.100058,-0.117092,-0.041173,0.023582,-0.102752,0.191229,-0.111168,-0.104915,-0.110314,-0.091319,-0.087720,-0.094337,-0.113691,-0.093773,-0.084424,-0.093221,0.053221,-0.113799,-0.105581,-0.110715,-0.104822,1.113742,-0.111701,0.482917,-0.104887,0.107861,-0.108002,-0.111247,0.043480,-0.114443,-0.066124,-0.095516,-0.111691,-0.106799,-0.097110,-0.102828,-0.063411,-0.093945,-0.084202,-0.108762,-0.101275,-0.090097,0.078232,-0.097348,-0.099030,0.389888,-0.103743,0.324171,0.168776,-0.120113,0.352780,-0.093604,-0.028148,-0.091856,0.380912,-0.089963,0.321652,-0.054262,-0.062731,0.046118,-0.082305,-0.056557,-0.103437,0.089713,-0.065964,-0.100810,-0.108442,0.396370,-0.094585,0.299330,-0.105385,-0.110923,-0.092127,-0.094027,-0.107301,-0.077611,-0.098453,-0.099868,-0.112403,0.334659,-0.088464,-0.111661,-0.116378,-0.087172,0.112799,-0.044621,0.015924,0.092049
48200,-0.085076,-0.019765,-0.040139,0.082385,0.178649,0.263367,-0.078611,0.071203,-0.063911,-0.069864,-0.083776,-0.077414,-0.075413,-0.075839,-0.033571,-0.073414,0.257971,-0.079189,-0.061539,-0.076336,-0.070014,0.390727,-0.065538,-0.059764,-0.016561,-0.076601,-0.081644,-0.079352,-0.080485,-0.075533,-0.071951,-0.073284,-0.072428,-0.076604,-0.069523,-0.081019,-0.064263,0.746180,-0.056343,-0.070689,-0.077235,-0.077561,-0.075413,0.140361,-0.085641,-0.070548,-0.064903,0.045767,-0.101629,-0.078929,-0.072299,-0.074348,-0.065689,-0.082947,-0.043394,-0.024674,-0.071774,-0.056329,-0.076296,-0.073641,-0.080424,-0.048682,-0.066995,-0.070810,-0.077974,-0.074897,-0.060521,-0.063480,0.051425,-0.083099,-0.077015,-0.081671,-0.074626,-0.011034,-0.078993,0.039805,-0.069220,-0.076706,-0.072964,0.145734,-0.064544,0.140592,-0.059370,-0.063342,-0.079359,-0.081457,-0.067831,-0.075096,0.043017,1.179554,-0.063517,-0.077332,-0.067052,-0.047125,-0.027744,-0.052680,0.057587,0.327698,-0.072898,0.466532,0.093638,-0.081686,-0.075682,-0.068885,-0.044305,-0.070817,0.159998,-0.078361,0.000312,-0.061928,-0.075957,0.106712,-0.077459,1.399233,-0.078626,-0.072083,-0.030713,-0.080510,-0.076025,0.024202,-0.066634,0.977552,-0.074948,-0.076043,-0.065239,-0.067072,-0.078083,-0.072537,-0.073559,-0.016842,-0.084288,0.075517,-0.068157,-0.082042,-0.075769,-0.035932,0.073994,-0.071730,-0.055548,0.074910
48201,-0.129128,-0.097133,-0.064378,0.162504,0.062285,0.240771,-0.116835,-0.087210,-0.052549,-0.122759,-0.147196,-0.094668,-0.111958,-0.072857,0.228428,-0.128448,0.152343,-0.079240,-0.124319,-0.094852,-0.056010,0.237674,-0.138501,-0.093899,0.322706,-0.115159,-0.138109,-0.150201,-0.104187,-0.104026,-0.109871,-0.079260,-0.112356,-0.112003,-0.110327,-0.127146,-0.017746,0.389237,-0.046356,-0.084448,-0.112604,-0.068824,-0.104885,0.212656,-0.152767,-0.113220,0.018889,-0.138697,-0.079695,-0.120572,-0.048413,-0.101269,-0.118785,-0.136030,0.026797,-0.033499,-0.096910,0.092198,-0.118307,-0.105043,-0.137010,-0.069529,-0.039909,-0.074379,-0.137791,-0.032580,0.045669,-0.046871,0.210007,-0.141396,-0.114719,-0.128153,-0.103959,0.298170,-0.143355,0.436532,-0.094557,0.259603,-0.113748,-0.127084,-0.126154,-0.177546,-0.090571,-0.067359,-0.145295,-0.140265,-0.075745,-0.120720,0.103899,0.651415,-0.007870,-0.121442,-0.096351,-0.026074,-0.095603,-0.086581,-0.050615,0.532529,-0.112513,0.111921,0.001366,-0.130919,0.052534,-0.034589,0.211631,-0.080331,0.338320,-0.113341,1.001636,0.099866,0.150150,0.050946,-0.063190,0.886158,-0.104874,-0.090285,0.043555,-0.120122,-0.128133,0.803109,-0.086178,0.036741,-0.109458,-0.115373,-0.039178,-0.054202,-0.111994,-0.058026,-0.090350,-0.102356,-0.124435,0.485146,-0.011179,-0.137868,-0.138146,-0.082405,0.565141,-0.111353,0.437580,0.196572


In [23]:
#cite_sub.to_csv('../../../../../summary/output/submit/cite_submit.csv')

In [32]:
# model #16: cite_mlp_corr_svd_128_flg_donor_val_30
pred_16 = np.zeros([48203, 140])

i = 'cite_mlp_corr_svd_128_flg_donor_val_30'
        
test_file = model_feat_dict[i][0]
test_weight = model_feat_dict[i][1]
X_test = pd.read_pickle(cite_feature_path  + test_file)
# columns = X_test.columns   # for SHAP
X_test = np.array(X_test)
feature_dims = X_test.shape[1]

test_ds = CiteDataset_test(X_test)
test_dataloader = DataLoader(test_ds, batch_size=128, pin_memory=True, 
                              shuffle=False, drop_last=False, num_workers=4)

if 'mish' in i:
    model = CiteModel_mish(feature_dims)
else:
    model = CiteModel(feature_dims)
    
model = model.to(device)
model.load_state_dict(torch.load(f'{cite_mlp_path}/{i}'))
#model.load_state_dict(torch.load(f'/dss/dsshome1/02/di93zoj/valentina/open-problems-multimodal-3rd-solution/code/4.model/train/cite/cite_mlp_corr_svd_128_flg_donor_val_28', 
                                #map_location='cuda:0'))  # cuda:0

result = test_loop(model, test_dataloader).astype(np.float32)
# result = std(result) * test_weight / weight_sum  # double check 
pred_16 += result

torch.cuda.empty_cache()
        
pd.DataFrame(pred_16)   # double check train_cite_targets.h5  -> omnipath -> then maybe shap
# TODO SHAP on this

RuntimeError: Attempting to deserialize object on CUDA device 0 but torch.cuda.device_count() is 0. Please use torch.load with map_location to map your storages to an existing device.

In [None]:
pd.read_pickle(cite_feature_path  + test_file)

In [None]:
import shap

# X_train for model #16: 'X_svd_128.pickle'
X_train = pd.read_pickle(cite_feature_path  + 'X_svd_128.pickle')
X_train = np.array(X_train)

# X_train = torch.tensor(X_train.values, dtype=torch.float32)
# X_train = X_train.to('cuda')

# X_test for model #16:
# X_test = torch.from_numpy(X_test)
# X_test = X_test.to('cuda')

# Explainer, KernelExplainer, don't rely on differentiable model
# shap beeswarm -> screenshot

In [None]:
# compute SHAP values
# explainer = shap.DeepExplainer(model, X_train)     # Warning: unrecognized nn.Module: LayerNorm
# shap_values = explainer.shap_values(X_test)
# shap.summary_plot(shap_values[0], plot_type = 'bar', feature_names = columns)

In [None]:
X_train.dtype

In [None]:
# https://stackoverflow.com/questions/70510341/shap-values-with-pytorch-kernelexplainer-vs-deepexplainer
# Get features
train_features_df = ... # pandas dataframe
test_features_df = ... # pandas dataframe

# Define function to wrap model to transform data to tensor
f = lambda x: model( torch.from_numpy(x) ).detach().numpy()   # model_list[0]

# Convert my pandas dataframe to numpy
# data = test_features_df.to_numpy(dtype=np.float32)
data = X_train

# The explainer doesn't like tensors, hence the f function
explainer = shap.KernelExplainer(f, data)

# Get the shap values from my test data
shap_values = explainer.shap_values(data)

In [None]:
# model: no tensor
# X_train: np.array
# X_test: np.array

explainer = shap.KernelExplainer(model, X_train)    #, keep_index=True)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values[0], plot_type = 'bar', feature_names = columns)

In [None]:
# X_tensor = torch.from_numpy(X_train).to('cuda')
# explainer = shap.Explainer(model, X_tensor)
# shap_values = explainer.explain_row(X_tensor, max_evals, main_effects, error_bounds, outputs, silent )  # X_test?
# shap_values

In [None]:
train_cite_targets = pd.read_hdf('/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/neurips_competition_data/train_cite_targets.h5')    # train_cite_targets.h5 - Surface protein levels for the same cells that have been dsb normalized.
print(train_cite_targets.shape)   # more rows; how to match rows -> metadata for cell_id
train_cite_targets  

In [None]:
metadata = pd.read_csv('/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/neurips_competition_data/metadata.csv')
metadata[metadata['cell_id'].isin(['45006fe3e4c8','d02759a80ba2','c016c6b0efa5','ba7f733a4f75','fbcf2443ffb2'])]

In [None]:
# filter dataset by IDs in list
metadata_filtered = metadata[metadata['cell_id'].isin(train_cite_targets.index.values)].set_index('cell_id')
metadata_filtered.head(3)

In [None]:
train_cite_targets = ad.AnnData(train_cite_targets, obs=metadata_filtered)
train_cite_targets

In [None]:
# run after preprocessing below
# cite_sub['cell_id']
pd.DataFrame(cite_sub)[0]

In [None]:
target_cell_ids = train_cite_targets.obs.index.tolist()
target_cell_ids[:10]  # 70988

print(set(pd.DataFrame(cite_sub)[0]).issubset(set(target_cell_ids)))


mask = pd.DataFrame(cite_sub)[0].isin(target_cell_ids)
result = pd.DataFrame(cite_sub).loc[mask]

result

In [None]:
set(test_sub_ids).issubset(set(target_cell_ids))

## Multi

In [None]:
mlp_model_name = [
    'multi_mlp_all_con_16',
    'multi_mlp_all_con_32', 
    'multi_mlp_all_binary_16',
    'multi_mlp_all_last_cluster',
    'multi_mlp_all_lsi_w2v_col_128_flg',
    'multi_mlp_all_lsi_w2v_128_flg',
    'multi_mlp_all_lsi_128_flg',
    'multi_mlp_all_lsi_w2v_col_64_flg',
    'multi_mlp_all_lsi_w2v_64_flg',
    'multi_mlp_all_lsi_64_flg',
    'multi_mlp_all_okapi_128_flg',
    'multi_mlp_all_okapi_64_flg',
    'multi_mlp_all_colmean_64_flg',
    'multi_mlp_corr_con_16_flg',
    'multi_mlp_corr_con_32_flg',
    'multi_mlp_corr_binary_16',
    'multi_mlp_corr_lsi_add_lc_svd_flg',
    
    'multi_mlp_corr_lsi_w2v_col_128_flg',
    'multi_mlp_corr_lsi_w2v_col_64_flg',
    'multi_mlp_corr_lsi_w2v_128_flg',
    'multi_mlp_corr_lsi_w2v_64_flg',
    
    'multi_mlp_corr_lsi_128_flg',
    'multi_mlp_corr_lsi_64_flg',
    
    'multi_mlp_corr_colmean_64_flg',
    'multi_mlp_corr_okapi_w2v_64_flg',
    'multi_mlp_corr_okapi_64_flg',
    
             ]

In [None]:
model_name_list = []

for i in mlp_model_name:
    for num, j in enumerate(os.listdir(multi_mlp_path)):
        if i in j:
            model_name_list.append(j)

print(len(model_name_list))
model_name_list

In [None]:
weight = [2.5, 2.5, 2.5, 1.2, 1.2, 1.2, 1, 
          1.5, 1.5, 2.5, 0.5, 0.5, 0.5, 
          2.5, 2.5, 1.8, 0.8, 1, 0.8, 1 ,0.8, 1, 0.3, 
          0.3, 0.3, 0.3, 0.2, 0.2, 0.2]
weight_sum = np.array(weight).sum()
weight_sum

model_feat_dict = {model_name_list[0]:['multi_test_con_16.pickle', 2.5],
                   model_name_list[1]:['multi_test_con_32.pickle', 2.5],
                   model_name_list[2]:['multi_test_binary_16.pickle', 2.5],
                   
                   model_name_list[3]:['multi_test_okapi_64_last_cluster.pickle', 1.2],
                   model_name_list[4]:['multi_test_lsi_w2v_col_128.pickle', 1.2],
                   model_name_list[5]:['multi_test_lsi_w2v_128.pickle', 1.2],
                   model_name_list[6]:['multi_test_okapi_lsi_128.pickle', 1],
                   
                   model_name_list[7]:['multi_test_lsi_w2v_col_64.pickle', 1.5],
                   model_name_list[8]:['multi_test_lsi_w2v_64.pickle', 1.5],
                   model_name_list[9]:['multi_test_okapi_lsi_64.pickle', 2.5],
                   
                   model_name_list[10]:['multi_test_okapi_feature_128.pickle', 0.5],
                   model_name_list[11]:['multi_test_okapi_feature_64.pickle', 0.5],
                   model_name_list[12]:['multi_test_okapi_w2v_col_64.pickle', 0.5],
                   
                   model_name_list[13]:['multi_test_con_16.pickle', 2.5],
                   model_name_list[14]:['multi_test_con_32.pickle', 2.5],
                   model_name_list[15]:['multi_test_binary_16.pickle', 1.8],
                   model_name_list[16]:['multi_test_lc_addsvd_64.pickle', 0.8],
                   
                   model_name_list[17]:['multi_test_lsi_w2v_col_128.pickle', 1],
                   model_name_list[18]:['multi_test_lsi_w2v_col_64.pickle', 0.8],
                   model_name_list[19]:['multi_test_lsi_w2v_128.pickle', 1],
                   model_name_list[20]:['multi_test_lsi_w2v_64.pickle', 0.8],
                   model_name_list[21]:['multi_test_okapi_lsi_128.pickle', 1],
                   model_name_list[22]:['multi_test_okapi_lsi_64.pickle', 0.3],
                   
                   model_name_list[23]:['multi_test_okapi_w2v_col_64.pickle', 0.3],
                   model_name_list[24]:['multi_test_okapi_w2v_64.pickle', 0.3],
                   model_name_list[25]:['multi_test_okapi_feature_64.pickle', 0.3],
                   
                   'lsi_128':['multi_test_okapi_lsi_128.pickle', 0.2],
                   'lsi_64':['multi_test_okapi_lsi_64.pickle', 0.2],
                   'lsi_w2v_col_64':['multi_test_lsi_w2v_col_64.pickle', 0.2],
                  }


### multi model

In [None]:
class MultiDataset(Dataset):
    
    def __init__(self, feature, target):
        
        self.feature = feature
        self.target = target
        
    def __len__(self):
        return len(self.feature)
    
    def __getitem__(self, index):
                
        d = {
            "X": self.feature[index],
            "y" : self.target[index],
        }
        return d

In [None]:
class MultiDataset_test(Dataset):
    
    def __init__(self, feature):
        self.feature = feature
        
    def __len__(self):
        return len(self.feature)
    
    def __getitem__(self, index):
                
        d = {
            "X": self.feature[index]
        }
        return d

In [None]:
def partial_correlation_score_torch_faster(y_true, y_pred):
    """Compute the correlation between each rows of the y_true and y_pred tensors.
    Compatible with backpropagation.
    """
    y_true_centered = y_true - torch.mean(y_true, dim=1)[:,None]
    y_pred_centered = y_pred - torch.mean(y_pred, dim=1)[:,None]
    cov_tp = torch.sum(y_true_centered*y_pred_centered, dim=1)/(y_true.shape[1]-1)
    var_t = torch.sum(y_true_centered**2, dim=1)/(y_true.shape[1]-1)
    var_p = torch.sum(y_pred_centered**2, dim=1)/(y_true.shape[1]-1)
    return cov_tp/torch.sqrt(var_t*var_p)

def correl_loss(pred, tgt):
    """Loss for directly optimizing the correlation.
    """
    return -torch.mean(partial_correlation_score_torch_faster(tgt, pred))


def correlation_score(y_true, y_pred):
    """Scores the predictions according to the competition rules. 
    
    It is assumed that the predictions are not constant.
    
    Returns the average of each sample's Pearson correlation coefficient"""
    if type(y_true) == pd.DataFrame: y_true = y_true.values
    if type(y_pred) == pd.DataFrame: y_pred = y_pred.values
    if y_true.shape != y_pred.shape: raise ValueError("Shapes are different.")
    corrsum = 0
    for i in range(len(y_true)):
        corrsum += np.corrcoef(y_true[i], y_pred[i])[1, 0]
    return corrsum / len(y_true)

In [None]:
class MultiModel(nn.Module):
    
    def __init__(self, feature_num):
        super(MultiModel, self).__init__()
        
        self.layer_seq_128 = nn.Sequential(nn.Linear(feature_num, 128),
                                           nn.LayerNorm(128),
                                           nn.ReLU(),
                                      )
        
        self.layer_seq_64 = nn.Sequential(nn.Linear(128, 64),
                                           nn.LayerNorm(64),
                                           nn.ReLU(),
                                      )
        
        self.layer_seq_32 = nn.Sequential(nn.Linear(64, 32),
                                   nn.LayerNorm(32),
                                   nn.ReLU(),
                              )
        
        self.layer_seq_8 = nn.Sequential(nn.Linear(32, 8),
                                         nn.LayerNorm(8),
                                         nn.ReLU(),
                                      )
        
        self.head = nn.Linear(128 + 64 + 32 + 8, target_num)
                   
    def forward(self, X, y=None):
        
        X_128 = self.layer_seq_128(X)
        X_64 = self.layer_seq_64(X_128)
        X_32 = self.layer_seq_32(X_64)
        X_8 = self.layer_seq_8(X_32)
        X = torch.cat([X_128, X_64, X_32, X_8], axis = 1)
        out = self.head(X)
        
        return out

In [None]:
def train_loop(model, optimizer, loader, epoch):
    
    losses, lrs = [], []
    model.train()
    optimizer.zero_grad()
    loss_fn = nn.MSELoss()
    
    with tqdm(total=len(loader),unit="batch") as pbar:
        pbar.set_description(f"Epoch{epoch}")
        
        for d in loader:
            X = d['X'].to(device).float()
            y = d['y'].to(device)
            
            logits = model(X)
            #loss = correl_loss(logits, y)
            loss = torch.sqrt(loss_fn(logits, y))
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            pbar.set_postfix({"loss":loss.item()})
            pbar.update(1)

    return model

In [None]:
def valid_loop(model, loader, y_val):
    
    model.eval()
    partial_correlation_scores = []
    oof_pred = []
    loss_fn = nn.MSELoss()
    
    for d in loader:
        with torch.no_grad():
            val_X = d['X'].to(device).float()
            val_y = d['y'].to(device)
            logits = model(val_X)
            #oof_pred.append(logits.detach().cpu().numpy())
            oof_pred.append(logits)
    
    y_val = torch.tensor(y_val).to(device)
    logits = torch.cat(oof_pred)
    #print(logits.shape, y_val.shape)
    loss = torch.sqrt(loss_fn(logits, y_val))
    logits = logits.detach().cpu().numpy()
    
    return logits, loss

In [None]:
def test_loop(model, loader):
    
    model.eval()
    predicts=[]

    for d in tqdm(loader):
        with torch.no_grad():
            X = d['X'].to(device).float()
            logits = model(X)
            predicts.append(logits.detach().cpu().numpy())
            
    return np.concatenate(predicts)

In [None]:
pred = np.zeros([55935, 23418])
svd = pickle.load(open(multi_target_path + 'multi_all_target_128.pkl', 'rb'))

for num, i in enumerate(model_feat_dict.keys()):
    
    print(i)
    
    if 'mlp' in i:
        
        test_file = model_feat_dict[i][0]
        test_weight = model_feat_dict[i][1]
        X_test = pd.read_pickle(multi_feature_path  + test_file)    
        X_test = np.array(X_test)
        feature_dims = X_test.shape[1]

        test_ds = MultiDataset_test(X_test)
        test_dataloader = DataLoader(test_ds, batch_size=128, pin_memory=True, 
                                     shuffle=False, drop_last=False, num_workers=4)
        
        if 'all' in i:
            target_num = 23418
        else:
            target_num = 128
        
        model = MultiModel(feature_dims)    
        model = model.to(device)
        model.load_state_dict(torch.load(f'{multi_mlp_path}/{i}'))
        
        result = test_loop(model, test_dataloader).astype(np.float32)
        
        if 'all' not in i:
            result = result@svd.components_
                
        result = result * test_weight / weight_sum
        pred += result

        torch.cuda.empty_cache()
        
    else:
        test_file = model_feat_dict[i][0]
        test_weight = model_feat_dict[i][1]
        X_test = pd.read_pickle(multi_feature_path  + test_file)
        
        cb_pred = np.zeros([55935, 128])
        
        for t in tqdm(range(128)): 
            cb_model_path = [j for j in os.listdir(multi_cb_path) if f'cb_{t}_{i}' in j][0]
            cb = pickle.load(open(multi_cb_path + cb_model_path, 'rb'))
            cb_pred[:,t] = cb.predict(X_test)
            
        cb_pred = cb_pred.astype(np.float32)
        cb_pred = cb_pred@svd.components_
        pred += cb_pred * test_weight / weight_sum
        
        #del cb_pred

In [None]:
multi_sub = pd.DataFrame(pred.round(6)).astype(np.float32)

In [None]:
del pred
gc.collect()

## Postprocess

In [None]:
preprocess_path = lrz_path + 'input/preprocess/'

#### first: fix cite output

In [None]:
test_sub_ids = np.load(preprocess_path + "test_cite_inputs_idxcol.npz", allow_pickle=True)
test_sub_ids = test_sub_ids["index"]
test_raw_ids = np.load(preprocess_path + "test_cite_raw_inputs_idxcol.npz", allow_pickle=True)
test_raw_ids = test_raw_ids["index"]
test_raw_ids

In [None]:
len(test_raw_ids)
cite_sub.shape

In [None]:
test_cite_df = pd.DataFrame(test_sub_ids, columns = ['cell_id'])
cite_sub['cell_id'] = test_raw_ids# .tolist()
test_cite_df = test_cite_df.merge(cite_sub, on = 'cell_id', how = 'left')
test_cite_df.fillna(0, inplace = True)
# test_cite_df.drop(['cell_id'], axis = 1, inplace = True)

cite_sub = test_cite_df.copy()


In [None]:
cite_sub

### preprocess

In [None]:
sub = pd.read_csv('/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/neurips_competition_data/sample_submission.csv')  
eval_ids = pd.read_csv('/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93zoj/neurips_competition_data/evaluation_ids.csv') 

cite_cols = pd.read_csv(preprocess_path + "cite_test_cols.csv") 
cite_index = pd.read_csv(preprocess_path + "cite_test_indexs.csv") 
cite_index.columns = ['cell_id']

# multi_cols = pd.read_csv(preprocess_path + "multi/multi_test_cols.csv") 
# multi_index = pd.read_csv(preprocess_path + "multi/multi_test_indexs.csv") 
# multi_index.columns = ['cell_id']

submission = pd.Series(name='target',index=pd.MultiIndex.from_frame(eval_ids), dtype=np.float32)


In [None]:
submission

### multi

In [None]:
multi_sub = np.array(multi_sub)

In [None]:
cell_dict = dict((k,v) for v,k in enumerate(np.array(multi_index['cell_id'])))
assert len(cell_dict)  == len(multi_index['cell_id'])

gene_dict = dict((k,v) for v,k in enumerate(np.array(multi_cols['gene_id']))) 
assert len(gene_dict)  == len(multi_cols['gene_id'])

eval_ids_cell_num = eval_ids.cell_id.apply(lambda x:cell_dict.get(x, -1))
eval_ids_gene_num = eval_ids.gene_id.apply(lambda x:gene_dict.get(x, -1))

valid_multi_rows = (eval_ids_gene_num !=-1) & (eval_ids_cell_num!=-1)
submission.iloc[valid_multi_rows] = multi_sub[eval_ids_cell_num[valid_multi_rows].to_numpy(),
                                                 eval_ids_gene_num[valid_multi_rows].to_numpy()]

### cite

In [None]:
cite_sub = np.array(cite_sub)
cite_sub

In [None]:
cell_dict = dict((k,v) for v,k in enumerate(np.array(cite_index['cell_id'])))
assert len(cell_dict)  == len(cite_index['cell_id'])

gene_dict = dict((k,v) for v,k in enumerate(np.array(cite_cols['gene_id']))) 
assert len(gene_dict)  == len(cite_cols['gene_id'])

eval_ids_cell_num = eval_ids.cell_id.apply(lambda x:cell_dict.get(x, -1))
eval_ids_gene_num = eval_ids.gene_id.apply(lambda x:gene_dict.get(x, -1))

valid_multi_rows = (eval_ids_gene_num !=-1) & (eval_ids_cell_num!=-1)

In [None]:
submission.iloc[valid_multi_rows] = cite_sub[eval_ids_cell_num[valid_multi_rows].to_numpy(),
                                                 eval_ids_gene_num[valid_multi_rows].to_numpy()]

### make submission

In [None]:
# submission = submission.round(6)
submission = pd.DataFrame(submission, columns = ['target'])
submission = submission.reset_index()

In [None]:
submission[['row_id', 'target']]#.dropna()        #.to_csv(output_path + 'submission.csv', index = False)

In [None]:
#!kaggle competitions submit -c open-problems-multimodal -f $sub_name_csv -m $message