In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam

import numpy as np
import os
import json

from tqdm import tqdm

import wandb

from utils import train_functions

In [None]:
# root = "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/Channeled_ASU/"
# root = "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/Channeled_CWRU_Dupont/"
# root = "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/Channeled_CWRU_SunEdison/"
# root = "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/Channeled_LBNL/"
root = "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/Channeled_Combined_CWRU_LBNL_ASU/"


model_weight_paths = {
    "emma_retrained": "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/retrained_pv-vision_model.pt",
    "original": "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/pv-vision_model.pt",
}

# weight_path = model_weight_paths["emma_retrained"]
weight_path = model_weight_paths["original"]

checkpoint_name = "wandb_experiment_" + root.split("/")[-2]

In [None]:
def dice_coefficient(pred, target, epsilon=1e-6):
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    dice = (2.0 * intersection + epsilon) / (union + epsilon)
    return dice


def iou_score(pred, target, epsilon=1e-6):
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    iou = (intersection + epsilon) / (union + epsilon)
    return iou

In [None]:
train_dataset, val_dataset = train_functions.load_dataset(root)
device, model = train_functions.load_device_and_model()

In [7]:
category_mapping = {0: "empty", 1: "dark", 2: "busbar", 3: "crack", 4: "non-cell"}

In [9]:
def create_wandb_image(idx, threshold=0.5):
    # Get the preprocessed image and multi-hot ground truth mask
    img, mask = train_loader.dataset.__getitem__(idx)
    img = img.to(device)

    # --- Run inference ---
    # Get raw logits from the model, then apply Sigmoid and threshold
    logits = model(img.unsqueeze(0)).detach().cpu()  # shape: [1, 5, H, W]
    probs = torch.sigmoid(logits)  # shape: [1, 5, H, W]
    pred_mask = (probs > threshold).float().squeeze(0).numpy()  # shape: [5, H, W]

    # Ground truth is assumed to be already a 5-channel multi-hot mask.
    gt_mask = mask.cpu().numpy()  # shape: [5, H, W]

    n_classes = len(category_mapping)

    this_id_mask_images = []
    for i in range(n_classes):
        masks_dict = {
            "predictions": {
                "mask_data": pred_mask[i],
                "class_labels": category_mapping,
            },
            "ground_truth": {
                "mask_data": gt_mask[i],
                "class_labels": category_mapping,
            },
        }

        mask_img = wandb.Image(
            img,
            masks=masks_dict,
        )
        this_id_mask_images.append(mask_img)
    return this_id_mask_images

In [10]:
def create_wandb_image_for_table(idx, threshold=0.5):
    # Get the preprocessed image and multi-hot ground truth mask
    img, mask = train_loader.dataset.__getitem__(idx)
    img = img.to(device)

    # --- Run inference ---
    # Get raw logits from the model, then apply Sigmoid and threshold
    logits = model(img.unsqueeze(0)).detach().cpu()  # shape: [1, 5, H, W]
    probs = torch.sigmoid(logits)  # shape: [1, 5, H, W]
    pred_mask = (probs > threshold).float().squeeze(0).numpy()  # shape: [5, H, W]

    # Ground truth is assumed to be already a 5-channel multi-hot mask.
    gt_mask = mask.cpu().numpy()  # shape: [5, H, W]

    n_classes = len(category_mapping)

    this_id_table_info = []
    this_id_table_info.append(wandb.Image(img))

    for i in range(n_classes):
        this_id_table_info.append(
            wandb.Image(
                img,
                masks={
                    "ground_truth": {
                        "mask_data": gt_mask[i],
                        "class_labels": {0: category_mapping[i]},
                    },
                },
            )
        )
        this_id_table_info.append(
            wandb.Image(
                img,
                masks={
                    "predictions": {
                        "mask_data": pred_mask[i],
                        "class_labels": {0: category_mapping[i]},
                    },
                },
            )
        )

    return this_id_table_info

# Training

In [None]:
save_dir = train_functions.get_save_dir(str(root), checkpoint_name)
os.makedirs(save_dir, exist_ok=True)

original_config = {
    "batch_size_train": 8,
    "lr": 0.00092234,
    "gamma": 0.11727,
    "num_epochs": 45,
    "batch_size_val": 8,
    "criterion": torch.nn.BCEWithLogitsLoss(),
    # "lr_scheduler_step_size": 1,
}

config_serializable = original_config.copy()
config_serializable["criterion"] = str(config_serializable["criterion"])

with open(os.path.join(save_dir, "config.json"), "w", encoding="utf-8") as f:
    json.dump(config_serializable, f, ensure_ascii=False, indent=4)

