In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam

import numpy as np
import os
import json

import matplotlib.pyplot as plt
from tqdm import tqdm

from utils.unet_model import construct_unet
from utils import img_functions
from torch.utils.data import random_split
import wandb

from utils import train_functions, viz_functions

ModuleNotFoundError: No module named 'tutorials'

In [None]:
root = "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/Channeled_Combined_CWRU_LBNL_ASU_No_Empty/"


model_weight_paths = {
    "emma_retrained": "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/retrained_pv-vision_model.pt",
    "original": "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/pv-vision_model.pt",
}

# weight_path = model_weight_paths["emma_retrained"]
weight_path = model_weight_paths["original"]

checkpoint_name = root.split("/")[-2]

In [None]:
category_mapping = {0: "dark", 1: "busbar", 2: "crack", 3: "non-cell"}

In [None]:
def dice_coefficient(pred, target, epsilon=1e-6):
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    dice = (2.0 * intersection + epsilon) / (union + epsilon)
    return dice


def iou_score(pred, target, epsilon=1e-6):
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    iou = (intersection + epsilon) / (union + epsilon)
    return iou

In [None]:
full_dataset = train_functions.load_dataset(root, full_set=True)

In [None]:
train_size = int(0.9 * len(full_dataset))
test_size = len(full_dataset) - train_size

train_subset, test_subset = random_split(
    full_dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42)
)

In [None]:
# this is needed so that we can a) split the dataset into train/test while ensuring our seed is the same as the wandb_k_fold, b) and preserver stuff like __getraw__ from solardataset when doing inference


class SubsetWithRaw(torch.utils.data.Subset):
    def __getraw__(self, idx):
        return self.dataset.__getraw__(self.indices[idx])


train_set = SubsetWithRaw(full_dataset, train_subset.indices)
test_set = SubsetWithRaw(full_dataset, test_subset.indices)

In [None]:
device, model = train_functions.load_device_and_model(weight_path)

# Training

In [None]:
save_name = "model.pt"
save_dir = train_functions.get_save_dir(str(root), checkpoint_name)
os.makedirs(save_dir, exist_ok=True)

original_config = {
    "batch_size_train": 8,
    "lr": 0.00092234,
    "gamma": 0.11727,
    "num_epochs": 2,
    # constants
    "batch_size_test": 8,
    "criterion": torch.nn.BCEWithLogitsLoss(),
    "k_folds": 5,
    # "lr_scheduler_step_size": 1,
}

config_serializable = original_config.copy()
config_serializable["criterion"] = str(config_serializable["criterion"])

with open(os.path.join(save_dir, "config.json"), "w", encoding="utf-8") as f:
    json.dump(config_serializable, f, ensure_ascii=False, indent=4)

run = wandb.init(
    project="pvcracks",
    entity="ojas-sanghi-university-of-arizona",
    config=original_config,
)
config = wandb.config

In [None]:
train_loader = DataLoader(train_set, batch_size=config.batch_size_train, shuffle=True)
test_loader = DataLoader(test_set, batch_size=config.batch_size_test, shuffle=False)

In [None]:
optimizer = Adam(model.parameters(), lr=config.lr)

# log gradients
run.watch(model, log_freq=100)

In [None]:
training_epoch_loss = []
test_epoch_loss = []
test_dice_loss = []
test_iou_loss = []

best_epoch_test_loss = float("inf")
best_epoch_dice = 0.0
best_epoch_iou = 0.0

for epoch in tqdm(range(1, config.num_epochs + 1)):
    training_step_loss = []
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        target = target.float()

        optimizer.zero_grad()
        # forward pass
        output = model(data)
        # calc loss -- bce with logits loss applies sigmoid interally
        training_loss = original_config["criterion"](output, target)
        # backward pass
        training_loss.backward()
        optimizer.step()
        # record loss
        training_step_loss.append(training_loss.item())

    test_step_loss = []
    dice_scores = []
    iou_scores = []
    for batch_idx, (data, target) in enumerate(test_loader):
        data, target = data.to(device), target.to(device)
        target = target.float()
        # forward pass
        # data = data.to(device)
        output = model(data)

        # calc loss -- bce with logits loss applies sigmoid interally
        test_loss = original_config["criterion"](output, target)
        test_step_loss.append(test_loss.item())

        # compute dice and iou
        pred_probs = torch.sigmoid(output)
        pred_binary = (pred_probs > 0.5).float()
        for i in range(pred_binary.size(1)):
            dice = dice_coefficient(pred_binary[:, i], target[:, i])
            iou = iou_score(pred_binary[:, i], target[:, i])
            dice_scores.append(dice.item())
            iou_scores.append(iou.item())

    epoch_train_loss = np.mean(training_step_loss)
    epoch_test_loss = np.mean(test_step_loss)
    epoch_avg_dice = np.mean(dice_scores)
    epoch_avg_iou = np.mean(iou_scores)

    training_epoch_loss.append(epoch_train_loss)
    test_epoch_loss.append(epoch_test_loss)
    test_dice_loss.append(epoch_avg_dice)
    test_iou_loss.append(epoch_avg_iou)

    run.log(
        {
            "train_loss": epoch_train_loss,
            "test_loss": epoch_test_loss,
            "avg_dice": epoch_avg_dice,
            "avg_iou": epoch_avg_iou,
        },
        step=epoch,
    )

    if epoch_test_loss < best_epoch_test_loss:
        best_epoch_test_loss = epoch_test_loss
        best_epoch_dice = epoch_avg_dice
        best_epoch_iou = epoch_avg_iou

        os.makedirs(os.path.join(save_dir, f"epoch_{epoch}"), exist_ok=True)
        torch.save(
            model.state_dict(), os.path.join(save_dir, f"epoch_{epoch}", save_name)
        )
        print(f"Saved model at epoch {epoch}")

    print(
        f"Epoch {epoch} best test_loss: {best_epoch_test_loss:.4f}, dice: {best_epoch_dice:.4f}, iou: {best_epoch_iou:.4f}"
    )

In [None]:
run.finish()

In [None]:
viz_functions.channeled_inference_and_show(test_loader, device, model, category_mapping, -32)

In [None]:
viz_functions.channeled_inference_and_show(test_loader, device, model, category_mapping, 13)

In [None]:
viz_functions.channeled_inference_and_show(test_loader, device, model, category_mapping, 44)

In [None]:
viz_functions.channeled_inference_and_show(test_loader, device, model, category_mapping, 1)

In [None]:
viz_functions.channeled_inference_and_show(test_loader, device, model, category_mapping, 6)

In [None]:
# for i in range(100):
#     viz_functions.channeled_inference_and_show(test_loader, device, model, category_mapping, i)

In [None]:
fig, ax = plt.subplots()

x = np.arange(1, len(training_epoch_loss) + 1, 1)

ax.scatter(x, training_epoch_loss, label="training loss")
ax.scatter(x, test_epoch_loss, label="test loss")
ax.legend()
ax.set_xlabel("Epoch")

print(training_epoch_loss)

In [None]:
test_epoch_loss