What is left to make it work: 
* 3D scan hat meshes, put in /assets/hat.obj
* Get drone images for /data/drone_footage (with annotations.json)
* Get printer specs from FabLab for config
* Calibrate drone scale in annotations.json metadata

# CloakHat Patch Generation Pipeline

## 1: Conda Setup

* Conda is the main env manager, pip is for Python packages
* PyTorch is the main AI/ML library
* NVIDIA CUDA is for GPU acceleration
* PyTorch3D is for rendering the hat
* ipykernel allows JupyterLab to use the Conda env
* Ultralytics has YOLO models
* opencv-python-headless is for image processing
* matplotlib is for plots
* tqdm is for progress bars
* NumPy is for data manipulation

### Option 1:

Activate <br>
`conda env create -f environment.yaml` <br>
`conda activate cloakhat` <br>

Apply the kernel <br>
`python -m ipykernel install --user --name cloakhat --display-name "Python (cloakhat)"`

Deactivate <br>
`conda deactivate` <br>
`conda env remove -n cloakhat` <br>

### Option 2: 

Set up the environment

`conda create -n cloakhat python=3.10 -y` <br>
`conda activate cloakhat`

PyTorch with CUDA. Also ipykernel. <br>
`conda install pytorch ipykernel pytorch-cuda=11.8 -c pytorch -c nvidia -y`<br>

PyTorch3D for differentiable rendering <br>
`conda install -c pytorch3d pytorch3d -y`

Detection models <br>
`pip install ultralytics`

Other stuff <br>
`pip install opencv-python-headless matplotlib tqdm numpy`

Apply the kernel <br>
`python -m ipykernel install --user --name cloakhat --display-name "Python (cloakhat)"`

Deactivate <br>
`conda deactivate` <br>
`conda env remove -n cloakhat` <br>

### Option 3: 

Run the bash <br>
`bash LabSetup.sh`

Activate the kernel <br>
`conda activate cloakhat` <br>

Deactivate <br>
`conda deactivate` <br>
`conda env remove -n cloakhat` <br>

## 2: Python Setup

Get the libraries we need

In [1]:
#Deep learning stuff
import torch
import torch.nn as nn
import torch.nn.functional as F

#Data manipulation
import numpy as np

#Image processing
import cv2

#Plotting
import matplotlib.pyplot as plt

#Working with the file system
from pathlib import Path

#Progress bars
from tqdm import tqdm

#Better logging than print statements
import logging

#JSON utilities
import json

#PyTorch3d utilities
from pytorch3d.io import load_obj
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (look_at_view_transform, FoVPerspectiveCameras, RasterizationSettings, MeshRasterizer, SoftPhongShader, TexturesUV, PointLights)

#Gets YOLO models
from ultralytics import YOLO

#Logging with timestamps
logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(message)s')
logger = logging.getLogger(__name__)

#Check what device is being used (especially if we want GPU) and log it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Device: {device}")

Creating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/home/jovyan/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.


2026-02-05 18:26:34,989 | Device: cuda


## 3: Config

Control variables

In [None]:
CONFIG = {
    'dataset_dir': './data/drone_footage', #Drone footage/image samples
    'mesh_path': './assets/hat.obj', #Hat mesh
    'output_dir': './outputs', #Where out outputs will go (textures, evaluations, stuff like that)
    
    #Generator
    'latent_channels': 128, #Channels
    'latent_size': 9, #Spatial size of latent input
    #(so the latent is 128x9x9)
    'texture_size': 288, #Output texture size from generator
    #(so the latent becomes a texture that is 3x288x288)
    
    #Viewpoint sampling
    'scale_jitter': 0.1, #Fraction of scale variation
    'camera_pitch_jitter': 5.0, #Alias for elevation jitter (degrees)
    'heading_jitter': 10.0, #Alias for azimuth jitter (degrees)

    
    'num_workers': 8, #DataLoader
    'det_conf_floor': 0.001, #Minimum confidence for detector loss
    
    #Training Stage 1
    'stage1_epochs': 100, #100 epochs
    'stage1_batch_size': 8, #8 batch minibatch gradient descent
    'stage1_lr': 2e-4, #learning rate
    
    #Training Stage 2  
    'stage2_iterations': 2000, #Now we optimize the single tensor
    'stage2_lr': 0.01, #Bigger learning rate
    'local_latent_size': 18, #Size of optimizable latent pattern. Bigger than 9x9 (input), so tile seamlessly
    
    #Loss weights
    'lambda_tv': 2.5, #Total variation - makes the textures smoother/less noisy/able to be printed
    'lambda_nps': 0.01, #Non-printability score - penalize colors that can't print well
    'lambda_info': 0.1, #Mutual information (Stage 1 only) - ensures latent is correlated to the texture
    
    #T-SEA Stuff
    'cutout_prob': 0.9, #90% of the time, cut off 40% of the hat
    'cutout_ratio': 0.4,
    'shakedrop_prob': 0.5, #50% of the time, mess with the 
    
    #Rendering
    'render_size': 256, #Output 256x256 images
    
    #Printer specifications (GET FROM FABLAB)
    'printer': {
        'dpi': 300,
        'physical_size_inches': (8, 8),
        'max_saturation': 0.85,
        'max_brightness': 0.95,
        'min_brightness': 0.08,
        'nps_threshold': 0.7, #Saturation * brightness threshold (penalize when saturation × brightness > 0.7)
        'gamut_samples_path': './assets/printer_gamut.npy',
    },
    
    #Attack config (white, gray, black)
    'attack_mode': 'gray',
}

#Calculate output size
CONFIG['texture_output_size'] = int(CONFIG['printer']['dpi'] * CONFIG['printer']['physical_size_inches'][0])

#Make sure the folder exists
Path(CONFIG['output_dir']).mkdir(parents=True, exist_ok=True)

## 4: Dataset Preparation

Prepare the dataset

In [None]:
"""
dataset_dir/
        frames/
            frame_0001.png
            frame_0002.png
            ...
        masks/
            frame_0001_mask.png  (binary mask of green hat region)
            ...
        annotations.json  (person bounding boxes, metadata)

Annotations.json:
{
    "frames": [
        {
            "frame_id": "frame_0001",
            "image_path": "frames/frame_0001.png",
            "mask_path": "masks/frame_0001_mask.png",
            "person_bbox": [x1, y1, x2, y2],
            "viewpoint": {
                "elevation": 82.5,
                "azimuth": 45.0,
                "altitude_meters": 15.0
            }
        },
        ...
    ],
    "metadata": {
        "altitude_to_scale": {
            "min_altitude": 5.0,
            "max_altitude": 50.0,
            "min_scale": 0.3,
            "max_scale": 1.2
        }
    }
}
"""

