In [1]:
!pip install -qU datasets transformers sentence-transformers git+https://github.com/naver/splade.git
!pip install einops

Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1


In [2]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
from splade.models.transformer_rep import Splade
from transformers import AutoTokenizer

sparse_model_id = 'naver/splade-cocondenser-selfdistil'

sparse_model = Splade(sparse_model_id, agg='max')
sparse_model.to(device)  # move to GPU if possible
sparse_model.eval()

Downloading:   0%|          | 0.00/670 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/418M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [4]:
tokenizer = AutoTokenizer.from_pretrained(sparse_model_id)

In [5]:
def process_row1(row):
    text_data = row['Title']
    # print(text_data)
    input_ids = tokenizer(
        text_data, return_tensors='pt',
        padding=True, truncation=True
    )

    with torch.no_grad():
        text_embed = sparse_model(
            d_kwargs=input_ids.to(device)
        )['d_rep'].squeeze()
    return text_embed.cpu().detach().numpy()

def process_row2(row):
    text_data = row['assoc_cond']
    # print(text_data)
    input_ids = tokenizer(
        text_data, return_tensors='pt',
        padding=True, truncation=True
    )

    with torch.no_grad():
        text_embed = sparse_model(
            d_kwargs=input_ids.to(device)
        )['d_rep'].squeeze()
    return text_embed.cpu().detach().numpy()

In [6]:
from scipy.spatial import distance
import numpy as np

def score(a, b):
  return distance.cosine(np.array(a),np.array(b))



In [7]:
def find_min_match(row):
    row_score = icd['embed'].apply(lambda x: score(row['embed'], x))
    min_score_index = row_score.idxmin()
    return icd.at[min_score_index, 'Code']

# assoc_cond['ICDCode'] = icd.loc[scores.idxmin(), 'Code']
# assoc_cond['ICDCode'] = icd['Code'].iloc[scores.idxmin(axis=1)]

In [8]:
import pandas as pd

icd = pd.read_excel('data/diseases/ICD11/simpletabulation.xlsx')
assoc_cond = pd.read_csv('parser/temp/DBID_AssocCondn.csv')

In [9]:
icd = icd[icd['ClassKind'] == 'category']
# icd2 = icd[icd['isLeaf'] == 'True']
icd['Title'] = icd['Title'].str.replace('-','')
icd = icd[['Code','Title']]
icd.reset_index(drop=True)

Unnamed: 0,Code,Title
0,1A00,Cholera
1,1A01,Intestinal infection due to other Vibrio
2,1A02,Intestinal infections due to Shigella
3,1A03,Intestinal infections due to Escherichia coli
4,1A03.0,Enteropathogenic Escherichia coli infection
...,...,...
34074,XD36Q1,"Infusion Pumps, Syringe"
34075,XD1N14,"Infusion Pumps, Syringe, Nuclear Magnetic ..."
34076,XD80Z7,Medical/medicinal gas systems and relative ...
34077,XD4U38,General purpose electrocardiographs


In [10]:
icd['embed'] = icd.apply(process_row1, axis=1)

In [19]:
# assoc_cond = assoc_cond[7000:]
assoc_cond = assoc_cond.drop_duplicates(subset='assoc_cond', keep='first')
assoc_cond['ICDCode'] = ''
assoc_cond.reset_index(drop=True)

Unnamed: 0,drugbank-id,assoc_cond,ICDCode
0,DB01598,Lower respiratory tract infection bacterial,
1,DB00537,Lower respiratory tract infection caused by En...,
2,DB00537,Lower respiratory tract infection caused by Es...,
3,DB00537,Lower respiratory tract infection caused by Ha...,
4,DB00537,Lower respiratory tract infection caused by Ha...,
...,...,...,...
2661,DB00618,Yaws,
2662,DB10805,Yellow Fever,
2663,DB02659,Zellweger Spectrum Disorder,
2664,DB01593,Zinc Deficiency,


In [20]:
assoc_cond['embed'] = assoc_cond.apply(process_row2, axis=1)

In [21]:
from tqdm import tqdm
tqdm.pandas()

In [None]:
assoc_cond['ICDCode'] = assoc_cond.progress_apply(find_min_match, axis=1)

  9%|▉         | 244/2666 [24:37<4:06:34,  6.11s/it]

In [None]:
assoc_org = pd.read_csv('parser/temp/DBID_AssocCondn.csv')

In [None]:
merged_df = assoc_org.merge(assoc_cond, on='assoc_cond', how='left')
merged_df = merged_df.drop(columns=['drugbank-id_y','embed']).rename(columns={'drugbank-id_x': 'drugbank-id'})

In [None]:
merged_df.to_excel('data/diseases/DB_to_ICD11/DB_to_ICD.xlsx', index=False)