In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam

import numpy as np
import os
import json

import matplotlib.pyplot as plt
from tqdm import tqdm

import wandb

from utils.unet_model import construct_unet
from utils import img_functions

In [2]:
root = "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/Channeled_ASU/"
# root = "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/Channeled_CWRU_Dupont/"
# root = "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/Channeled_CWRU_SunEdison/"
# root = "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/Channeled_LBNL/"
# root = "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/Channeled_Combined_CWRU_LBNL_ASU/"


model_weight_paths = {
    "emma_retrained": "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/retrained_pv-vision_model.pt",
    "original": "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/pv-vision_model.pt",
}

# weight_path = model_weight_paths["emma_retrained"]
weight_path = model_weight_paths["original"]

checkpoint_name = "wandb_experiment_" + root.split("/")[-2]

In [None]:
def load_dataset(root):
    transformers = img_functions.Compose(
        [
            img_functions.ChanneledFixResize(256),
            img_functions.ToTensor(),
            img_functions.Normalize(),
        ]
    )

    train_dataset = img_functions.SolarDataset(
        root, image_folder="img/train", mask_folder="ann/train", transforms=transformers
    )

    val_dataset = img_functions.SolarDataset(
        root, image_folder="img/val", mask_folder="ann/val", transforms=transformers
    )

    return train_dataset, val_dataset

In [4]:
def load_device_and_model(weight_path):
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = torch.device("mps")
    unet = construct_unet(5)
    unet = torch.nn.DataParallel(unet)

    model = unet.module.to(device)

    return device, model

In [5]:
def get_save_dir(base_dir, checkpoint_name):
    checkpoint_dir = base_dir + "/checkpoints/"
    folders = [folder for folder in os.listdir(checkpoint_dir)]

    max_number = 0
    for folder in folders:
        number = int(folder[-1])
        if number > max_number:
            max_number = number

    new_folder_name = f"{checkpoint_name}{max_number + 1}"
    new_folder_path = os.path.join(checkpoint_dir, new_folder_name)

    os.makedirs(new_folder_path, exist_ok=True)

    return new_folder_path

In [6]:
train_dataset, val_dataset = load_dataset(root)
device, model = load_device_and_model(weight_path)

In [7]:
category_mapping = {0: "empty", 1: "dark", 2: "busbar", 3: "crack", 4: "non-cell"}

In [8]:
def new_inference_and_show(idx, threshold=0.5):
    # Get the preprocessed image and multi-hot ground truth mask
    img, mask = train_loader.dataset.__getitem__(idx)
    img = img.to(device)

    # Get the raw image for display (assuming __getraw__ returns a PIL image)
    raw_img, _ = train_loader.dataset.__getraw__(idx)

    # --- Run inference ---
    # Get raw logits from the model, then apply Sigmoid and threshold
    logits = model(img.unsqueeze(0)).detach().cpu()  # shape: [1, 5, H, W]
    probs = torch.sigmoid(logits)  # shape: [1, 5, H, W]
    pred_mask = (probs > threshold).float().squeeze(0).numpy()  # shape: [5, H, W]

    # Ground truth is assumed to be already a 5-channel multi-hot mask.
    gt_mask = mask.cpu().numpy()  # shape: [5, H, W]

    # --- Visualization ---
    # Create a grid with 3 rows and 5 columns:
    #   Row 0: Raw image (displayed only once in the first column)
    #   Row 1: Ground truth masks for each class
    #   Row 2: Predicted masks for each class
    n_classes = 5
    class_names = [f"({k}) {v}" for k, v in category_mapping.items()]

    fig, axs = plt.subplots(3, n_classes, figsize=(4 * n_classes, 12))

    # Row 0: Display raw image in first subplot; hide other subplots in this row.
    axs[0, 0].imshow(raw_img.convert("L"), cmap="viridis")
    axs[0, 0].set_title("Raw Image")
    axs[0, 0].axis("off")
    for j in range(1, n_classes):
        axs[0, j].axis("off")

    # Row 1: Ground truth for each class (each channel)
    for j in range(n_classes):
        axs[1, j].imshow(gt_mask[j], cmap="viridis")
        axs[1, j].set_title(f"GT: {class_names[j]}")
        axs[1, j].axis("off")

    # Row 2: Predictions for each class (each channel)
    for j in range(n_classes):
        axs[2, j].imshow(pred_mask[j], cmap="viridis")
        axs[2, j].set_title(f"Pred: {class_names[j]}")
        axs[2, j].axis("off")

    fig.suptitle("Retrained Model Prediction", fontsize=16)

    plt.tight_layout()
    plt.show()