class DroneDataset(torch.utils.data.Dataset):
    
    def __init__(self, dataset_dir, transform=None):
        self.dataset_dir = Path(dataset_dir)
        self.transform = transform
        
        if not self.dataset_dir.exists():
            logger.warning(f"Dataset directory not found: {dataset_dir}")
            logger.warning("Using placeholder data for testing.")
            self.use_placeholder = True
            self.length = 100
            return
        
        # Load annotations
        annotations_path = self.dataset_dir / 'annotations.json'
        if not annotations_path.exists():
            raise FileNotFoundError(f"Missing annotations.json in {dataset_dir}")
        
        with open(annotations_path) as f:
            self.annotations = json.load(f)
        
        self.frames = self.annotations['frames']
        self.length = len(self.frames)
        self.use_placeholder = False
        
        # Altitude to scale conversion params
        meta = self.annotations['metadata']['altitude_to_scale']
        self.alt_min = meta['min_altitude']
        self.alt_max = meta['max_altitude']
        self.scale_min = meta['min_scale']
        self.scale_max = meta['max_scale']
        
        logger.info(f"Loaded {self.length} frames from {dataset_dir}")
    
    def _altitude_to_scale(self, altitude_meters):
        """Higher altitude = smaller hat (farther away)."""
        t = (altitude_meters - self.alt_min) / (self.alt_max - self.alt_min)
        t = np.clip(t, 0, 1)
        scale = self.scale_max - t * (self.scale_max - self.scale_min)
        return scale
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        if self.use_placeholder:
            #Placeholder with fake viewpoint data
            hat_half_size = 50
            person_pad_y, person_pad_x, person_pad_bottom = 150, 100, 200
            
            cy = np.random.randint(person_pad_y + hat_half_size, 1080 - person_pad_bottom - hat_half_size)
            cx = np.random.randint(person_pad_x + hat_half_size, 1920 - person_pad_x - hat_half_size)
            
            image = torch.rand(3, 1080, 1920)
            hat_mask = torch.zeros(1, 1080, 1920)
            hat_mask[:, cy-hat_half_size:cy+hat_half_size, cx-hat_half_size:cx+hat_half_size] = 1.0
            
            person_bbox = torch.tensor([
                cx - person_pad_x, cy - person_pad_y,
                cx + person_pad_x, cy + person_pad_bottom
            ], dtype=torch.float32)
            
            return {
                'image': image,
                'hat_mask': hat_mask,
                'person_bbox': person_bbox,
                'elevation': torch.tensor(np.random.uniform(60, 90), dtype=torch.float32),
                'azimuth': torch.tensor(np.random.uniform(0, 360), dtype=torch.float32),
                'scale': torch.tensor(np.random.uniform(0.3, 1.2), dtype=torch.float32),
            }
        
        # Real data
        frame = self.frames[idx]
        
        # Load image
        image_path = self.dataset_dir / frame['image_path']
        image = cv2.imread(str(image_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
        
        # Load mask
        mask_path = self.dataset_dir / frame['mask_path']
        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        mask = torch.from_numpy(mask).unsqueeze(0).float() / 255.0
        
        # Bounding box
        person_bbox = torch.tensor(frame['person_bbox'], dtype=torch.float32)
        
        # Viewpoint
        drone = frame['drone']
        elevation = drone['camera_pitch']
        azimuth = drone.get('heading', np.random.uniform(0, 360))  # Randomize if not provided
        scale = self._altitude_to_scale(drone['altitude_meters'])
        
        return {
            'image': image,
            'hat_mask': mask,
            'person_bbox': person_bbox,
            'elevation': torch.tensor(elevation, dtype=torch.float32),
            'azimuth': torch.tensor(azimuth, dtype=torch.float32),
            'scale': torch.tensor(scale, dtype=torch.float32),
        }


def segment_green_hat(frame):
    hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
    
    #Green range in HSV
    lower_green = np.array([35, 100, 100])
    upper_green = np.array([85, 255, 255])
    
    mask = cv2.inRange(hsv, lower_green, upper_green)
    
    #Clean up mask
    kernel = np.ones((5, 5), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
    
    return mask


#Create dataset (will use placeholder if data doesn't exist)
dataset = DroneDataset(CONFIG['dataset_dir'])
logger.info(f"Dataset size: {len(dataset)}")

## 5: FCN Generator

Make the texture (turn noise into an image)

In [None]:
class FCNGenerator(nn.Module):
    
    def __init__(self, latent_channels=128):
        super().__init__()
        
        self.net = nn.Sequential(
            #9 -> 9
            nn.Conv2d(latent_channels, 512, 3, 1, 1, padding_mode='zeros'),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),     
            #9 -> 18
            nn.ConvTranspose2d(512, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            #18 -> 36
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # 36 -> 72
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # 72 -> 144
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            # 144 -> 288
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            # 288 -> 288 (to RGB)
            nn.Conv2d(32, 3, 3, 1, 1, padding_mode='zeros'),
            nn.Tanh()
        )
        self.output_size = 288
        self._init_weights()
        
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.kaiming_normal_(m.weight, a=0.2, nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
    
    def forward(self, z):
        return self.net(z)
    
    def generate(self, z=None, batch_size=1):
        if z is None:
            z = torch.randn(batch_size, 128, 9, 9, device=next(self.parameters()).device)
        return (self.forward(z) + 1) / 2

#Test
generator = FCNGenerator().to(device)
test_texture = generator.generate(batch_size=1)
assert test_texture.shape[-1] == CONFIG['texture_size'], f"Generator outputs {test_texture.shape[-1]}px but config expects {CONFIG['texture_size']}px"
logger.info(f"Generator output: {test_texture.shape}")  #Should be (1, 3, 288, 288)

## 6: Auxiliary Network

Forces the texture to derive from the latent

In [None]:
class AuxiliaryNetwork(nn.Module):
    def __init__(self, latent_channels=128):
        super().__init__()
        
        #Texture encoder
        self.tex_enc = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1), nn.Flatten()
        )
        
        #Latent encoder
        self.lat_enc = nn.Sequential(
            nn.Conv2d(latent_channels, 256, 3, 1, 1), nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1), nn.Flatten()
        )
        
        #Joint network
        self.joint = nn.Sequential(
            nn.Linear(512, 256), nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )
        
    def forward(self, texture, z):
        tex_feat = self.tex_enc(texture)
        lat_feat = self.lat_enc(z)
        return self.joint(torch.cat([tex_feat, lat_feat], dim=1))


def compute_mi_loss(aux_net, texture, z):

    #Matched pairs
    T_joint = aux_net(texture, z)
    pos_term = -F.softplus(-T_joint).mean()
    
    #Mismatched pairs (shuffle z)
    z_shuffle = z[torch.randperm(z.size(0))]
    T_marginal = aux_net(texture, z_shuffle)
    neg_term = F.softplus(T_marginal).mean()
    
    mi = pos_term - neg_term
    return -mi  #Negate because we minimize loss but want to maximize MI

## 7: Render Hat

Render the hat using the texture and capture angle

In [None]:
class HatRenderer:
    def __init__(self, mesh_path, render_size=256, device='cuda'):
        self.device = device
        self.render_size = render_size
        
        #Load mesh
        self.mesh_loaded = False
        if Path(mesh_path).exists():
            verts, faces, aux = load_obj(mesh_path, device=device)
            self.verts = verts
            self.faces = faces.verts_idx
            self.verts_uvs = aux.verts_uvs
            self.faces_uvs = faces.textures_idx
            self.mesh_loaded = True
            logger.info(f"Loaded mesh: {len(verts)} verts, {len(self.faces)} faces")
        else:
            logger.warning(f"Mesh not found at {mesh_path}. Using placeholder.")
            self._create_placeholder_mesh()
            
        #Rasterization settings
        self.raster_settings = RasterizationSettings(
            image_size=render_size, 
            blur_radius=0.0, 
            faces_per_pixel=1
        )
        
        #Create rasterizer once
        self.rasterizer = MeshRasterizer(raster_settings=self.raster_settings)
    
    def _create_placeholder_mesh(self):
        #Simple disk
        n_points = 32
        angles = torch.linspace(0, 2*np.pi, n_points+1)[:-1]
        #Vertices: center + rim
        verts = [[0, 0, 0]]  # center
        for a in angles:
            verts.append([torch.cos(a).item(), torch.sin(a).item(), 0])
        self.verts = torch.tensor(verts, dtype=torch.float32, device=self.device)
        #Faces: triangles from center to rim
        faces = []
        for i in range(n_points):
            faces.append([0, i+1, (i % n_points) + 2 if i < n_points-1 else 1])
        self.faces = torch.tensor(faces, dtype=torch.int64, device=self.device)
        #UVs: simple radial mapping
        uvs = [[0.5, 0.5]]  # center
        for a in angles:
            uvs.append([0.5 + 0.5*torch.cos(a).item(), 0.5 + 0.5*torch.sin(a).item()])
        self.verts_uvs = torch.tensor(uvs, dtype=torch.float32, device=self.device)
        self.faces_uvs = self.faces.clone()
        self.mesh_loaded = True
    
    def render(self, texture, elevation=90, azimuth=0, scale=1.0):
        batch_size = texture.shape[0]
        
        #Scale vertices
        verts = self.verts * scale
        
        #Camera setup
        dist = 2.5  #Camera distance
        R, T = look_at_view_transform(dist=dist, elev=elevation, azim=azimuth, device=self.device)
        cameras = FoVPerspectiveCameras(R=R, T=T, device=self.device)
        
        #Lighting (varying lighting)
        light_x = np.random.uniform(-1, 1)
        light_y = np.random.uniform(1, 3)  #Always somewhat above
        light_z = np.random.uniform(-1, 1)
        lights = PointLights(
            device=self.device, 
            location=[[light_x, light_y, light_z]],
            ambient_color=[[0.5, 0.5, 0.5]],
            diffuse_color=[[0.3, 0.3, 0.3]],
            specular_color=[[0.2, 0.2, 0.2]]
        )
        
        #Create shader once per render call (lighting changes)
        shader = SoftPhongShader(device=self.device, cameras=cameras, lights=lights)
        
        #Create batched texture
        tex_maps = texture.permute(0, 2, 3, 1)  # (B, H, W, 3)
        textures = TexturesUV(
            maps=tex_maps,
            faces_uvs=[self.faces_uvs] * batch_size,
            verts_uvs=[self.verts_uvs] * batch_size
        )
        
        #Create batched mesh
        meshes = Meshes(
            verts=[verts] * batch_size,
            faces=[self.faces] * batch_size,
            textures=textures
        )
        
        #Render entire batch at once
        fragments = self.rasterizer(meshes, cameras=cameras)
        images = shader(fragments, meshes, cameras=cameras, lights=lights)
        
        rendered_images = images[..., :3].permute(0, 3, 1, 2)
        alpha_masks = images[..., 3:4].permute(0, 3, 1, 2)
            
        return rendered_images, alpha_masks

#Test renderer
renderer = HatRenderer(CONFIG['mesh_path'], CONFIG['render_size'], device)
test_render, test_alpha = renderer.render(test_texture, elevation=85, azimuth=45)
logger.info(f"Rendered shape: {test_render.shape}, alpha shape: {test_alpha.shape}")

#Visualize
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(test_texture[0].permute(1,2,0).detach().cpu())
axes[0].set_title('Texture')
axes[1].imshow(test_render[0].permute(1,2,0).detach().cpu())
axes[1].set_title('Rendered Hat')
axes[2].imshow(test_alpha[0, 0].detach().cpu(), cmap='gray')
axes[2].set_title('Alpha Mask')
plt.tight_layout()
plt.show()

## 8: T-SEA Augmentations

Helper methods for black/gray/white box transfers

In [None]:
#Randomly mask a region of the rendered hat. Prevents overfitting to specific texture patterns.
def patch_cutout(rendered_hat, alpha_mask, prob=0.9, ratio=0.4, fill=0.5):
    if np.random.random() > prob:
        return rendered_hat
    B, C, H, W = rendered_hat.shape
    #Random cutout size
    cut_h = int(H * ratio)
    cut_w = int(W * ratio)
    #Random position
    top = np.random.randint(0, H - cut_h + 1)
    left = np.random.randint(0, W - cut_w + 1)
    #Apply cutout (only where alpha > 0)
    mask = alpha_mask.clone()
    mask[:, :, top:top+cut_h, left:left+cut_w] = 0
    rendered_hat = rendered_hat * mask + fill * (1 - mask) * (alpha_mask > 0).float()
    return rendered_hat

#Mild augmentations that don't distort the image too much.
def constrained_augmentation(image):
    B, C, H, W = image.shape
    #Random scale (0.9 - 1.1)
    scale = np.random.uniform(0.9, 1.1)
    new_size = int(H * scale)
    image = F.interpolate(image, size=new_size, mode='bilinear', align_corners=False)
    #Crop/pad back to original size
    if new_size > H:
        start = (new_size - H) // 2
        image = image[:, :, start:start+H, start:start+W]
    else:
        pad = (H - new_size) // 2
        image = F.pad(image, [pad, pad, pad, pad], mode='reflect')
        image = image[:, :, :H, :W]
    #Color jitter (mild)
    brightness = np.random.uniform(0.9, 1.1)
    image = image * brightness
    #Random horizontal flip
    if np.random.random() > 0.5:
        image = torch.flip(image, dims=[3])
    return image.clamp(0, 1)

#ShakeDrop reates virtual ensemble of model variants
def shakedrop_forward(model, x, drop_prob=0.5, alpha_range=(0, 2)):
    #I will make a simplified version: add noise to intermediate features
    if np.random.random() < drop_prob:
        alpha = np.random.uniform(*alpha_range)
        noise = torch.randn_like(x) * 0.1 * alpha
        x = x + noise
    return x

## 9: URAdv Augmentations

For better performance under drone conditions

In [None]:
#Add simulated light reflections on the hat surface.
def add_light_spots(image, alpha_mask, num_range=(0, 3), intensity_range=(0.1, 0.4)):
    if np.random.random() > 0.5:
        return image
    B, C, H, W = image.shape
    num_spots = np.random.randint(*num_range)
    for _ in range(num_spots):
        #Random spot position (within hat region)
        cy = np.random.randint(H // 4, 3 * H // 4)
        cx = np.random.randint(W // 4, 3 * W // 4)
        #Spot parameters
        radius = np.random.uniform(0.05, 0.15) * min(H, W)
        intensity = np.random.uniform(*intensity_range)
        #Create Gaussian spot
        y, x = torch.meshgrid(torch.arange(H, device=image.device), torch.arange(W, device=image.device), indexing='ij')
        dist = ((x - cx) ** 2 + (y - cy) ** 2).float()
        spot = torch.exp(-dist / (2 * radius ** 2)) * intensity
        #Apply only within hat (where alpha > 0)
        spot = spot.unsqueeze(0).unsqueeze(0) * (alpha_mask > 0).float()
        image = image + spot
    return image.clamp(0, 1)

#Add simulated shadows on the hat surface.
def add_shadows(image, alpha_mask, num_range=(0, 2), opacity_range=(0.2, 0.5)):
    if np.random.random() > 0.5:
        return image
    B, C, H, W = image.shape
    num_shadows = np.random.randint(*num_range)
    for _ in range(num_shadows):
        #Random shadow as diagonal stripe
        angle = np.random.uniform(0, np.pi)
        opacity = np.random.uniform(*opacity_range)
        width = np.random.uniform(0.1, 0.3) * min(H, W)
        #Create shadow mask
        y, x = torch.meshgrid(torch.arange(H, device=image.device), torch.arange(W, device=image.device), indexing='ij')
        offset = np.random.uniform(0, H)
        dist = torch.abs(x * np.cos(angle) + y * np.sin(angle) - offset)
        shadow = (dist < width).float() * opacity
        #Apply only within hat
        shadow = shadow.unsqueeze(0).unsqueeze(0) * (alpha_mask > 0).float()
        image = image * (1 - shadow)
    return image.clamp(0, 1)

#Simulate printer color/brightness variation.
def simulate_printing(texture, mul_std=0.1, add_std=0.05):
    #Multiplicative noise
    mul_noise = torch.randn_like(texture) * mul_std + 1.0
    texture = texture * mul_noise
    #Additive noise
    add_noise = torch.randn_like(texture) * add_std
    texture = texture + add_noise
    return texture.clamp(0, 1

class PrinterGamut:
    
    def __init__(self, config):
        self.config = config['printer']
        
        gamut_path = self.config.get('gamut_samples_path')
        if gamut_path and Path(gamut_path).exists():
            self.gamut_samples = torch.from_numpy(np.load(gamut_path)).float().to(device)
            self.use_measured_gamut = True
            logger.info(f"Loaded {len(self.gamut_samples)} gamut samples")
        else:
            self.use_measured_gamut = False
            logger.info("Using simplified gamut constraints")
    
    def nps_loss(self, texture):
        max_ch = texture.max(dim=1)[0]
        min_ch = texture.min(dim=1)[0]
        saturation = (max_ch - min_ch) / (max_ch + 1e-8)
        brightness = max_ch
        
        loss = 0.0
        
        #Saturation * brightness threshold
        loss = loss + F.relu(saturation * brightness - self.config['nps_threshold']).mean()
        
        #Saturation cap
        loss = loss + F.relu(saturation - self.config['max_saturation']).mean()
        
        #Brightness bounds
        loss = loss + F.relu(brightness - self.config['max_brightness']).mean()
        loss = loss + F.relu(self.config['min_brightness'] - brightness).mean()
        
        return loss
    
    def clamp_to_gamut(self, texture):
        return texture.clamp(self.config['min_brightness'], self.config['max_brightness'])


#Initialize globally
printer_gamut = PrinterGamut(CONFIG)

#Apply camera artifacts: blur, noise.
def apply_environmental_augmentation(image, prob=0.3):
    #Motion blur
    if np.random.random() < prob:
        kernel_size = np.random.choice([3, 5, 7])
        kernel = torch.zeros(kernel_size, kernel_size, device=image.device)
        kernel[kernel_size//2, :] = 1.0 / kernel_size
        #Random rotation of kernel
        angle = np.random.uniform(0, 360)
        # Simplified: just apply horizontal blur
        image = F.conv2d(image, kernel.view(1, 1, kernel_size, kernel_size).expand(3, 1, -1, -1), padding=kernel_size//2, groups=3)
    #Gaussian noise
    if np.random.random() < prob:
        noise_std = np.random.uniform(0.01, 0.05)
        image = image + torch.randn_like(image) * noise_std
    return image.clamp(0, 1)

def apply_viewpoint_jitter(elevation, azimuth, scale, config):
    elev = elevation + np.random.uniform(-config['camera_pitch_jitter'], config['camera_pitch_jitter'])
    elev = np.clip(elev, 0, 90)
    
    azim = azimuth + np.random.uniform(-config['heading_jitter'], config['heading_jitter'])
    azim = azim % 360
    
    scl = scale * (1 + np.random.uniform(-config['scale_jitter'], config['scale_jitter']))
    scl = np.clip(scl, 0.1, 2.0)
    
    return elev, azim, scl

## 10: Toroidal Cropping

Wrapping the texture

In [None]:
class ToroidalLatent(nn.Module):
    
    def __init__(self, local_size, crop_size=9, latent_channels=128, device='cuda'):
        super().__init__()
        self.local_size = local_size
        self.crop_size = crop_size
        self.latent_channels = latent_channels
        
        #Initialize local latent pattern as registered parameter
        self.z_local = nn.Parameter(
            torch.randn(1, latent_channels, local_size, local_size) * 0.1
        )
        
        #Move to device
        self.to(device)
        
    def get_random_crops(self, batch_size):
        #Tile 3x3 for wraparound
        z_tiled = self.z_local.repeat(1, 1, 3, 3)
        
        crops = []
        for _ in range(batch_size):
            #Random offset within middle tile (to enable wraparound)
            i = np.random.randint(self.local_size, 2 * self.local_size)
            j = np.random.randint(self.local_size, 2 * self.local_size)
            crop = z_tiled[:, :, i:i+self.crop_size, j:j+self.crop_size]
            crops.append(crop)
            
        return torch.cat(crops, dim=0)
    
    def get_full_latent(self, target_spatial_size):
        reps = (target_spatial_size + self.local_size - 1) // self.local_size + 1
        z_tiled = self.z_local.repeat(1, 1, reps, reps)
        return z_tiled[:, :, :target_spatial_size, :target_spatial_size]

## 11: Sceen Composition

Render the sceen

In [None]:
def composite_hat_on_scene(scene_image, hat_mask, rendered_hat, alpha_mask):

    B, C, H, W = scene_image.shape
    
    #For each image in batch, place hat at mask location
    composited = scene_image.clone()
    
    for i in range(B):
        #Find bounding box of hat mask
        mask = hat_mask[i, 0]
        if mask.sum() == 0:
            continue
            
        ys, xs = torch.where(mask > 0.5)
        y1, y2 = ys.min().item(), ys.max().item()
        x1, x2 = xs.min().item(), xs.max().item()
        
        hat_h = y2 - y1
        hat_w = x2 - x1
        
        #Resize rendered hat to fit
        hat_resized = F.interpolate(rendered_hat[i:i+1], size=(hat_h, hat_w), mode='bilinear', align_corners=False)
        alpha_resized = F.interpolate(alpha_mask[i:i+1], size=(hat_h, hat_w), mode='bilinear', align_corners=False)
        
        #Composite
        region = composited[i:i+1, :, y1:y2, x1:x2]
        composited[i:i+1, :, y1:y2, x1:x2] = (hat_resized * alpha_resized + region * (1 - alpha_resized))
        
    return composited

## 12: Ensamble

Ensamble detection

In [None]:
class DetectorEnsemble:
    
    def __init__(self, attack_mode='gray', device='cuda', conf_floor=0.001):
        self.device = device
        self.conf_floor = conf_floor
        self.models = {}
        self.weights = {}
        
        if attack_mode == 'white':
            self.models['yolov8m'] = YOLO('yolov8m.pt')
            self.weights['yolov8m'] = 1.0
            
        elif attack_mode == 'gray':
            model_configs = [
                ('yolov8s', 0.20),
                ('yolov8m', 0.25),
                ('yolov8l', 0.20),
                ('yolov5m', 0.20),
                ('yolov5l', 0.15),
            ]
            for name, weight in model_configs:
                try:
                    self.models[name] = YOLO(f'{name}.pt')
                    self.weights[name] = weight
                    logger.info(f"Loaded {name}")
                except Exception as e:
                    logger.warning(f"Failed to load {name}: {e}")
                    
        elif attack_mode == 'black':
            #Add More
            model_configs = [
                ('yolov8m', 0.30),
                ('yolov8l', 0.25),
                ('yolov5l', 0.25),
                ('yolov5m', 0.20),
            ]
            for name, weight in model_configs:
                try:
                    self.models[name] = YOLO(f'{name}.pt')
                    self.weights[name] = weight
                except Exception as e:
                    logger.warning(f"Failed to load {name}: {e}")
                    
        #Normalize weights
        total = sum(self.weights.values())
        self.weights = {k: v/total for k, v in self.weights.items()}
        
        logger.info(f"Detector ensemble ({attack_mode}): {list(self.weights.keys())}")
        
    def compute_loss(self, images, return_detections=False):
        total_loss = 0.0
        all_detections = [] if return_detections else None
        
        #Convert to uint8 numpy for YOLO
        images_np = (images * 255).byte().permute(0, 2, 3, 1).cpu().numpy()
        
        for name, model in self.models.items():
            weight = self.weights[name]
            
            #Run detection with low confidence to get gradients
            results = model.predict(images_np, conf=self.conf_floor, classes=[0], verbose=False)
            
            #Collect person detection confidences
            batch_confs = []
            for r in results:
                if len(r.boxes) > 0:
                    confs = r.boxes.conf.to(self.device)
                    batch_confs.append(confs)
                    
            if batch_confs:
                #Loss = mean of top-k confidences per image
                all_confs = torch.cat(batch_confs)
                k = min(10, len(all_confs))
                top_confs, _ = torch.topk(all_confs, k)
                loss = top_confs.mean()
                total_loss = total_loss + weight * loss
                
            if return_detections:
                all_detections.append({
                    'model': name,
                    'results': results
                })
                
        if return_detections:
            return total_loss, all_detections
        return total_loss
    
    def detect(self, images, conf_threshold=0.5):
        images_np = (images * 255).byte().permute(0, 2, 3, 1).cpu().numpy()
        
        all_results = {}
        for name, model in self.models.items():
            results = model.predict(images_np, conf=self.conf_floor, classes=[0], verbose=False)
            all_results[name] = []
            for r in results:
                all_results[name].append({
                    'boxes': r.boxes.xyxy.cpu().numpy() if len(r.boxes) > 0 else np.array([]),
                    'scores': r.boxes.conf.cpu().numpy() if len(r.boxes) > 0 else np.array([]),
                })
        return all_results


#Initialize detector ensemble
detector = DetectorEnsemble(CONFIG['attack_mode'], device, conf_floor=CONFIG['det_conf_floor'])

## 13: Loss Calculation

Custom Loss Equations

In [None]:
#Try to get smooth textures
def total_variation_loss(texture):
    diff_h = texture[:, :, 1:, :] - texture[:, :, :-1, :]
    diff_w = texture[:, :, :, 1:] - texture[:, :, :, :-1]
    return (diff_h.pow(2).mean() + diff_w.pow(2).mean()) / 2

#Try to get printable colors
def nps_loss(texture, threshold=0.7):
    #Compute saturation and brightness
    max_ch = texture.max(dim=1)[0]
    min_ch = texture.min(dim=1)[0]
    saturation = (max_ch - min_ch) / (max_ch + 1e-8)
    brightness = max_ch
    
    #Penalize when saturation * brightness > threshold
    penalty = F.relu(saturation * brightness - threshold)
    return penalty.mean()

#Everything together
def compute_total_loss(texture, detector, config, stage='stage2'):
    loss_det = detector.compute_loss(texture)
    loss_tv = total_variation_loss(texture)
    loss_nps = printer_gamut.nps_loss(texture)  # Use new class
    
    total = loss_det + config['lambda_tv'] * loss_tv + config['lambda_nps'] * loss_nps
    
    return total, {
        'total': total.item() if isinstance(total, torch.Tensor) else total,
        'detection': loss_det.item() if isinstance(loss_det, torch.Tensor) else loss_det,
        'tv': loss_tv.item(),
        'nps': loss_nps.item(),
    }

## 14: Stage 1: Generator Training

Train the Generator

In [None]:
def train_stage1(generator, aux_net, detector, dataset, config):

    generator.train()
    aux_net.train()
    
    #Optimizers
    opt_g = torch.optim.Adam(generator.parameters(), lr=config['stage1_lr'], betas=(0.5, 0.999))
    opt_aux = torch.optim.Adam(aux_net.parameters(), lr=config['stage1_lr'], betas=(0.5, 0.999))
    
    #DataLoader
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=config['stage1_batch_size'], shuffle=True,
        num_workers=config['num_workers'], pin_memory=True, persistent_workers=True
    )
    
    #Training loop
    logger.info("Starting Stage 1 training...")
    
    for epoch in range(config['stage1_epochs']):
        epoch_losses = {'total': 0, 'detection': 0, 'tv': 0, 'nps': 0, 'mi': 0}
        
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config['stage1_epochs']}")
        for batch_idx, batch in enumerate(pbar):
            #Sample latent
            z = torch.randn(config['stage1_batch_size'], 128, 9, 9, device=device)
            
            #Generate texture
            texture = generator.generate(z)  # [0, 1]
            
            #Sample viewpoint and render
            rendered_hats = []
            alphas = []
            for i in range(texture.shape[0]):
                elev, azim, scl = apply_viewpoint_jitter(
                    batch['elevation'][i].item(),
                    batch['azimuth'][i].item(),
                    batch['scale'][i].item(),
                    CONFIG
                )
                rh, al = renderer.render(texture[i:i+1], elev, azim, scl)
                rendered_hats.append(rh)
                alphas.append(al)
            
            rendered_hat = torch.cat(rendered_hats, dim=0)
            alpha = torch.cat(alphas, dim=0)
            
            #Apply T-SEA augmentations
            rendered_hat = patch_cutout(rendered_hat, alpha, config['cutout_prob'], config['cutout_ratio'])
            
            #Apply URAdv augmentations
            rendered_hat = add_light_spots(rendered_hat, alpha)
            rendered_hat = add_shadows(rendered_hat, alpha)
            rendered_hat = simulate_printing(rendered_hat)
            
            #Composite onto scene (using placeholder or real data)
            if not dataset.use_placeholder:
                scene = batch['image'].to(device)
                hat_mask = batch['hat_mask'].to(device)
                composite = composite_hat_on_scene(scene, hat_mask, rendered_hat, alpha)
            else:
                #For placeholder, just use rendered hat directly
                composite = rendered_hat
                
            #Apply environmental augmentations
            composite = constrained_augmentation(composite)
            composite = apply_environmental_augmentation(composite)

            #ShakeDrop: perturb composite to create virtual ensemble variants
            composite = shakedrop_forward(composite, drop_prob=config['shakedrop_prob'])
            
            #Compute losses
            #Detection loss
            loss_det = detector.compute_loss(composite)
            
            #Regularization losses
            loss_tv = total_variation_loss(texture)
            loss_nps = nps_loss(texture, config['printer']['nps_threshold'])
            
            #Mutual information (maximize = negate for min)
            loss_mi = compute_mi_loss(aux_net, texture, z)
            
            #Total loss
            loss = (loss_det + 
                   config['lambda_tv'] * loss_tv + 
                   config['lambda_nps'] * loss_nps + 
                   config['lambda_info'] * loss_mi)  # loss_mi is already negated
            
            #Optimize
            opt_g.zero_grad()
            opt_aux.zero_grad()
            
            if isinstance(loss, torch.Tensor) and loss.requires_grad:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
                torch.nn.utils.clip_grad_norm_(aux_net.parameters(), 1.0)
                opt_g.step()
                opt_aux.step()
            
            #Track losses
            epoch_losses['total'] += loss.item() if isinstance(loss, torch.Tensor) else loss
            epoch_losses['detection'] += loss_det.item() if isinstance(loss_det, torch.Tensor) else loss_det
            epoch_losses['tv'] += loss_tv.item()
            epoch_losses['nps'] += loss_nps.item()
            epoch_losses['mi'] += loss_mi.item()
            
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}" if isinstance(loss, torch.Tensor) else f"{loss:.4f}",
                'det': f"{loss_det.item():.4f}" if isinstance(loss_det, torch.Tensor) else f"{loss_det:.4f}"
            })
            
        #Epoch summary
        n_batches = len(dataloader)
        for k in epoch_losses:
            epoch_losses[k] /= n_batches
            
        logger.info(f"Epoch {epoch+1} - Loss: {epoch_losses['total']:.4f}, "
                   f"Det: {epoch_losses['detection']:.4f}, "
                   f"TV: {epoch_losses['tv']:.4f}, "
                   f"MI: {epoch_losses['mi']:.4f}")
        
        #Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'generator': generator.state_dict(),
                'aux_net': aux_net.state_dict(),
                'epoch': epoch,
            }, f"{config['output_dir']}/stage1_epoch{epoch+1}.pth")
            
            #Save sample texture
            with torch.no_grad():
                sample = generator.generate(batch_size=1)
                save_texture(sample[0], f"{config['output_dir']}/texture_epoch{epoch+1}.png")
                
    return generator


