## Baseline-решение предсказания свойств ФБ, с помощью kNN (сравнение последовательностей)

### Подготовка к работе

In [None]:
! pip install --upgrade git+https://github.com/rimgro/biocadprotein.git

In [16]:
from fpgen.prop_prediction.dataset import FPbase
from fpgen.generation.metrics import identity

from fpgen.prop_prediction.metrics import get_regression_metrics, get_classification_metrics, bootstrap_metric_ci

from sklearn.model_selection import train_test_split, KFold

import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import Counter

### Загрузка датасета

In [17]:
dataset = FPbase('data/fpbase.csv')
df_ident = pd.read_csv('data/sequence_distance_matrix.csv', index_col=0)

### Реализация алгоритма KNN

Поиск расстояния между последовательнотсями аминокислот.

In [18]:
def ident(seq_1, seq_2):
    return df_ident.loc[seq_1.strip(), seq_2.strip()]

Реализация KNN в задаче регрессии и классификации.

In [19]:
def knn(x_train, y_train, x_test, k, knn_type):
    predictions = []
    for test_seq in x_test:
        similarities = []
        for train_seq, train_target in zip(x_train, y_train):
            sim = ident(test_seq, train_seq)
            similarities.append((train_target, sim))
        similarities.sort(key=lambda x: -x[1])
        neighbors = similarities[:k]
        if knn_type == 'reg':
            pred = np.mean([neighbor[0] for neighbor in neighbors])
            predictions.append(pred)
        elif knn_type == 'class':
            neighbor_classes = [neighbor[0] for neighbor in neighbors]
            most_common_class = Counter(neighbor_classes).most_common(1)[0][0]
            predictions.append(most_common_class)
    return predictions

Подбор гиперпараметров, методом кросс-валидации.

In [21]:
def cross_validate(x_train, y_train, kf_split, k_max, problem_type='class'):
    kf = KFold(n_splits=kf_split)
    
    if problem_type == 'reg':
        best_score = np.inf
    elif problem_type == 'class':
        best_score = -np.inf
    
    best_k = 1

    for k in tqdm(range(1, k_max + 1)):
        fold_scores = []

        for train_index, test_index in kf.split(x_train):
            X_train, X_test = x_train[train_index], x_train[test_index]
            y_train_fold, y_test = y_train[train_index], y_train[test_index]

            if problem_type == 'reg':
                predict = knn(X_train, y_train_fold, X_test, k, knn_type='reg')
            elif problem_type == 'class':
                predict = knn(X_train, y_train_fold, X_test, k, knn_type='class')

            valid_mask = ~pd.isna(predict)
            if sum(valid_mask) == 0:
                continue

            if problem_type == 'reg':
                metrics = get_regression_metrics(np.array(predict)[valid_mask], y_test[valid_mask])
                fold_scores.append(metrics['rmse'])
            elif problem_type == 'class':
                metrics = get_classification_metrics(np.array(predict)[valid_mask], y_test[valid_mask])
                fold_scores.append(metrics['accuracy'])

        if not fold_scores:
            continue
            
        mean_score = np.mean(fold_scores)
        
        if problem_type == 'reg' and mean_score < best_score:
            best_score = mean_score
            best_k = k
        elif problem_type == 'class' and mean_score > best_score:
            best_score = mean_score
            best_k = k
            
    return best_k

## Метрики

In [22]:
def metrics_reg(metrics):
    print(f'\t RMSE: {metrics["rmse"]}')
    print(f'\t MAE: {metrics["mae"]}')
    print(f'\t R2: {metrics["r2"]}')
    print(f'\t MAE (med.): {metrics["mae_median"]}')

def metrics_class(metrics):
    print(f'\t Accuracy: {metrics["accuracy"]}')
    print(f'\t Precision: {metrics["precision"]}')
    print(f'\t Recall: {metrics["recall"]}')
    print(f'\t F1: {metrics["f1"]}')

