-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
17e86ac
commit 9a700ac
Showing
2 changed files
with
339 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
|
||
from .. import tl | ||
from anndata import AnnData | ||
import matplotlib.pyplot as plt | ||
from typing import Union, List | ||
import pandas as pd | ||
import numpy as np | ||
|
||
|
||
def vdj_usage( | ||
adata: AnnData, | ||
*, | ||
target_cols: list = [ | ||
"TRA_1_j_gene", | ||
"TRA_1_v_gene", | ||
"TRB_1_v_gene", | ||
"TRB_1_d_gene", | ||
"TRB_1_j_gene", | ||
], | ||
for_cells: Union[None, list, np.ndarray, pd.Series] = None, | ||
cell_weights: Union[None, str, list, np.ndarray, pd.Series] = None, | ||
#size_column: str = "cell_weights", | ||
fraction_base: Union[None, str] = None, | ||
ax: Union[plt.axes, None] = None, | ||
bar_clip: int = 5, | ||
top_n: Union[None, int] = 10, | ||
barwidth: float = 0.4, | ||
draw_bars: bool = True, | ||
) -> plt.Axes: | ||
"""Creates a ribbon plot of the most abundant VDJ combinations in a given subset of cells. | ||
Currently works with primary alpha and beta chains only. | ||
Does not search for precomputed results in `adata`. | ||
Parameters | ||
---------- | ||
adata | ||
AnnData object to work on. | ||
target_cols | ||
Columns containing gene segment information. Overwrite default only if you know what you are doing! | ||
for_cells | ||
A whitelist of cells that should be included in the analysis. If not specified, | ||
all cells in `adata` will be used that have at least a primary alpha or beta chain. | ||
cell_weights | ||
The size factor for each cell. By default, each cell count as 1, but due to normalization | ||
to different sample sizes for example, it is possible that one cell in a small sample | ||
is weighted more than a cell in a large sample. | ||
size_column | ||
The name of the column that will be used for storing cell weights. This value is used internally | ||
and should be matched with the column name used by the tool function. Best left untouched. | ||
fraction_base | ||
As an alternative to supplying ready-made cell weights, this feature can also be calculated | ||
on the fly if a grouping column name is supplied. The parameter `cell_weights` takes piority | ||
over `fraction_base`. If both is `None`, each cell will have a weight of 1. | ||
ax | ||
Custom axis if needed. | ||
bar_clip | ||
The maximum number of stocks for bars (number of different V, D or J segments | ||
that should be shown separately). | ||
top_n | ||
The maximum number of ribbons (individual VDJ combinations). If set to `None`, | ||
all ribbons are drawn. | ||
barwidth | ||
Width of bars. | ||
draw_bars | ||
If `False`, only ribbons are drawn and no bars. | ||
Returns | ||
------- | ||
Axes object. | ||
""" | ||
|
||
# Execute the tool | ||
df = tl.vdj_usage( | ||
adata, | ||
target_cols=target_cols, | ||
for_cells=for_cells, | ||
cell_weights=cell_weights, | ||
fraction_base=fraction_base, | ||
) | ||
|
||
if top_n is None: | ||
top_n = df.shape[0] | ||
if ax is None: | ||
fig, ax = plt.subplots() | ||
|
||
# Draw a stacked bar for every gene loci and save positions on the bar | ||
gene_tops = dict() | ||
for i in range(len(target_cols)): | ||
td = ( | ||
df.groupby(target_cols[i]) | ||
.agg({size_column: "sum"}) | ||
.sort_values(by=size_column, ascending=False) | ||
.reset_index() | ||
) | ||
genes = td[target_cols[i]].tolist() | ||
td = td[size_column] | ||
sector = target_cols[i][2:7] | ||
#sector = sector.replace('_', '') | ||
unct = td[bar_clip + 1 :,].sum() | ||
if td.size > bar_clip: | ||
if draw_bars: | ||
ax.bar(i + 1, unct, width=barwidth, color="grey", edgecolor="black") | ||
gene_tops["other_" + sector] = unct | ||
bottom = unct | ||
else: | ||
gene_tops["other_" + sector] = 0 | ||
bottom = 0 | ||
for j in range(bar_clip + 1): | ||
try: | ||
y = td[bar_clip - j] | ||
gene = genes[bar_clip - j] | ||
if gene == "None": | ||
gene = "No_" + sector | ||
gene_tops[gene] = bottom + y | ||
if draw_bars: | ||
ax.bar( | ||
i + 1, | ||
y, | ||
width=barwidth, | ||
bottom=bottom, | ||
color="lightgrey", | ||
edgecolor="black", | ||
) | ||
ax.text( | ||
1 + i - barwidth / 2 + 0.05, | ||
bottom + 0.05, | ||
gene.replace("TRA", "").replace("TRB", ""), | ||
) | ||
bottom += y | ||
except: | ||
pass | ||
|
||
# Count occurance of individual VDJ combinations | ||
td = df.loc[:,target_cols+[size_column]] | ||
td["genecombination"] = td.apply(lambda x, y: '|'.join([x[e] for e in y]), y=target_cols, axis=1) | ||
td = td.groupby("genecombination").agg({size_column: "sum"}).sort_values(by=size_column, ascending=False).reset_index() | ||
td["genecombination"] = td.apply(lambda x: [x[size_column]] + x['genecombination'].split('|'), axis=1) | ||
|
||
# Draw ribbons | ||
for r in td["genecombination"][1 : top_n + 1]: | ||
d = [] | ||
ht = r[0] | ||
for i in range(len(r) - 1): | ||
g = r[i + 1] | ||
sector = target_cols[i][2:7] | ||
if g == "None": | ||
g = "No_" + sector | ||
if g not in gene_tops: | ||
g = "other_" + sector | ||
t = gene_tops[g] | ||
d.append([t - ht, t]) | ||
t = t - ht | ||
gene_tops[g] = t | ||
if draw_bars: | ||
gapped_ribbons(d, ax=ax, gapwidth=barwidth) | ||
else: | ||
gapped_ribbons(d, ax=ax, gapwidth=0.1) | ||
|
||
# Make tick labels nicer | ||
ax.set_xticks(range(1, len(target_cols) + 1)) | ||
if target_cols == [ | ||
"TRA_1_j_gene", | ||
"TRA_1_v_gene", | ||
"TRB_1_v_gene", | ||
"TRB_1_d_gene", | ||
"TRB_1_j_gene", | ||
]: | ||
ax.set_xticklabels(["TRAJ", "TRAV", "TRBV", "TRBD", "TRBJ"]) | ||
else: | ||
ax.set_xticklabels(target_cols) | ||
|
||
return ax | ||
|
||
def gapped_ribbons( | ||
data: list, | ||
*, | ||
ax: Union[plt.axes, list, None] = None, | ||
xstart: float = 1.2, | ||
gapfreq: float = 1.0, | ||
gapwidth: float = 0.4, | ||
fun: Callable = lambda x: x[3] + (x[4] / (1 + np.exp(-((x[5] / x[2]) * (x[0] - x[1]))))), | ||
figsize: Tuple[float, float] = (3.44, 2.58), | ||
figresolution: int = 300, | ||
) -> plt.Axes: | ||
"""Draws ribbons using `fill_between` | ||
Called by VDJ usage plot to connect bars. | ||
Parameters | ||
---------- | ||
data | ||
Breakpoints defining the ribbon as a 2D matrix. Each row is an x position, columns are the lower and upper extent of the ribbon at that position. | ||
ax | ||
Custom axis, almost always called with an axis supplied. | ||
xstart | ||
The midpoint of the first bar. | ||
gapfreq | ||
Frequency of bars. Normally a bar would be drawn at every integer x position, hence default is 1. | ||
gapwidth | ||
At every bar position, there will be a gap. The width of this gap is identical to bar widths, but could also be set to 0 if we need continous ribbons. | ||
fun | ||
A function defining the curve of each ribbon segment from breakpoint to breakpoint. By default, it is a sigmoid with 6 parameters: | ||
range between x position of bars, | ||
curve start on the x axis, | ||
slope of curve, | ||
curve start y position, | ||
curve end y position, | ||
compression factor of the sigmoid curve | ||
figsize | ||
Size of the resulting figure in inches. | ||
figresolution | ||
Resolution of the figure in dpi. | ||
Returns | ||
------- | ||
Axes object. | ||
""" | ||
|
||
if ax is None: | ||
fig, ax = plt.subplots(figsize=figsize, dpi=figresolution) | ||
else: | ||
if isinstance(ax, list): | ||
ax = ax[0] | ||
|
||
spread = 10 | ||
xw = gapfreq - gapwidth | ||
slope = xw * 0.8 | ||
x, y1, y2 = [], [], [] | ||
for i in range(1, len(data)): | ||
xmin = xstart + (i - 1) * gapfreq | ||
tx = np.linspace(xmin, xmin + xw, 100) | ||
xshift = xmin + xw / 2 | ||
p1, p2 = data[i - 1] | ||
p3, p4 = data[i] | ||
ty1 = fun((tx, xshift, slope, p1, p3 - p1, spread)) | ||
ty2 = fun((tx, xshift, slope, p2, p4 - p2, spread)) | ||
x += tx.tolist() | ||
y1 += ty1.tolist() | ||
y2 += ty2.tolist() | ||
x += np.linspace(xmin + xw, xstart + i * gapfreq, 10).tolist() | ||
y1 += np.zeros(10).tolist() | ||
y2 += np.zeros(10).tolist() | ||
ax.fill_between(x, y1, y2, alpha=0.6) | ||
|
||
return ax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
|
||
from anndata import AnnData | ||
from typing import Union, List | ||
import pandas as pd | ||
|
||
|
||
def vdj_usage( | ||
adata: AnnData, | ||
*, | ||
target_cols: Collection = ( | ||
"TRA_1_j_gene", | ||
"TRA_1_v_gene", | ||
"TRB_1_v_gene", | ||
"TRB_1_d_gene", | ||
"TRB_1_j_gene", | ||
), | ||
for_cells: Union[None, list, np.ndarray, pd.Series] = None, | ||
cell_weights: Union[None, str, list, np.ndarray, pd.Series] = None, | ||
fraction_base: Union[None, str] = None, | ||
as_dict: bool = False, | ||
) -> Union[AnnData, dict]: | ||
"""Gives a summary of the most abundant VDJ combinations in a given subset of cells. | ||
Currently works with primary alpha and beta chains only. | ||
Does not add the result to `adata`! | ||
Parameters | ||
---------- | ||
adata | ||
AnnData object to work on. | ||
target_cols | ||
Columns containing gene segment information. Overwrite default only if you know what you are doing! | ||
for_cells | ||
A whitelist of cells that should be included in the analysis. If not specified, | ||
all cells in `adata` will be used that have at least a primary alpha or beta chain. | ||
cell_weights | ||
A size factor for each cell. By default, each cell count as 1, but due to normalization | ||
to different sample sizes for example, it is possible that one cell in a small sample | ||
is weighted more than a cell in a large sample. | ||
size_column | ||
The name of the column that will be used for storing cell weights. This value is used internally | ||
and should be matched with the column name used by the plotting function. Best left untouched. | ||
fraction_base | ||
As an alternative to supplying ready-made cell weights, this feature can also be calculated | ||
on the fly if a grouping column name is supplied. The parameter `cell_weights` takes piority | ||
over `fraction_base`. If both is `None`, each cell will have a weight of 1. | ||
as_dict | ||
If True, returns a dictionary instead of a dataframe. Useful for testing. | ||
Returns | ||
------- | ||
Depending on the value of `as_dict`, either returns a data frame or a dictionary. | ||
""" | ||
|
||
# Preproces the data table (remove unnecessary rows and columns) | ||
size_column = 'cell_weights' | ||
if for_cells is None: | ||
for_cells = adata.obs.loc[ | ||
~_is_na(adata.obs.loc[:, target_cols]).all(axis="columns"), target_cols | ||
].index.values | ||
observations = adata.obs.loc[for_cells, :] | ||
|
||
# Check how cells should be weighted | ||
makefractions = False | ||
if cell_weights is None: | ||
if fraction_base is None: | ||
observations[size_column] = 1 | ||
else: | ||
makefractions = True | ||
else: | ||
if isinstance(cell_weights, str): | ||
makefractions = True | ||
fraction_base = cell_weights | ||
else: | ||
if len(cell_weights) == len(for_cells): | ||
observations[size_column] = cell_weights | ||
else: | ||
raise ValueError( | ||
"Although `cell_weights` appears to be a list, its length is not identical to the number of cells specified by `for_cells`." | ||
) | ||
|
||
# Calculate fractions if necessary | ||
if makefractions: | ||
group_sizes = observations.loc[:, fraction_base].value_counts().to_dict() | ||
observations[size_column] = ( | ||
observations[fraction_base].map(group_sizes).astype("int32") | ||
) | ||
observations[size_column] = 1 / observations[size_column] | ||
|
||
# Return the requested format | ||
if as_dict: | ||
observations = observations.to_dict(orient="index") | ||
|
||
return observations |