def save_texture(texture, path):
    """Save texture for preview."""
    if isinstance(texture, torch.Tensor):
        texture = texture.detach().cpu()
        if texture.dim() == 3:
            texture = texture.permute(1, 2, 0).numpy()
        texture = (texture * 255).astype(np.uint8)
    cv2.imwrite(str(path), cv2.cvtColor(texture, cv2.COLOR_RGB2BGR))
    logger.info(f"Saved texture to {path}")


def save_final_texture(texture, config, path):
    """Save print-ready texture at correct DPI."""
    from PIL import Image
    
    output_size = config['texture_output_size']
    
    # Resize to print resolution
    texture_highres = F.interpolate(
        texture, 
        size=(output_size, output_size), 
        mode='bilinear', 
        align_corners=False
    )
    
    # Clamp to printable gamut
    texture_highres = printer_gamut.clamp_to_gamut(texture_highres)
    
    # Convert to uint8
    texture_np = texture_highres[0].permute(1, 2, 0).detach().cpu().numpy()
    texture_np = (texture_np * 255).astype(np.uint8)
    
    img = Image.fromarray(texture_np)
    
    # Save TIFF with DPI metadata (for print)
    tiff_path = Path(path).with_suffix('.tiff')
    img.save(tiff_path, dpi=(config['printer']['dpi'], config['printer']['dpi']))
    
    # Save PNG for preview
    img.save(path)
    
    logger.info(f"Saved: {path} and {tiff_path}")
    logger.info(f"  {output_size}x{output_size}px at {config['printer']['dpi']} DPI")