run = wandb.init(
    project="pvcracks",
    entity="ojas-sanghi-university-of-arizona",
    config=original_config,
)
config = wandb.config

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mojas-sanghi[0m ([33mojas-sanghi-university-of-arizona[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
train_loader = DataLoader(
    train_dataset, batch_size=config.batch_size_train, shuffle=True
)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size_val, shuffle=False)

In [None]:
optimizer = Adam(model.parameters(), lr=config.lr)
# lr_scheduler = StepLR(optimizer, step_size=config.lr_scheduler_step_size, gamma=config.gamma)

save_name = "model.pt"

In [14]:
# log gradients
run.watch(model, log_freq=100)

In [None]:
training_epoch_loss = []
val_epoch_loss = []

for epoch in tqdm(range(1, config.num_epochs + 1)):
    training_step_loss = []

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        target = target.float()

        optimizer.zero_grad()

        # forward pass
        output = model(data)

        # calc loss -- bce with logits loss applies sigmoid interally
        training_loss = original_config["criterion"](output, target)

        # backward pass
        training_loss.backward()
        optimizer.step()

        # record loss
        training_step_loss.append(training_loss.item())

    training_epoch_loss.append(np.array(training_step_loss).mean())

    val_step_loss = []

    for batch_idx, (data, target) in enumerate(val_loader):
        data, target = data.to(device), target.to(device)
        target = target.float()

        # forward pass
        data = data.to(device)

        output = model(data)

        # calc loss -- bce with logits loss applies sigmoid interally
        val_loss = original_config["criterion"](output, target)

        val_step_loss.append(val_loss.item())

    val_epoch_loss.append(np.array(val_step_loss).mean())
    
    # Compute dice and IoU metrics per channel
    pred_probs = torch.sigmoid(output)
    pred_binary = (pred_probs > 0.5).float()

    dice_scores = []
    iou_scores = []

    for i in range(pred_binary.size(1)):  # Loop over channels
        dice = dice_coefficient(pred_binary[:, i], target[:, i])
        iou = iou_score(pred_binary[:, i], target[:, i])
        dice_scores.append(dice.item())
        iou_scores.append(iou.item())

    avg_dice = np.mean(dice_scores)
    avg_iou = np.mean(iou_scores)

    print(
        f"Epoch {epoch}/{config.num_epochs}, Training Loss: {np.array(training_step_loss).mean()}, Validation Loss: {np.array(val_step_loss).mean()}, Avg Dice: {avg_dice}, Avg IoU: {avg_iou}"
    )

    # print("Generating predictions for wandb...")
    # mask_images = []
    # table = wandb.Table(
    #     columns=[
    #         "Image",
    #         "GT Empty",
    #         "Pred Empty",
    #         "GT Dark",
    #         "Pred Dark",
    #         "GT Busbar",
    #         "Pred Busbar",
    #         "GT Crack",
    #         "Pred Crack",
    #         "GT Non-cell",
    #         "Pred Non-cell",
    #     ]
    # )
    # for id in range(20):
    #     mask_images.extend(create_wandb_image(id))
    #     new_img_table = create_wandb_image_for_table(id)
    #     table.add_data(*new_img_table)

    print("Logging to wandb...")
    run.log(
        {
            "train_loss": np.array(training_step_loss).mean(),
            "val_loss": np.array(val_step_loss).mean(),
            # "predictions": mask_images,
            # "table": table,
        }
    )

    table = wandb.Table(columns=["Image"])

    print("Saving model...")
    os.makedirs(os.path.join(save_dir, f"epoch_{epoch}"), exist_ok=True)
    torch.save(model.state_dict(), os.path.join(save_dir, f"epoch_{epoch}", save_name))
    print(f"Saved model at epoch {epoch}.", end=" ")

    if epoch >= 2 and epoch < config.num_epochs:
        os.remove(os.path.join(save_dir, f"epoch_{epoch - 1}", save_name))
        print(f"Removed model at epoch {epoch - 1}.", end="")
    print("\n")

  0%|          | 0/2 [00:00<?, ?it/s]

Epoch 1/2, Training Loss: 0.5101762526277183, Validation Loss: 0.43658209235771844
Generating predictions for wandb...
Logging to wandb...
Saving model...


 50%|█████     | 1/2 [00:49<00:49, 49.94s/it]

Saved model at epoch 1. 

Epoch 2/2, Training Loss: 0.40933988232543506, Validation Loss: 0.38161036242609436
Generating predictions for wandb...
Logging to wandb...
Saving model...


100%|██████████| 2/2 [01:40<00:00, 50.33s/it]

Saved model at epoch 2. 






---

In [16]:
run.finish()

0,1
train_loss,█▁
val_loss,█▁

0,1
train_loss,0.40934
val_loss,0.38161
