Skip to content

Commit

Permalink
Added VDJ function
Browse files Browse the repository at this point in the history
  • Loading branch information
szabogtamas authored and grst committed Mar 28, 2020
1 parent 17e86ac commit 9a700ac
Show file tree
Hide file tree
Showing 2 changed files with 339 additions and 0 deletions.
245 changes: 245 additions & 0 deletions sctcrpy/_plotting/_vdj_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@

from .. import tl
from anndata import AnnData
import matplotlib.pyplot as plt
from typing import Union, List
import pandas as pd
import numpy as np


def vdj_usage(
adata: AnnData,
*,
target_cols: list = [
"TRA_1_j_gene",
"TRA_1_v_gene",
"TRB_1_v_gene",
"TRB_1_d_gene",
"TRB_1_j_gene",
],
for_cells: Union[None, list, np.ndarray, pd.Series] = None,
cell_weights: Union[None, str, list, np.ndarray, pd.Series] = None,
#size_column: str = "cell_weights",
fraction_base: Union[None, str] = None,
ax: Union[plt.axes, None] = None,
bar_clip: int = 5,
top_n: Union[None, int] = 10,
barwidth: float = 0.4,
draw_bars: bool = True,
) -> plt.Axes:
"""Creates a ribbon plot of the most abundant VDJ combinations in a given subset of cells.
Currently works with primary alpha and beta chains only.
Does not search for precomputed results in `adata`.
Parameters
----------
adata
AnnData object to work on.
target_cols
Columns containing gene segment information. Overwrite default only if you know what you are doing!
for_cells
A whitelist of cells that should be included in the analysis. If not specified,
all cells in `adata` will be used that have at least a primary alpha or beta chain.
cell_weights
The size factor for each cell. By default, each cell count as 1, but due to normalization
to different sample sizes for example, it is possible that one cell in a small sample
is weighted more than a cell in a large sample.
size_column
The name of the column that will be used for storing cell weights. This value is used internally
and should be matched with the column name used by the tool function. Best left untouched.
fraction_base
As an alternative to supplying ready-made cell weights, this feature can also be calculated
on the fly if a grouping column name is supplied. The parameter `cell_weights` takes piority
over `fraction_base`. If both is `None`, each cell will have a weight of 1.
ax
Custom axis if needed.
bar_clip
The maximum number of stocks for bars (number of different V, D or J segments
that should be shown separately).
top_n
The maximum number of ribbons (individual VDJ combinations). If set to `None`,
all ribbons are drawn.
barwidth
Width of bars.
draw_bars
If `False`, only ribbons are drawn and no bars.
Returns
-------
Axes object.
"""

# Execute the tool
df = tl.vdj_usage(
adata,
target_cols=target_cols,
for_cells=for_cells,
cell_weights=cell_weights,
fraction_base=fraction_base,
)

if top_n is None:
top_n = df.shape[0]
if ax is None:
fig, ax = plt.subplots()

# Draw a stacked bar for every gene loci and save positions on the bar
gene_tops = dict()
for i in range(len(target_cols)):
td = (
df.groupby(target_cols[i])
.agg({size_column: "sum"})
.sort_values(by=size_column, ascending=False)
.reset_index()
)
genes = td[target_cols[i]].tolist()
td = td[size_column]
sector = target_cols[i][2:7]
#sector = sector.replace('_', '')
unct = td[bar_clip + 1 :,].sum()
if td.size > bar_clip:
if draw_bars:
ax.bar(i + 1, unct, width=barwidth, color="grey", edgecolor="black")
gene_tops["other_" + sector] = unct
bottom = unct
else:
gene_tops["other_" + sector] = 0
bottom = 0
for j in range(bar_clip + 1):
try:
y = td[bar_clip - j]
gene = genes[bar_clip - j]
if gene == "None":
gene = "No_" + sector
gene_tops[gene] = bottom + y
if draw_bars:
ax.bar(
i + 1,
y,
width=barwidth,
bottom=bottom,
color="lightgrey",
edgecolor="black",
)
ax.text(
1 + i - barwidth / 2 + 0.05,
bottom + 0.05,
gene.replace("TRA", "").replace("TRB", ""),
)
bottom += y
except:
pass

# Count occurance of individual VDJ combinations
td = df.loc[:,target_cols+[size_column]]
td["genecombination"] = td.apply(lambda x, y: '|'.join([x[e] for e in y]), y=target_cols, axis=1)
td = td.groupby("genecombination").agg({size_column: "sum"}).sort_values(by=size_column, ascending=False).reset_index()
td["genecombination"] = td.apply(lambda x: [x[size_column]] + x['genecombination'].split('|'), axis=1)

# Draw ribbons
for r in td["genecombination"][1 : top_n + 1]:
d = []
ht = r[0]
for i in range(len(r) - 1):
g = r[i + 1]
sector = target_cols[i][2:7]
if g == "None":
g = "No_" + sector
if g not in gene_tops:
g = "other_" + sector
t = gene_tops[g]
d.append([t - ht, t])
t = t - ht
gene_tops[g] = t
if draw_bars:
gapped_ribbons(d, ax=ax, gapwidth=barwidth)
else:
gapped_ribbons(d, ax=ax, gapwidth=0.1)

# Make tick labels nicer
ax.set_xticks(range(1, len(target_cols) + 1))
if target_cols == [
"TRA_1_j_gene",
"TRA_1_v_gene",
"TRB_1_v_gene",
"TRB_1_d_gene",
"TRB_1_j_gene",
]:
ax.set_xticklabels(["TRAJ", "TRAV", "TRBV", "TRBD", "TRBJ"])
else:
ax.set_xticklabels(target_cols)

