In [None]:
# %%
# === Imports ===
from keras.models import Model
from keras.layers import Input, Activation, UpSampling2D, Convolution2D, MaxPooling2D, BatchNormalization
from keras.callbacks import Callback
from keras import backend as K
from keras import optimizers
from keras import losses
import tensorflow as tf
import numpy as np
import pickle
from sklearn.model_selection import train_test_split
from keras.callbacks import EarlyStopping, ModelCheckpoint
import matplotlib.pyplot as plt
import random


In [None]:
import tensorflow as tf

# Check GPU availability
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    print(f"✅ GPU detected: {gpus}")
    try:
        # Limit TensorFlow to just the first GPU (optional)
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        print(e)
else:
    print("⚠️ No GPU detected, training will be on CPU.")


In [None]:
import matplotlib.pyplot as plt
from IPython.display import clear_output

train_loss = []
val_loss = []

def on_epoch_end(epoch, logs):
    train_loss.append(logs['loss'])
    val_loss.append(logs['val_loss'])
    
    clear_output(wait=True)
    plt.figure(figsize=(8,5))
    plt.plot(train_loss, label='Train Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Live Training Progress')
    plt.legend()
    plt.grid(True)
    plt.show()

live_plot_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=on_epoch_end)


In [None]:
# === Helper Functions ===

def project_01(im):
    im = np.squeeze(im)
    min_val = im.min()
    max_val = im.max()
    return (im - min_val)/(max_val - min_val)

def normalize_im(im, dmean, dstd):
    im = np.squeeze(im)
    im_norm = np.zeros(im.shape,dtype=np.float32)
    im_norm = (im - dmean)/dstd
    return im_norm

class LossHistory(Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))

def matlab_style_gauss2D(shape=(7,7),sigma=1):
    m,n = [(ss-1.)/2. for ss in shape]
    y,x = np.ogrid[-m:m+1,-n:n+1]
    h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
    h.astype(dtype=K.floatx())
    h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
    sumh = h.sum()
    if sumh != 0:
        h /= sumh
    h = h*2.0
    h = h.astype('float32')
    return h

psf_heatmap = matlab_style_gauss2D(shape = (7,7),sigma=1)
gfilter = tf.reshape(psf_heatmap, [7, 7, 1, 1])

def L1L2loss(input_shape):
    def bump_mse(heatmap_true, spikes_pred):
        # Perform convolution using TensorFlow
        heatmap_pred = tf.nn.conv2d(
            spikes_pred,
            filters=gfilter,
            strides=[1, 1, 1, 1],
            padding='SAME'
        )

        # Compute losses
        loss_heatmaps = tf.reduce_mean(tf.square(heatmap_true - heatmap_pred))
        loss_spikes = tf.reduce_mean(tf.abs(spikes_pred))  # L1 regularization on spikes

        return loss_heatmaps + loss_spikes
    return bump_mse


def conv_bn_relu(nb_filter, rk, ck, name):
    def f(input):
        conv = Convolution2D(nb_filter, kernel_size=(rk, ck), strides=(1,1),\
                               padding="same", use_bias=False,\
                               kernel_initializer="Orthogonal",name='conv-'+name)(input)
        conv_norm = BatchNormalization(name='BN-'+name)(conv)
        conv_norm_relu = Activation(activation = "relu",name='Relu-'+name)(conv_norm)
        return conv_norm_relu
    return f

