# 1b. Baseline models - TEMPORAL split

**TODO** add short summary about what we do in this notebook

## Setup

In [1]:
import collections

import matplotlib.pyplot as plt
import numpy as np 
import pandas as pd
import seaborn as sns
import matplotlib.lines as mlines
import matplotlib.transforms as mtransforms
from sklearn.neighbors import KNeighborsRegressor
from sklearn.preprocessing import OneHotEncoder

import bin.baseline_models as bm
import bin.feature_generators as fg
import bin.params as p
import bin.utils as u

In [2]:
sns.set_style()

In [3]:
CHAINS = p.CHAINS

**Load the data:**

In [4]:
metadata_df = pd.read_csv(f'{p.DATA_DIR}/csv/metadata/metadata_{CHAINS}.csv', index_col=0)
fasta_df = pd.read_csv(f'{p.DATA_DIR}/csv/fasta_aligned_cleaned/fasta_{p.FINAL_NUMBERING_SCHEME}_{CHAINS}.csv', index_col=0)
sasa_df = pd.read_csv(f'{p.DATA_DIR}/csv/sasa_aligned/sasa_{CHAINS}.csv', index_col=0)

**Baseline wrapper functions:**

In [5]:
def predict_avg_by_pos(x_train, y_train, x_test):
    avg_by_pos = bm.AverageForResidueAtPosition()
    avg_by_pos.fit(x_train, y_train)
    avg_by_pos_predicts = avg_by_pos.predict(x_test)
    return avg_by_pos_predicts

In [6]:
def predict_avg_by_same_res_pos(x_train, y_train, x_test):
    avg_by_sameres_pos = bm.StatisticForSameResidueAtPosition(statistic='mean')
    avg_by_sameres_pos.fit(x_train, y_train)
    avg_by_sameres_pos_predicts = avg_by_sameres_pos.predict(x_test)
    return avg_by_sameres_pos_predicts

In [7]:
def predict_median_by_same_res_pos(x_train, y_train, x_test):
    avg_by_sameres_pos = bm.StatisticForSameResidueAtPosition(statistic='median')
    avg_by_sameres_pos.fit(x_train, y_train)
    avg_by_sameres_pos_predicts = avg_by_sameres_pos.predict(x_test)
    return avg_by_sameres_pos_predicts

In [8]:
def predict_knn(x_train, y_train, x_test):
    onehot = OneHotEncoder(handle_unknown='ignore')
    onehot.fit(x_train)
    x_train_oh = onehot.transform(x_train)
    x_test_oh = onehot.transform(x_test)
    knn = KNeighborsRegressor(n_neighbors=3)
    # knn cannot handle NaNs, so replace them with -1
    knn.fit(x_train_oh, y_train.fillna(-1))
    knn_predictions = pd.DataFrame(knn.predict(x_test_oh), columns=x_test.columns, index=x_test.index)
    return knn_predictions

**Other functions:**

In [9]:
def generate_result_frame(x_train, y_train, x_test, y_test) -> pd.DataFrame:
    data = collections.defaultdict(list)
    
    def append_model_results(model_name, predicts, actual):
        for index, row in predicts.iterrows():
            for column in predicts.columns:
                if pd.isna(predicts.loc[index, column]) or pd.isna(actual.loc[index, column]):
                    continue
                data['id'].append(index)
                data['model'].append(model_name)
                data['chain'].append(index[-1])
                data['position'].append(column)
                data['region'].append('')
                
                sasa, predicted = actual.loc[index, column], predicts.loc[index, column]
                data['sasa'].append(sasa)
                data['predicted'].append(predicted)
                data['error'].append(sasa-predicted)
                data['abs_error'].append(abs(sasa-predicted))


    append_model_results('avg_by_pos', predict_avg_by_pos(x_train, y_train, x_test), y_test)
    append_model_results('avg_by_same_res_pos', predict_avg_by_same_res_pos(x_train, y_train, x_test), y_test)
    append_model_results('median_by_same_res_pos', predict_median_by_same_res_pos(x_train, y_train, x_test), y_test)
    append_model_results('knn', predict_knn(x_train, y_train, x_test), y_test)

    result_df = pd.DataFrame(data, columns=data.keys())
    return result_df