return ax

def gapped_ribbons(
data: list,
*,
ax: Union[plt.axes, list, None] = None,
xstart: float = 1.2,
gapfreq: float = 1.0,
gapwidth: float = 0.4,
fun: Callable = lambda x: x[3] + (x[4] / (1 + np.exp(-((x[5] / x[2]) * (x[0] - x[1]))))),
figsize: Tuple[float, float] = (3.44, 2.58),
figresolution: int = 300,
) -> plt.Axes:
"""Draws ribbons using `fill_between`
Called by VDJ usage plot to connect bars.
Parameters
----------
data
Breakpoints defining the ribbon as a 2D matrix. Each row is an x position, columns are the lower and upper extent of the ribbon at that position.
ax
Custom axis, almost always called with an axis supplied.
xstart
The midpoint of the first bar.
gapfreq
Frequency of bars. Normally a bar would be drawn at every integer x position, hence default is 1.
gapwidth
At every bar position, there will be a gap. The width of this gap is identical to bar widths, but could also be set to 0 if we need continous ribbons.
fun
A function defining the curve of each ribbon segment from breakpoint to breakpoint. By default, it is a sigmoid with 6 parameters:
range between x position of bars,
curve start on the x axis,
slope of curve,
curve start y position,
curve end y position,
compression factor of the sigmoid curve
figsize
Size of the resulting figure in inches.
figresolution
Resolution of the figure in dpi.
Returns
-------
Axes object.
"""

if ax is None:
fig, ax = plt.subplots(figsize=figsize, dpi=figresolution)
else:
if isinstance(ax, list):
ax = ax[0]

spread = 10
xw = gapfreq - gapwidth
slope = xw * 0.8
x, y1, y2 = [], [], []
for i in range(1, len(data)):
xmin = xstart + (i - 1) * gapfreq
tx = np.linspace(xmin, xmin + xw, 100)
xshift = xmin + xw / 2
p1, p2 = data[i - 1]
p3, p4 = data[i]
ty1 = fun((tx, xshift, slope, p1, p3 - p1, spread))
ty2 = fun((tx, xshift, slope, p2, p4 - p2, spread))
x += tx.tolist()
y1 += ty1.tolist()
y2 += ty2.tolist()
x += np.linspace(xmin + xw, xstart + i * gapfreq, 10).tolist()
y1 += np.zeros(10).tolist()
y2 += np.zeros(10).tolist()
ax.fill_between(x, y1, y2, alpha=0.6)

return ax
94 changes: 94 additions & 0 deletions sctcrpy/_tools/_vdj_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@

from anndata import AnnData
from typing import Union, List
import pandas as pd


def vdj_usage(
adata: AnnData,
*,
target_cols: Collection = (
"TRA_1_j_gene",
"TRA_1_v_gene",
"TRB_1_v_gene",
"TRB_1_d_gene",
"TRB_1_j_gene",
),
for_cells: Union[None, list, np.ndarray, pd.Series] = None,
cell_weights: Union[None, str, list, np.ndarray, pd.Series] = None,
fraction_base: Union[None, str] = None,
as_dict: bool = False,
) -> Union[AnnData, dict]:
"""Gives a summary of the most abundant VDJ combinations in a given subset of cells.
Currently works with primary alpha and beta chains only.
Does not add the result to `adata`!
Parameters
----------
adata
AnnData object to work on.
target_cols
Columns containing gene segment information. Overwrite default only if you know what you are doing!
for_cells
A whitelist of cells that should be included in the analysis. If not specified,
all cells in `adata` will be used that have at least a primary alpha or beta chain.
cell_weights
A size factor for each cell. By default, each cell count as 1, but due to normalization
to different sample sizes for example, it is possible that one cell in a small sample
is weighted more than a cell in a large sample.
size_column
The name of the column that will be used for storing cell weights. This value is used internally
and should be matched with the column name used by the plotting function. Best left untouched.
fraction_base
As an alternative to supplying ready-made cell weights, this feature can also be calculated
on the fly if a grouping column name is supplied. The parameter `cell_weights` takes piority
over `fraction_base`. If both is `None`, each cell will have a weight of 1.
as_dict
If True, returns a dictionary instead of a dataframe. Useful for testing.
Returns
-------
Depending on the value of `as_dict`, either returns a data frame or a dictionary.
"""

# Preproces the data table (remove unnecessary rows and columns)
size_column = 'cell_weights'
if for_cells is None:
for_cells = adata.obs.loc[
~_is_na(adata.obs.loc[:, target_cols]).all(axis="columns"), target_cols
].index.values
observations = adata.obs.loc[for_cells, :]

# Check how cells should be weighted
makefractions = False
if cell_weights is None:
if fraction_base is None:
observations[size_column] = 1
else:
makefractions = True
else:
if isinstance(cell_weights, str):
makefractions = True
fraction_base = cell_weights
else:
if len(cell_weights) == len(for_cells):
observations[size_column] = cell_weights
else:
raise ValueError(
"Although `cell_weights` appears to be a list, its length is not identical to the number of cells specified by `for_cells`."
)

# Calculate fractions if necessary
if makefractions:
group_sizes = observations.loc[:, fraction_base].value_counts().to_dict()
observations[size_column] = (
observations[fraction_base].map(group_sizes).astype("int32")
)
observations[size_column] = 1 / observations[size_column]

# Return the requested format
if as_dict:
observations = observations.to_dict(orient="index")

return observations

0 comments on commit 9a700ac

Please sign in to comment.