In [None]:
from src.init_params import N_ELECTRODES, MATERIAL_ORDER, batch_size, N_PEAKS
from src.represent_cv_sample import DataSample
from src.represent_tea_sample import CompoundDataSample, DataEncodeMode
from src.represent_ms_sample import MSDataSample
from src.preprocess_data import split_dataset, make_data, aug
from src.metrics import calc_metrics
from src.FCN import FCN_Model, train, plot_history, evaluate

In [None]:
from collections import defaultdict, namedtuple
from typing import Tuple, List, Dict, Optional
from pathlib import Path
import numpy as np

## Collect CV data

In [None]:
train_datadir = Path("CV")
data: Dict[Tuple[int], DataSample] = defaultdict(list)
    
for sample_file in train_datadir.glob('*.txt'):
    ds = DataSample(sample_file)
    if ds.target == -1:
        continue
    if ds.metadata.mat == -1:
        continue
    data[ds.key].append(ds)
    
assert all([len(electrodes) == N_ELECTRODES for key, electrodes in data.items()])

## Represent CV characteristic fingerprint

In [None]:
compound_data = [CompoundDataSample(samples) for samples in data.values()]
n_compounds = len(compound_data)
print(f'n = {n_compounds}')

In [None]:
cds = CompoundDataSample(list(data.values())[0])
nn = cds.to_nn_sample(mode=DataEncodeMode.SECOND_CYCLE_ONLY, onehot_target=False, vector=True)
print(nn['cycles'].shape, nn['target'])

In [None]:
compound_data[17].represent()

In [None]:
compound_data = [CompoundDataSample(samples) for samples in data.values()]
data_train, data_val = split_dataset(compound_data, 0.4)

## Train XGBoost

In [None]:
import xgboost as xgb

In [None]:
(X_train, y_train), (X_val, y_val) = make_data(data_train, data_val,
                                               batch_size=batch_size,
                                               mode=DataEncodeMode.ALL_CYCLES,
                                               augmentations=[], 
                                               vector=True,
                                               onehot_target=False,
                                               to_tf=False)

In [None]:
xgb_cl = xgb.XGBClassifier(objective="multi:softmax", max_depth=20, n_estimators=100)
xgb_cl.fit(X_train, y_train)

In [None]:
preds = xgb_cl.predict(X_val)
calc_metrics(preds, y_val)

### Train FCN on CVs

In [None]:
!del FCN.weights.h5*
all_cycles = Path('FCN.weights.h5') 

In [None]:
mode = DataEncodeMode.ALL_CYCLES #alternatively use FIRST_CYCLE_ONLY or SECOND_CYCLE_ONLY
train_dataset, val_dataset = make_data(data_train, data_val,
                                       batch_size=batch_size,
                                       mode=mode,
                                       augmentations=[aug], 
                                       vector=False,
                                       onehot_target=True,
                                       to_tf=True)

In [None]:
model = FCN_Model(mode=mode,
                                   n_filters_sequence = (8, 16, 32),
                                   kernel_size_sequence = (8, 4, 2),
                                   pool_window_size_sequence = (4, 4, 2))

In [None]:
history = train(all_cycles, model, train_dataset, val_dataset, 200)

In [None]:
plot_history(history)

In [None]:
evaluate(model, val_dataset)

# Collect MS data

In [None]:
train_datadir = Path("MS").glob('*.mzXML')
ms_data: List[MSDataSample] = []
for sample_file in train_datadir:
    ds = MSDataSample(sample_file.as_posix())
    ms_data.append(ds)

In [None]:
#ms_data[0].represent()

In [None]:
ms_data_train, ms_data_val = split_dataset(ms_data, 0.4)

In [None]:
print(len(ms_data_train), len(ms_data_val))

## Train FCN

In [None]:
mode = DataEncodeMode.MS
ms_train_dataset, ms_val_dataset = make_data(ms_data_train, ms_data_val,
                                             batch_size=batch_size,
                                             mode=mode,
                                             augmentations=[], 
                                             vector=False,
                                             onehot_target=True,
                                             to_tf=True,
                                             do_preprocess=False)

In [None]:
model = FCN_Model(mode=DataEncodeMode.MS, 
                                   n_filters_sequence=(8, 16, 32),
                                   kernel_size_sequence=(4, 2, 3),
                                   pool_window_size_sequence=(2, 2, 2),
                                   input_sequence_length=N_PEAKS)

In [None]:
ms_path = Path('ms.weights.h5')

In [None]:
history = train(ms_path, model, ms_train_dataset, ms_val_dataset, 100)

In [None]:
plot_history(history)

In [None]:
evaluate(model, ms_val_dataset)