In [9]:
def create_wandb_image(idx, threshold=0.5):
    # Get the preprocessed image and multi-hot ground truth mask
    img, mask = train_loader.dataset.__getitem__(idx)
    img = img.to(device)

    # --- Run inference ---
    # Get raw logits from the model, then apply Sigmoid and threshold
    logits = model(img.unsqueeze(0)).detach().cpu()  # shape: [1, 5, H, W]
    probs = torch.sigmoid(logits)  # shape: [1, 5, H, W]
    pred_mask = (probs > threshold).float().squeeze(0).numpy()  # shape: [5, H, W]

    # Ground truth is assumed to be already a 5-channel multi-hot mask.
    gt_mask = mask.cpu().numpy()  # shape: [5, H, W]

    n_classes = len(category_mapping)

    this_id_mask_images = []
    for i in range(n_classes):
        masks_dict = {
            "predictions": {
                "mask_data": pred_mask[i],
                "class_labels": category_mapping,
            },
            "ground_truth": {
                "mask_data": gt_mask[i],
                "class_labels": category_mapping,
            },
        }

        mask_img = wandb.Image(
            img,
            masks=masks_dict,
        )
        this_id_mask_images.append(mask_img)
    return this_id_mask_images

In [10]:
def create_wandb_image_for_table(idx, threshold=0.5):
    # Get the preprocessed image and multi-hot ground truth mask
    img, mask = train_loader.dataset.__getitem__(idx)
    img = img.to(device)

    # --- Run inference ---
    # Get raw logits from the model, then apply Sigmoid and threshold
    logits = model(img.unsqueeze(0)).detach().cpu()  # shape: [1, 5, H, W]
    probs = torch.sigmoid(logits)  # shape: [1, 5, H, W]
    pred_mask = (probs > threshold).float().squeeze(0).numpy()  # shape: [5, H, W]

    # Ground truth is assumed to be already a 5-channel multi-hot mask.
    gt_mask = mask.cpu().numpy()  # shape: [5, H, W]

    n_classes = len(category_mapping)

    this_id_table_info = []
    this_id_table_info.append(wandb.Image(img))

    for i in range(n_classes):
        this_id_table_info.append(
            wandb.Image(
                img,
                masks={
                    "ground_truth": {
                        "mask_data": gt_mask[i],
                        "class_labels": {0: category_mapping[i]},
                    },
                },
            )
        )
        this_id_table_info.append(
            wandb.Image(
                img,
                masks={
                    "predictions": {
                        "mask_data": pred_mask[i],
                        "class_labels": {0: category_mapping[i]},
                    },
                },
            )
        )

    return this_id_table_info

# Training

In [11]:
save_dir = get_save_dir(str(root), checkpoint_name)
os.makedirs(save_dir, exist_ok=True)

config = {
    "batch_size_val": 4,
    "batch_size_train": 4,
    "lr": 1e-4,
    "step_size": 1,
    "gamma": 0.1,
    "num_epochs": 2,
    "criterion": torch.nn.BCEWithLogitsLoss(),
}

config_serializable = config.copy()
config_serializable["criterion"] = str(config_serializable["criterion"])

with open(os.path.join(save_dir, "config.json"), "w", encoding="utf-8") as f:
    json.dump(config_serializable, f, ensure_ascii=False, indent=4)

