Notes:
- In the spirit of RGB matters, I am not going to predict the width...
- that could be done with FAS, and is depdent on cordinates, whilest the orientation and position is not!

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Callable
import gymnasium as gym
import bam_gym 
from bam_gym import print_reset, print_step
import time
np.random.seed(1)
from bam_artist.heatmap_helper import show_heatmap_img, show_heatmap_img_grid

In [None]:
screen_size = 2**5
min_width = np.ceil(screen_size*0.01)
max_width = min_width * 3
print("Screen Size:", screen_size)
print("Min Width:", min_width)
print("Max Width:", max_width)
render_screen_scale = 500/screen_size
N_ANGLES = 6    
ANGLE_RANGE = [0, np.pi]

env = gym.make(
    "bam_local/GraspXYRZ",
    render_mode="human",
    screen_size=(screen_size, screen_size),
    num_rectangles=5,
    min_width=min_width,
    max_width=max_width,
    min_height_ratio=2,
    max_height_ratio=5.0,
    render_screen_scale=render_screen_scale,
    n_angles=N_ANGLES,
    angle_range=ANGLE_RANGE
)

print("Action Space:", env.action_space)
print("Observation Space:", env.observation_space)
state, info = env.reset()
print(state.shape)

In [None]:
def view_labels(state, label, alpha=0.5):
    labels = []

    for i in range(N_ANGLES):
        step_size = (ANGLE_RANGE[1] - ANGLE_RANGE[0]) / N_ANGLES
        angle = ANGLE_RANGE[0] + i*step_size
        # Wrap angle from pi to -pi, for easier understanding...
        wrapped_angle = (angle + np.pi) % (2 * np.pi) - np.pi
        labels.append(f"Angle {np.rad2deg(wrapped_angle):.1f}째 ~ {wrapped_angle:.1f}rad")
        # labels.append(f"Angle {np.rad2deg(angle):.1f}째 ~ {angle:.1f}rad")

    # When thinking about if the angle is correct, keep in mind the neutral positoin of the gripper... if at gripper 0, it should pick up an item along the x axis...

    show_heatmap_img_grid(state, label, rows=2, cols=8, alpha=alpha, title_list=labels)

In [None]:
state, info = env.reset()
env.unwrapped.render_mode = "none"
for i in range(10):
    action = env.action_space.sample()
    print(action)
    # action = policy(state)
    state, reward, terminated, truncated, info = env.step(action)
    view_labels(state, info["curr_label"], alpha=0.9)
    # time.sleep(0.5)
    # print_step(action, state, reward, terminated, truncated, info, i)
    break

env.close()

In [None]:
print(state.shape)
label = env.unwrapped.get_binary_label()
print(label.shape)

# for row in np.argmax(label, axis=2):
#     print(row)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid

from torchinfo import summary

In [None]:
dataset_scale = 100
train_dataset_size = 6 * dataset_scale
test_dataset_size = 1 * dataset_scale



def to_tensor_pairs(numpy_pairs: list[tuple[np.ndarray, np.ndarray]]):
    tensor_pairs = []
    for img, mask in numpy_pairs:
        img_t = torch.from_numpy(img).permute(2,0,1).float() / 255.0
        mask_t = torch.from_numpy(mask).permute(2,0,1).float()
        tensor_pairs.append((img_t, mask_t))
    return tensor_pairs

train_data_pairs = to_tensor_pairs(env.unwrapped.create_dataset(train_dataset_size))

test_data_pairs = to_tensor_pairs(env.unwrapped.create_dataset(test_dataset_size))


print("Size of train tensor dataset:", len(train_data_pairs))
print("Shape of first image:", train_data_pairs[0][0].shape)
print("Shape of first mask:", train_data_pairs[0][1].shape)

print("Size of test tensor dataset:", len(test_data_pairs))
print("Shape of first image:", test_data_pairs[0][0].shape)
print("Shape of first mask:", test_data_pairs[0][1].shape)

In [None]:
images_to_plot = []
for i in range(64):
    images_to_plot.append(train_data_pairs[i][0])

images_to_plot = torch.stack(images_to_plot)

fig, ax = plt.subplots(figsize=(12, 12))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(make_grid(images_to_plot[:64], nrow=8, pad_value=1).permute(1, 2, 0))


In [None]:
# Debug Nvidia Failure, solution was just to restart...
import torch, subprocess, os

