In [None]:
import pandas as pd
import pandas.util.testing as tm
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import MiniBatchKMeans
import matplotlib.patches as mpatches
import seaborn as sns
%matplotlib inline

In [None]:
# Matach the transcriptional states of human spermatogonia 
cell_types = {'1':'State 0',
              '2':'State 1',
              '3':'State 2',
              '4':'State 3',
              '5':'State 4'}

num_cell_types = len(cell_types)

# Define methods

In [None]:
#convert spatial coordinates into array 
def coords_to_arr(bc_loc_df):
    coords_arr = bc_loc_df.loc[:,'x':'y'].to_numpy()
    return coords_arr

In [None]:
#perform nearest neighbor analysis and generate neighbor indices df
def nbrs_df(coords_arr, k):
    #calculate n nearest neighbors
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(coords_arr)
    distances, indices = nbrs.kneighbors(coords_arr)
    
    #create df with indices of nearest neighbors 
    nbrs_inds = pd.DataFrame(indices)
    return nbrs_inds

In [None]:
#create list of windows with cell type counts
def nbr_wind_dfs(nbrs_inds, bc_cell_type):
    num_arr = [None]*len(nbrs_inds)
    for index, row in nbrs_inds.iterrows():
        num_arr[index] = pd.DataFrame(bc_cell_type.loc[row])
    return num_arr

In [None]:
#calculate frequencies of cell types given cell type counts in list of windows 
#param: df with raw counts data -> i.e. num_arr[i]
#return: list of cell type frequencies where index in list is cell type number
def calc_freq(cell_type_counts):
    
    #convert cell type assignments into list
    row_nums = []
    for index, row in cell_type_counts.iterrows(): #calculate frequency for each row
        row_nums.append(row['max_cell_type'])
    
    #calculate frequency for each row
    row_freq = {}
    for n in row_nums:
        row_freq[n] = row_freq.get(n, 0) + 1
    
    #store frequencies of cell type
    freq_lst = []
    
    #add to freqs list
    for ct in range(1,num_cell_types+1):
        if ct in row_freq.keys():
            freq_lst.append(row_freq[ct]/k)
        else:
            freq_lst.append(0)
    return freq_lst

In [None]:
#calculate frequency of cell types within each window
def ct_freq_wind(num_arr):
    wind_freq = []
    for window in num_arr:
        new_wind = calc_freq(window)
        wind_freq.append(new_wind)
    return wind_freq

# Calculate cell state frequency

In [None]:
#define window size
k = 10

In [None]:
#read in files containing the bead location and the cell state assignment info
bc_loc_df = pd.read_csv('file name.csv', index_col=0)
cell_state_df = pd.read_csv('file name.csv', index_col=0)

In [None]:
# Calculate cell state frequency for each window
coords_arr = coords_to_arr(bc_loc_df)
nbrs_inds = nbrs_df(coords_arr, k)
num_arr = nbr_wind_dfs(nbrs_inds, cell_state_df)
wind_freq = ct_freq_wind(num_arr)
wind_freq_df = pd.DataFrame(wind_freq, columns =['State 0', 'State 1', 'State 2', 'State 3', 'State 4'], dtype = float)
print(wind_freq_df.shape)
wind_freq_df.head(3)

In [None]:
df_combined = pd.concat([wind_freq_df, bc_loc], axis=1)
df_combined.head(3)

In [None]:
df_combined = df_combined.sort_values(by=['cluster']) # sort by spermatogonium states
df_combined.head(3)

In [None]:
df_combined_select = df_combined[['State 0', 'State 1', 'State 2', 'State 3', 'State 4']]
df_combined_select.head(3)

In [None]:
df_combbined_select = df_combbined_select.reset_index(drop=True)
df_combbined_select.head(3)

In [None]:
ax = sns.heatmap(df_combbined_select, cmap="YlGnBu")