In [20]:
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression, Lasso, Ridge, ElasticNet
import lightgbm as lgb
from tqdm import tqdm

from sklearn.neighbors import NearestNeighbors
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

### Dataset Setup

1. Get non-response indexes.
2. Load Xenium Data.
3. Strip location data (OPTIONAL)
4. Get neighborhood ligand/receptor mean summary.
5. Append it to each observation in the data.

In [21]:
NON_RESPONSE_FILE = "../spatial/non_response_blank_removed_xenium.txt"
with open(NON_RESPONSE_FILE, "r", encoding="utf-8") as f:
    non_response_genes = f.read().split(',')

non_response_genes = [int(x) for x in non_response_genes]

In [22]:
xenium_df = pd.read_csv("../data/raw/xenium.csv", index_col="cell_id")

In [23]:
locations = xenium_df.iloc[:, -4:-1] 
location_names = ["x_location", "y_location", "z_location", "qv"]

xenium_df = xenium_df.iloc[:, :-4] 

non_response_gene_names = xenium_df.columns[non_response_genes]
response_gene_names = xenium_df.columns[~xenium_df.columns.isin(non_response_gene_names | location_names)]
print(non_response_gene_names, response_gene_names)

xenium_df_inputs = xenium_df.iloc[:, non_response_genes] 

xenium_df_outputs = xenium_df.iloc[:, xenium_df.columns.isin(response_gene_names)]

