Ok now I feel ready to take on training on the grasp net dataset.

- This is still just grasp xy rx ry rz beacuse we are not doing the FAS for z width yet.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# PYTHON
import numpy as np
import matplotlib.pyplot as plt
import time
import wandb
from typing import Callable
import gymnasium as gym
import yaml

# BAM
import bam_gym 
from bam_gym import print_reset, print_step
from bam_artist.heatmap_helper import show_heatmap_img, show_heatmap_img_grid
from bam_grasp.net.layers import HeatmapResNet, MaskedBCEWithLogitsLoss

# TORCH
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_

import torchvision as tv
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid
from torchinfo import summary
from torch.utils.data import Subset

np.random.seed(1)

In [None]:
from rgb_matters.constant import GRASPNET_ROOT, LABEL_DIR
from rgb_matters.data.utils.generate_anchor_matrix import NUM_VIEWS, NUM_ANGLES
from rgb_matters.data import GraspNetDataset


In [None]:
config_path = "/home/bam/bam_ws/src/bam_brain/actor/examples/ex10_train_config.yaml"

with open(config_path, "r") as f:
    C = yaml.load(f, Loader=yaml.FullLoader)

print(C)

C["train_label_root"] = LABEL_DIR
C["test_label_root"] = LABEL_DIR
C["eval_for_n_batch"] = 50
C["num_layers"] = 50

C["epsilon"] = 0.5
C["mask_loss_fn"] = False
C["n_executed_actions"] = 0.0

In [None]:
run = wandb.init(
    entity="zach-yamaoka-independent",
    project="bam-grasp",
    # job_type="test",
    group="ex10-grasp-xy-rxryrz",
    config=C
)

# max_action = 32 * 32 * 1
# print(f"Max action: {max_action}")
# assert run.config["n_executed_actions"] <= max_action

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = HeatmapResNet(num_layers=C["num_layers"], out_channels=NUM_ANGLES*NUM_VIEWS, upsample_logits=False, input_config='none').to(device)

loss_fn = nn.MSELoss()
loss_fn = loss_fn.cuda()
LR = float(C["lr"])
optimizer = torch.optim.Adam(params = model.parameters(), lr=LR)
# loss_fn = nn.BCEWithLogitsLoss()
masked_loss_fn = MaskedBCEWithLogitsLoss()


In [None]:
# torch.Size([3, 288, 384]) torch.Size([1, 288, 384]) torch.Size([360, 72, 96]) torch.Size([0])
# torch.Size([1, 6, 288, 384])
info = summary(
    model, 
    input_size=(1, 3, 288, 384),  # Example input size for CIFAR images
    col_names=["input_size", "output_size", "num_params", "params_percent", "kernel_size", "mult_adds"]
)
print(info)
torch.cuda.empty_cache()   # clears cached allocator

In [None]:
from rgb_matters.net.rgb_normal_net import RGBNormalNet

if False:

    net = RGBNormalNet(
        num_layers=50, use_normal=False, normal_only=False
    )
    info = summary(
        net, 
        input_size=[(1, 3, 288, 384), (1, 3, 288, 384)],  # Example input size for CIFAR images
        col_names=["input_size", "output_size", "num_params", "params_percent", "kernel_size", "mult_adds"]
    )
    print(info)
    torch.cuda.empty_cache() 

In [None]:

graspnet_train = GraspNetDataset(
    graspnet_root=GRASPNET_ROOT,
    label_root=C["train_label_root"],
    use_normal=C["use_normal"],
    camera=C["train_camera"],
    split=C["split"]["train_split"],
    grayscale=C["augmentation"]["grayscale"],
    colorjitter_scale=C["augmentation"]["colorjitter_scale"],
    random_crop=C["augmentation"]["random_crop"],
)