print("torch:", torch.__version__)
print("is_available:", torch.cuda.is_available())
print("device_count:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("current_device:", torch.cuda.current_device())
    print("name:", torch.cuda.get_device_name(0))
    print("capability:", torch.cuda.get_device_capability(0))
    try:
        a = torch.randn(1, device="cuda")
        b = a.cpu()
        print("basic CUDA op: OK")
        free, total = torch.cuda.mem_get_info()
        print("mem free/total (MB):", free//2**20, "/", total//2**20)
    except Exception as e:
        print("basic CUDA op failed:", e)

try:
    print(subprocess.check_output(["nvidia-smi"]).decode().splitlines()[0])
except Exception as e:
    print("nvidia-smi failed:", e)

In [None]:
imgs = torch.stack([img for img, _ in train_data_pairs], dim=0)

# Compute per-channel mean and std across N,H,W
mean = imgs.mean(dim=(0,2,3))   # [3]
std  = imgs.std(dim=(0,2,3))    # [3]

print("Mean:", mean)
print("Std:", std)

In [None]:

class HeatmapDataset(Dataset):
    def __init__(
        self,
        input_label_pairs: list[tuple[torch.Tensor, torch.Tensor]],
        transform: Callable | None = None,
        label_transforms: Callable | None = None,
    ):
        """
        tensor_pairs: list of (img, mask) where
            img: torch.Tensor [3,H,W], float in [0,1]
            mask: torch.Tensor [1,H,W], float {0,1}
        transform: applied to images
        label_transforms: applied to masks
        """
        self.data = input_label_pairs
        self.transform = transform
        self.label_transforms = label_transforms


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, label = self.data[idx]

        if self.transform:
            img = self.transform(img)
        if self.label_transforms:
            label = self.label_transforms(label)

        return img, label


transform = transforms.Compose([
    # transforms.RandomCrop(32, padding=4),
    # transforms.RandomHorizontalFlip(),
    # transforms.ToTensor(),
    transforms.Normalize(mean, std, inplace=True),
])
label_transforms = None
# transforms.Compose([
#     # transforms.ToTensor(),
#     # transforms.Normalize(mean, std, inplace=True),
# ])


batch_size = 2**9
print(f"Batch size: {batch_size}")


train_dataset = HeatmapDataset(train_data_pairs, transform=transform, label_transforms=label_transforms)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

test_dataset = HeatmapDataset(test_data_pairs, transform=transform, label_transforms=label_transforms)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
def show_batch(data_loader: DataLoader):
    for images, labels in data_loader:
        fig, ax = plt.subplots(figsize=(12, 12))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images[:64], nrow=8, pad_value=1).permute(1, 2, 0))
        break

show_batch(train_dataloader)




In [None]:

def show_batch_histograms(data_loader: torch.utils.data.DataLoader, n_batches: int = 1):
    """
    Show histograms of pixel values from a few batches.
    - Plots per-channel histograms for images
    - Plots mask pixel distribution
    """
    for batch_idx, (images, masks) in enumerate(data_loader):
        # Flatten image values: [B, C, H, W] -> [Npixels_per_channel]
        img_vals = images.permute(1,0,2,3).reshape(images.shape[1], -1)  # [C, N]
        
        # Flatten mask values: [B,1,H,W] -> [N]
        mask_vals = masks.reshape(-1)
        
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))

        # --- Image histograms (per channel) ---
        colors = ['r', 'g', 'b'] if images.shape[1] == 3 else ['k']
        for c in range(images.shape[1]):
            axes[0].hist(img_vals[c].numpy().ravel(), bins=50, color=colors[c], alpha=0.5, label=f"Channel {c}")
        axes[0].set_title("Image Pixel Distribution")
        axes[0].legend()

        # --- Mask histogram ---
        axes[1].hist(mask_vals.numpy().ravel(), bins=50, color='gray', alpha=0.7)
        axes[1].set_title("Mask Pixel Distribution")

        plt.show()
        
        if batch_idx + 1 >= n_batches:
            break
show_batch_histograms(train_dataloader, n_batches=2)


In [None]:
from bam_grasp.net.layers.Heatmap_Resnet import HeatmapResNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HeatmapResNet(num_layers=18, out_channels=N_ANGLES).to(device)
# print(model)


In [None]:
summary(
    model, 
    input_size=(batch_size, 3, 32, 32),  # Example input size for CIFAR images
    col_names=["input_size", "output_size", "num_params", "params_percent", "kernel_size", "mult_adds"]
)

In [None]:
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(params = model.parameters(), lr=0.01)
# scheduler = torch.optim.lr_scheduler.LinearLR(optimizer)



In [None]:
# Test one batch
img, label = next(iter(train_dataloader))
img, label = img.to(device), label.to(device)
print(img.shape, label.shape)
logits = model(img) 
print(logits.shape)

assert logits.shape == label.shape

In [None]:
train_losses = []
test_losses = []
test_accuracies = []
epoch_total = 0


