In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
import torchvision.transforms.functional as F

import numpy as np
import os
import sys
import json
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.patches as mpatches
from tqdm import tqdm

from sklearn.model_selection import train_test_split

# pv_vision_dir = os.path.join(Path.home(), 'pv-vision')
pv_vision_dir = os.path.join("/home/eccoope", "pv-vision")
# functions_dir = os.path.join(Path.home(), 'el_img_cracks_ec', 'scripts')
functions_dir = os.path.join("/home/eccoope", "el_img_cracks_ec", "scripts")

sys.path.append(pv_vision_dir)
sys.path.append(functions_dir)

# ojas_functions_dir = os.path.join(Path.home(), 'pvcracks/retrain/')
ojas_functions_dir = "/Users/ojas/Desktop/saj/SANDIA/pvcracks/retrain/"
sys.path.append(ojas_functions_dir)

from utils.unet_model import construct_unet
import functions
from torch.utils.data import random_split
import wandb

ModuleNotFoundError: No module named 'tutorials'

In [None]:
root = "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/Channeled_Combined_CWRU_LBNL_ASU_No_Empty/"


model_weight_paths = {
    "emma_retrained": "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/retrained_pv-vision_model.pt",
    "original": "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/pv-vision_model.pt",
}

# weight_path = model_weight_paths["emma_retrained"]
weight_path = model_weight_paths["original"]

checkpoint_name = root.split("/")[-2]

In [None]:
category_mapping = {0: "dark", 1: "busbar", 2: "crack", 3: "non-cell"}

In [None]:
def dice_coefficient(pred, target, epsilon=1e-6):
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    dice = (2.0 * intersection + epsilon) / (union + epsilon)
    return dice


def iou_score(pred, target, epsilon=1e-6):
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    iou = (intersection + epsilon) / (union + epsilon)
    return iou

In [None]:
def load_dataset(root):
    transformers = functions.Compose(
        [functions.ChanneledFixResize(256), functions.ToTensor(), functions.Normalize()]
    )

    full_dataset = functions.SolarDataset(
        root, image_folder="img/all", mask_folder="ann/all", transforms=transformers
    )

    return full_dataset

In [None]:
def load_device_and_model(weight_path):
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = torch.device("mps")
    unet = construct_unet(len(category_mapping))
    unet = torch.nn.DataParallel(unet)

    model = unet.module.to(device)

    return device, model

In [None]:
def get_save_dir(base_dir, checkpoint_name):
    checkpoint_dir = base_dir + "/checkpoints/"
    folders = [folder for folder in os.listdir(checkpoint_dir)]

    max_number = 0
    for folder in folders:
        number = int(folder[-1])
        if number > max_number:
            max_number = number

    new_folder_name = f"{checkpoint_name}{max_number + 1}"
    new_folder_path = os.path.join(checkpoint_dir, new_folder_name)

    os.makedirs(new_folder_path, exist_ok=True)

    return new_folder_path

In [None]:
full_dataset = load_dataset(root)

In [None]:
train_size = int(0.9 * len(full_dataset))
test_size = len(full_dataset) - train_size

train_subset, test_subset = random_split(
    full_dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42)
)

In [None]:
# this is needed so that we can a) split the dataset into train/test while ensuring our seed is the same as the wandb_k_fold, b) and preserver stuff like __getraw__ from solardataset when doing inference


class SubsetWithRaw(torch.utils.data.Subset):
    def __getraw__(self, idx):
        return self.dataset.__getraw__(self.indices[idx])


train_set = SubsetWithRaw(full_dataset, train_subset.indices)
test_set = SubsetWithRaw(full_dataset, test_subset.indices)

In [None]:
device, model = load_device_and_model(weight_path)

In [None]:
import matplotlib.pyplot as plt


def new_inference_and_show(idx, threshold=0.5):
    # Get the preprocessed image and multi-hot ground truth mask
    img, mask = test_loader.dataset.__getitem__(idx)
    img = img.to(device)

    # Get the raw image for display (assuming __getraw__ returns a PIL image)
    raw_img, _ = test_loader.dataset.__getraw__(idx)

    # --- Run inference ---
    # Get raw logits from the model, then apply Sigmoid and threshold
    logits = model(img.unsqueeze(0)).detach().cpu()  # shape: [1, 4, H, W]
    probs = torch.sigmoid(logits)  # shape: [1, 4, H, W]
    pred_mask = (probs > threshold).float().squeeze(0).numpy()  # shape: [4, H, W]

    # Ground truth is assumed to be already a 4-channel multi-hot mask.
    gt_mask = mask.cpu().numpy()  # shape: [4, H, W]

    # --- Visualization ---
    # Create a grid with 3 rows and 4 columns:
    #   Row 0: Raw image (displayed only once in the first column)
    #   Row 1: Ground truth masks for each class
    #   Row 2: Predicted masks for each class
    n_classes = len(category_mapping)
    class_names = [f"({k}) {v}" for k, v in category_mapping.items()]

    fig, axs = plt.subplots(3, n_classes, figsize=(4 * n_classes, 12))

    # Row 0: Display raw image in first subplot; hide other subplots in this row.
    axs[0, 0].imshow(raw_img.convert("L"), cmap="viridis")
    axs[0, 0].set_title("Raw Image")
    axs[0, 0].axis("off")
    for j in range(1, n_classes):
        axs[0, j].axis("off")

    # Row 1: Ground truth for each class (each channel)
    for j in range(n_classes):
        axs[1, j].imshow(gt_mask[j], cmap="viridis")
        axs[1, j].set_title(f"GT: {class_names[j]}")
        axs[1, j].axis("off")

    # Row 2: Predictions for each class (each channel)
    for j in range(n_classes):
        axs[2, j].imshow(pred_mask[j], cmap="viridis")
        axs[2, j].set_title(f"Pred: {class_names[j]}")
        axs[2, j].axis("off")

    fig.suptitle("Retrained Model Prediction", fontsize=16)

    plt.tight_layout()
    plt.show()

