In [None]:
from bread.vis import *
from bread.data import Features, SegmentationFile, Microscopy
from bread.algo.lineage import LineageGuesser, LineageGuesserNN, LineageGuesserML, LineageGuesserNearestCell, accuracy
from bread.data import Lineage

import pandas as pd
import numpy as np
import pathlib
import os
from sklearn.preprocessing import MinMaxScaler

## define functions to extract features in batches


In [None]:
# required functions for feature extraction
def extract_features(segmentation, guesser):
    candidate_features = pd.DataFrame()
    bud_ids, time_ids = segmentation.find_buds(
    ).bud_ids, segmentation.find_buds().time_ids
    f_list = []
    for i, (bud_id, time_id) in enumerate(zip(bud_ids, time_ids)):
        frame_range = guesser.segmentation.request_frame_range(
            time_id, time_id + guesser.num_frames)
        num_frames_available = guesser.num_frames
        if len(frame_range) < 2:
            # not enough frames
            continue
        if len(frame_range) < guesser.num_frames:
            num_frames_available = len(frame_range)

        # check the bud still exists !
        for time_id_ in frame_range:
            if bud_id not in guesser.segmentation.cell_ids(time_id_):
                # bud has disappeared
                continue
        selected_times = [i for i in range(
            time_id, time_id + num_frames_available)]
        try:
            candidate_parents = guesser._candidate_parents(
                time_id, nearest_neighbours_of=bud_id)
            for c_id, candidate in enumerate(candidate_parents):
                features, f_list = guesser._get_features(
                    bud_id, candidate, time_id, selected_times)
                new_row = {'bud_id': bud_id,
                           'candid_id': candidate, 'time_id': time_id}
                new_row.update(features)
                new_df = pd.DataFrame(new_row, index=[0])
                candidate_features = pd.concat([candidate_features, new_df])
        except Exception as e:
            print("Error for bud {} at time {} with candidate: {}".format(
                bud_id, time_id, e))
    return candidate_features, f_list

# turn features in to a matrix format


def get_custom_matrix_features(features_all, lineage_gt, feature_list, filling_features=[-100 for i in range(100)]):
    # Generate np array of feature sets for each bud
    lineage = lineage_gt.copy()
    # remove the rows with parent_GT = -1 (no parent) and the rows with candid_GT = -2 (disappearing buds)
    lineage = lineage.loc[lineage.parent_GT != -1]
    lineage = lineage.loc[lineage.parent_GT != -2]
    lineage = lineage.loc[lineage.parent_GT != -3]
    candidate_features = features_all.copy()
    features_list = []
    filling_features = filling_features[:len(feature_list)]
    parent_index_list = []
    candidate_list = []
    for bud, colony in lineage[['bud_id', 'colony']].values:
        bud_data = candidate_features.loc[(candidate_features['bud_id'] == bud) & (
            candidate_features['colony'] == colony)]
        candidates = bud_data['candid_id'].to_numpy()
        features = bud_data[feature_list].to_numpy()
        if (len(candidates) == 0):
            if (len(bud_data) == 0):
                # bud only appears in the last frame
                lineage.drop(lineage.loc[(lineage['bud_id'] == bud) & (lineage.colony == colony)].index,
                             inplace=True)
            else:
                print('no candidates', bud, colony, candidates)
            continue
        elif candidates.shape[0] < 4:
            n_rows = 4 - candidates.shape[0]
            if candidates.shape[0] == 1:
                # only one candidate
                # fill with -100
                candidates = np.concatenate(
                    (candidates, np.array([-3 for i in range(n_rows)])), axis=0)
                features = np.concatenate(
                    (features, np.full((n_rows, features.shape[1]), -100)), axis=0)
            else:  # more than one candidate
                parent = int(lineage.loc[(lineage['bud_id'] == bud) & (
                    lineage['colony'] == colony), 'parent_GT'].iloc[0])
                not_parent_data = candidate_features.loc[(candidate_features['bud_id'] == bud) & (
                    candidate_features['colony'] == colony) & (candidate_features['candid_id'] != parent)]
                filling_features = not_parent_data[feature_list].to_numpy()[0]
                features = np.concatenate(
                    (features, [filling_features for i in range(n_rows)]), axis=0)
                candidates = np.concatenate(
                    (candidates, [-3 for i in range(n_rows)]), axis=0)
        elif features.shape[0] > 4:
            sorted_indices = np.argsort(features[:, 0])
            print('more than 4 candidates', bud, colony, candidates, sorted_indices, int(lineage.loc[(lineage['bud_id'] == bud) & (
                lineage['colony'] == colony), 'parent_GT']))
            # slice the top 4 rows
            features = features[sorted_indices[:4]]
            candidates = candidates[sorted_indices[:4]]

        parent = int(lineage.loc[(lineage['bud_id'] == bud) & (
            lineage['colony'] == colony), 'parent_GT'].iloc[0])
        if (parent not in candidates):
            print('parent not in candidates', bud, colony, candidates, parent)
            lineage.drop(lineage.loc[(lineage['bud_id'] == bud) & (
                lineage.colony == colony)].index, inplace=True)
            continue
        else:
            parent_index = np.where(candidates == parent)[0][0]
        parent_index_list.append(parent_index)
        features_list.append(features)
        candidate_list.append(candidates)
    lineage['features'] = features_list
    lineage['candidates'] = candidate_list
    lineage['parent_index_in_candidates'] = parent_index_list
    return lineage

