# 1. Data Exploration

This notebook explores the PrimeKG and PINNACLE datasets.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
sys.path.insert(0, '..')

## 1.1 PrimeKG Node Statistics

In [None]:
node_mapping = pd.read_csv('../data/processed/node_mapping.csv')
print(f'Total nodes: {len(node_mapping)}')
print('\nNode type distribution:')
print(node_mapping['node_type'].value_counts())

In [None]:
fig, ax = plt.subplots(figsize=(8, 5))
node_mapping['node_type'].value_counts().plot(kind='bar', ax=ax)
ax.set_title('Node Type Distribution')
ax.set_ylabel('Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

## 1.2 PrimeKG Edge Statistics

In [None]:
edge_files = [
    'edges_drug_gene.csv',
    'edges_gene_disease.csv',
    'edges_drug_disease_gold.csv',
    'edges_ppi_general.csv'
]

for f in edge_files:
    path = f'../data/processed/{f}'
    if os.path.exists(path):
        df = pd.read_csv(path)
        print(f'{f}: {len(df)} edges')

## 1.3 PINNACLE Cell Types

In [None]:
ppi_dir = '../data/raw/pinnacle/networks/ppi_edgelists'
if os.path.exists(ppi_dir):
    cell_types = [f for f in os.listdir(ppi_dir) if f.endswith('.txt')]
    print(f'Available cell types: {len(cell_types)}')
    print('\nSample cell types:')
    for ct in cell_types[:10]:
        print(f'  - {ct}')

In [None]:
# Edge count per cell type
edge_counts = {}
for ct in cell_types[:20]:
    with open(os.path.join(ppi_dir, ct)) as f:
        edge_counts[ct.replace('.txt', '')] = sum(1 for _ in f)

plt.figure(figsize=(12, 5))
plt.bar(edge_counts.keys(), edge_counts.values())
plt.xticks(rotation=90)
plt.ylabel('Number of edges')
plt.title('PPI Edges per Cell Type')
plt.tight_layout()
plt.show()

## 1.4 PINNACLE Embeddings

In [None]:
import torch

embed_path = '../data/raw/pinnacle/pinnacle_embeds/pinnacle_protein_embed.pth'
if os.path.exists(embed_path):
    embeds = torch.load(embed_path, weights_only=False)
    print(f'Number of cell type embeddings: {len(embeds)}')
    
    first_key = list(embeds.keys())[0]
    print(f'Embedding shape for cell {first_key}: {embeds[first_key].shape}')

## 1.5 Drug-Disease Label Analysis

In [None]:
dd_df = pd.read_csv('../data/processed/edges_drug_disease_gold.csv')
print(f'Total drug-disease associations: {len(dd_df)}')
print(f'Unique drugs: {dd_df["x_id"].nunique()}')
print(f'Unique diseases: {dd_df["y_id"].nunique()}')

In [None]:
# Diseases with most drug associations
disease_counts = dd_df.groupby('y_name').size().sort_values(ascending=False)
print('Top 10 diseases by drug associations:')
print(disease_counts.head(10))