def CNN(input,names):
    Features1 = conv_bn_relu(32,3,3,names+'F1')(input)
    pool1 = MaxPooling2D(pool_size=(2,2),name=names+'Pool1')(Features1)
    Features2 = conv_bn_relu(64,3,3,names+'F2')(pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool2')(Features2)
    Features3 = conv_bn_relu(128,3,3,names+'F3')(pool2)
    pool3 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool3')(Features3)
    Features4 = conv_bn_relu(512,3,3,names+'F4')(pool3)
    up5 = UpSampling2D(size=(2, 2),name=names+'Upsample1')(Features4)
    Features5 = conv_bn_relu(128,3,3,names+'F5')(up5)
    up6 = UpSampling2D(size=(2, 2),name=names+'Upsample2')(Features5)
    Features6 = conv_bn_relu(64,3,3,names+'F6')(up6)
    up7 = UpSampling2D(size=(2, 2),name=names+'Upsample3')(Features6)
    Features7 = conv_bn_relu(32,3,3,names+'F7')(up7)
    return Features7

def buildModel(input_dim):
    input_ = Input (shape = (input_dim))
    act_ = CNN (input_,'CNN')
    density_pred = Convolution2D(1, kernel_size=(1, 1), strides=(1, 1), padding="same",\
                                  activation="linear", use_bias = False,\
                                  kernel_initializer="Orthogonal",name='Prediction')(act_)
    model = Model (inputs= input_, outputs=density_pred)
    opt = optimizers.Adam(learning_rate=0.001)
    model.compile(optimizer=opt, loss = L1L2loss(input_dim))
    return model


In [None]:
# ============================
# 📦 LOAD & RESHAPE THE DATA
# ============================

with open("from scratch_training_data.pkl", "rb") as f:
    data = pickle.load(f)

patches = data["patches"]
heatmaps = data["heatmaps"]
spikes = data["spikes"]  # Optional

print("✅ Loaded data shapes:")
print("patches:", patches.shape)
print("heatmaps:", heatmaps.shape)
print("spikes:", spikes.shape)

# ============================
# 📏 NORMALIZATION
# ============================
# Compute dataset mean & std
mean_val = np.mean(patches)
std_val = np.std(patches)

# Normalize each patch
patches = (patches - mean_val) / std_val

# Add channel axis
patches = patches[..., np.newaxis]
heatmaps = heatmaps[..., np.newaxis]

print("📐 Reshaped & normalized:")
print("patches:", patches.shape)
print("heatmaps:", heatmaps.shape)

# ============================
# ✂️ TRAIN/TEST SPLIT
# ============================
X_train, X_val, y_train, y_val = train_test_split(
    patches, heatmaps, test_size=0.1, random_state=42
)

print("✅ Split:")
print("Training samples:", X_train.shape[0])
print("Validation samples:", X_val.shape[0])


In [None]:
input_shape = (208, 208, 1)
model = buildModel(input_shape)
model.load_weights(r"weights_RealMicrotubules.hdf5")


In [None]:
# Pick a sample
# pick I random within patches length
i = random.randint(1,5000)

In [None]:

input_patch = patches[i]
true_heatmap = heatmaps[i]
predicted = model.predict(input_patch[np.newaxis])[0]

# Plot
plt.figure(figsize=(16, 4))

# 1. Input Patch
plt.subplot(1, 5, 1)
plt.imshow(input_patch, cmap='gray')
plt.title("Input Patch")
plt.axis('off')

# 2. True Heatmap
plt.subplot(1, 5, 2)
plt.imshow(true_heatmap, cmap='hot')
plt.title("True Heatmap")
plt.axis('off')

# 3. Predicted Heatmap
plt.subplot(1, 5, 3)
plt.imshow(predicted, cmap='hot')
plt.title("Predicted Heatmap")
plt.axis('off')

# 4. True Heatmap overlayed on Input Patch
# plt.subplot(1, 5, 4)
# plt.imshow(input_patch, cmap='gray')
# plt.imshow(true_heatmap, cmap='hot', alpha=0.5)
# plt.title("True Heatmap Overlay")
# plt.axis('off')

# # 5. Predicted Heatmap overlayed on True Heatmap
# plt.subplot(1, 5, 5)
# plt.imshow(true_heatmap, cmap='hot')
# plt.imshow(predicted, cmap='cool', alpha=0.5)
# plt.title("Predicted over True")
# plt.axis('off')

plt.tight_layout()
plt.show()


In [None]:
# === Load Model + Weights ===

# Specify the correct input shape
input_shape = (208, 208, 1)

# Build the model
model = buildModel(input_shape)

# Load the weights from the .h5 file
model.load_weights(r"DeepSTORM_model_weights_best.hdf5")

print("✅ Model loaded and weights applied.")

In [None]:
# Pick a sample
input_patch = patches[i]
true_heatmap = heatmaps[i]
predicted = model.predict(input_patch[np.newaxis])[0]


# Plot
plt.figure(figsize=(16, 4))

# 1. Input Patch
plt.subplot(1, 5, 1)
plt.imshow(input_patch, cmap='gray')
plt.title("Input Patch")
plt.axis('off')

# 2. True Heatmap
plt.subplot(1, 5, 2)
plt.imshow(true_heatmap, cmap='hot')
plt.title("True Heatmap")
plt.axis('off')

# 3. Predicted Heatmap
plt.subplot(1, 5, 3)
plt.imshow(predicted, cmap='hot')
plt.title("Predicted Heatmap")
plt.axis('off')

# # 4. True Heatmap overlayed on Input Patch
# plt.subplot(1, 5, 4)
# plt.imshow(input_patch, cmap='gray')
# plt.imshow(true_heatmap, cmap='hot', alpha=0.5)
# plt.title("True Heatmap Overlay")
# plt.axis('off')

# # 5. Predicted Heatmap overlayed on True Heatmap
# plt.subplot(1, 5, 5)
# plt.imshow(true_heatmap, cmap='hot')
# plt.imshow(predicted, cmap='cool', alpha=0.5)
# plt.title("Predicted over True")
# plt.axis('off')

plt.tight_layout()
plt.show()


In [None]:
# === Load Model + Weights ===

# Specify the correct input shape
input_shape = (208, 208, 1)

# Build the model
model = buildModel(input_shape)

# Load the weights from the .h5 file
model.load_weights(r"weights_SimulatedMicrotubules.hdf5")

print("✅ Model loaded and weights applied.")

In [None]:

input_patch = patches[i]
true_heatmap = heatmaps[i]
predicted = model.predict(input_patch[np.newaxis])[0]


# Plot
plt.figure(figsize=(16, 4))

# 1. Input Patch
plt.subplot(1, 5, 1)
plt.imshow(input_patch, cmap='gray')
plt.title("Input Patch")
plt.axis('off')

# 2. True Heatmap
plt.subplot(1, 5, 2)
plt.imshow(true_heatmap, cmap='hot')
plt.title("True Heatmap")
plt.axis('off')

# 3. Predicted Heatmap
plt.subplot(1, 5, 3)
plt.imshow(predicted, cmap='hot')
plt.title("Predicted Heatmap")
plt.axis('off')

# # 4. True Heatmap overlayed on Input Patch
# plt.subplot(1, 5, 4)
# plt.imshow(input_patch, cmap='gray')
# plt.imshow(true_heatmap, cmap='hot', alpha=0.5)
# plt.title("True Heatmap Overlay")
# plt.axis('off')

# # 5. Predicted Heatmap overlayed on True Heatmap
# plt.subplot(1, 5, 5)
# plt.imshow(true_heatmap, cmap='hot')
# plt.imshow(predicted, cmap='cool', alpha=0.5)
# plt.title("Predicted over True")
# plt.axis('off')

plt.tight_layout()
plt.show()

plt.show()

In [None]:
import tensorflow as tf
import multiprocessing

# Get number of physical CPU cores
num_cores = multiprocessing.cpu_count()
print(f"🖥️ CPU cores detected: {num_cores}")

# Configure TensorFlow threading
tf.config.threading.set_intra_op_parallelism_threads(num_cores)
tf.config.threading.set_inter_op_parallelism_threads(num_cores)

print("✅ Configured TensorFlow to use all CPU cores")

In [None]:
# ============================
# 🏗️ BUILD THE MODEL
# ============================

input_shape = X_train.shape[1:]  # (208, 208, 1)
model = buildModel(input_shape)

In [None]:
# ============================
# ⏱️ CALLBACKS
# ============================

early_stop = EarlyStopping(
    monitor='val_loss',
    # patience=3,
    restore_best_weights=True
)

checkpoint = ModelCheckpoint(
    filepath='best_finetuned_weights.keras',
    monitor='val_loss',
    save_best_only=True
)

In [None]:
# ============================
# 🧠 TRAINING
# ============================

history = model.fit(
    X_train,
    y_train,
    batch_size=8,
    epochs=1,
    validation_data=(X_val, y_val),
    shuffle=True,
    callbacks=[early_stop, checkpoint, live_plot_callback]  # add here
)

In [None]:

input_patch = patches[i]
true_heatmap = heatmaps[i]
predicted = model.predict(input_patch[np.newaxis])[0]

# Plot
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1); plt.imshow(input_patch, cmap='gray'); plt.title("Input Patch")
plt.subplot(1, 3, 2); plt.imshow(true_heatmap, cmap='hot'); plt.title("True Heatmap")
plt.subplot(1, 3, 3); plt.imshow(predicted, cmap='hot'); plt.title("Predicted Heatmap")
plt.show()