# CUDA Thread Indexing Visualisation

This notebook provides interactive visualisation to help understand how CUDA threads are organised into blocks and grids, and how they map to array indices.

In [None]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.widgets import Slider
import ipywidgets as widgets

In CUDA programming:
- **Threads** are the smallest execution units
- **Blocks** contain multiple threads
- **Grids** contain multiple blocks

Each thread calculates its unique index to determine which data element to process.

### Key Concepts to Remember:

- Each thread calculates its unique global index
- 1D index = `blockIdx.x * blockDim.x + threadIdx.x`
- 2D to linear index = `y * width + x`
- **Warps**: Hardware executes threads in groups of 32 (warps)
- Different blocks may execute in any order

### Important Performance Tips:

- Try your best to use block sizes that are multiples of warp size (32)
- Common efficient block sizes: 128, 256, 512
- Remember that masked threads in partial warps waste GPU resources

Use the interactive controls to explore different configurations and click on threads to see their index calculations.

## 1D Thread Indexing Visualisation

In [None]:
class ThreadVisualiser1D:
    def __init__(self, array_size=32, threads_per_block=32):
        self.array_size = array_size
        self.threads_per_block = threads_per_block
        self.num_blocks = (array_size + threads_per_block - 1) // threads_per_block
        self.warp_size = 32
        self.highlighted_thread = None
        
        # create figure and axis
        self.fig, self.ax = plt.subplots(figsize=(16, 10))
        self.fig.suptitle("1D Thread Indexing and Warp Visualisation", fontsize=16)
        
        # create sliders at the very bottom
        self.fig.subplots_adjust(bottom=0.15, top=0.85)
        
        # array size slider
        ax_array = plt.axes([0.2, 0.08, 0.6, 0.03])
        self.slider_array = Slider(ax_array, "Array Size", 8, 128, 
                                   valinit=array_size, valstep=8)
        
        # threads per block slider
        ax_threads = plt.axes([0.2, 0.03, 0.6, 0.03])
        self.slider_threads = Slider(ax_threads, "Threads/Block", 8, 128, 
                                     valinit=threads_per_block, valstep=8)
        
        # connect sliders to update function
        self.slider_array.on_changed(self.update_params)
        self.slider_threads.on_changed(self.update_params)
        
        # connect mouse click event
        self.fig.canvas.mpl_connect("button_press_event", self.on_click)
        
        # initial draw
        self.draw_visualisation()
        
    def update_params(self, val):
        self.array_size = int(self.slider_array.val)
        self.threads_per_block = int(self.slider_threads.val)
        self.num_blocks = (self.array_size + self.threads_per_block - 1) // self.threads_per_block
        self.highlighted_thread = None
        self.draw_visualisation()
        
    def draw_visualisation(self):
        self.ax.clear()
        
        # calculate dimensions
        block_width = 2.5
        block_height = 0.8
        warp_width = block_width / max(1, (self.threads_per_block + self.warp_size - 1) // self.warp_size)
        thread_width = warp_width / self.warp_size
        
        # vertical offset for warp labels
        y_offset = 0
        
        # draw blocks and threads
        for block_idx in range(self.num_blocks):
            block_x = block_idx * (block_width + 0.2)
            
            # draw block outline
            block_rect = patches.Rectangle((block_x, y_offset), block_width, 
                                           block_height, linewidth=3, 
                                           edgecolor="black", 
                                           facecolor="lightblue", alpha=0.3)
            self.ax.add_patch(block_rect)
            
            # label block
            self.ax.text(block_x + block_width/2, y_offset + block_height + 0.3, 
                         f"Block {block_idx}", ha="center", va="bottom", 
                         fontsize=12, fontweight="bold")
            
            # calculate warps in this block
            warps_in_block = (self.threads_per_block + self.warp_size - 1) // self.warp_size
            
            # draw warps within block
            for warp_idx in range(warps_in_block):
                warp_x = block_x + warp_idx * warp_width
                
                # draw warp outline
                warp_rect = patches.Rectangle((warp_x, y_offset), warp_width, 
                                              block_height, linewidth=2, 
                                              edgecolor="darkblue", 
                                              facecolor="none")
                self.ax.add_patch(warp_rect)
                
                # label warp
                self.ax.text(warp_x + warp_width/2, y_offset + block_height + 0.15, 
                             f"Warp {warp_idx}", ha="center", va="top", 
                             fontsize=9, style="italic")
                
                # draw threads within warp
                for thread_in_warp in range(self.warp_size):
                    thread_idx = warp_idx * self.warp_size + thread_in_warp
                    global_idx = block_idx * self.threads_per_block + thread_idx
                    thread_x = warp_x + thread_in_warp * thread_width
                    
                    # determine thread status
                    is_within_block = thread_idx < self.threads_per_block
                    is_within_array = global_idx < self.array_size
                    
                    # determine thread colour and status
                    if is_within_block and is_within_array:
                        # active thread with valid array access
                        if self.highlighted_thread == global_idx:
                            colour = "red"
                        else:
                            colour = "lightgreen"
                        thread_status = "active"
                    elif is_within_block and not is_within_array:
                        # active thread but would access out of bounds
                        if self.highlighted_thread == global_idx:
                            colour = "red"
                        else:
                            colour = "orange"
                        thread_status = "out_of_bounds"
                    else:
                        # masked thread (not active in warp)
                        colour = "lightgray"
                        thread_status = "masked"

                    # draw thread
                    alpha = 0.7 if thread_status != "masked" else 0.3
                    thread_rect = patches.Rectangle((thread_x, y_offset), thread_width, 
                                                    block_height, linewidth=0.5, 
                                                    edgecolor="gray", 
                                                    facecolor=colour,
                                                    alpha=alpha)
                    self.ax.add_patch(thread_rect)
                    
                    # label thread
                    if thread_status == "active":
                        label = str(global_idx)
                        fontsize = 6 if global_idx > 99 else 7
                        text_color = "black"
                    elif thread_status == "out_of_bounds":
                        label = f"!{global_idx}"
                        fontsize = 6
                        text_color = "darkred"
                    else:  # masked
                        label = "×"
                        fontsize = 8
                        text_color = "gray"
                    
                    self.ax.text(thread_x + thread_width/2, y_offset + block_height/2, 
                                 label, ha="center", va="center", 
                                 fontsize=fontsize, color=text_color,
                                 fontweight="bold" if thread_status == "out_of_bounds" else "normal")
        
        # add legend on the right side
        legend_x = self.num_blocks * (block_width + 0.2) + 0.5
        legend_y = y_offset + block_height * 0.5
        
        self.ax.text(legend_x, legend_y + 0.3, "Legend:", fontweight="bold", fontsize=11)
        
        # active thread
        active_rect = patches.Rectangle((legend_x, legend_y), 0.2, 0.2,
                                        facecolor="lightgreen", edgecolor="gray")
        self.ax.add_patch(active_rect)
        self.ax.text(legend_x + 0.3, legend_y + 0.1, "Active thread (valid access)", 
                     fontsize=10, va="center")
        
        # out of bounds thread
        oob_rect = patches.Rectangle((legend_x, legend_y - 0.3), 0.2, 0.2,
                                     facecolor="orange", edgecolor="gray")
        self.ax.add_patch(oob_rect)
        self.ax.text(legend_x + 0.3, legend_y - 0.2, "Out-of-bounds", 
                     fontsize=10, va="center")
        
        # masked thread
        masked_rect = patches.Rectangle((legend_x, legend_y - 0.6), 0.2, 0.2,
                                        facecolor="lightgray", edgecolor="gray", alpha=0.3)
        self.ax.add_patch(masked_rect)
        self.ax.text(legend_x + 0.3, legend_y - 0.5, "Masked thread (wasted)", 
                     fontsize=10, va="center")
        
        # add index calculation for highlighted thread
        if self.highlighted_thread is not None:
            block_id = self.highlighted_thread // self.threads_per_block
            thread_id = self.highlighted_thread % self.threads_per_block
            warp_id = thread_id // self.warp_size
            thread_in_warp = thread_id % self.warp_size
            
            calc_text = (f"Thread {self.highlighted_thread} calculation:\n"
                         f"blockIdx.x = {block_id}\n"
                         f"threadIdx.x = {thread_id}\n"
                         f"blockDim.x = {self.threads_per_block}\n"
                         f"Global index = blockIdx.x × blockDim.x + threadIdx.x\n"
                         f"Global index = {block_id} × {self.threads_per_block} + {thread_id} = {self.highlighted_thread}")
            
            # place calculation box on the left side, below blocks
            calc_box = patches.FancyBboxPatch((0, y_offset - 0.8), 
                                              block_width * 1.5, 0.65,
                                              boxstyle="round,pad=0.05",
                                              facecolor="wheat", alpha=0.9,
                                              edgecolor="brown", linewidth=1)
            self.ax.add_patch(calc_box)
            
            self.ax.text(block_width * 0.75, y_offset - 0.475,
                         calc_text, ha="center", va="center",
                         fontsize=9, family="monospace")
        
        # warnings setup
        warning_y_base = y_offset - 1.0
        total_threads = self.num_blocks * self.threads_per_block
        out_of_bounds_threads = total_threads - self.array_size if total_threads > self.array_size else 0
        
        # add warnings below the main visualization
        # add out-of-bounds warning if applicable
        if out_of_bounds_threads > 0:
            oob_warning = (f"Bounds Check Required:\n{out_of_bounds_threads} threads would access out-of-bounds.\n"
                           f"Your kernel must include: if (idx < {self.array_size}) {{ ... }}")
            
            # place warning below the visualization
            warning_box = patches.FancyBboxPatch((0, warning_y_base - 0.4), 
                                                 self.num_blocks * (block_width + 0.2) - 0.2, 0.35,
                                                 boxstyle="round,pad=0.1",
                                                 facecolor="orange", alpha=0.8,
                                                 edgecolor="darkorange", linewidth=2)
            self.ax.add_patch(warning_box)
            
            self.ax.text(self.num_blocks * (block_width + 0.2) / 2, warning_y_base - 0.225,
                         oob_warning, ha="center", va="center",
                         fontsize=11, fontweight="bold")
            
            warning_y_base -= 0.6
        
        # add warp efficiency warning if applicable
        if self.threads_per_block % self.warp_size != 0:
            warps_used = (self.threads_per_block + self.warp_size - 1) // self.warp_size
            threads_wasted = (warps_used * self.warp_size) - self.threads_per_block
            efficiency = (self.threads_per_block / (warps_used * self.warp_size)) * 100
            
            warning_text = (f"Warning: Block size ({self.threads_per_block}) is not a multiple of warp size (32)\n"
                            f"This wastes {threads_wasted} threads per block (efficiency: {efficiency:.1f}%)")
            
            warning_box = patches.FancyBboxPatch((0, warning_y_base - 0.4), 
                                                 self.num_blocks * (block_width + 0.2) - 0.2, 0.35,
                                                 boxstyle="round,pad=0.1",
                                                 facecolor="yellow", alpha=0.8,
                                                 edgecolor="goldenrod", linewidth=2)
            self.ax.add_patch(warning_box)
            
            self.ax.text(self.num_blocks * (block_width + 0.2) / 2, warning_y_base - 0.225,
                         warning_text, ha="center", va="center",
                         fontsize=11)
        
        # set axis properties to accommodate legend and warnings
        has_warnings = (out_of_bounds_threads > 0 or self.threads_per_block % self.warp_size != 0)
        x_max = legend_x + 2.5  # space for legend on right
        y_min = warning_y_base - 0.5 if has_warnings else y_offset - 1.0
        
        self.ax.set_xlim(-0.3, x_max)
        self.ax.set_ylim(y_min, y_offset + block_height + 0.8)
        self.ax.set_aspect("equal")
        self.ax.axis("off")
        
        # add title with parameters
        title = (f"Array size: {self.array_size}, "
                 f"Threads per block: {self.threads_per_block}, "
                 f"Number of blocks: {self.num_blocks}")
        self.ax.set_title(title, fontsize=14, pad=30)
        
        plt.draw()
    
    def on_click(self, event):
        if event.inaxes != self.ax:
            return
        
        # calculate which thread was clicked
        block_width = 2.5
        block_spacing = 0.2
        warp_width = block_width / max(1, (self.threads_per_block + self.warp_size - 1) // self.warp_size)
        thread_width = warp_width / self.warp_size
        
        for block_idx in range(self.num_blocks):
            block_x = block_idx * (block_width + block_spacing)
            
            if block_x <= event.xdata <= block_x + block_width:
                # clicked within this block
                relative_x = event.xdata - block_x
                
                # find which warp
                warp_idx = int(relative_x / warp_width)
                warp_relative_x = relative_x - (warp_idx * warp_width)
                
                # find thread within warp
                thread_in_warp = int(warp_relative_x / thread_width)
                
                # calculate thread index within block
                thread_idx = warp_idx * self.warp_size + thread_in_warp
                
                # calculate global index
                global_idx = block_idx * self.threads_per_block + thread_idx
                
                # only highlight if thread is active or within bounds
                if thread_idx < self.threads_per_block: #and global_idx < self.array_size:
                    self.highlighted_thread = global_idx
                    self.draw_visualisation()
                break

vis1d = ThreadVisualiser1D()