In [None]:
import os
if 'workbookDir' not in globals():
    print('Updating working directory')
    workbookDir = os.path.dirname(os.getcwd())
    os.chdir(workbookDir)
print(os.getcwd())

from python_code.util import  preprocess
from python_code.util.binary_loss import BinaryLoss
from trace_brightfield.util_deep_learning import predict_3D_stack
from trace_brightfield.UNet_3D_model import UNet_3D

import tifffile
import torch
from pathlib import Path
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from torch import tensor
import numpy as np
from torch import nn, cat
import torch.nn.functional as F
from torch.nn import ReLU, MaxPool3d, MSELoss, ConvTranspose3d, Conv3d,BCELoss, CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader

%load_ext autoreload
%autoreload 2

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, folder_imgs, folder_targets):
        self.folder_imgs = folder_imgs
        self.folder_targets = folder_targets
        valid_suffix = {'.tif'}
        
        self.file_names = [p.name for p in Path(self.folder_imgs).iterdir() if p.suffix in valid_suffix] 

    def __len__(self):
        return len(self.file_names)
        
    def __getitem__(self, idx):
        file_name = self.file_names[idx]

        img = tifffile.imread(Path(self.folder_imgs,file_name))
        target = tifffile.imread(Path(self.folder_targets,file_name))

        # Preprocess image for normalization
        #img = preprocess.preprocess_3d_stack_for_AI_segmentation(img)

        # Perform simple data augmentation
        if np.random.uniform(0, 1) > 0.5:
            # randomly invert x-axis
            img[:,::-1,:]

        if np.random.uniform(0, 1) > 0.5:
            # randomly invert y-axis
            img[:,:,::-1]
        '''
        if np.random.uniform(0, 1) > 0.5:
            img += np.random.uniform(-0.1, 0.1)
        '''
        
        #img is a 3D image with [D,W,H], We require a 4D stack [C,D,W,H] where C is the number of channels. In this case 1.
        img = tensor(img).float().unsqueeze(0)
        target = tensor(target).float().unsqueeze(0)
        
        return  img, target 

In [None]:
folder_imgs = r"E:\SPERM\Training_dataset\2024_12_26_flagellum_head_brightfield\input"
folder_targets = r"E:\SPERM\Training_dataset\2024_12_26_flagellum_head_brightfield\target"
dataset = CustomImageDataset(folder_imgs,folder_targets)
img, target = dataset.__getitem__(0)

# Create the dataloader
train_dataloader = DataLoader(dataset, batch_size= 16, shuffle=True, num_workers=0, drop_last=True)

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"The total of images in dataset is: {dataset.__len__()}")
print(f"The input image shape is: {img.shape}")
print(f"Running device: {device}")
"""
sample_batch = next(iter(train_dataloader))
images,targets = sample_batch

# plot some training images
fig, axes = plt.subplots(4, 4, figsize=(12, 12))  # Create a 4x4 grid to hold two 4x2 groups of images
# Flatten the axes for easier iteration
axes = axes.flatten()
# Plot images for the first figure in the first 8 subplots
for i in range(8):
    ax = axes[2*i]
    ax.imshow(np.max(images[i].squeeze(0).cpu().numpy(),axis=0), cmap='gray')
    ax.axis('off')  # Turn off axis labels and ticks
    ax = axes[2*i+1]
    ax.imshow(np.max(targets[i].squeeze(0).cpu().numpy(),axis=0), cmap='gray')
    ax.axis('off')  # Turn off axis labels and ticks    

# Adjust layout and display
plt.tight_layout()
"""

In [None]:
def save_samples(folder_path, model, device, label=""):
    folder_output = Path(folder_path, "results")
    folder_output.mkdir(parents=True, exist_ok=True)
    file_names = [p.name for p in Path(folder_path).iterdir() if p.suffix in {'.tif'}]

    if not isinstance(label, str):
        label = f"{label:05}"
        
    if len(file_names)>0:
        # plot some training images
        fig, axes = plt.subplots(len(file_names), 2, figsize=(12, 12))  # Create a Nx2 grid to hold the predictions
        # Flatten the axes for easier iteration
        axes = axes.flatten()
        
        for i,fn in enumerate(file_names):
            # make prediction
            img = tifffile.imread(Path(folder_path, fn))
            network_output = predict_3D_stack(img, model, device = device)
    
            #plot images
            ax = axes[2*i]
            ax.imshow(np.max(img,axis=0), cmap='gray')        
            ax.axis('off')  # Turn off axis labels and ticks
    
            #plot predictions
            ax = axes[2*i+1]
            ax.imshow(np.max(network_output,axis=0), cmap='gray')        
            ax.axis('off')  # Turn off axis labels and ticks
            tifffile.imwrite(Path(folder_output,f"Network_output_{Path(fn).stem}_epoch_{label}.tif"), network_output)
        plt.savefig(Path(folder_output, f"Network_output_epoch_{label}"))
        plt.close()

In [None]:
# create model and test it
modelo = UNet_3D().to(device)
#modelo.load_state_dict(torch.load(r"E:\SPERM\Training_dataset\2024_12_24_flagellum_head_brightfield\model\modelo_UNet_3D_epoch_00099.pth", weights_only=True))
X = torch.rand(size=(10,1, 9, 101, 96), dtype=torch.float32, device=device)        
out = modelo(X)
out.shape

In [None]:
BCE_loss = BCELoss()
Binary_loss = BinaryLoss()
optimizer = Adam(modelo.parameters(), lr = 0.001)

# folder with test images to check the performance in unseen images
folder_test = r"E:\SPERM\Training_dataset\2024_12_24_flagellum_head_brightfield\test"

In [None]:
for epoch in range(100):
    #set the model in training mode
    modelo.train()
    epoch_error = 0
    counter = 0
    for imgs, targets in train_dataloader:
        # Change anotations to float
        targets = targets.float().to(device)
        imgs = imgs.float().to(device)

        optimizer.zero_grad()  # sets to zero the gradients of the optimizer
        
        network_output = modelo(imgs) # forward pass
    
        loss = Binary_loss(torch.sigmoid(network_output), targets) #+ BCE_loss(torch.sigmoid(network_output), targets) # Loss function
        #loss = BCE_loss(torch.sigmoid(network_output), targets)  # Loss function

        loss.backward() # compute the gradients given|| the loss value
    
        optimizer.step() # update the weights of models using the gradients and the given optimizer

        epoch_error+=loss.item()

        counter+=1

    if epoch%5==0:
        torch.save(modelo.state_dict(), f'modelo_UNet_3D_epoch_{epoch:05}.pth')
    print(f"El error epoch = {epoch} es {epoch_error/counter}")
    save_samples(folder_test, modelo, device, label=epoch)
    

In [None]:
torch.save(modelo.state_dict(), 'full_model.pth')
print("Entire model saved!")

modelo_final = UNet_3D().to(device)  # Reinitialize model
modelo_final.load_state_dict(torch.load('full_model.pth', weights_only=True))

save_samples(folder_test, modelo_final, device, label="_final")