In [1]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam

import numpy as np
import os
import json

import matplotlib.pyplot as plt
from tqdm import tqdm

from pvcracks.utils import train_functions, viz_functions

In [2]:
root = "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/Channeled_Combined_CWRU_LBNL_ASU_No_Empty_RNE_Revise/"

checkpoint_name = root.split("/")[-2]

In [3]:
category_mapping = {0: "dark", 1: "busbar", 2: "crack", 3: "non-cell"}

In [4]:
train_dataset, val_dataset = train_functions.load_dataset(root)
device, model = train_functions.load_device_and_model(category_mapping)

# Training

In [5]:
batch_size_val = 1
batch_size_train = 1
lr = 1e-4
step_size = 1
gamma = 0.1
num_epochs = 2
criterion = torch.nn.BCEWithLogitsLoss()

save_dir = train_functions.get_save_dir(str(root), checkpoint_name)
os.makedirs(save_dir, exist_ok=True)

params_dict = {
    "batch_size_val": batch_size_val,
    "batch_size_train": batch_size_train,
    "lr": lr,
    "step_size": step_size,
    "gamma": gamma,
    "num_epochs": num_epochs,
    "criterion": str(criterion),
}

with open(os.path.join(save_dir, "params.json"), "w", encoding="utf-8") as f:
    json.dump(params_dict, f, ensure_ascii=False, indent=4)

In [6]:
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=False)

In [7]:
optimizer = Adam(model.parameters(), lr=lr)
evaluate_metric = None
running_record = {"train": {"loss": []}, "val": {"loss": []}}

save_name = "model.pt"
cache_output = True

In [None]:
def get_metrics(pred, target, epsilon=1e-6):
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    
    tp = intersection
    fp = pred.sum() - tp
    fn = target.sum() - tp
    tn = pred.numel() - (tp + fp + fn)
    
    accuracy = (tp + tn) / (tp + tn + fp + fn + epsilon)
    precision = (tp + epsilon) / (tp + fp + epsilon)
    recall = (tp + epsilon) / (tp + fn + epsilon)
    f1 = (2 * precision * recall) / (precision + recall + epsilon)
    iou = (intersection + epsilon) / (union + epsilon)
    
    return {
        "accuracy": accuracy.item(),
        "precision": precision.item(),
        "recall": recall.item(),
        "dice": f1.item(),
        "iou": iou.item()
    }

training_epoch_loss = []
val_epoch_loss = []

for epoch in tqdm(range(1, num_epochs + 1)):
    training_step_loss = []

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        target = target.float()

        optimizer.zero_grad()

        # forward pass
        output = model(data)

        # calc loss -- bce with logits loss applies sigmoid interally
        training_loss = criterion(output, target)

        # backward pass
        training_loss.backward()
        optimizer.step()

        # record loss
        training_step_loss.append(training_loss.item())

    training_epoch_loss.append(np.array(training_step_loss).mean())

    val_step_loss = []
    val_metrics_storage = {}

    for batch_idx, (data, target) in enumerate(val_loader):
        data, target = data.to(device), target.to(device)
        target = target.float()

        # forward pass
        data = data.to(device)

        output = model(data)

        # calc loss -- bce with logits loss applies sigmoid interally
        val_loss = criterion(output, target)

        val_step_loss.append(val_loss.item())

        # Compute metrics per channel
        pred_probs = torch.sigmoid(output)
        pred_binary = (pred_probs > 0.5).float()
        
        for i in range(pred_binary.size(1)):
            if i not in val_metrics_storage:
                val_metrics_storage[i] = {k: [] for k in ["accuracy", "precision", "recall", "dice", "iou"]}
            
            m = get_metrics(pred_binary[:, i], target[:, i])
            for k, v in m.items():
                val_metrics_storage[i][k].append(v)

    val_epoch_loss.append(np.array(val_step_loss).mean())

    print(f"Epoch {epoch} Metrics:")
    # Per class average
    for i in sorted(val_metrics_storage.keys()):
        c_name = category_mapping.get(i, f"class_{i}")
        c_avgs = {k: np.mean(v) for k, v in val_metrics_storage[i].items()}
        print(f"  {c_name}: {c_avgs}")
        
    # Global average
    global_avgs_list = {k: [] for k in ["accuracy", "precision", "recall", "dice", "iou"]}
    for i in val_metrics_storage:
        for k in global_avgs_list:
            global_avgs_list[k].extend(val_metrics_storage[i][k])
            
    final_global_avgs = {k: np.mean(v) for k, v in global_avgs_list.items()}
    print(f"  Aggregate: {final_global_avgs}")


    os.makedirs(os.path.join(save_dir, f"epoch_{epoch}"), exist_ok=True)
    torch.save(model.state_dict(), os.path.join(save_dir, f"epoch_{epoch}", save_name))
    print(f"Saved model at epoch {epoch}")

  0%|          | 0/2 [00:00<?, ?it/s]