run = wandb.init(
    project="pvcracks", entity="ojas-sanghi-university-of-arizona", config=config
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mojas-sanghi[0m ([33mojas-sanghi-university-of-arizona[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
train_loader = DataLoader(
    train_dataset, batch_size=config["batch_size_train"], shuffle=True
)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size_val"], shuffle=False)

In [13]:
optimizer = Adam(model.parameters(), lr=config["lr"])
# lr_scheduler = StepLR(optimizer, step_size=config["step_size"], gamma=config["gamma"])

save_name = "model.pt"

In [14]:
# log gradients
run.watch(model, log_freq=100)

In [15]:
training_epoch_loss = []
val_epoch_loss = []

for epoch in tqdm(range(1, config["num_epochs"] + 1)):
    training_step_loss = []

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        target = target.float()

        optimizer.zero_grad()

        # forward pass
        output = model(data)

        # calc loss -- bce with logits loss applies sigmoid interally
        training_loss = config["criterion"](output, target)

        # backward pass
        training_loss.backward()
        optimizer.step()

        # record loss
        training_step_loss.append(training_loss.item())

    training_epoch_loss.append(np.array(training_step_loss).mean())

    val_step_loss = []

    for batch_idx, (data, target) in enumerate(val_loader):
        data, target = data.to(device), target.to(device)
        target = target.float()

        # forward pass
        data = data.to(device)

        output = model(data)

        # calc loss -- bce with logits loss applies sigmoid interally
        val_loss = config["criterion"](output, target)

        val_step_loss.append(val_loss.item())

    val_epoch_loss.append(np.array(val_step_loss).mean())

    print(
        f"Epoch {epoch}/{config['num_epochs']}, Training Loss: {np.array(training_step_loss).mean()}, Validation Loss: {np.array(val_step_loss).mean()}"
    )

    print("Generating predictions for wandb...")
    mask_images = []
    table = wandb.Table(
        columns=[
            "Image",
            "GT Empty",
            "Pred Empty",
            "GT Dark",
            "Pred Dark",
            "GT Busbar",
            "Pred Busbar",
            "GT Crack",
            "Pred Crack",
            "GT Non-cell",
            "Pred Non-cell",
        ]
    )
    for id in range(20):
        mask_images.extend(create_wandb_image(id))
        new_img_table = create_wandb_image_for_table(id)
        table.add_data(*new_img_table)

    print("Logging to wandb...")
    run.log(
        {
            "train_loss": np.array(training_step_loss).mean(),
            "val_loss": np.array(val_step_loss).mean(),
            "predictions": mask_images,
            "table": table,
        }
    )

    table = wandb.Table(columns=["Image"])

    print("Saving model...")
    os.makedirs(os.path.join(save_dir, f"epoch_{epoch}"), exist_ok=True)
    torch.save(model.state_dict(), os.path.join(save_dir, f"epoch_{epoch}", save_name))
    print(f"Saved model at epoch {epoch}.", end=" ")

    if epoch >= 2 and epoch < config["num_epochs"]:
        os.remove(os.path.join(save_dir, f"epoch_{epoch - 1}", save_name))
        print(f"Removed model at epoch {epoch - 1}.", end="")
    print("\n")

  0%|          | 0/2 [00:00<?, ?it/s]

Epoch 1/2, Training Loss: 0.5101762526277183, Validation Loss: 0.43658209235771844
Generating predictions for wandb...
Logging to wandb...
Saving model...


 50%|█████     | 1/2 [00:49<00:49, 49.94s/it]

Saved model at epoch 1. 

Epoch 2/2, Training Loss: 0.40933988232543506, Validation Loss: 0.38161036242609436
Generating predictions for wandb...
Logging to wandb...
Saving model...


100%|██████████| 2/2 [01:40<00:00, 50.33s/it]

Saved model at epoch 2. 






---

In [16]:
run.finish()

0,1
train_loss,█▁
val_loss,█▁

0,1
train_loss,0.40934
val_loss,0.38161