# Training

In [None]:
save_name = "model.pt"
save_dir = get_save_dir(str(root), checkpoint_name)
os.makedirs(save_dir, exist_ok=True)

original_config = {
    "batch_size_train": 8,
    "lr": 0.00092234,
    "gamma": 0.11727,
    "num_epochs": 2,
    # constants
    "batch_size_test": 8,
    "criterion": torch.nn.BCEWithLogitsLoss(),
    "k_folds": 5,
    # "lr_scheduler_step_size": 1,
}

config_serializable = original_config.copy()
config_serializable["criterion"] = str(config_serializable["criterion"])

with open(os.path.join(save_dir, "config.json"), "w", encoding="utf-8") as f:
    json.dump(config_serializable, f, ensure_ascii=False, indent=4)

run = wandb.init(
    project="pvcracks",
    entity="ojas-sanghi-university-of-arizona",
    config=original_config,
)
config = wandb.config

In [None]:
train_loader = DataLoader(train_set, batch_size=config.batch_size_train, shuffle=True)
test_loader = DataLoader(test_set, batch_size=config.batch_size_test, shuffle=False)

In [None]:
optimizer = Adam(model.parameters(), lr=config.lr)

# log gradients
run.watch(model, log_freq=100)

In [None]:
training_epoch_loss = []
test_epoch_loss = []
test_dice_loss = []
test_iou_loss = []

best_epoch_test_loss = float("inf")
best_epoch_dice = 0.0
best_epoch_iou = 0.0

for epoch in tqdm(range(1, config.num_epochs + 1)):
    training_step_loss = []
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        target = target.float()

        optimizer.zero_grad()
        # forward pass
        output = model(data)
        # calc loss -- bce with logits loss applies sigmoid interally
        training_loss = original_config["criterion"](output, target)
        # backward pass
        training_loss.backward()
        optimizer.step()
        # record loss
        training_step_loss.append(training_loss.item())

    test_step_loss = []
    dice_scores = []
    iou_scores = []
    for batch_idx, (data, target) in enumerate(test_loader):
        data, target = data.to(device), target.to(device)
        target = target.float()
        # forward pass
        # data = data.to(device)
        output = model(data)

        # calc loss -- bce with logits loss applies sigmoid interally
        test_loss = original_config["criterion"](output, target)
        test_step_loss.append(test_loss.item())

        # compute dice and iou
        pred_probs = torch.sigmoid(output)
        pred_binary = (pred_probs > 0.5).float()
        for i in range(pred_binary.size(1)):
            dice = dice_coefficient(pred_binary[:, i], target[:, i])
            iou = iou_score(pred_binary[:, i], target[:, i])
            dice_scores.append(dice.item())
            iou_scores.append(iou.item())

    epoch_train_loss = np.mean(training_step_loss)
    epoch_test_loss = np.mean(test_step_loss)
    epoch_avg_dice = np.mean(dice_scores)
    epoch_avg_iou = np.mean(iou_scores)

    training_epoch_loss.append(epoch_train_loss)
    test_epoch_loss.append(epoch_test_loss)
    test_dice_loss.append(epoch_avg_dice)
    test_iou_loss.append(epoch_avg_iou)

    run.log(
        {
            "train_loss": epoch_train_loss,
            "test_loss": epoch_test_loss,
            "avg_dice": epoch_avg_dice,
            "avg_iou": epoch_avg_iou,
        },
        step=epoch,
    )

    if epoch_test_loss < best_epoch_test_loss:
        best_epoch_test_loss = epoch_test_loss
        best_epoch_dice = epoch_avg_dice
        best_epoch_iou = epoch_avg_iou

        os.makedirs(os.path.join(save_dir, f"epoch_{epoch}"), exist_ok=True)
        torch.save(
            model.state_dict(), os.path.join(save_dir, f"epoch_{epoch}", save_name)
        )
        print(f"Saved model at epoch {epoch}")

    print(
        f"Epoch {epoch} best test_loss: {best_epoch_test_loss:.4f}, dice: {best_epoch_dice:.4f}, iou: {best_epoch_iou:.4f}"
    )

In [None]:
run.finish()

In [None]:
new_inference_and_show(-32)

In [None]:
new_inference_and_show(13)

In [None]:
new_inference_and_show(44)

In [None]:
new_inference_and_show(1)

In [None]:
new_inference_and_show(6)

In [None]:
# for i in range(100):
#     new_inference_and_show(i)

In [None]:
fig, ax = plt.subplots()

x = np.arange(1, len(training_epoch_loss) + 1, 1)

ax.scatter(x, training_epoch_loss, label="training loss")
ax.scatter(x, test_epoch_loss, label="test loss")
ax.legend()
ax.set_xlabel("Epoch")

print(training_epoch_loss)

In [None]:
test_epoch_loss