<a href="https://colab.research.google.com/github/sbooeshaghi/azucar/blob/main/analysis/293T/obs2/select_extract.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title import
import os

import matplotlib.pyplot as plt
from sklearn.metrics import rand_score
from mpl_toolkits.axes_grid1 import make_axes_locatable
import json
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
from collections import defaultdict
from scipy.io import mmread, mmwrite
from scipy.sparse import csr_matrix
from sklearn.neighbors import KDTree
from scipy.stats import entropy
from itertools import combinations
import sys
import gzip
from scipy.stats import entropy
from sklearn.mixture import GaussianMixture


def nd(arr):
    return np.asarray(arr).reshape(-1)

def yex(ax):
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]

    # now plot both limits against eachother
    ax.plot(lims, lims, 'k-', alpha=0.75, zorder=0)
    ax.set_aspect('equal')
    ax.set_xlim(lims)
    ax.set_ylim(lims)
    return ax

fsize=20

plt.rcParams.update({'font.size': fsize})
%config InlineBackend.figure_format = 'retina'

In [5]:
#@title mx select

def write_markers(fname, markers):
    with open(fname, 'w') as f:
        for k, v in markers.items():
            f.write(f'{k}\t')
            n = len(v)
            for idx, i in enumerate(v):
                f.write(f'{i}')
                if idx < n - 1:
                    f.write(',')
            f.write('\n')

def read_markers(fname,
                 markers_ec=defaultdict(list),
                 celltype=defaultdict(),
                 marker_genes=defaultdict()):
    with open(fname, 'r') as f:
        for idx, line in enumerate(f.readlines()):
            ct, genes = line.strip().split('\t')
            celltype[ct] = idx

            # two things
            # 1. make marker_genes list
            # 2. make markers_ec
            for g in genes.split(','):
                gidx = len(marker_genes)

                # check if the gene has been added already
                if g in marker_genes.keys():  # gene repeated
                    gidx = marker_genes[g]
                else:
                    marker_genes[g] = gidx

                # for the cell type index, add the marker gene index
                markers_ec[celltype[ct]].append(marker_genes[g])

            # sort the marker genes
            markers_ec[celltype[ct]] = sorted(markers_ec[celltype[ct]])


def read_genes(genes_fname, genes=defaultdict()):
    with open(genes_fname) as f:
        for idx, line in enumerate(f.readlines()):
            gene = line.strip()
            genes[gene] = idx


def sel_genes(genes, marker_genes, sel=[]):
    mg_inv = {v: k for k, v in marker_genes.items()}
    for idx in range(len(mg_inv)):
        # this maps the marker gene name index to the gene index
        # in order of the marker_genes file
        sel.append(genes[mg_inv[idx]])


def write_list(fname, lst):
    with open(fname, 'w') as f:
        for el in lst:
            f.write(f'{el}\n')


def mx_select(markers_fname, genes_fname, out_select_fn):
  # select should be extensible to axis and genes -> md (metadata)
    markers_ec = defaultdict(list)
    celltypes = defaultdict()
    marker_genes = defaultdict()
    # this is duplicated from index, not ideal but w/e maybe ok
    # ideally would want to give it markers.ec
    read_markers(markers_fname, markers_ec, celltypes, marker_genes)

    genes = defaultdict()
    read_genes(genes_fname, genes)

    sel = []
    sel_genes(genes, marker_genes, sel)
    write_list(out_select_fn, sel)

In [6]:
test_mtx = csr_matrix(np.array(
    [
     [1, 2, 2, 1, 5], # work on select.py for the case when a column is dropped
     [0, 1, 2, 0, 3],
     [0, 0, 0, 0, 0],
     [2, 5, 1, 0, 3],
     [4, 3, 3, 0, 1],
     [1, 2, 1, 0, 0],
     [3, 3, 0, 0, 0]
    ]
))

test_genes = ["tag1", "tag2", "tag3", "tag4", "tag5"]
test_barcodes = ["A", "B", "C", "D", "E", "F", "G"]

test_markers = {
    "g1": ["tag1", "tag3"],
    "g2": ["tag3", "tag4"],
    "g3" : ["tag2", "tag5"]
}

test_matrix_fn = "test_mtx.mtx"
test_genes_fn = "test_genes.txt"
test_barcodes_fn = "test_barcodes.txt"
test_markers_fn = "test_markers.txt"

mmwrite(test_matrix_fn, test_mtx)
write_list(test_genes_fn, test_genes)
write_list(test_barcodes_fn, test_barcodes)
write_markers(test_markers_fn, test_markers)

In [8]:
#@title test mx select
# get the gene ids -> select.txt (selects in order of markers.ec)
mx_select(test_markers_fn, 
          f"test_genes.txt", 
          "./select.txt")
!cat select.txt

0
2
3
1
4


In [10]:
#@title mx extract

def read_int_list(fname, lst=[]):
  with open(fname) as f:
    for idx, i in enumerate(f.readlines()):
      lst.append(int(i.strip()))

def read_str_list(fname, lst=list):
  with open(fname, 'r') as f:
    for idx, line in enumerate(f.readlines()):
      lst.append(line.strip())

def mx_extract(matrix_fn, 
               md_fn, 
               select_fn, 
               out_matrix_fn, 
               out_md_fn, 
               axis=1):
  M = mmread(matrix_fn).toarray()

  # column indices to select from gene matrix
  sel = []
  read_int_list(select_fn, sel)

  # read in axis metadata (in this case its the gene list)
  md = []
  read_str_list(md_fn, md)

  # # markers.ec, maps groups (indices) to marker genes (indices)
  # markers_ec = defaultdict(list)
  # read_markers_ec(markers_ec_fn, markers_ec)

  # # the gene names
  # genes = []
  # read_str_list(genes_fn, genes)


  # TODO MAKE SURE DROP GENES USED
  # drop_genes = np.arange(M.shape[1])[~col_mask]
  # drop_markers(markers_ec, set(drop_genes))

  # mtx = M[row_mask][:,col_mask].astype(int)
  # mtx_ipf = do_ipf(mtx.copy())

  # mmwrite(matrix_e_fn, csr_matrix(mtx_ipf[:,sel]))
  mmwrite(out_matrix_fn, csr_matrix(np.take(M, sel, axis=axis)))
  write_list(out_md_fn, np.take(md, sel))

In [11]:
#@title test mx extract
# extract elements from matrix that are of interest, rows / columns (with associated metadata)
mx_extract(test_matrix_fn, 
           test_genes_fn, 
           "select.txt", 
           f"extr.{test_matrix_fn}", 
           f"extr.{test_genes_fn}", axis=1)

!echo "Before"
!cat $test_genes_fn | tr '\n' ' ' && echo -e '\n'
mmread(test_matrix_fn).toarray()

!echo "Extracted"
!cat extr.$test_genes_fn | tr '\n' ' ' && echo -e '\n'
mmread(f"extr.{test_matrix_fn}").toarray()

Before
tag1 tag2 tag3 tag4 tag5 

Extracted
tag1 tag3 tag4 tag2 tag5 



array([[1, 2, 1, 2, 5],
       [0, 2, 0, 1, 3],
       [0, 0, 0, 0, 0],
       [2, 1, 0, 5, 3],
       [4, 3, 0, 3, 1],
       [1, 1, 0, 2, 0],
       [3, 0, 0, 3, 0]])