In [23]:
for item in dataset.targets:
    if item != 'agg' and item != 'switch_type':
        print(item)
        x_train, y_train = dataset.get_train(item)
        x_test, y_test = dataset.get_test(item)

        k = cross_validate(x_train, y_train, 4, 30, 'reg')
        print(f'k: {k}')
        y_pred = knn(x_train, y_train, x_test, k, 'reg')

        y_test_rescaled = dataset.rescale_targets(y_test, item)
        y_pred_rescaled = dataset.rescale_targets(y_pred, item)

        print('Scaled:')
        metrics_reg(get_regression_metrics(y_pred, y_test))
        print('Rescaled:')
        metrics_reg(get_regression_metrics(y_pred_rescaled, y_test_rescaled))
        print('')
    else:
        print(item)
        x_train, y_train = dataset.get_train(item, is_scaled=False)
        x_test, y_test = dataset.get_test(item, is_scaled=False)

        k = cross_validate(x_train, y_train, 4, 30, 'class')
        print(f'k: {k}')
        y_pred = knn(x_train, y_train, x_test, k, 'class')
        
        metrics_class(get_classification_metrics(y_pred, y_test))
        print('')

brightness


 23%|██▎       | 7/30 [00:06<00:21,  1.08it/s]


KeyboardInterrupt: 

In [26]:
for item in dataset.targets:
    if item != 'agg' and item != 'switch_type':
        print(item)
        x_train, y_train = dataset.get_train(item)
        x_test, y_test = dataset.get_test(item)

        k = cross_validate(x_train, y_train, 4, 30, 'reg')
        print(f'k: {k}')
        y_pred = knn(x_train, y_train, x_test, k, 'reg')

        y_test_rescaled = dataset.rescale_targets(y_test, item)
        y_pred_rescaled = dataset.rescale_targets(y_pred, item)

        print('Scaled:')
        metrics_reg(get_regression_metrics(y_pred, y_test))
        print('Rescaled:')
        metrics_reg(get_regression_metrics(y_pred_rescaled, y_test_rescaled))
        print('')

        print('')
        metrics_ci = bootstrap_metric_ci(
            y_pred_rescaled, y_test_rescaled, get_regression_metrics, 
            n_bootstrap=1000, alpha=0.05, random_state=42
        )

        print("\nMetrics with 95% confidence intervals:")
        print(f"RMSE: {(metrics_ci['rmse'][1] + metrics_ci['rmse'][2]) / 2:.2f} ± {(metrics_ci['rmse'][1] + metrics_ci['rmse'][2]) / 2 -metrics_ci['rmse'][1]:.2f}")
        print(f"MAE: {(metrics_ci['mae'][1] + metrics_ci['mae'][2]) / 2:.2f} ± {(metrics_ci['mae'][1] + metrics_ci['mae'][2]) / 2 -metrics_ci['mae'][1]:.2f}")
        print(f"R2: {(metrics_ci['r2'][1] + metrics_ci['r2'][2]) / 2:.2f} ± {(metrics_ci['r2'][1] + metrics_ci['r2'][2]) / 2 -metrics_ci['r2'][1]:.2f}")
        print(f"MAE Median: {(metrics_ci['mae_median'][1] + metrics_ci['mae_median'][2]) / 2:.2f} ± {(metrics_ci['mae_median'][1] + metrics_ci['mae_median'][2]) / 2 -metrics_ci['mae_median'][1]:.2f}")
    else:
        print(item)
        x_train, y_train = dataset.get_train(item, is_scaled=False)
        x_test, y_test = dataset.get_test(item, is_scaled=False)

        k = cross_validate(x_train, y_train, 4, 30, 'class')
        print(f'k: {k}')
        y_pred = knn(x_train, y_train, x_test, k, 'class')
        
        metrics_class(get_classification_metrics(y_pred, y_test))
        

brightness


100%|██████████| 30/30 [00:23<00:00,  1.29it/s]


k: 5
Scaled:
	 RMSE: 0.7412173589207028
	 MAE: 0.507162649397322
	 R2: 0.48928575551555786
	 MAE (med.): 0.3323219950617097
Rescaled:
	 RMSE: 22.855105559444116
	 MAE: 15.638133333333334
	 R2: 0.48928575551555775
	 MAE (med.): 10.247



Metrics with 95% confidence intervals:
RMSE: 23.17 ± 5.63
MAE: 15.68 ± 2.73
R2: 0.47 ± 0.21
MAE Median: 11.09 ± 3.01
em_max