graspnet_test = GraspNetDataset(
    graspnet_root=GRASPNET_ROOT,
    label_root=C["test_label_root"],
    use_normal=C["use_normal"],
    camera=C["test_camera"],
    split=C["split"]["test_split"],
    grayscale=C["augmentation"]["grayscale"],
    colorjitter_scale=0, # no aug on testing
    random_crop=0,
)

subset_indices = list(range(1))
overfit_dataset = Subset(graspnet_train, subset_indices)

eval_test_dataloader = DataLoader(
    graspnet_test,
    shuffle=True,
    batch_size=C["eval_batch_size"],
    num_workers=12,
)
eval_train_dataloader = DataLoader(
    graspnet_train,
    shuffle=False,
    batch_size=C["eval_batch_size"],
    num_workers=12,
)
train_dataloader = DataLoader(
    graspnet_train,
    shuffle=True,
    batch_size=C["batch_size"],
    num_workers=12,
)

overfit_dataloader = DataLoader(
    overfit_dataset,
    shuffle=False,
    batch_size=C["batch_size"],
    num_workers=12,
)


In [None]:
# from bam_utils.python.torch_helper import show_batch, show_batch_histograms

# show_batch(graspnet_test)

In [None]:
# Test one batch
rgb, depth, label, normal = next(iter(graspnet_train))
print(rgb.shape, depth.shape, label.shape, normal.shape)
rgb, label = rgb.to(device), label.to(device)
rgb = rgb.unsqueeze(0)

logits = model(rgb) 
print(logits.shape)

assert logits.shape[1:] == label.shape


In [None]:
LOG = dict(
    epoch_total = 0,
    batch_total = 0,
    samples_total = 0,
    wall_time_total = 0,
    lr = float(C["lr"])
)

print(LOG)

In [None]:
def train(data_loader: DataLoader):
    model.train()
    n_total = len(data_loader.dataset)
    n_batches = len(data_loader)
    running_loss = 0.0

    K = C["n_executed_actions"]

    for batch, (rgb, depth, label, normal) in enumerate(data_loader):
        X, y = rgb.to(device), label.to(device)

        LOG["batch_total"] += 1
        LOG["samples_total"] += len(X)

        logits = model(X) 

        if C["mask_loss_fn"]: # this essential does epsilon = 1.0 random selection... may be suboptimal beacuse you keep on learning about the wrong actions... grr...

            B, CH, H, W = logits.shape
            flat = logits.view(B, -1)                  # [B, C*H*W]
            assert K <= flat.size(1)

            # 1) Exploitation: per-batch top-K (largest logits)
            top_vals, top_flat_idx = torch.topk(flat, K, dim=1)  # [B, K]
            c = top_flat_idx // (H * W)
            rem = top_flat_idx % (H * W)
            h = rem // W
            w = rem % W
            executed_list = torch.stack([config_path, h, w], dim=2).to(torch.long)  # [B, K, 3]

            # 2) Exploration: with prob epsilon, replace each selected (c,h,w) with a random one
            if C["epsilon"] > 0.0:
                explore_mask = torch.bernoulli(
                    torch.full((B, K), float(C["epsilon"]), device=X.device)
                ).bool()  # True => replace with random

                if explore_mask.any():
                    rand_c = torch.randint(0, CH, (B, K), device=X.device)
                    rand_h = torch.randint(0, H, (B, K), device=X.device)
                    rand_w = torch.randint(0, W, (B, K), device=X.device)

                    executed_list[..., 0] = torch.where(explore_mask, rand_c, executed_list[..., 0])
                    executed_list[..., 1] = torch.where(explore_mask, rand_h, executed_list[..., 1])
                    executed_list[..., 2] = torch.where(explore_mask, rand_w, executed_list[..., 2])

            assert executed_list.shape == (B, K, 3)

            # Gather labels at executed indices: y is [B, C, H, W] -> labels [B, K]
            b_idx = torch.arange(B, device=X.device)[:, None].expand(B, K)
            labels = y[
                b_idx,
                executed_list[..., 0],
                executed_list[..., 1],
                executed_list[..., 2],
            ]  # [B, K]

            executed_list_per_batch = list(executed_list.unbind(0))  # B × [K,3]
            labels_per_batch        = list(labels.unbind(0))         # B × [K]

            train_loss: torch.Tensor = masked_loss_fn.forward(logits, executed_list_per_batch, labels_per_batch)
        else:
            train_loss: torch.Tensor = loss_fn(logits, y) 
            
        optimizer.zero_grad()
        train_loss.backward()
        if C["grad_clip"] > 0:
            clip_grad_norm_(model.parameters(), C["grad_clip"])
        optimizer.step()

        running_loss += train_loss.item()

        if batch % 10 == 0:
            loss, current = train_loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{n_total:>5d}]")

            run.log({
                "train/train_loss": loss,
            },
            step = LOG["samples_total"])

        if LOG["batch_total"] % (20000) == 0 and LOG["batch_total"] > 100:
            LR *= float(C["lr_decay"])
            optimizer = torch.optim.Adam(params = model.parameters(), lr=LR)
            LOG["lr"] = LR


    LOG["avg_train_loss"] = running_loss / n_batches


