# Fit polyclonal model
Here we fit [polyclonal](https://jbloomlab.github.io/polyclonal) models to the data.

First, import Python modules:

In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import yaml

In [2]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

In [3]:
import os
os.chdir('../../')

## Read input data

Get parameterized variable from [papermill](https://papermill.readthedocs.io/)

In [4]:
# papermill parameters cell (tagged as `parameters`)
prob_escape_csv = None
n_threads = None
pickle_file = None
antibody = None

In [5]:
# Parameters
prob_escape_csv = "results/prob_escape/libA_221021_1_3x-1C04_5G04_1_prob_escape.csv"
pickle_file = "results/polyclonal_fits/libA_221021_1_3x-1C04_5G04_1.pickle"
n_threads = 2


Read the probabilities of escape, and filter for those with sufficient no-antibody counts:

In [6]:
print(f"\nReading probabilities of escape from {prob_escape_csv}")

prob_escape = pd.read_csv(
    prob_escape_csv, keep_default_na=False, na_values="nan"
).query("`no-antibody_count` >= no_antibody_count_threshold")
assert prob_escape.notnull().all().all()


Reading probabilities of escape from results/prob_escape/libA_221021_1_3x-1C04_5G04_1_prob_escape.csv


Read the rest of the configuration and input data:

In [7]:
# get information from config
with open("config.yaml") as f:
    config = yaml.safe_load(f)

antibody = prob_escape["antibody"].unique()
assert len(antibody) == 1, antibody
antibody = antibody[0]

# get the reference sites in order
reference_sites = (
    pd.read_csv(config["site_numbering_map"])
    .sort_values("sequential_site")["reference_site"]
    .tolist()
)

# get the polyclonal configuration for this antibody
with open(config["polyclonal_config"]) as f:
    polyclonal_config = yaml.safe_load(f)
if antibody not in polyclonal_config:
    raise ValueError(f"`polyclonal_config` lacks configuration for {antibody=}")
antibody_config = polyclonal_config[antibody]

# print names of variables and settings
print(f"{antibody=}")
print(f"{n_threads=}")
print(f"{pickle_file=}")
print(f"{antibody_config=}")

antibody='3x-1C04_5G04'
n_threads=2
pickle_file='results/polyclonal_fits/libA_221021_1_3x-1C04_5G04_1.pickle'
antibody_config={'max_epitopes': 2, 'n_bootstrap_samples': 50, 'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0, 'times_seen': 3, 'min_epitope_activity_to_include': 0.2}


## Some summary statistics
Note that these statistics are only for the variants that passed upstream filtering in the pipeline.

Number of variants per concentration:

In [8]:
display(
    prob_escape.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
1.65,35041
3.3,35041
6.6,35041
13.2,35041


Plot mean probability of escape across all variants with the indicated number of mutations.
Note that this plot weights each variant the same in the means regardless of how many barcode counts it has.
We plot means for both censored (set to between 0 and 1) and uncensored probabilities of escape.
Also, note it uses a symlog scale for the y-axis.
Mouseover points for values:

In [9]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape = (
    prob_escape.assign(
        n_subs=lambda x: (
            x["aa_substitutions_reference"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart = (
    alt.Chart(mean_prob_escape)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart

  for col_name, dtype in df.dtypes.iteritems():


## Fit `polyclonal` model
First, get the fitting related keyword arguments from the configuration passed by `snakemake`:

In [10]:
times_seen = antibody_config["times_seen"]
print(f"{times_seen=}")

max_epitopes = antibody_config["max_epitopes"]
print(f"{max_epitopes=}")

fit_kwargs = {
    "reg_escape_weight": antibody_config["reg_escape_weight"],
    "reg_spread_weight": antibody_config["reg_spread_weight"],
    "reg_activity_weight": antibody_config["reg_activity_weight"],
}
print(f"{fit_kwargs=}")

min_epitope_activity_to_include = antibody_config["min_epitope_activity_to_include"]
print(f"{min_epitope_activity_to_include=}")

times_seen=3
max_epitopes=2
fit_kwargs={'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}
min_epitope_activity_to_include=0.2


Fit a model to all the data, and keep adding epitopes until we either reach the maximum specified or the new epitope has negative activity.
Note that that we fit using the **reference** based-site-numbering scheme, so results are shown with those numbers:Z

In [11]:
models = []

for n_epitopes in range(1, max_epitopes + 1):
    print(f"\nFitting model with {n_epitopes=}")

    # create model
    model = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
        sites=reference_sites,
    )

    # fit model
    opt_res = model.fit(logfreq=200, **fit_kwargs)

    # display activities
    print("Activities of epitopes:")
    display(model.activity_wt_df.round(1))
    print("Max and mean absolute-value escape at each epitope:")
    display(
        model.mut_escape_df.groupby("epitope")
        .aggregate(
            max_escape=pd.NamedAgg("escape", "max"),
            mean_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().mean()),
        )
        .round(1)
    )

    # stop if activity below threshold for any epitope and fit at least one epitope
    if len(models) and any(
        model.activity_wt_df["activity"] <= min_epitope_activity_to_include
    ):
        print(f"Stop fitting, epitope has activity <={min_epitope_activity_to_include}")
        models.append(model)
        model = models[-2]  # get previous model
        break
    else:
        models.append(model)

print(f"\nThe selected model has {len(model.epitopes)} epitopes")


Fitting model with n_epitopes=1
# First fitting site-level model.
# Starting optimization of 527 parameters at Mon Oct 24 12:41:28 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.039337       5669.2       5668.3            0            0      0.90499
           21      0.67491       42.855       37.885     0.034902            0       4.9344
# Successfully finished at Mon Oct 24 12:41:28 2022.
# Starting optimization of 3461 parameters at Mon Oct 24 12:41:28 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0      0.03648       51.855       46.388      0.53189   3.1629e-34       4.9344
            8      0.39943       51.375       46.358      0.00737   0.00022573       5.0098
# Successfully finished at Mon Oct 24 12:41:29 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,5.1


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0.0,0.0



Fitting model with n_epitopes=2
# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 24 12:41:33 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.062031       816.91          816            0            0      0.90499
           26       1.7483       42.453       38.487     0.043229            0       3.9223
# Successfully finished at Mon Oct 24 12:41:35 2022.
# Starting optimization of 6922 parameters at Mon Oct 24 12:41:35 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.067547       51.565       46.986      0.65711   9.8396e-34       3.9223
            8      0.65465       50.974       46.947    0.0091113   0.00027395       4.0174
# Successfully finished at Mon Oct 24 12:41:35 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,4.1
1,2,0.1


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0.0,0.0
2,0.0,0.0


Stop fitting, epitope has activity <=0.2

The selected model has 1 epitopes


Epitope activities:

In [12]:
model.activity_wt_barplot()

  for col_name, dtype in df.dtypes.iteritems():


Line plot of escape at each site:

In [13]:
model.mut_escape_plot()

  for col_name, dtype in df.dtypes.iteritems():


# Testing different selection concentrations

In [14]:
display(
    prob_escape.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
1.65,35041
3.3,35041
6.6,35041
13.2,35041


## Start by dropping lowest conc

In [15]:
prob_escape_high = prob_escape.loc[(prob_escape['antibody_concentration'] != 1.65)
                             ]

In [16]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape_high = (
    prob_escape_high.assign(
        n_subs=lambda x: (
            x["aa_substitutions_reference"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart_high = (
    alt.Chart(mean_prob_escape_high)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart_high

  for col_name, dtype in df.dtypes.iteritems():


In [17]:
times_seen = antibody_config["times_seen"]
print(f"{times_seen=}")

max_epitopes = antibody_config["max_epitopes"]
print(f"{max_epitopes=}")

fit_kwargs = {
    "reg_escape_weight": antibody_config["reg_escape_weight"],
    "reg_spread_weight": antibody_config["reg_spread_weight"],
    "reg_activity_weight": antibody_config["reg_activity_weight"],
}
print(f"{fit_kwargs=}")

min_epitope_activity_to_include = antibody_config["min_epitope_activity_to_include"]
print(f"{min_epitope_activity_to_include=}")

times_seen=3
max_epitopes=2
fit_kwargs={'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}
min_epitope_activity_to_include=0.2


In [18]:
models = []

for n_epitopes in range(1, max_epitopes + 1):
    print(f"\nFitting model with {n_epitopes=}")

    # create model
    model_high = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape_high.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
        sites=reference_sites,
    )

    # fit model
    opt_res = model_high.fit(logfreq=200, **fit_kwargs)

    # display activities
    print("Activities of epitopes:")
    display(model_high.activity_wt_df.round(1))
    print("Max and mean absolute-value escape at each epitope:")
    display(
        model_high.mut_escape_df.groupby("epitope")
        .aggregate(
            max_escape=pd.NamedAgg("escape", "max"),
            mean_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().mean()),
        )
        .round(1)
    )

    # stop if activity below threshold for any epitope and fit at least one epitope
    if len(models) and any(
        model_high.activity_wt_df["activity"] <= min_epitope_activity_to_include
    ):
        print(f"Stop fitting, epitope has activity <={min_epitope_activity_to_include}")
        models.append(model_high)
        model_high = models[-2]  # get previous model
        break
    else:
        models.append(model_high)

print(f"\nThe selected model has {len(model_high.epitopes)} epitopes")


Fitting model with n_epitopes=1
# First fitting site-level model.
# Starting optimization of 527 parameters at Mon Oct 24 12:43:18 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.027381       1954.7       1953.8            0            0      0.90499
           18      0.47462       17.944       13.554     0.021851            0       4.3685
# Successfully finished at Mon Oct 24 12:43:18 2022.
# Starting optimization of 3461 parameters at Mon Oct 24 12:43:18 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.028675       21.105       16.384      0.35188    1.208e-34       4.3685
            7      0.23762       20.794       16.378    0.0044604   0.00012133       4.4117
# Successfully finished at Mon Oct 24 12:43:18 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,4.5


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0.0,0.0



Fitting model with n_epitopes=2
# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 24 12:43:22 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0      0.04751       96.177       95.272            0            0      0.90499
           19      0.95484       16.806       13.951     0.030669            0        2.824
# Successfully finished at Mon Oct 24 12:43:23 2022.
# Starting optimization of 6922 parameters at Mon Oct 24 12:43:23 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.051874       20.094       16.779      0.49088   6.8073e-34        2.824
            7      0.43306       19.662       16.768    0.0066081   0.00017454       2.8869
# Successfully finished at Mon Oct 24 12:43:23 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,2.9
1,2,0.1


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0.0,0.0
2,0.0,0.0


Stop fitting, epitope has activity <=0.2

The selected model has 1 epitopes


In [19]:
model_high.activity_wt_barplot()

  for col_name, dtype in df.dtypes.iteritems():


In [20]:
model_high.mut_escape_plot()

  for col_name, dtype in df.dtypes.iteritems():


## Now try lowest conc

In [21]:
prob_escape_low = prob_escape.loc[(prob_escape['antibody_concentration'] != 13.20)
                             ]

In [22]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape_low = (
    prob_escape_low.assign(
        n_subs=lambda x: (
            x["aa_substitutions_reference"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart_low = (
    alt.Chart(mean_prob_escape_low)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart_low

  for col_name, dtype in df.dtypes.iteritems():


In [23]:
models = []

for n_epitopes in range(1, max_epitopes + 1):
    print(f"\nFitting model with {n_epitopes=}")

    # create model
    model_low = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape_low.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
        sites=reference_sites,
    )

    # fit model
    opt_res = model_low.fit(logfreq=200, **fit_kwargs)

    # display activities
    print("Activities of epitopes:")
    display(model_low.activity_wt_df.round(1))
    print("Max and mean absolute-value escape at each epitope:")
    display(
        model_low.mut_escape_df.groupby("epitope")
        .aggregate(
            max_escape=pd.NamedAgg("escape", "max"),
            mean_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().mean()),
        )
        .round(1)
    )

    # stop if activity below threshold for any epitope and fit at least one epitope
    if len(models) and any(
        model_low.activity_wt_df["activity"] <= min_epitope_activity_to_include
    ):
        print(f"Stop fitting, epitope has activity <={min_epitope_activity_to_include}")
        models.append(model_low)
        model_low = models[-2]  # get previous model
        break
    else:
        models.append(model_low)

print(f"\nThe selected model has {len(model_low.epitopes)} epitopes")


Fitting model with n_epitopes=1
# First fitting site-level model.
# Starting optimization of 527 parameters at Mon Oct 24 12:44:28 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.027143       5555.1       5554.2            0            0      0.90499
           21      0.56242       38.355       33.356     0.034725            0       4.9644
# Successfully finished at Mon Oct 24 12:44:28 2022.
# Starting optimization of 3461 parameters at Mon Oct 24 12:44:29 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.028637       46.534       41.037      0.53251   2.1509e-34       4.9644
            8      0.26077       46.054       41.006    0.0070033   0.00021023       5.0412
# Successfully finished at Mon Oct 24 12:44:29 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,5.1


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0.0,0.0



Fitting model with n_epitopes=2
# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 24 12:44:32 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.050619       812.73       811.83            0            0      0.90499
           27       1.4206       37.748       33.775     0.043546            0       3.9299
# Successfully finished at Mon Oct 24 12:44:33 2022.
# Starting optimization of 6922 parameters at Mon Oct 24 12:44:34 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.052863       46.051       41.458       0.6635   1.1393e-33       3.9299
            8      0.46792       45.454       41.419    0.0090242   0.00026887       4.0257
# Successfully finished at Mon Oct 24 12:44:34 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,4.1
1,2,0.1


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0.0,0.0
2,0.0,0.0


Stop fitting, epitope has activity <=0.2

The selected model has 1 epitopes


In [24]:
model_low.activity_wt_barplot()

  for col_name, dtype in df.dtypes.iteritems():


In [25]:
model_low.mut_escape_plot()

  for col_name, dtype in df.dtypes.iteritems():


## now drop intermediate conc?

In [31]:
prob_escape_lower = prob_escape.loc[(prob_escape['antibody_concentration'] != 3.30)
                             ]

In [32]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape_lower = (
    prob_escape_lower.assign(
        n_subs=lambda x: (
            x["aa_substitutions_reference"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart_lower = (
    alt.Chart(mean_prob_escape_lower)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart_lower

  for col_name, dtype in df.dtypes.iteritems():


In [33]:
models = []

for n_epitopes in range(1, max_epitopes + 1):
    print(f"\nFitting model with {n_epitopes=}")

    # create model
    model_lower = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape_lower.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
        sites=reference_sites,
    )

    # fit model
    opt_res = model_lower.fit(logfreq=200, **fit_kwargs)

    # display activities
    print("Activities of epitopes:")
    display(model_lower.activity_wt_df.round(1))
    print("Max and mean absolute-value escape at each epitope:")
    display(
        model_lower.mut_escape_df.groupby("epitope")
        .aggregate(
            max_escape=pd.NamedAgg("escape", "max"),
            mean_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().mean()),
        )
        .round(1)
    )

    # stop if activity below threshold for any epitope and fit at least one epitope
    if len(models) and any(
        model_lower.activity_wt_df["activity"] <= min_epitope_activity_to_include
    ):
        print(f"Stop fitting, epitope has activity <={min_epitope_activity_to_include}")
        models.append(model_lower)
        model_lower = models[-2]  # get previous model
        break
    else:
        models.append(model_lower)

print(f"\nThe selected model has {len(model_lower.epitopes)} epitopes")


Fitting model with n_epitopes=1
# First fitting site-level model.
# Starting optimization of 527 parameters at Mon Oct 24 12:47:03 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.026771       4261.7       4260.8            0            0      0.90499
           19      0.50249       37.217       32.221     0.036197            0       4.9595
# Successfully finished at Mon Oct 24 12:47:03 2022.
# Starting optimization of 3461 parameters at Mon Oct 24 12:47:04 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.028769           45       39.476      0.56482   3.1409e-34       4.9595
            7      0.23488        44.49        39.44    0.0063371   0.00018138       5.0432
# Successfully finished at Mon Oct 24 12:47:04 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,5.1


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0.0,0.0



Fitting model with n_epitopes=2
# First fitting site-level model.
# Starting optimization of 1054 parameters at Mon Oct 24 12:47:07 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0      0.04796       733.83       732.93            0            0      0.90499
           24       1.2903       36.639       32.606     0.045203            0       3.9879
# Successfully finished at Mon Oct 24 12:47:08 2022.
# Starting optimization of 6922 parameters at Mon Oct 24 12:47:08 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.052191       44.543       39.852      0.70359   1.1741e-33       3.9879
            9      0.51419        43.91       39.811     0.008018   0.00022263       4.0907
# Successfully finished at Mon Oct 24 12:47:09 2022.
Activities of epitopes:


Unnamed: 0,epitope,activity
0,1,4.2
1,2,0.1


Max and mean absolute-value escape at each epitope:


Unnamed: 0_level_0,max_escape,mean_abs_escape
epitope,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0.0,0.0
2,0.0,0.0


Stop fitting, epitope has activity <=0.2

The selected model has 1 epitopes


In [34]:
model_lower.activity_wt_barplot()

  for col_name, dtype in df.dtypes.iteritems():


In [35]:
model_lower.mut_escape_plot()

  for col_name, dtype in df.dtypes.iteritems():