## define functions for training neural network


In [None]:
import itertools
import wandb
import torch.optim as optim
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim.lr_scheduler as lr_scheduler
import torch
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
import os


class BudDataset(Dataset):
    def __init__(self, data, augment=True):
        X = data['features'].to_numpy()
        labels = data['parent_index_in_candidates'].to_numpy()
        if (augment):
            X, labels = generate_all_permutations(X, labels)
        X = flatten_3d_array(X)
        self.data = torch.tensor(X, dtype=torch.float32)
        self.labels = torch.zeros(len(labels), 4)  # initialize labels as zeros
        for i, label in enumerate(labels):
            if label != -1:
                # set the position of the correct parent to 1
                self.labels[i][label] = 1.0

    def __getitem__(self, index):
        data = self.data[index]
        return self.data[index], self.labels[index]

    def __len__(self):
        return len(self.labels)

# lineage NN with mask


class LineageNN(nn.Module):
    def __init__(self, layers):
        super(LineageNN, self).__init__()
        self.layers = nn.ModuleList()  # create an empty nn.ModuleList
        for i in range(len(layers)-1):
            self.layers.append(nn.Linear(layers[i], layers[i+1]))

    def forward(self, x):
        mask = (x != -100).float()
        x = x * mask  # apply the mask to zero out invalid values
        for layer in self.layers[:-1]:
            x = F.relu(layer(x))
        x = self.layers[-1](x)
        return x


# functions to deal with data


def flatten_3d_array(arr):
    """
    Flattens a 3-dimensional numpy array while keeping the first dimension unchanged
    """
    if arr.ndim == 1:
        arr2 = np.array([arr[i] for i in range(len(arr))])
        arr = np.stack(arr2)
    shape = arr.shape
    new_shape = (shape[0], np.prod(shape[1:]))
    return arr.reshape(new_shape)


def permute_matrix(matrix, row_id):
    """
    Generates all possible permutations of a matrix  rows
    it takes row_index as input, which is a one-hot encoded label for the classification 
    and outputs the one-hot encoded labels of the permutated matrices
    """
    # Get the number of rows in the matrix
    rows = len(matrix)

    # Get all possible permutations of the row indices
    permutations = list(itertools.permutations(range(rows)))

    # Use list comprehension to create a list of all permuted matrices
    permuted_matrices = [np.array([matrix[i] for i in permutation])
                         for permutation in permutations]

    # Use list comprehension to find the index of the specified row in each permuted matrix
    row_indices = [list(permutation).index(row_id)
                   for permutation in permutations]

    return permuted_matrices, row_indices


