In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam

import numpy as np
import os
import json

import matplotlib.pyplot as plt
from tqdm import tqdm

from utils import train_functions, viz_functions

In [None]:
root = "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/Channeled_Combined_CWRU_LBNL_ASU_No_Empty/"

weight_path = "/Users/ojas/Desktop/saj/SANDIA/pvcracks_data/pv-vision_model.pt"

checkpoint_name = root.split("/")[-2] + "utilstest"

In [None]:
category_mapping = {0: "dark", 1: "busbar", 2: "crack", 3: "non-cell"}

In [None]:
train_dataset, val_dataset = train_functions.load_dataset(root)
device, model = train_functions.load_device_and_model(category_mapping)

# Training

In [None]:
batch_size_val = 1
batch_size_train = 1
lr = 1e-4
step_size = 1
gamma = 0.1
num_epochs = 1
criterion = torch.nn.BCEWithLogitsLoss()

save_dir = train_functions.get_save_dir(str(root), checkpoint_name)
os.makedirs(save_dir, exist_ok=True)

params_dict = {
    "batch_size_val": batch_size_val,
    "batch_size_train": batch_size_train,
    "lr": lr,
    "step_size": step_size,
    "gamma": gamma,
    "num_epochs": num_epochs,
    "criterion": str(criterion),
}

with open(os.path.join(save_dir, "params.json"), "w", encoding="utf-8") as f:
    json.dump(params_dict, f, ensure_ascii=False, indent=4)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=False)

In [None]:
optimizer = Adam(model.parameters(), lr=lr)
evaluate_metric = None
running_record = {"train": {"loss": []}, "val": {"loss": []}}

save_name = "model.pt"
cache_output = True

In [None]:
training_epoch_loss = []
val_epoch_loss = []

for epoch in tqdm(range(1, num_epochs + 1)):
    training_step_loss = []

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        target = target.float()

        optimizer.zero_grad()

        # forward pass
        output = model(data)

        # calc loss -- bce with logits loss applies sigmoid interally
        training_loss = criterion(output, target)

        # backward pass
        training_loss.backward()
        optimizer.step()

        # record loss
        training_step_loss.append(training_loss.item())

    training_epoch_loss.append(np.array(training_step_loss).mean())

    val_step_loss = []

    for batch_idx, (data, target) in enumerate(val_loader):
        data, target = data.to(device), target.to(device)
        target = target.float()

        # forward pass
        data = data.to(device)

        output = model(data)

        # calc loss -- bce with logits loss applies sigmoid interally
        val_loss = criterion(output, target)

        val_step_loss.append(val_loss.item())

    val_epoch_loss.append(np.array(val_step_loss).mean())

    os.makedirs(os.path.join(save_dir, f"epoch_{epoch}"), exist_ok=True)
    torch.save(model.state_dict(), os.path.join(save_dir, f"epoch_{epoch}", save_name))
    print(f"Saved model at epoch {epoch}")

In [None]:
viz_functions.channeled_inference_and_show(
    train_loader, device, model, category_mapping, -32
)

In [None]:
viz_functions.channeled_inference_and_show(
    train_loader, device, model, category_mapping, 13
)

In [None]:
viz_functions.channeled_inference_and_show(
    train_loader, device, model, category_mapping, 44
)

In [None]:
viz_functions.channeled_inference_and_show(
    train_loader, device, model, category_mapping, 1
)

In [None]:
viz_functions.channeled_inference_and_show(
    train_loader, device, model, category_mapping, 6
)

In [None]:
# for i in range(100):
#     viz_functions.channeled_inference_and_show(train_loader, device, model, category_mapping, i)

In [None]:
fig, ax = plt.subplots()

x = np.arange(1, len(training_epoch_loss) + 1, 1)

ax.scatter(x, training_epoch_loss, label="training loss")
ax.scatter(x, val_epoch_loss, label="validation loss")
ax.legend()
ax.set_xlabel("Epoch")

print(training_epoch_loss)

In [None]:
val_epoch_loss