""" _summary_

This Jupyter notebook calls functions and classes to generate a simulation.

"""

# Library

In [None]:
import json
import numpy as np
import pandas as pd
import plotly.express as px

In [None]:
from classes_and_functions.settings import get_cell_configurations, get_simulation_parameters
from classes_and_functions.initialisation_functions import init_lattice_in_simulation, init_cell_dictionaries
from classes_and_functions.simulation_functions import update_cell_states, implicit_immune_predation
from classes_and_functions.analysis_functions import get_tumour_sizes

from classes_and_functions.cell_classes import CancerCell, Hepatocyte

# Main

## settings

In [None]:
# paths to read files

lattice_size_explore = 'small'; date_str = "2025-06-23"

path_to_lattice_settings = f"./files/lattice_settings_{date_str}.json" 
path_to_lattice_without_tumour = f"./files/lattice_with_CVs_PTs_{date_str}_annotated_without_tumour.csv" 

## read lattice 

In [None]:
# read lattice settings 

with open(path_to_lattice_settings) as json_file:
    lattice_settings = json.load(json_file)
lattice_settings

## Simulation

In [None]:
# get cell configurations
(site_types, sites_states, color_map, markersize_map) = get_cell_configurations()
print(site_types)

In [None]:
# parameters

# relevant to all model types
P_CC_GROW = 1       # probability of cancer cells growing
P_HEP_DAMAGED = 0.5 # probability of healthy hepatocytes damaged by cancer cells to become apoptotic
P_HEP_CLEARED = 0.5 # probability of apoptotic hepatocytes becoming cleared

# only relevant to model_3
P_CC_KILLED = 0.5   # probability of cancer cells being killed, by implicit immune predation 


parameters = {
    "P_CC_GROW": P_CC_GROW,
    "P_HEP_DAMAGED": P_HEP_DAMAGED,
    "P_HEP_CLEARED": P_HEP_CLEARED,
    "P_CC_KILLED": P_CC_KILLED
}

In [None]:
cancer_cells_seeding_density = 1 # number = seeding density x number of CVs
model_type = "model_3" 
T = 15

In [None]:
# simulation

# ... load the lattice without cancer cells
lattice_in_simulation_without_cancer_cells = pd.read_csv(path_to_lattice_without_tumour)
n_CVs = lattice_in_simulation_without_cancer_cells.loc[
    lattice_in_simulation_without_cancer_cells.site_type==0].shape[0]


# ... initialise cancer cells & create cell dictionaries containing CancerCell and Hepatocyte objects
cell_dictionaries, lattice_in_simulation = init_cell_dictionaries(
    lattice=lattice_in_simulation_without_cancer_cells,
    n_cancer_cells_init=int(cancer_cells_seeding_density * n_CVs),
    CancerCell=CancerCell,
    Hepatocyte=Hepatocyte
)

# data structures for the simulation
lattice_in_simulation_copy = lattice_in_simulation.copy()
cell_dictionaries_copy = cell_dictionaries.copy()

# data frames to collect snapshots 
snapshots_at_selected_times = pd.DataFrame()
dbscan_clusters_at_selected_times = pd.DataFrame()

# simulation starts
for t in np.arange(T+1):
    
    if t % (T / (T//5)) == 0:
    
        total_number_of_cancer_cells = len(cell_dictionaries_copy['CancerCell'])
        total_number_of_hepatocytes  = len(cell_dictionaries_copy['Hepatocyte'])
        number_of_apoptotic_hepatocytes = len(
            {
                hep_id:hep for hep_id, hep in cell_dictionaries_copy['Hepatocyte'].items()
                if hep.attributes['cell_state'] == 2
            }
        )
    
        print(f"t = {t}: \n > # of Cancer Cells = {total_number_of_cancer_cells}")
        print(f" > # of Hepatocytes = {total_number_of_hepatocytes}, of which {number_of_apoptotic_hepatocytes} are apoptotic.")
        
        # record simulation snapshots
        snapshots_at_t = lattice_in_simulation_copy.copy()
        snapshots_at_t['time'] = t
        snapshots_at_selected_times = pd.concat(
            [snapshots_at_selected_times, snapshots_at_t]
        )
        
        # perform DBSCAN clustering to get tumour sizes
        tumour_t = snapshots_at_t.loc[snapshots_at_t.site_type==4].copy()
        tumour_t_sizes, tumour_t_labelled = get_tumour_sizes(tumour_t=tumour_t)
        tumour_t_labelled['time'] = t
        dbscan_clusters_at_selected_times = pd.concat([dbscan_clusters_at_selected_times, tumour_t_labelled])
        
    # cancer cell proliferating, damaging hepatocytes
    cell_dictionaries_copy, lattice_in_simulation_copy = update_cell_states(
        cell_dictionaries=cell_dictionaries_copy,
        lattice=lattice_in_simulation_copy,
        parameters=parameters,
        CancerCell=CancerCell,
        Hepatocyte=Hepatocyte,
        model_type=model_type
    )
    
    # immune cell killing cancer cells
    if model_type=='model_3':
        implicit_immune_predation(
            cell_dictionaries=cell_dictionaries_copy,
            lattice=lattice_in_simulation_copy,
            parameters=parameters,
            model_type=model_type
        )
    

In [None]:
# visualisation - colour by site type

df_plot = snapshots_at_selected_times.copy()
    
# ===== scatter plots =====
df_plot["site_type_name"] = [
    site_types[site_type]+"-PC" if zonation_type=='peri-central' and site_type=='HEP' else site_types[site_type]
    for site_type, zonation_type in df_plot[['site_type', 'zonation_type']].values   
]
color_map['HEP-PC']='cyan'

sca = px.scatter(
    data_frame=df_plot,
    x='x', y='y',
    color='site_type_name',
    facet_col='time', facet_col_wrap=2,
    color_discrete_map=color_map,
)

# customize the figure
sca.update_layout(
    template='simple_white', width=1000, height=1000
)
sca.update_traces(
    marker=dict(size=3)
)
sca.update_xaxes(title=dict(text="x", font_family="Arial", font_size=14))
sca.update_yaxes(
    title=dict(text="y", font_family="Arial", font_size=14),
    scaleanchor="x", scaleratio=1
    )

In [None]:
# visualisation - colour by dbscan cluster ids

df_plot_2 = dbscan_clusters_at_selected_times.copy()
    
# ===== scatter plots =====

sca = px.scatter(
    data_frame=df_plot_2,
    x='x', y='y',
    color='label',
    facet_col='time', facet_col_wrap=2,
    color_continuous_scale='HSV',
    hover_data=['label']
)

# customize the figure
sca.update_layout(
    template='simple_white', width=1000, height=1000
)
sca.update_traces(
    marker=dict(size=3)
)
sca.update_xaxes(title=dict(text="x", font_family="Arial", font_size=14))
sca.update_yaxes(
    title=dict(text="y", font_family="Arial", font_size=14),
    scaleanchor="x", scaleratio=1
    )