In [10]:
def make_baseline_hist_grid(x_train, y_train, x_test, y_test) -> tuple:  
    print('predicting avg by pos...')
    avg_by_pos_dev = u.avg_deviations(y_test,
                                    predict_avg_by_pos(x_train, y_train, x_test))
    
    print('predicting avg by same res pos...')
    avg_by_same_res_pos_dev = u.avg_deviations(y_test,
                                    predict_avg_by_same_res_pos(x_train, y_train, x_test))
    
    print('predicting median by same res pos...')
    median_by_same_res_pos_dev = u.avg_deviations(y_test,
                                    predict_median_by_same_res_pos(x_train, y_train, x_test))
    
    print('predicting knn...')
    knn_dev = u.avg_deviations(y_test,
                            predict_knn(x_train, y_train, x_test))

    print('plotting grid... (this will take a while)')
    fig, ax = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(15, 10))
    model_results = [avg_by_pos_dev, avg_by_same_res_pos_dev, median_by_same_res_pos_dev, knn_dev]
    model_names = ['avg col', 'avg col-same-aa', 'med col-same-aa', 'knn']
    for i in range(4):
        print(f'prepare data for grid row {i+1} out of 4')
        for j in range(4):
            if i == j:
                continue
            data = model_results[i]-model_results[j]
            ax[i,j].hist(data, bins=50)
            ax[i,j].axvline(x=0, c='orange')
            ax[i,j].set_title(f'{model_names[i]} X {model_names[j]}\n Score: {data.sum():.2f}')
    fig.tight_layout()
    plt.show()
    
    return avg_by_pos_dev, avg_by_same_res_pos_dev, median_by_same_res_pos_dev, knn_dev

In [11]:
def nice_sequence_bar_plot(data: pd.Series, xlabel='', ylabel='', title='', only_nth_ticklabels=5, ylogscale=False):
    plt.figure(figsize = (15,5))
    plot = sns.barplot(x=data.index, y=data)
    plot.set_xlabel(xlabel)
    plot.set_ylabel(ylabel)
    plot.set_title(title)
    u.show_only_nth_ticklabel(plot, n=only_nth_ticklabels)
    if ylogscale:
        plot.set_yscale('log')

In [12]:
def get_deviations_per_position(data, model='', plot=False):
    # computete average mean by position
    rv = data
    rv = rv.assign(diff=pd.Series(abs(rv['sasa'] - rv['predicted'])).values)
    means_by_position = rv.groupby('position')['diff'].mean()

    # resort the pandas series
    sorted_labels = sorted(means_by_position.index, key=u.anarci_column_sorter)
    sorted_values = [means_by_position[label] for label in sorted_labels]
    sorted_deviations = pd.Series(sorted_values, index=sorted_labels)

    if plot:
        # plot
        nice_sequence_bar_plot(sorted_deviations, 
                               xlabel='ANARCI position',
                               ylabel='mean deviation',
                               title=f'Mean relative sasa deviation per ANARCI position | {model}')

    # cleanup
    del rv['diff']
    
    return sorted_deviations

In [13]:
def get_dataset(dataset: str, chains: str):
    raw_x, raw_y = u.load_dataset(dataset, chains=chains) 
    final_x, final_y, _ = fg.generate(raw_x, raw_y, None, '', 'lco_whole_sequence', dict(raw=True))
    final_x.index = raw_x['Id']
    final_y.index = raw_y['Id']
    return final_x, final_y

---

## Temporal split

In [None]:
train_x, train_y = get_dataset('train', CHAINS)
print(f'TRAIN sets generated, X.shape: {train_x.shape} Y.shape: {train_y.shape}\n----')
val_x, val_y = get_dataset('val', CHAINS)
print(f'VAL sets generated, X.shape: {val_x.shape} Y.shape: {val_y.shape}\n----')
test_x, test_y = get_dataset('test', CHAINS)
print(f'TEST sets generated, X.shape: {test_x.shape} Y.shape: {test_y.shape}')

In [None]:
train_x.head(n=3)

In [None]:
train_y_boxplot = train_y.boxplot(figsize=(15, 3))
u.show_only_nth_ticklabel(train_y_boxplot, n=5)

---

## TRAIN vs. VAL

### Generate predictions and comparison grid

**Generate baseline model comparison grid (running this cell may take a while):**

In [None]:
a_dev, asr_dev, asr_med_dev, knn_dev = make_baseline_hist_grid(train_x, train_y, val_x, val_y)

**Show result dataframe:**

In [None]:
results_val_df = generate_result_frame(train_x, train_y, val_x, val_y)
results_val_df

### Groundtruth vs. Prediction plots

**Generating each of the following plots may take a while:**

In [None]:
sns.set(rc={'figure.figsize': (15, 10)})
sns.relplot(data=results_val_df, x='predicted', y='sasa', hue='model', col='model', s=5)