## 15: Stage 2: Latent Optimization

Optimize the Latent

In [None]:
def train_stage2(generator, detector, dataset, config, z_local=None):

    generator.eval()  #Freeze generator
    
    #Initialize toroidal latent
    toroidal = ToroidalLatent(
        local_size=config['local_latent_size'],
        crop_size=9,
        latent_channels=128,
        device=device
    )
    
    if z_local is not None:
        toroidal.z_local.data = z_local
        
    #Optimizer for latent only
    optimizer = torch.optim.Adam(toroidal.parameters(), lr=config['stage2_lr'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config['stage2_iterations'], eta_min=1e-4
    )
    
    #DataLoader
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=config['stage1_batch_size'], shuffle=True,
        num_workers=config['num_workers'], pin_memory=True, persistent_workers=True
    )
    data_iter = iter(dataloader)
    
    logger.info("Starting Stage 2 latent optimization...")
    best_loss = float('inf')
    
    pbar = tqdm(range(config['stage2_iterations']), desc="Stage 2")
    for iteration in pbar:
        #Get batch (cycle through dataset)
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            batch = next(data_iter)
            
        #Get random crops from toroidal latent
        z_crops = toroidal.get_random_crops(config['stage1_batch_size'])
            
        #We need gradients through generator for z
        texture = generator.generate(z_crops)
        
        #Sample viewpoints and render
        rendered_hats = []
        alphas = []
        for i in range(config['stage1_batch_size']):
            elev, azim, scl = apply_viewpoint_jitter(
                batch['elevation'][i].item(),
                batch['azimuth'][i].item(),
                batch['scale'][i].item(),
                config
            )
            rh, al = renderer.render(texture[i:i+1], elev, azim, scl)
            rendered_hats.append(rh)
            alphas.append(al)
        rendered_hat = torch.cat(rendered_hats, dim=0)
        alpha = torch.cat(alphas, dim=0)
        
        #Apply augmentations
        rendered_hat = patch_cutout(rendered_hat, alpha, config['cutout_prob'], config['cutout_ratio'])
        rendered_hat = add_light_spots(rendered_hat, alpha)
        rendered_hat = simulate_printing(rendered_hat)
        
        #Composite (placeholder mode)
        composite = constrained_augmentation(rendered_hat)
        composite = apply_environmental_augmentation(composite)

        #ShakeDrop: virtual ensemble
        composite = shakedrop_forward(composite, drop_prob=config['shakedrop_prob'])
        
        #Compute loss (no MI term in Stage 2)
        loss_det = detector.compute_loss(composite)
        loss_tv = total_variation_loss(texture)
        loss_nps = nps_loss(texture, config['printer']['nps_threshold'])
        
        loss = (loss_det + 
               config['lambda_tv'] * loss_tv + 
               config['lambda_nps'] * loss_nps)
        
        #Optimize
        optimizer.zero_grad()
        if isinstance(loss, torch.Tensor) and loss.requires_grad:
            loss.backward()
            optimizer.step()
        scheduler.step()
        
        #Track best
        loss_val = loss.item() if isinstance(loss, torch.Tensor) else loss
        if loss_val < best_loss:
            best_loss = loss_val
            best_z_local = toroidal.z_local.data.clone()
            
        pbar.set_postfix({
            'loss': f"{loss_val:.4f}",
            'best': f"{best_loss:.4f}",
            'lr': f"{scheduler.get_last_lr()[0]:.6f}"
        })
        
        #Periodic logging
        if (iteration + 1) % 200 == 0:
            with torch.no_grad():
                # Generate final texture at full resolution
                z_full = toroidal.get_full_latent(config['latent_size'])
                final_texture = generator.generate(z_full)
                save_texture(final_texture[0], 
                           f"{config['output_dir']}/texture_stage2_iter{iteration+1}.png")
                
    #Save final results
    logger.info(f"Stage 2 complete. Best loss: {best_loss:.4f}")
    
    #Generate final texture
    with torch.no_grad():
        z_full = toroidal.get_full_latent(config['latent_size'])
        final_texture = generator.generate(z_full)
        
        # Resize to target output resolution
        def tile_texture(tex, target):
            _, _, h, w = tex.shape
            reps = (target + h - 1) // h
            tiled = tex.repeat(1, 1, reps, reps)
            return tiled[:, :, :target, :target]
        
        final_texture = tile_texture(final_texture, 1024)
        
    save_final_texture(final_texture, config, f"{config['output_dir']}/final_texture.png")
    
    torch.save({
        'z_local': best_z_local,
        'generator': generator.state_dict(),
    }, f"{config['output_dir']}/stage2_final.pth")
    
    return best_z_local, final_texture