In [None]:
def train(data_loader: DataLoader) -> dict[str, float]:
    model.train()
    n_total = len(data_loader.dataset)
    n_batches = len(data_loader)
    running_loss = 0.0

    for batch, (X, y) in enumerate(data_loader):
        X, y = X.to(device), y.to(device)

        pred = model(X) 
        resized_pred = nn.functional.interpolate(pred, size=y.shape[-2:], mode="nearest")

        train_loss: torch.Tensor = loss_fn(resized_pred, y) 

        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        running_loss += train_loss.item()

        if batch % 10 == 0:
            loss, current = train_loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{n_total:>5d}]")

    avg_loss = running_loss / n_batches
    return {"avg_loss": avg_loss}


In [None]:
def test(data_loader: DataLoader) -> dict[str, float]:
    n_total = len(data_loader.dataset)
    n_batches = len(data_loader)
    model.eval()
    test_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for X, y in data_loader:
            X: torch.Tensor = X.to(device)
            y: torch.Tensor = y.to(device)
            logits: torch.Tensor = model(X)

            # accumulate loss
            test_loss += loss_fn(logits, y).item()

            # BUG: you cannot use this classificaiton style accuracy
            # correct += (logits.argmax(1) == y).type(torch.float).sum().item()

            # --- accuracy ---
            N, C, H, W = logits.shape
            # flatten spatial+channel dims
            logits_flat = logits.view(N, -1)   # (N, C*H*W)
            y_flat = y.view(N, -1)             # (N, H*W) or (N, C*H*W) depending on encoding

            # argmax per sample
            pred_idx = logits_flat.argmax(dim=1)  # (N,)

            hits = y_flat[torch.arange(N), pred_idx]          # (N,)
            correct += (hits == 1).sum().item()
            total += N


    avg_loss = test_loss / n_batches
    accuracy = correct / n_total

    print(f"Test Error: \n Accuracy: {(100*accuracy):>0.1f}%, Avg loss: {test_loss:>8f} \n")

    return {"test_loss": test_loss, "test_accuracy": accuracy}


In [None]:
for epoch in range(10):
    epoch_total += 1
    print(f"Epoch {epoch_total}\n-------------------------------")
    train_info = train(train_dataloader)
    test_info = test(test_dataloader)
    # test_info = test(train_dataloader) # can I overfit to 100% a dataset? 

    train_losses.append(train_info["avg_loss"])
    test_losses.append(test_info["test_loss"])
    test_accuracies.append(test_info["test_accuracy"])


In [None]:
plt.figure(figsize=(10,4))

