## Training the CELTIC Model

In this notebook, we demonstrate the process of training the CELTIC model. Using preprocessed images and context data, we initialize the experiment, configure the model, and run the training process. The trained model is saved in a local folder for later use in predictions (see `predict.ipynb`).


In [2]:
from celtic.utils.functions import initialize_experiment, download_resources
from celtic.train import train
import os

# Presets
organelle = 'microtubules'
resources_dir = '../resources'
path_single_cells = f'/sise/assafzar-group/assafzar/Nitsan/hipsc_single_cell_image_dataset/{organelle}/fov_processed/cells/source'


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# download resources
if not os.path.exists(resources_dir):
    shared_folder_link = 'https://drive.google.com/drive/folders/1KTzb3fzwjH5ffSLtLNHuYiLiPg2p2VUf?usp=sharing'
    download_resources(shared_folder_link, os.path.dirname(resources_dir))

### Initialize the Experiment

This step initializes the experiment by creating a local folder to store the training files. It also sets up CSV files that contain the paths to the images, and if contexts are used, it includes CSV files with the context data. In this example, we provide the microtubules context files. The process of context creation is explained in the `context_creation.ipynb` notebook.


In [6]:
path_run_dir, context_model_config = initialize_experiment(organelle, 'train', resources_dir)
print("the experiment will be saved in:", path_run_dir)

path_images_csv = [f'{resources_dir}/{organelle}/metadata/{item}_images.csv' for item in ['train', 'valid']]
path_context_csv = [f'{resources_dir}/{organelle}/metadata/{item}_context.csv' for item in ['train', 'valid']]


the experiment will be saved in: ./experiments/train/microtubules/2025-01-11-15-43-45


### Run Training

This step starts the training process using the specified parameters, including image paths, context data, and model configuration. The results are saved in the local folder of the experiment.


In [None]:
train.run_training(path_run_dir,
                    path_images_csv, 
                    path_context_csv,
                    path_single_cells, 
                    masked = True,
                    transforms = context_model_config['transforms'],
                    patch_size = context_model_config['train_patch_size'],
                    iterations = 60_000,
                    batch_size = 24,
                    learning_rate = 0.001,
                    context_features = context_model_config['context_features'], 
                    daft_embedding_factor = context_model_config['daft_embedding_factor'], 
                    daft_scale_activation = context_model_config['daft_scale_activation'])

bottleneck_dim=132
embedding factor: 4 | activation: Sigmoid()
Model instianted from fnet_nn_3d
fnet_nn_3d | {} | iter: 0
mask_efficieny_threshold initialized: 0.001


buffering images: 100%|██████████| 15/15 [00:08<00:00,  1.81it/s]


mask_efficieny_threshold initialized: 0.001


buffering images: 100%|██████████| 98/98 [01:04<00:00,  1.52it/s]
  signal = torch.tensor(signal, dtype=torch.float32, device=self.device)
  target = torch.tensor(target, dtype=torch.float32, device=self.device)
  tabular_signal = torch.tensor(tabular_signal, dtype=torch.float32, device=self.device)