## 16: Evaluation

See how we did

In [None]:
def compute_asr(detector, images, gt_boxes, conf_threshold=0.5, iou_threshold=0.5):

    results = detector.detect(images, conf_threshold)
    
    n_success = 0
    for model_name, model_results in results.items():
        for i, det in enumerate(model_results):
            gt_box = gt_boxes[i].cpu().numpy()
            
            detected = False
            for box, score in zip(det['boxes'], det['scores']):
                #Compute IoU
                x1 = max(box[0], gt_box[0])
                y1 = max(box[1], gt_box[1])
                x2 = min(box[2], gt_box[2])
                y2 = min(box[3], gt_box[3])
                
                inter = max(0, x2-x1) * max(0, y2-y1)
                area1 = (box[2]-box[0]) * (box[3]-box[1])
                area2 = (gt_box[2]-gt_box[0]) * (gt_box[3]-gt_box[1])
                iou = inter / (area1 + area2 - inter + 1e-8)
                
                if iou >= iou_threshold:
                    detected = True
                    break
                    
            if not detected:
                n_success += 1
                
    #Average across models
    total = len(results) * len(gt_boxes)
    return n_success / total if total > 0 else 0.0


def evaluate_texture(generator, detector, dataset, z_local, config, num_samples=100):
    generator.eval()
    
    #Create toroidal latent with optimized pattern
    toroidal = ToroidalLatent(
        local_size=config['local_latent_size'],
        crop_size=9,
        device=device
    )
    toroidal.z_local.data = z_local
    
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=8, shuffle=False,
        num_workers=config['num_workers'], pin_memory=True
    )
    
    asr_scores = {thresh: [] for thresh in [0.1, 0.3, 0.5, 0.7, 0.9]}
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            #Generate texture
            z_crops = toroidal.get_random_crops(8)
            texture = generator.generate(z_crops)
            
            #Render and composite
            rendered_hat, alpha = renderer.render(texture, 85, 0, 1.0)
            
            #For placeholder, use rendered directly
            composite = rendered_hat
            
            # Test at multiple thresholds
            gt_boxes = batch.get('person_bbox', torch.zeros(8, 4))
            for thresh in asr_scores.keys():
                asr = compute_asr(detector, composite, gt_boxes, conf_threshold=thresh)
                asr_scores[thresh].append(asr)
                
    #Compute mean ASR
    mean_asr = {}
    for thresh, scores in asr_scores.items():
        mean_asr[thresh] = np.mean(scores)
        logger.info(f"ASR@{thresh}: {mean_asr[thresh]:.2%}")
        
    overall_masr = np.mean(list(mean_asr.values()))
    logger.info(f"Mean ASR: {overall_masr:.2%}")
    
    return mean_asr