plt.subplot(1,2,1)
plt.plot(train_losses, label="Train Loss")
plt.plot(test_losses, label="Test Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1,2,2)
plt.plot(test_accuracies, label="Test Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()

plt.show()

In [None]:

def preprocess_state(state: np.ndarray):

    transform = transforms.Compose([
        # transforms.RandomCrop(32, padding=4),
        # transforms.RandomHorizontalFlip(),
        # transforms.ToTensor(),
        transforms.Normalize(mean, std, inplace=False),
    ])

    state_tensor = torch.from_numpy(state).permute(2,0,1).float() / 255.0 # Convert to tensor as above
    state_tensor = transform(state_tensor) # Apply same transforms as test dataset
    state_tensor = state_tensor.unsqueeze(0) # add batch dim
    state_tensor = state_tensor.to(device) # move to same device as model

    return state_tensor

# Edited on Sept 4 to change order.. of x, y, c. May to need play around if it doens't work any more...
class HeatmapPolicy():
    def __init__(self, model: nn.Module, preprocess_state: Callable, epsilon=0.2):

        self.model = model
        self.model.eval()
        self.epsilon = epsilon
        self.iter = 0
        self.preprocess_state = preprocess_state
        self.last_heatmaps = None


    def __call__(self, state: np.ndarray):

        # Be very careful with order of output actions.
        # Gym usally expects x, y, c, Pytorch gives c, x, y
        self.iter += 1

        state = self.preprocess_state(state)


        heatmaps = self.model(state)
        # print("Logits shape:", heatmaps.shape)
        self.last_heatmaps = heatmaps
        B, C, H, W = heatmaps.shape

        if np.random.random() <= self.epsilon:
            cs = torch.randint(C, (B,)) 
            xs = torch.randint(W, (B,))
            ys = torch.randint(H, (B,))

            coords = torch.stack([xs, ys, cs], dim=1)  # [B,3]
            if C == 1:
                coords = coords[:, :2]  # drop channel column -> [B,2]

            if B == 1:
                return tuple(int(v) for v in coords[0])
            return coords

        # Exploit: global argmax over (C,H,W) for each sample
        heatmaps_flat = heatmaps.view(B, -1)                     # [B, C*H*W]
        argmax_idx = heatmaps_flat.argmax(dim=1)               # [B]

        # Decode to (x, y, c), then output as (x, y, c) (or (x,y) if C==1)
        c = argmax_idx // (H * W)                     # [B]
        rem = argmax_idx % (H * W)
        y = rem // W
        x = rem % W

        if C == 1:
            if B == 1:
                return int(x.item()), int(y.item())
            return torch.stack([x, y], dim=1)         # [B,2] -> (x,y)
        else:
            if B == 1:
                return int(c.item()), int(x.item()), int(y.item())
            return torch.stack([x, y, c], dim=1)      # [B,3] -> (x,y,c)

print(preprocess_state(state).shape)

policy = HeatmapPolicy(model, preprocess_state, epsilon=0.2)

In [None]:
state, info = env.reset()
policy.epsilon = 0.0
action = policy(state)
print(action)

# at zero deg the gripper should be point upwards! at 90, its 90deg offset..
# Without the offset, then we are expecting the angle to be along the longital axis of the item
gripper_offset = np.pi/2

def display_heatmap(state, action, policy):
    # Determine if action includes angle
    if len(action) == 2:
        x, y = action
        angle = None
    else:
        c, x, y = action
        angle = c * 2 * np.pi / N_ANGLES
        angle *= -1 # make z go in other direction, I was expecting angle to start and x and go up, but... computer coordinates are strange...
        # angle += gripper_offset
        # angle = 0

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))

    # Plot state with action
    axs[0].set_title("State")
    axs[0].imshow(state)
    axs[0].scatter([x], [y], color='white', marker='x', s=100, label='Action')
    if angle is not None:
        dx = np.cos(angle) * max_width
        dy = np.sin(angle) * max_width
        axs[0].arrow(x, y, dx, dy, head_width=np.ceil(max_width*0.1), head_length=np.ceil(max_width*0.1), fc='white', ec='white')
    axs[0].legend()

    # Plot heatmap with action
    axs[1].set_title("Heatmap")
    heatmap = policy.last_heatmaps[0,:,:,:].detach().cpu().numpy()
    # BUG: Arrow was appearing on a low score place... its beacuse you where just plotting the top angle...
    heatmap = heatmap.max(axis=0) # show the max of any of the angles
    im = axs[1].imshow(heatmap, cmap='hot')
    axs[1].scatter([x], [y], color='blue', marker='x', s=100, label='Action')
    if angle is not None:
        dx = np.cos(angle) * max_width
        dy = np.sin(angle) * max_width
        axs[1].arrow(x, y, dx, dy, head_width=np.ceil(max_width*0.1), head_length=np.ceil(max_width*0.1), fc='blue', ec='blue')
    axs[1].legend()
    fig.colorbar(im, ax=axs[1], fraction=0.046, pad=0.04)

    plt.tight_layout()
    plt.show()


C, X, Y = action
angle = C*2*np.pi/N_ANGLES
angle = (angle + np.pi) % (2 * np.pi) - np.pi

display_heatmap(state, action, policy)
heatmap = policy.last_heatmaps[0,:,:,:].detach().cpu().numpy()
heatmap = np.transpose(heatmap, (1, 2, 0))


print(f"Action Angle {np.rad2deg(angle):.1f}째 or {angle:.1f}rad")

view_labels(state, heatmap, alpha=0.9)
view_labels(state, info["curr_label"], alpha=0.9)

In [None]:
reward_history = []
epsilon_history = []

In [None]:
env.unwrapped.render_mode = "human"

In [None]:
state, info = env.reset()

policy.epsilon = 0.0

for i in range(10):
    # action = env.action_space.sample()
    action = policy(state)
    new_state, reward, terminated, truncated, info = env.step(action)
    reward_history.append(reward)
    epsilon_history.append(policy.epsilon)

    if reward == 0 and False:
        C, X, Y = action
        angle = C*2*np.pi/N_ANGLES
        angle = (angle + np.pi) % (2 * np.pi) - np.pi

        display_heatmap(state, action, policy)
        heatmap = policy.last_heatmaps[0,:,:,:].detach().cpu().numpy()
        heatmap = np.transpose(heatmap, (1, 2, 0))

     
        print(f"Action Angle {np.rad2deg(angle):.1f}째 or {angle:.1f}rad")

        view_labels(state, heatmap, alpha=0.9)
        view_labels(state, info["last_label"], alpha=0.9)

        break
    state = new_state

    # print_step(action, state, reward, terminated, truncated, info, i)

env.close()

In [None]:
plt.plot(reward_history, label="Reward")
plt.plot(epsilon_history, label="Epsilon", color='red')
# Compute moving average
window_size = 50
if len(reward_history) >= window_size:
    moving_avg = np.convolve(reward_history, np.ones(window_size)/window_size, mode='valid')
    plt.plot(range(window_size-1, len(reward_history)), moving_avg, label=f"Moving Avg ({window_size})", color='orange')

plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("Reward History and Moving Average")
plt.legend()
plt.show()