In [None]:
def test(data_loader: DataLoader, namespace="") -> dict[str, float]:
    n_total = len(data_loader.dataset)
    n_batches = len(data_loader)
    model.eval()
    test_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for i, (rgb, depth, label, normal) in enumerate(data_loader):
            X: torch.Tensor = rgb.to(device)
            y: torch.Tensor = label.to(device)
            logits: torch.Tensor = model(X)

            # accumulate loss
            test_loss += loss_fn(logits, y).item()

            # BUG: you cannot use this classificaiton style accuracy
            # correct += (logits.argmax(1) == y).type(torch.float).sum().item()

            # --- accuracy ---
            N, CH, H, W = logits.shape
            # flatten spatial+channel dims
            logits_flat = logits.view(N, -1)   # (N, C*H*W)
            y_flat = y.view(N, -1)             # (N, H*W) or (N, C*H*W) depending on encoding

            # argmax per sample
            pred_idx = logits_flat.argmax(dim=1)  # (N,)

            hits = y_flat[torch.arange(N), pred_idx]          # (N,)
            correct += (hits == 1).sum().item()
            total += N

            if i > C["eval_for_n_batch"]:
                break


    LOG[namespace + "avg_test_loss"] = test_loss / n_batches
    LOG[namespace + "avg_test_accuracy"] = correct / n_total

    print(f"Test Error: \n Accuracy: {(100*LOG[namespace + "avg_test_accuracy"]):>0.1f}%, Avg loss: {LOG[namespace + "avg_test_loss"]:>8f} \n")



In [None]:
if C["mask_loss_fn"]:
    print(f"Using masked loss fucntion for {C['n_executed_actions']} actions executed")

for epoch in range(300):
    print(f"Epoch {LOG['epoch_total']}\n-------------------------------")

    epoch_start = time.time()
    if True:
        train(train_dataloader)
        LOG["wall_time_total"] += time.time() - epoch_start
        LOG["epoch_total"] += 1

        test(eval_train_dataloader, 'eval_train/') 
        test(eval_test_dataloader, 'eval_test/') 

    if False:
        train(overfit_dataloader)
        LOG["wall_time_total"] += time.time() - epoch_start
        LOG["epoch_total"] += 1

        test(overfit_dataloader, 'eval_train/') 
        test(eval_test_dataloader, 'eval_test/') 


    run.log(LOG)

    
    # Checkpoint the model every 100 epochs
    if LOG["epoch_total"] % 100 == 0:
        checkpoint_path = f"model_checkpoint_epoch_{LOG['epoch_total']}.pt"
        torch.save({
            'epoch': LOG["epoch_total"],
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'lr': LOG["lr"],
        }, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")


TODO! Time to write this as a standalone python script!!!!
with args... just like a launch file :)