## 17: Main

Bring it all together

In [None]:
def main():
    #Initialize models
    logger.info("Initializing models...")
    
    generator = FCNGenerator(latent_channels=CONFIG['latent_channels']).to(device)
    aux_net = AuxiliaryNetwork().to(device)
    
    logger.info(f"Generator params: {sum(p.numel() for p in generator.parameters()):,}")
    logger.info(f"Auxiliary params: {sum(p.numel() for p in aux_net.parameters()):,}")
    
    #Stage 1: Train generator
    logger.info("="*50)
    logger.info("STAGE 1: Generator Training")
    logger.info("="*50)
    
    generator = train_stage1(generator, aux_net, detector, dataset, CONFIG)
    
    #Stage 2: Optimize latent
    logger.info("="*50)
    logger.info("STAGE 2: Latent Optimization")
    logger.info("="*50)
    
    best_z_local, final_texture = train_stage2(generator, detector, dataset, CONFIG)
    
    #Evaluation
    logger.info("="*50)
    logger.info("EVALUATION")
    logger.info("="*50)
    
    asr_results = evaluate_texture(generator, detector, dataset, best_z_local, CONFIG)
    
    #Final visualization
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Final texture
    axes[0].imshow(final_texture[0].permute(1, 2, 0).cpu())
    axes[0].set_title('Final Adversarial Texture')
    axes[0].axis('off')
    
    #Rendered examples
    with torch.no_grad():
        rendered, alpha = renderer.render(final_texture, 85, 45, 1.0)
    axes[1].imshow(rendered[0].permute(1, 2, 0).cpu())
    axes[1].set_title('Rendered Hat (85°, 45°)')
    axes[1].axis('off')
    
    #ASR plot
    thresholds = list(asr_results.keys())
    values = [asr_results[t] for t in thresholds]
    axes[2].bar(range(len(thresholds)), values)
    axes[2].set_xticks(range(len(thresholds)))
    axes[2].set_xticklabels([f'{t}' for t in thresholds])
    axes[2].set_xlabel('Confidence Threshold')
    axes[2].set_ylabel('Attack Success Rate')
    axes[2].set_title('ASR vs Threshold')
    
    plt.tight_layout()
    plt.savefig(f"{CONFIG['output_dir']}/final_results.png", dpi=150)
    plt.show()
    
    logger.info("Training complete!")
    logger.info(f"Final texture saved to: {CONFIG['output_dir']}/final_texture.png")
    
    return generator, best_z_local, final_texture


#Run if this is the main notebook
if __name__ == "__main__":
    generator, z_local, texture = main()