100%|██████████| 30/30 [00:47<00:00,  1.58s/it]


k: 1
Scaled:
	 RMSE: 0.5277660222068801
	 MAE: 0.24677351033531916
	 R2: 0.6888330674145302
	 MAE (med.): 0.03112626603172086
Rescaled:
	 RMSE: 33.91129675940145
	 MAE: 15.8562874251497
	 R2: 0.6888330674145303
	 MAE (med.): 2.0



Metrics with 95% confidence intervals:
RMSE: 33.36 ± 8.56
MAE: 15.95 ± 4.39
R2: 0.67 ± 0.16
MAE Median: 3.50 ± 1.50
ex_max


100%|██████████| 30/30 [01:05<00:00,  2.20s/it]


k: 3
Scaled:
	 RMSE: 0.560141054302693
	 MAE: 0.3526708467947944
	 R2: 0.6677765093830835
	 MAE (med.): 0.1795792685665914
Rescaled:
	 RMSE: 37.43022624652124
	 MAE: 23.566473988439306
	 R2: 0.6677765093830835
	 MAE (med.): 12.0



Metrics with 95% confidence intervals:
RMSE: 37.25 ± 6.97
MAE: 23.88 ± 4.24
R2: 0.65 ± 0.13
MAE Median: 13.33 ± 4.00
ext_coeff


100%|██████████| 30/30 [00:24<00:00,  1.21it/s]


k: 2
Scaled:
	 RMSE: 0.8337854146769339
	 MAE: 0.5829458364789102
	 R2: 0.4346668268309095
	 MAE (med.): 0.35492514186645163
Rescaled:
	 RMSE: 31479.10147451571
	 MAE: 22008.79365079365
	 R2: 0.4346668268309094
	 MAE (med.): 13400.0



Metrics with 95% confidence intervals:
RMSE: 31592.52 ± 6024.91
MAE: 22285.22 ± 3953.62
R2: 0.42 ± 0.15
MAE Median: 15725.00 ± 3725.00
lifetime


100%|██████████| 30/30 [00:02<00:00, 11.58it/s]


k: 4
Scaled:
	 RMSE: 1.5785460811112475
	 MAE: 0.6631146885927949
	 R2: 0.3601461220166019
	 MAE (med.): 0.245165641909308
Rescaled:
	 RMSE: 1.8350272477541036
	 MAE: 0.7708571428571429
	 R2: 0.3601461220166018
	 MAE (med.): 0.28500000000000014



Metrics with 95% confidence intervals:
RMSE: 1.77 ± 1.25
MAE: 0.89 ± 0.52
R2: 0.51 ± 0.30
MAE Median: 0.34 ± 0.16
maturation


100%|██████████| 30/30 [00:01<00:00, 15.20it/s]


k: 20
Scaled:
	 RMSE: 0.47439700100830845
	 MAE: 0.3670827142930148
	 R2: 0.07842265056394604
	 MAE (med.): 0.3182861769729677
Rescaled:
	 RMSE: 100.49515218407302
	 MAE: 77.76194444444444
	 R2: 0.07842265056394582
	 MAE (med.): 67.425



Metrics with 95% confidence intervals:
RMSE: 99.94 ± 24.16
MAE: 78.72 ± 20.62
R2: -1.21 ± 1.56
MAE Median: 62.17 ± 24.81
pka


100%|██████████| 30/30 [00:10<00:00,  2.96it/s]


k: 12
Scaled:
	 RMSE: 1.230376790021117
	 MAE: 0.8062922685857539
	 R2: 0.10611251314067638
	 MAE (med.): 0.48329206908236994
Rescaled:
	 RMSE: 1.3842920717574583
	 MAE: 0.9071562499999999
	 R2: 0.10611251314067627
	 MAE (med.): 0.5437499999999997



Metrics with 95% confidence intervals:
RMSE: 1.36 ± 0.35
MAE: 0.91 ± 0.23
R2: 0.09 ± 0.13
MAE Median: 0.52 ± 0.15
stokes_shift


 90%|█████████ | 27/30 [00:42<00:04,  1.57s/it]


KeyboardInterrupt: 