In [None]:
import os
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import timm
import lightning as pl
import mlflow.pytorch
from lightning.pytorch.loggers import MLFlowLogger
import torchvision
import sys
sys.path.append(r'd:\DeepLearning\Projects\Arecanut_segmentation/')

In [None]:
from src.unet_sentinel_resnet import unet,UnetResnetSentinel2
from src.unet_sentinel_resnet import UnetResnetSentinel2
from src.data_loader import SegmentDataLoader

In [None]:
data_dir = r"D:\DeepLearning\Projects\Arecanut_segmentation\Data\augmented_128"
batch_size = 25
data_module = SegmentDataLoader(data_dir, batch_size=batch_size)
data_module.setup()

# Training dataloader
train_loader = data_module.train_dataloader()

# Validation dataloader
val_loader = data_module.val_dataloader()


In [None]:
model = unet(UnetResnetSentinel2,3,15,'jaccard')

In [None]:
early_stop_callback = EarlyStopping(monitor="train_JaccardIndex", min_delta=0.00, patience=20, verbose=False, mode="max")

In [None]:
checkpoint_callback = ModelCheckpoint(
    save_top_k=10,
    monitor="val_JaccardIndex",
    mode="max",
    dirpath=r"Arecanut_segmentation/models/",
    filename="arecanut-unetresnetsent2-128-{epoch:02d}-{val_JaccardIndex}",
    save_last=True,
)



In [None]:
mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment('Training unet')
with mlflow.start_run(log_system_metrics=True) as run:
    mlf_logger = MLFlowLogger(
            experiment_name=mlflow.get_experiment(run.info.experiment_id).name,
            tracking_uri=mlflow.get_tracking_uri(),
            log_model='all',
            run_id = run.info.run_id,
            # save_dir = r"Projects\Arecanut_segmentation\NB\mlruns"
        )
    trainer = pl.Trainer(accelerator='gpu', max_epochs= 1,callbacks = [checkpoint_callback],logger=mlf_logger)
    trainer.fit(model=model, train_dataloaders = data_module)
    # mlflow.pytorch.log_model(model, "Unet")

In [None]:
# Ensure that the model class is used directly for loading the checkpoint.
# model1 = unet.load_from_checkpoint(
#     checkpoint_path=r'D:\DeepLearning\Projects\Arecanut_segmentation\models\arecanut-unetresnetsent2-128-epoch=99-val_JaccardIndex=0.36997315287590027.ckpt',
#     encoder = UnetResnetSentinel2,
#     nc=3,
#     c=15,
#     loss='jaccard'
# )


In [None]:
# import numpy as np

# # Fetch a batch from the test loader
# images, masks = next(iter(train_loader))

# # Switch the model to evaluation mode
# with torch.no_grad():
#     model1.eval()
#     logits = model1(images.to('cuda'))  # Get raw logits from the model

# # Apply softmax to get class probabilities
# # Shape: [batch_size, num_classes, H, W]

# pr_masks = logits.softmax(dim=1)
# # Convert class probabilities to predicted class labels
# pr_masks = pr_masks.argmax(dim=1)  # Shape: [batch_size, H, W]

# # Visualize a few samples (image, ground truth mask, and predicted mask)
# for idx, (image, gt_mask, pr_mask) in enumerate(zip(images, masks, pr_masks)):
#     if idx <= 4:  # Visualize first 5 samples
#         plt.figure(figsize=(12, 6))

#         # Original Image
#         plt.subplot(1, 3, 1)
#         plt.imshow(
#             image.cpu().numpy().transpose(1, 2, 0)
#         )  # Convert CHW to HWC for plotting
#         plt.title("Image")
#         plt.axis("off")

#         # Ground Truth Mask
#         plt.subplot(1, 3, 2)
#         plt.imshow(gt_mask.cpu().numpy(), cmap="tab20")  # Visualize ground truth mask
#         plt.title("Ground truth")
#         plt.axis("off")

#         # Predicted Mask
#         plt.subplot(1, 3, 3)
#         plt.imshow(pr_mask.cpu().numpy(), cmap="tab20")  # Visualize predicted mask
#         plt.title("Prediction")
#         plt.axis("off")

#         # Show the figure
#         plt.show()
#     else:
#         break