def generate_all_permutations(data, labels):
    """
    Generates all posible permutations for matrices in data 
    and the corresponding labels
    Labels should be integers
    """

    permuted_matrices_list = []
    permuted_labels_list = []

    for matrix, label in zip(data, labels):
        permuted_matrices, permuted_labels = permute_matrix(matrix, label)
        permuted_matrices_list.extend(permuted_matrices)
        permuted_labels_list.extend(permuted_labels)

    return np.array(permuted_matrices_list), np.array(permuted_labels_list)

# functions to train and test the model


def train_nn(train_df, eval_df, save_path='bst_nn.pth', config={}, seed=42):
    # Initialize wandb
    use_wandb = config['use_wandb']
    if (use_wandb):
        wandb.init(project="lineage_tracing", group='with_mask', config=config)

    # initialize neural network
    # manualy set the seed to enable reproducibility
    torch.manual_seed(seed)
    net = LineageNN(layers=config['layers'])

    # define your loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=config['lr'])
    scheduler = lr_scheduler.LinearLR(
        optimizer, start_factor=1, end_factor=0.01, total_iters=200)

    train_bud_dataset = BudDataset(train_df, augment=config['augment'])
    train_bud_dataloader = DataLoader(
        train_bud_dataset, batch_size=config['batch_size'], shuffle=True)
    eval_bud_dataset = BudDataset(eval_df, augment=False)
    eval_bud_dataloader = DataLoader(
        eval_bud_dataset, batch_size=config['batch_size'], shuffle=True)

    # train your neural network
    patient = 0
    best_accuracy = 0.0
    for epoch in range(config['epoch_n']):
        running_loss = 0.0
        # training loop
        predicted_all = []
        labels_all = []
        net.train()
        for i, data in enumerate(train_bud_dataloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            # forward pass
            outputs = net(inputs)
            # calculate the loss
            loss = criterion(outputs, labels)
            _, labels = torch.max(labels.data, 1)
            _, predicted = torch.max(outputs.data, 1)
            predicted_all.extend(predicted)
            labels_all.extend(labels)
            # backward pass and optimize
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            scheduler.step()
        train_accuracy = accuracy_score(labels_all, predicted_all)
        # eval loop
        predicted_all = []
        labels_all = []
        net.eval()
        for i, data in enumerate(eval_bud_dataloader, 0):
            inputs, labels = data
            with torch.no_grad():
                outputs = net(inputs)
            _, labels = torch.max(labels.data, 1)
            _, predicted = torch.max(outputs.data, 1)
            predicted_all.extend(predicted)
            labels_all.extend(labels)
        eval_accuracy = accuracy_score(labels_all, predicted_all)
        if (eval_accuracy > best_accuracy):
            best_accuracy = eval_accuracy
            best_model = net
            torch.save(net.state_dict(), save_path)
            patient = 0
        else:
            patient += 1
        if (patient > config['patience']):
            print('early stopping at ', epoch, 'LR: ',
                  optimizer.param_groups[0]['lr'])
            break
        if (use_wandb):
            wandb_log = {'epoch': epoch, 'patience': patient, 'eval_accuracy': eval_accuracy,
                         'train_accuracy': train_accuracy, 'best_accuracy': best_accuracy, 'lr': optimizer.param_groups[0]['lr']}
            wandb.log(wandb_log)
    return best_model, best_accuracy


def test_nn(model, test_df):
    bud_dataset = BudDataset(test_df, augment=False)
    bud_dataloader = DataLoader(
        bud_dataset, batch_size=len(test_df), shuffle=False)
    for i, data in enumerate(bud_dataloader, 0):
        if (i > 0):
            print('more than one batch')
        inputs, labels = data
        with torch.no_grad():
            outputs = model(inputs)
        _, labels = torch.max(labels.data, 1)
        _, predicted = torch.max(outputs.data, 1)
        accuracy = accuracy_score(predicted, labels)
    print('test accuracy', accuracy)
    test_df['predicted'] = predicted
    return test_df, accuracy


def cv_nn(df, config={}, seed=42):
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    X = df['features'].to_numpy()
    y = df['parent_index_in_candidates'].to_numpy()
    # repeat df because it necessary for the function to have two arguments
    skf.get_n_splits(X, y)
    accuracies = []
    models = []
    i = 0
    # repeat df because it necessary for the function to have two arguments
    for train_index, test_index in skf.split(X, y):
        i = i+1
        config['cv_number'] = i
        print('config: ', config)
        train_df = df.iloc[train_index]
        test_df = df.iloc[test_index]
        net, accuracy = train_nn(train_df, test_df, config=config, seed=seed)
        accuracies.append(accuracy)
        models.append(net)

    print('accuracy: ', np.mean(accuracies), '+/-', np.std(accuracies))
    return models, accuracies

## A pipeline to train a neural network based on given arguments


In [None]:
# constant values and arguments
selected_keys = ['dist_0', 'dist_std', 'poly_fit_budcm_budpt', 'poly_fit_expansion_vector', 'position_bud_std', 'position_bud_max', 'position_bud_min', 'position_bud_last', 'position_bud_first',
                 'orientation_bud_std', 'orientation_bud_max', 'orientation_bud_min', 'orientation_bud_last', 'orientation_bud_first', 'orientation_bud_last_minus_first', 'plyfit_orientation_bud']
args = {'fov': 0, 'bud_distance_max': 12, 'num_frames': 8,
        'num_frames_refractory': 0, 'normalized': True, 'selected_keys': selected_keys}
# config for trianing neural network
config = {'owner': 'train_pipeline', 'epoch_n': 200, 'patience': 10, 'lr': 0.01, 'batch_size': 256, 'layers': [
    64, 64, 4], 'augment': True, 'features': selected_keys, 'dist_threshold': args['bud_distance_max'], 'num_frames': args['num_frames'], 'normalized': args['normalized'], 'use_wandb': False}
data_path = '../../../data_edited/'

res = {}

In [None]:
# load ground truth data
lin_gt_all = pd.DataFrame()
for colony in range(0, 6):
    lineage_gt_path = os.path.join(
        data_path, f'colony00{colony}_lineage_gt.csv')
    lin_truth = pd.read_csv(lineage_gt_path)
    lin_truth['colony'] = colony
    lin_gt_all = pd.concat([lin_gt_all, lin_truth])
lin_gt_all.rename(columns={'# parent_id': 'parent_GT'}, inplace=True)
lin_gt_all.reset_index(inplace=True, drop=True)

In [None]:
# extracting features and normalizing them if required
features_all = pd.DataFrame()
for colony in range(0, 6):
    print(f'Processing colony {colony}')
    segmentation_file = os.path.join(
        data_path, f'colony00{colony}_segmentation.h5')
    segmentation = SegmentationFile.from_h5(
        segmentation_file).get_segmentation('FOV'+str(args['fov']))
    guesser = LineageGuesserNN(
        segmentation=segmentation,
        nn_threshold=args['bud_distance_max'],
        num_frames_refractory=args['num_frames_refractory'],
        num_frames=args['num_frames'],
    )
    features, f_list = extract_features(segmentation, guesser)
    features['colony'] = colony
    features_all = pd.concat([features_all, features])
X = features_all[selected_keys]
scaler = MinMaxScaler()
# save this scaler to use it later for normalization
X_norm = scaler.fit_transform(X)
if (args['normalized']):
    # normalize features
    features_all_normalized = features_all.copy()
    features_all_normalized[selected_keys] = X_norm
    features_all = features_all_normalized

Processing colony 0
No model was provided




Error for bud 1 at time 0 with candidate: No candidate parents have been found for in frame #0.
Error for bud 2 at time 0 with candidate: No candidate parents have been found for in frame #0.




Error for bud 3 at time 0 with candidate: No candidate parents have been found for in frame #0.
Error for bud 4 at time 0 with candidate: No candidate parents have been found for in frame #0.




Error for bud 44 at time 74 with candidate: No candidate parents have been found for in frame #74.
Error for bud 305 at time 140 with candidate: Unable to find cell_id=94 at time_id=147 in the segmentation.
Error for bud 306 at time 140 with candidate: Unable to find cell_id=36 at time_id=147 in the segmentation.
Error for bud 307 at time 140 with candidate: Unable to find cell_id=74 at time_id=147 in the segmentation.
Error for bud 308 at time 140 with candidate: Unable to find cell_id=160 at time_id=147 in the segmentation.
Error for bud 309 at time 140 with candidate: Unable to find cell_id=42 at time_id=147 in the segmentation.
Error for bud 310 at time 140 with candidate: Unable to find cell_id=22 at time_id=147 in the segmentation.
Error for bud 311 at time 140 with candidate: Unable to find cell_id=109 at time_id=147 in the segmentation.
Error for bud 312 at time 140 with candidate: Unable to find cell_id=1 at time_id=147 in the segmentation.
Error for bud 313 at time 140 with c



Error for bud 334 at time 142 with candidate: Unable to find cell_id=5 at time_id=147 in the segmentation.
Error for bud 335 at time 142 with candidate: Unable to find cell_id=161 at time_id=147 in the segmentation.
Error for bud 336 at time 143 with candidate: Unable to find cell_id=106 at time_id=147 in the segmentation.
Error for bud 337 at time 143 with candidate: Unable to find cell_id=333 at time_id=147 in the segmentation.
Error for bud 338 at time 143 with candidate: Unable to find cell_id=25 at time_id=147 in the segmentation.
Error for bud 339 at time 143 with candidate: Unable to find cell_id=160 at time_id=147 in the segmentation.
Error for bud 340 at time 143 with candidate: Unable to find cell_id=19 at time_id=147 in the segmentation.
Error for bud 341 at time 143 with candidate: No candidate parents have been found for in frame #143.




Error for bud 342 at time 143 with candidate: Unable to find cell_id=22 at time_id=147 in the segmentation.
Error for bud 343 at time 143 with candidate: Unable to find cell_id=48 at time_id=147 in the segmentation.
Error for bud 344 at time 143 with candidate: Unable to find cell_id=31 at time_id=147 in the segmentation.
Error for bud 345 at time 143 with candidate: Unable to find cell_id=17 at time_id=147 in the segmentation.
Error for bud 346 at time 143 with candidate: Unable to find cell_id=79 at time_id=147 in the segmentation.
Error for bud 347 at time 144 with candidate: Unable to find cell_id=54 at time_id=147 in the segmentation.
Error for bud 348 at time 144 with candidate: Unable to find cell_id=95 at time_id=147 in the segmentation.
Error for bud 349 at time 144 with candidate: Unable to find cell_id=52 at time_id=147 in the segmentation.
Error for bud 350 at time 144 with candidate: Unable to find cell_id=89 at time_id=147 in the segmentation.
Error for bud 351 at time 14



Error for bud 361 at time 146 with candidate: Unable to find cell_id=341 at time_id=147 in the segmentation.
Error for bud 362 at time 146 with candidate: Unable to find cell_id=43 at time_id=147 in the segmentation.
Error for bud 363 at time 146 with candidate: Unable to find cell_id=86 at time_id=147 in the segmentation.
Error for bud 364 at time 146 with candidate: Unable to find cell_id=88 at time_id=147 in the segmentation.
Error for bud 365 at time 146 with candidate: Unable to find cell_id=208 at time_id=147 in the segmentation.
Error for bud 366 at time 146 with candidate: Unable to find cell_id=26 at time_id=147 in the segmentation.
Error for bud 367 at time 146 with candidate: Unable to find cell_id=91 at time_id=147 in the segmentation.
Error for bud 368 at time 146 with candidate: Unable to find cell_id=34 at time_id=147 in the segmentation.
Error for bud 369 at time 146 with candidate: Unable to find cell_id=23 at time_id=147 in the segmentation.
Error for bud 370 at time 



Error for bud 77 at time 167 with candidate: No candidate parents have been found for in frame #167.
Processing colony 4
No model was provided
Error for bud 40 at time 123 with candidate: Unable to find cell_id=40 at time_id=124 in the segmentation.
Error for bud 50 at time 136 with candidate: Unable to find cell_id=34 at time_id=141 in the segmentation.
Error for bud 55 at time 137 with candidate: Unable to find cell_id=34 at time_id=141 in the segmentation.
Error for bud 82 at time 154 with candidate: Unable to find cell_id=82 at time_id=161 in the segmentation.
Error for bud 87 at time 157 with candidate: Unable to find cell_id=60 at time_id=161 in the segmentation.
Error for bud 105 at time 169 with candidate: Unable to find cell_id=67 at time_id=170 in the segmentation.




Error for bud 114 at time 173 with candidate: No candidate parents have been found for in frame #173.
Processing colony 5
No model was provided


In [None]:
# get matrix features
matrix_features = get_custom_matrix_features(
    features_all, lin_gt_all, selected_keys)
matrix_features.reset_index(inplace=True, drop=True)

  print('more than 4 candidates', bud, colony, candidates, sorted_indices, int(lineage.loc[(lineage['bud_id'] == bud) & (
  print('more than 4 candidates', bud, colony, candidates, sorted_indices, int(lineage.loc[(lineage['bud_id'] == bud) & (
  print('more than 4 candidates', bud, colony, candidates, sorted_indices, int(lineage.loc[(lineage['bud_id'] == bud) & (
  print('more than 4 candidates', bud, colony, candidates, sorted_indices, int(lineage.loc[(lineage['bud_id'] == bud) & (
  print('more than 4 candidates', bud, colony, candidates, sorted_indices, int(lineage.loc[(lineage['bud_id'] == bud) & (
  print('more than 4 candidates', bud, colony, candidates, sorted_indices, int(lineage.loc[(lineage['bud_id'] == bud) & (


more than 4 candidates 141 0 [  3  54  68  93 134] [1 0 4 3 2] 54
more than 4 candidates 146 0 [ 14  19  49 131 144] [2 0 1 4 3] 14
more than 4 candidates 152 0 [  9  23  48  91 111] [1 4 0 2 3] 23
more than 4 candidates 166 0 [  3  54  68  93 134] [4 3 0 1 2] 54
more than 4 candidates 168 0 [  1  33  77 154 160] [1 2 4 3 0] 1
parent not in candidates 168 0 [ 33  77 160 154] 1
more than 4 candidates 221 0 [  9  19  77 160 168] [2 3 1 4 0] 77


  print('more than 4 candidates', bud, colony, candidates, sorted_indices, int(lineage.loc[(lineage['bud_id'] == bud) & (
  print('more than 4 candidates', bud, colony, candidates, sorted_indices, int(lineage.loc[(lineage['bud_id'] == bud) & (
  print('more than 4 candidates', bud, colony, candidates, sorted_indices, int(lineage.loc[(lineage['bud_id'] == bud) & (
  print('more than 4 candidates', bud, colony, candidates, sorted_indices, int(lineage.loc[(lineage['bud_id'] == bud) & (


more than 4 candidates 264 0 [ 94 105 167 191 236] [2 0 1 4 3] 94
more than 4 candidates 266 0 [ 48  80 137 183 242] [2 1 4 3 0] 137
more than 4 candidates 276 0 [ 23  48 152 182 265] [4 1 3 0 2] 265
more than 4 candidates 283 0 [ 11  27 105 192 237] [2 3 1 4 0] 27


  print('more than 4 candidates', bud, colony, candidates, sorted_indices, int(lineage.loc[(lineage['bud_id'] == bud) & (


more than 4 candidates 52 1 [ 3  7 13 25 29] [3 4 0 2 1] 25


  print('more than 4 candidates', bud, colony, candidates, sorted_indices, int(lineage.loc[(lineage['bud_id'] == bud) & (
  print('more than 4 candidates', bud, colony, candidates, sorted_indices, int(lineage.loc[(lineage['bud_id'] == bud) & (
  print('more than 4 candidates', bud, colony, candidates, sorted_indices, int(lineage.loc[(lineage['bud_id'] == bud) & (


more than 4 candidates 69 3 [12 19 21 42 67] [4 0 1 2 3] 12
more than 4 candidates 87 3 [18 37 39 59 71] [1 2 0 4 3] 37
more than 4 candidates 92 3 [22 31 49 50 75] [3 4 0 1 2] 22
parent not in candidates 68 4 [ 4  8 64 -3] 39


In [37]:
# split data into train and test
# split to train and test (5-fold cross validation take care of the rest)
test_df = matrix_features.sample(frac=0.2, random_state=4)
train_df = matrix_features.drop(test_df.index)

In [None]:
# train the model using 5-fold cross validation with fixed seed
models, accuracies = cv_nn(train_df, config=config, seed=42)
res['val_accuracy'] = np.mean(accuracies)
res['val_accuracy_std'] = np.std(accuracies)
# test all of the models fom 5-fold cv on the test set
accuracies_test = []
for model in models:
    pred, accuracy = test_nn(model, test_df)
    accuracies_test.append(accuracy)
res['internal_test_accuracy'] = np.mean(accuracies_test)
res['internal_test_accuracy_std'] = np.std(accuracies_test)

config:  {'owner': 'train_pipeline', 'epoch_n': 200, 'patience': 10, 'lr': 0.01, 'batch_size': 256, 'layers': [64, 64, 4], 'augment': True, 'features': ['dist_0', 'dist_std', 'poly_fit_budcm_budpt', 'poly_fit_expansion_vector', 'position_bud_std', 'position_bud_max', 'position_bud_min', 'position_bud_last', 'position_bud_first', 'orientation_bud_std', 'orientation_bud_max', 'orientation_bud_min', 'orientation_bud_last', 'orientation_bud_first', 'orientation_bud_last_minus_first', 'plyfit_orientation_bud'], 'dist_threshold': 12, 'num_frames': 8, 'normalized': True, 'use_wandb': False, 'cv_number': 1}
early stopping at  14 LR:  9.99999999999999e-05
config:  {'owner': 'train_pipeline', 'epoch_n': 200, 'patience': 10, 'lr': 0.01, 'batch_size': 256, 'layers': [64, 64, 4], 'augment': True, 'features': ['dist_0', 'dist_std', 'poly_fit_budcm_budpt', 'poly_fit_expansion_vector', 'position_bud_std', 'position_bud_max', 'position_bud_min', 'position_bud_last', 'position_bud_first', 'orientation_b

In [39]:
res

{'val_accuracy': 0.8811008929761428,
 'val_accuracy_std': 0.01829194368117267,
 'internal_test_accuracy': 0.9006535947712418,
 'internal_test_accuracy_std': 0.007622159339667032}

In [None]:
# save the best model
best_model = 0
best_sum_accuracy = 0
for model in models:
    _, test_accuracy = test_nn(model, test_df)
    _, train_accuracy = test_nn(model, train_df)
    if (train_accuracy+test_accuracy > best_sum_accuracy):
        best_model = model
        best_sum_accuracy = train_accuracy+test_accuracy
torch.save(best_model.state_dict(), 'best_model_with_fake_candid_thresh{}_frame_num{}_normalized_{}.pth'.format(
    args['bud_distance_max'], args['num_frames'], args['normalized']))
print('best model saved :',  'best_model_fake_candid_thresh{}_frame_num{}_normalized_{}.pth'.format(
    args['bud_distance_max'], args['num_frames'], args['normalized']))

test accuracy 0.8954248366013072
test accuracy 0.8794788273615635
test accuracy 0.9150326797385621
test accuracy 0.8859934853420195
test accuracy 0.8954248366013072
test accuracy 0.8827361563517915
test accuracy 0.8954248366013072
test accuracy 0.8843648208469055
test accuracy 0.9019607843137255
test accuracy 0.8827361563517915
best model saved : best_model_fake_candid_thresh12_frame_num8_normalized_True.pth


In [41]:
# test the best model on the test set
_, test_accuracy = test_nn(best_model, test_df)

test accuracy 0.9150326797385621
