# About the notebook
The purpose of this Jupyter Notebook is to use a pre-trained deep learning model to generate class predictions for a given input image.  


# 00 - Special Instructions for Google Colab Users 

The following lines of code should be executed only when running your script on Google Colab. This is crucial to leverage the additional features provided by Colab, most notably, the availability of a free GPU.  **If, you're running the code locally, this line can be skipped (GO TO STEP 01 - Loading dependencies) as it pertains specifically to the Colab setup.**

## Give access to google drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Install Napari

In [None]:
!pip install napari

## Copy code to current session

In [None]:
!git clone https://github.com/paul-hernandez-herrera/image_classification_pytorch
import os
workbookDir = "/content/image_classification_pytorch/"
os.chdir(workbookDir)

# 01 - Loading dependencies
In this notebook, before running any code, there are several libraries and modules that need to be imported to ensure that the notebook runs smoothly. These libraries and modules contain pre-written code that performs specific tasks, such as reading and processing images, defining the UNET model, and training the model.

In [None]:
import os
if 'workbookDir' not in globals():
    print('Updating working directory')
    workbookDir = os.path.dirname(os.getcwd())
    os.chdir(workbookDir)
print(os.getcwd())
import torch
from pathlib import Path

from core_code.regression_predict_neck_point_2D import remove_head_from_trace

#allow reloading the functions updates
%load_ext autoreload
%autoreload 2

# 02 - Setting required parameters
In this section, users can specify the necessary parameters to predict the segmentation mask for a given input image. The following parameters are required:

**Model path**: The path to the trained model that will be used for segmentation prediction.

**3D Trace path**: The path to the folder containing the raw traces in format "swc".

**3D Images path**: The path to the folder containing the 3D images corresponding to the traces.

**Device**: The device that will be used to perform the operations.

In [None]:
predict = remove_head_from_trace()

# 03 - Do the prediction
This line of code allows you to predict the images using the trained deep learning model.

In [None]:
output = predict.run()

## 04 - Manually correct erroneus traces

In [None]:
from core_code.regression_predict_neck_point_2D import manually_remove_head
from pathlib import Path

In [None]:
#trace_p = r"C:\Users\jalip\Documentos\Proyectos\Sperm\Campo_claro\HIGH_VISCOCITY\2017_11_22_HIGH_VISCOCITY_DONE\Exp4_stacks\Exp4_stacks_TP0312_DC_LogN_rec_FM.swc"
for i in range(21,22):
    trace_p = Path(predict.trace_path_w.value, "Exp8_stacks_TP" + f"{i:0{4}}" +"_DC_LogN_rec_FM.swc")
    #trace_p = Path(predict.trace_path_w.value, "Exp17_stacks_TP0049_DC_LogN_rec_FM.swc")
    folder_images = Path(predict.images_path_w.value)
    folder_output = Path(Path(predict.trace_path_w.value), "trace_head_removed")
    trace_p = Path(trace_p)
    ind = 10
    manually_remove_head(trace_p, folder_images, folder_output, ind)