Epoch 1 Metrics:
  dark: {'accuracy': np.float64(0.9993067358294104), 'precision': np.float64(0.8974359230883957), 'recall': np.float64(0.9065076329042108), 'dice': np.float64(0.8128577883574966), 'iou': np.float64(0.8124905623107876)}
  busbar: {'accuracy': np.float64(0.9639099642761753), 'precision': np.float64(0.7631089084944463), 'recall': np.float64(0.9375813718025501), 'dice': np.float64(0.8276390014318941), 'iou': np.float64(0.7156763084435365)}
  crack: {'accuracy': np.float64(0.9813991122775607), 'precision': np.float64(0.3658282790655074), 'recall': np.float64(0.745650381896432), 'dice': np.float64(0.30709135677645416), 'iou': np.float64(0.22497829348251977)}
  non-cell: {'accuracy': np.float64(0.9759155990731003), 'precision': np.float64(0.43034185962671906), 'recall': np.float64(0.9838367308306898), 'dice': np.float64(0.5825299907189149), 'iou': np.float64(0.42022088434324306)}
  Aggregate: {'accuracy': np.float64(0.9801328528640617), 'precision': np.float64(0.6141787425687

 50%|█████     | 1/2 [04:43<04:43, 283.83s/it]

Saved model at epoch 1


In [None]:
def get_metrics(pred, target, epsilon=1e-6):
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    
    tp = intersection
    fp = pred.sum() - tp
    fn = target.sum() - tp
    tn = pred.numel() - (tp + fp + fn)
    
    accuracy = (tp + tn) / (tp + tn + fp + fn + epsilon)
    precision = (tp + epsilon) / (tp + fp + epsilon)
    recall = (tp + epsilon) / (tp + fn + epsilon)
    f1 = (2 * precision * recall) / (precision + recall + epsilon)
    iou = (intersection + epsilon) / (union + epsilon)
    
    return {
        "accuracy": accuracy.item(),
        "precision": precision.item(),
        "recall": recall.item(),
        "dice": f1.item(),
        "iou": iou.item()
    }

training_epoch_loss = []
val_epoch_loss = []

for epoch in tqdm(range(1, num_epochs + 1)):
    training_step_loss = []

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        target = target.float()

        optimizer.zero_grad()

        # forward pass
        output = model(data)

        # calc loss -- bce with logits loss applies sigmoid interally
        training_loss = criterion(output, target)

        # backward pass
        training_loss.backward()
        optimizer.step()

        # record loss
        training_step_loss.append(training_loss.item())

    training_epoch_loss.append(np.array(training_step_loss).mean())

    val_step_loss = []
    val_metrics_storage = {}

    for batch_idx, (data, target) in enumerate(val_loader):
        data, target = data.to(device), target.to(device)
        target = target.float()

        # forward pass
        data = data.to(device)

        output = model(data)

        # calc loss -- bce with logits loss applies sigmoid interally
        val_loss = criterion(output, target)

        val_step_loss.append(val_loss.item())

        # Compute metrics per channel
        pred_probs = torch.sigmoid(output)
        pred_binary = (pred_probs > 0.5).float()
        
        for i in range(pred_binary.size(1)):
            if i not in val_metrics_storage:
                val_metrics_storage[i] = {k: [] for k in ["accuracy", "precision", "recall", "dice", "iou"]}
            
            m = get_metrics(pred_binary[:, i], target[:, i])
            for k, v in m.items():
                val_metrics_storage[i][k].append(v)

    val_epoch_loss.append(np.array(val_step_loss).mean())

    print(f"Epoch {epoch} Metrics:")
    # Per class average
    for i in sorted(val_metrics_storage.keys()):
        c_name = category_mapping.get(i, f"class_{i}")
        c_avgs = {k: np.mean(v) for k, v in val_metrics_storage[i].items()}
        print(f"  {c_name}: {c_avgs}")
        
    # Global average
    global_avgs_list = {k: [] for k in ["accuracy", "precision", "recall", "dice", "iou"]}
    for i in val_metrics_storage:
        for k in global_avgs_list:
            global_avgs_list[k].extend(val_metrics_storage[i][k])
            
    final_global_avgs = {k: np.mean(v) for k, v in global_avgs_list.items()}
    print(f"  Aggregate: {final_global_avgs}")


    os.makedirs(os.path.join(save_dir, f"epoch_{epoch}"), exist_ok=True)
    torch.save(model.state_dict(), os.path.join(save_dir, f"epoch_{epoch}", save_name))
    print(f"Saved model at epoch {epoch}")

In [None]:
viz_functions.channeled_inference_and_show(
    train_loader, device, model, category_mapping, -32
)

In [None]:
viz_functions.channeled_inference_and_show(
    train_loader, device, model, category_mapping, 13
)

In [None]:
viz_functions.channeled_inference_and_show(
    train_loader, device, model, category_mapping, 44
)

In [None]:
viz_functions.channeled_inference_and_show(
    train_loader, device, model, category_mapping, 1
)

In [None]:
viz_functions.channeled_inference_and_show(
    train_loader, device, model, category_mapping, 6
)

In [None]:
# for i in range(100):
#     viz_functions.channeled_inference_and_show(train_loader, device, model, category_mapping, i)

In [None]:
fig, ax = plt.subplots()

x = np.arange(1, len(training_epoch_loss) + 1, 1)

ax.scatter(x, training_epoch_loss, label="training loss")
ax.scatter(x, val_epoch_loss, label="validation loss")
ax.legend()
ax.set_xlabel("Epoch")

print(training_epoch_loss)

In [None]:
val_epoch_loss