Index(['Acvrl1', 'Adamts2', 'Adgrl4', 'Angpt1', 'Ano1', 'Aqp4', 'Bdnf',
       'Cbln1', 'Cbln4', 'Ccn2', 'Cd44', 'Cd53', 'Cd68', 'Cd93', 'Cdh13',
       'Cdh4', 'Cdh6', 'Chat', 'Chrm2', 'Cldn5', 'Cntn6', 'Cntnap4', 'Col19a1',
       'Col1a1', 'Col6a1', 'Cort', 'Crh', 'Cspg4', 'Cyp1b1', 'Dcn', 'Dkk3',
       'Dner', 'Dpyd', 'Epha4', 'Fn1', 'Gad1', 'Gad2', 'Gfra2', 'Gpr17',
       'Grik3', 'Hapln1', 'Hat1', 'Htr1f', 'Igf1', 'Igf2', 'Igfbp4', 'Inpp4b',
       'Kctd8', 'Kdr', 'Mapk4', 'Neto2', 'Npnt', 'Npy2r', 'Nr2f2', 'Nrn1',
       'Nrp2', 'Nts', 'Ntsr2', 'Nxph3', 'Opn3', 'Paqr5', 'Pcsk5', 'Pde11a',
       'Pde7b', 'Pdgfra', 'Pdyn', 'Pecam1', 'Penk', 'Pglyrp1', 'Pip5k1b',
       'Plch1', 'Prph', 'Pthlh', 'Rbp4', 'Ror1', 'Rspo1', 'Rspo2', 'Rxfp1',
       'Sdk2', 'Sema3a', 'Sema3d', 'Sema3e', 'Sema5b', 'Sema6a', 'Shisa6',
       'Slc13a4', 'Slc17a6', 'Slc39a12', 'Slc6a3', 'Slit2', 'Sorcs3', 'Spp1',
       'Sst', 'Tacr1', 'Th', 'Trbc2', 'Trem2', 'Trpc4', 'Vip', 'Vwc2l'],
      dtype='object

  response_gene_names = xenium_df.columns[~xenium_df.columns.isin(non_response_gene_names | location_names)]


In [24]:
# log1p Transforms
xenium_df_inputs = np.log1p(xenium_df_inputs)
xenium_df_outputs = np.log1p(xenium_df_outputs)

In [25]:
locations

Unnamed: 0_level_0,x_location,y_location,z_location
cell_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,1557.532239,2528.022437,14.000948
2,1560.669312,2543.632678,14.789414
3,1570.462885,2530.810461,15.395041
4,1573.927734,2546.454529,14.478160
5,1581.344379,2557.024951,14.901122
...,...,...,...
162029,8310.558740,4345.094580,18.296561
162030,8316.195801,4321.954077,19.010482
162031,8323.133594,4271.474121,18.488743
162032,8327.415137,4348.097388,18.421222


In [26]:
(xenium_df_inputs == 0).mean().mean()

0.6890738306394375

### Modeling

1. Linear
2. Ridge
3. Lasso
4. ElasticNet
5. LightGBM

In [27]:
# Convert the input DataFrame to a NumPy array for faster slicing.
data = xenium_df_inputs.values
N, P = data.shape

# Define the radius values for the r-ball.
r_values = range(0, 41, 5)
mse_performances = {}
training_mse = {}
testing_mse = {}

# Iterate over radius values with a progress bar
for r in tqdm(r_values, desc="Radius values"):
    nn = NearestNeighbors(radius=r)
    nn.fit(locations)
    # Get neighbor indices for each row.
    indices = nn.radius_neighbors(locations, return_distance=False)
    
    # Compute neighbor means using a list comprehension.
    neighbor_means = np.array([
        data[inds].mean(axis=0) if len(inds) > 0 
        else np.full(P, np.nan) 
        for inds in indices
    ])

    if np.isnan(neighbor_means).any():  # Check if there are any NaN values in neighbor_means
        raise ValueError("neighbor_means contains NaN values.")
    
    # Combine original data and neighbor means horizontally.
    combined = np.hstack([data, neighbor_means])
    
    # Create column names for both halves.
    original_columns = xenium_df_inputs.columns
    mean_columns = [f"{col}_mean" for col in original_columns]
    combined_columns = list(original_columns) + mean_columns
    
    combined_df = pd.DataFrame(combined, columns=combined_columns)
    
    for i, response_gene in enumerate(response_gene_names, start=1):
        print(f"Processing response gene: {i}/{len(response_gene_names)}")
        y = xenium_df_outputs[response_gene]

        X_train, X_test, y_train, y_test = train_test_split(combined_df, y, test_size=0.2, random_state=42)
        
        models = [LinearRegression(), Ridge(), Lasso(), ElasticNet(), lgb.LGBMRegressor(verbosity=0)]
        model_names = ['Linear Regression', 'Ridge', 'Lasso', 'Elastic Net', 'LightGBM']

        for model, model_name in zip(models, model_names):
            model.fit(X_train, y_train)
            y_pred_train = model.predict(X_train)
            y_pred_test = model.predict(X_test)
            
            training_mse_value = mean_squared_error(y_train, y_pred_train)
            testing_mse_value = mean_squared_error(y_test, y_pred_test)
            
            if r not in training_mse:
                training_mse[r] = {}
            if response_gene not in training_mse[r]:
                training_mse[r][response_gene] = {}
            training_mse[r][response_gene][model_name] = training_mse_value
            
            if r not in testing_mse:
                testing_mse[r] = {}
            if response_gene not in testing_mse[r]:
                testing_mse[r][response_gene] = {}
            testing_mse[r][response_gene][model_name] = testing_mse_value

Processing response gene: 33/148
Processing response gene: 34/148
Processing response gene: 35/148
Processing response gene: 36/148
Processing response gene: 37/148
Processing response gene: 38/148
Processing response gene: 39/148
Processing response gene: 40/148
Processing response gene: 41/148
Processing response gene: 42/148
Processing response gene: 43/148
Processing response gene: 44/148
Processing response gene: 45/148
Processing response gene: 46/148
Processing response gene: 47/148
Processing response gene: 48/148
Processing response gene: 49/148
Processing response gene: 50/148
Processing response gene: 51/148
Processing response gene: 52/148
Processing response gene: 53/148
Processing response gene: 54/148
Processing response gene: 55/148
Processing response gene: 56/148
Processing response gene: 57/148
Processing response gene: 58/148
Processing response gene: 59/148
Processing response gene: 60/148
Processing response gene: 61/148
Processing response gene: 62/148
Processing

Radius values:  67%|██████▋   | 6/9 [1:18:47<40:31, 810.43s/it]

Processing response gene: 1/148
Processing response gene: 2/148
Processing response gene: 3/148
Processing response gene: 4/148
Processing response gene: 5/148
Processing response gene: 6/148
Processing response gene: 7/148
Processing response gene: 8/148
Processing response gene: 9/148
Processing response gene: 10/148
Processing response gene: 11/148
Processing response gene: 12/148
Processing response gene: 13/148
Processing response gene: 14/148
Processing response gene: 15/148
Processing response gene: 16/148
Processing response gene: 17/148
Processing response gene: 18/148
Processing response gene: 19/148
Processing response gene: 20/148
Processing response gene: 21/148
Processing response gene: 22/148
Processing response gene: 23/148
Processing response gene: 24/148
Processing response gene: 25/148
Processing response gene: 26/148
Processing response gene: 27/148
Processing response gene: 28/148
Processing response gene: 29/148
Processing response gene: 30/148
Processing response

Radius values:  78%|███████▊  | 7/9 [1:33:12<27:36, 828.35s/it]

Processing response gene: 1/148
Processing response gene: 2/148
Processing response gene: 3/148
Processing response gene: 4/148
Processing response gene: 5/148
Processing response gene: 6/148
Processing response gene: 7/148
Processing response gene: 8/148
Processing response gene: 9/148
Processing response gene: 10/148
Processing response gene: 11/148
Processing response gene: 12/148
Processing response gene: 13/148
Processing response gene: 14/148
Processing response gene: 15/148
Processing response gene: 16/148
Processing response gene: 17/148
Processing response gene: 18/148
Processing response gene: 19/148
Processing response gene: 20/148
Processing response gene: 21/148
Processing response gene: 22/148
Processing response gene: 23/148
Processing response gene: 24/148
Processing response gene: 25/148
Processing response gene: 26/148
Processing response gene: 27/148
Processing response gene: 28/148
Processing response gene: 29/148
Processing response gene: 30/148
Processing response

Radius values:  89%|████████▉ | 8/9 [1:47:31<13:57, 837.94s/it]

Processing response gene: 1/148
Processing response gene: 2/148
Processing response gene: 3/148
Processing response gene: 4/148
Processing response gene: 5/148
Processing response gene: 6/148
Processing response gene: 7/148
Processing response gene: 8/148
Processing response gene: 9/148
Processing response gene: 10/148
Processing response gene: 11/148
Processing response gene: 12/148
Processing response gene: 13/148
Processing response gene: 14/148
Processing response gene: 15/148
Processing response gene: 16/148
Processing response gene: 17/148
Processing response gene: 18/148
Processing response gene: 19/148
Processing response gene: 20/148
Processing response gene: 21/148
Processing response gene: 22/148
Processing response gene: 23/148
Processing response gene: 24/148
Processing response gene: 25/148
Processing response gene: 26/148
Processing response gene: 27/148
Processing response gene: 28/148
Processing response gene: 29/148
Processing response gene: 30/148
Processing response

Radius values: 100%|██████████| 9/9 [2:01:32<00:00, 810.33s/it]


In [28]:
import json

with open('xenium_competing_results_training.json', 'w') as f:
    json.dump(training_mse, f)

with open('xenium_competing_results_testing.json', 'w') as f:
    json.dump(testing_mse, f)