In [None]:
sns.set(rc={'figure.figsize': (12, 8)})
sns.scatterplot(data=results_val_df, x='predicted', y='sasa', hue='model', s=5)

In [None]:
sns.set(rc={'figure.figsize': (12, 8)})
sns.kdeplot(data=results_val_df.sample(10000), x='predicted', y='sasa', hue='model')

In [None]:
rvd = results_val_df
sns.set(rc={'figure.figsize': (15, 10)})

f, axes = plt.subplots(4, 1)
 
for index, model_name in enumerate(rvd['model'].unique()):
    model_data = rvd[rvd['model'] == model_name]
    bp = sns.boxplot(x="position", y="abs_error", data=model_data, ax=axes[index])
    bp.set_title(f'Model {model_name}')
    u.show_only_nth_ticklabel(bp, n=5)

f.tight_layout()

**Compute dataframe with mean absolute error per sample and model:**

In [None]:
mean_abs_errors_df = results_val_df.groupby(['id', 'model'])['abs_error'].mean().to_frame().reset_index()
mean_abs_errors_df.head()

In [None]:
sns.set(rc={'figure.figsize': (8, 4)})
sns.violinplot(data=mean_abs_errors_df, x='model', y='abs_error')

In [None]:
sns.set(rc={'figure.figsize': (8, 4)})
sns.boxplot(data=mean_abs_errors_df, x='model', y='abs_error')

### Individual sequences - case study

**We randomly pick few sequences from the validation set and display one scatterplot (groundtruth vs predicted) for each (model X sequence) combination.**

**Look at the columns (going from the top to the bottom) and see how the points are getting closer to the x=y line (show in red) as the model of choice changes:**

In [None]:
rv = results_val_df
N_RANDOM_SEQUENCES = 3
N_BASELINE_MODELS = 4
sequences = np.random.choice(results_val_df['id'].unique(), N_RANDOM_SEQUENCES)

In [None]:
fig, ax = plt.subplots(N_BASELINE_MODELS, N_RANDOM_SEQUENCES, sharex=True, sharey=True, figsize=(15, 10))
for i, model_name in enumerate(results_val_df['model'].unique()):
    for j, sequence in enumerate(sequences):
        data = rv[(rv['id'] == sequence) & (rv['model'] == model_name)]
        ax[i,j].scatter(data['predicted'], data['sasa'])
        ax[i,j].set_title(f'{sequence}, {model_name}')
        
        line = mlines.Line2D([0, 1], [0, 1], color='red')
        transform = ax[i,j].transAxes
        line.set_transform(transform)
        ax[i,j].add_line(line)

### Average deviation per ANARCI position

In [None]:
devs = dict()
for model_name in results_val_df['model'].unique():
    dev = get_deviations_per_position(results_val_df[results_val_df['model'] == model_name],
                                      model=model_name,
                                      plot=True)
    devs[model_name] = dev

**Since the last two plots were fairly similar in shape, let us view the difference \
in their mean values by ANARCI position:**

In [None]:
diffs = devs['median_by_same_res_pos'] - devs['knn']
diffs.index = sorted(diffs.index, key=u.anarci_column_sorter)

In [None]:
nice_sequence_bar_plot(diffs, 
                       xlabel='ANARCI position',
                       ylabel='median avg2 - mean_knn (RSA units)',
                       title=f'mean difference in deviations | median_by_same_res_pos vs. knn')

**We see that for the majority of `ANARCI` positions the `knn` model performs better than `median_by_same_res_pos` model.**

**However, there actually still is a lot of positions for which this is not the case, because usually their values are close to zero and thus not visible on the bar plot with linear y scale.**

**Let us view counts and means of both positive and negative differences:**

In [None]:
len(diffs[diffs < 0]), len(diffs[diffs > 0])

In [None]:
diffs[diffs < 0].mean(), diffs[diffs > 0].mean()

**While setting log scale on the y-axis, the plot looks a bit wild right now:**

In [None]:
nice_sequence_bar_plot(diffs, 
                       xlabel='ANARCI position',
                       ylabel='median avg2 - mean_knn (RSA units)',
                       title=f'mean difference in deviations | median_by_same_res_pos vs. knn',
                       ylogscale=True)

---

## TRAIN vs. TEST

**Generate baseline model comparison grid (running this cell may take a while):**

In [None]:
make_baseline_hist_grid(train_x, train_y, test_x, test_y)

---

## Further ideas

- We see different results on different data sets (test vs val)..

- maybe let us do bootstrap resampling + statistical testing 
to obtain more robust model comparison??