In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import os
import sys

# ==============================================================================
# 0. Configuration & Constants
# ==============================================================================
class Config:
    PROJECT_NAME = "FPGA_CNN_ACCELERATOR"
    VERSION = "1.0.0"
    
    # Paths
    DATA_DIR = "./data"
    EXPORT_DIR = "./export"
    
    # Training Params
    BATCH_SIZE = 64
    LEARNING_RATE = 0.001
    EPOCHS = 1  # Kept low for demo speed
    
    # Hardware Constraints
    INPUT_WIDTH = 28
    INPUT_HEIGHT = 28
    DATA_BIT_WIDTH = 8
    WEIGHT_BIT_WIDTH = 8
    
    @staticmethod
    def banner():
        print(f"\n{'='*80}")
        print(f"  {Config.PROJECT_NAME} - v{Config.VERSION}")
        print(f"  Target: Xilinx Artix-7 FPGA (8-bit Quantized)")
        print(f"{'='*80}\n")

# ==============================================================================
# 1. Hardware-Aware Model
# ==============================================================================
class HardwareCNN(nn.Module):
    def __init__(self):
        super(HardwareCNN, self).__init__()
        # Matches FPGA architecture: 1 Input -> 8 Filters -> ReLU -> MaxPool
        self.conv = nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=0, bias=False)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.pool(x)
        return x

# ==============================================================================
# 2. Data Pipeline
# ==============================================================================
class DataEngine:
    @staticmethod
    def hardware_transform(x):
        """
        Transforms [0.0, 1.0] float input to [-1.0, 1.0] approximate range 
        to simulate signed 8-bit integer behavior (-128 to 127).
        """
        return ((x * 255.0) - 128.0) / 128.0

    @staticmethod
    def get_loaders():
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(DataEngine.hardware_transform)
        ])
        
        train_data = datasets.MNIST(Config.DATA_DIR, train=True, download=True, transform=transform)
        test_data = datasets.MNIST(Config.DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
        
        train_loader = torch.utils.data.DataLoader(train_data, batch_size=Config.BATCH_SIZE, shuffle=True)
        return train_loader, test_data

# ==============================================================================
# 3. Export Engine
# ==============================================================================
class Exporter:
    def __init__(self, output_dir):
        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)

    def save_decimal(self, filename, tensor, comment=""):
        path = os.path.join(self.output_dir, filename)
        flat_data = tensor.view(-1).int().numpy()
        with open(path, "w") as f:
            for val in flat_data:
                f.write(f"{val}\n")
        print(f"  [EXPORT] Wrote {len(flat_data)} values to {path} {comment}")

    def save_verilog_hex(self, filename, weights_tensor):
        """
        Packs 9 weights (one 3x3 filter) into a single line for Verilog `readmemh`.
        Each line = 72 bits (9 weights * 8 bits).
        """
        path = os.path.join(self.output_dir, filename)
        flat_weights = weights_tensor.view(-1).int().numpy()
        
        with open(path, "w") as f:
            # We have 8 filters, each 3x3 = 9 weights
            for i in range(0, len(flat_weights), 9):
                hex_line = ""
                # Process 9 weights for one filter row
                for j in range(9):
                    if i + j < len(flat_weights):
                        val = flat_weights[i + j]
                        # 2's Complement conversion for 8-bit
                        if val < 0: val = (256 + val) & 0xFF
                        else:       val = val & 0xFF
                        hex_line += f"{val:02X}" # Append hex byte
                f.write(f"{hex_line}\n")
        
        print(f"  [EXPORT] Wrote Verilog ROM file to {path} (Packed Hex)")

# ==============================================================================
# 4. Main Execution Pipeline
# ==============================================================================
def main():
    Config.banner()
    
    # --- Step 1: Setup ---
    print("[1/5] Initializing Environment...")
    device = torch.device("cpu") # CPU is fine for this size
    model = HardwareCNN().to(device)
    train_loader, test_dataset = DataEngine.get_loaders()
    optimizer = optim.Adam(model.parameters(), lr=Config.LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()
    
    # --- Step 2: Training ---
    print(f"\n[2/5] Training Model (Epochs: {Config.EPOCHS})...")
    model.train()
    # Temporary FC layer for training convergence
    fc_temp = nn.Linear(8 * 13 * 13, 10).to(device) 
    
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        features = model(data)
        output = fc_temp(features.view(features.size(0), -1))
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        # Simple progress indicator
        if batch_idx % 50 == 0:
            print(f"  Batch {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}")
        if batch_idx > 150: break # Quick exit for demo purposes

    # --- Step 3: Quantization ---
    print(f"\n[3/5] Quantizing Weights (Float32 -> Int8)...")
    float_weights = model.conv.weight.data
    max_val = torch.max(torch.abs(float_weights))
    scale_factor = 100.0 / max_val
    
    # Quantize and Clamp to -127..127
    int_weights = torch.round(float_weights * scale_factor).int()
    int_weights = torch.clamp(int_weights, -127, 127)
    
    # Update model with integer weights (stored as float for PyTorch compatibility)
    model.conv.weight.data = int_weights.float()
    print(f"  Scale Factor used: {scale_factor:.2f}")
    print(f"  Weight Range: [{int_weights.min()}, {int_weights.max()}]")

    # --- Step 4: Golden Model Generation ---
    print(f"\n[4/5] Generating Golden Reference Data...")
    
    # Pick a specific test image (Index 3 is usually a '0' digit)
    raw_image, label = test_dataset[3]
    
    # Transform to hardware domain (0..255 -> -128..127)
    input_int = torch.round(raw_image * 255.0) - 128.0
    input_int = torch.clamp(input_int, -128, 127)
    
    # Run Inference
    with torch.no_grad():
        golden_output = model(input_int.unsqueeze(0))
        
    print(f"  Processed Image Label: {label}")
    print(f"  Golden Output Shape: {golden_output.shape} (Flattened: {golden_output.numel()})")

    # --- Step 5: Exporting & Reporting ---
    print(f"\n[5/5] Exporting Files to '{Config.EXPORT_DIR}'...")
    exporter = Exporter(Config.EXPORT_DIR)
    
    # Export Files
    exporter.save_decimal("image_data.txt", input_int, "(Input Image)")
    exporter.save_decimal("weights.txt", model.conv.weight.data, "(Flattened Weights)")
    exporter.save_decimal("golden_output.txt", golden_output, "(Expected Output)")
    exporter.save_verilog_hex("weights.mem", model.conv.weight.data)

    # --- Visual Report ---
    print_visual_report(input_int, model.conv.weight.data, golden_output)

def print_visual_report(img, weights, output):
    print("\n" + "="*80)
    print(" PROJECT DEMO DASHBOARD")
    print("="*80)
    
    # 1. Visualize Image Data (Center Crop)
    print("\n1. INPUT IMAGE DATA (Center 10x10 Crop)")
    print("   Format: Signed 8-bit Integer (-128 to 127)")
    print("   Note: -128 represents logical 0 (Black), 127 represents 255 (White)")
    print("-" * 50)
    
    img_data = img.view(28, 28).int().numpy()
    center_y, center_x = 14, 14
    for y in range(center_y - 5, center_y + 5):
        row_str = "  "
        for x in range(center_x - 5, center_x + 5):
            # Print padded number for alignment
            row_str += f"{img_data[y, x]:4d} "
        print(row_str)
    print("-" * 50)

    # 2. Visualize Weights
    print("\n2. TRAINED KERNELS (First 2 Filters)")
    print("   Format: 3x3 Signed Integers")
    print("-" * 50)
    w_data = weights.view(8, 3, 3).int().numpy()
    
    col_width = 25
    print(f"{'Filter 0':<{col_width}} | {'Filter 1':<{col_width}}")
    print("-" * 50)
    
    for r in range(3):
        f0_row = " ".join(f"{val:3d}" for val in w_data[0][r])
        f1_row = " ".join(f"{val:3d}" for val in w_data[1][r])
        print(f"[{f0_row}]   |   [{f1_row}]")
    print("-" * 50)

    # 3. Verification Stats
    print("\n3. VERIFICATION METRICS")
    print(f"   Max Pool Output Non-Zeros: {torch.count_nonzero(output).item()} / {output.numel()}")
    # print(f"   Verification File Path:    {os.path.abspath(Config.EXPORT_DIR)}")
    print("\n   [SUCCESS] System Ready for Vivado Simulation.")
    print("="*80 + "\n")

if __name__ == "__main__":
    main()


  FPGA_CNN_ACCELERATOR - v1.0.0
  Target: Xilinx Artix-7 FPGA (8-bit Quantized)

[1/5] Initializing Environment...

[2/5] Training Model (Epochs: 1)...
  Batch 0/938 | Loss: 2.3141
  Batch 50/938 | Loss: 2.2539
  Batch 100/938 | Loss: 2.2811
  Batch 150/938 | Loss: 2.2242

[3/5] Quantizing Weights (Float32 -> Int8)...
  Scale Factor used: 230.18
  Weight Range: [-85, 100]

[4/5] Generating Golden Reference Data...
  Processed Image Label: 0
  Golden Output Shape: torch.Size([1, 8, 13, 13]) (Flattened: 1352)

[5/5] Exporting Files to './export'...
  [EXPORT] Wrote 784 values to ./export\image_data.txt (Input Image)
  [EXPORT] Wrote 72 values to ./export\weights.txt (Flattened Weights)
  [EXPORT] Wrote 1352 values to ./export\golden_output.txt (Expected Output)
  [EXPORT] Wrote Verilog ROM file to ./export\weights.mem (Packed Hex)

 PROJECT DEMO DASHBOARD

1. INPUT IMAGE DATA (Center 10x10 Crop)
   Format: Signed 8-bit Integer (-128 to 127)
   Note: -128 represents logical 0 (Black), 12

In [None]:
import sys
import os
import math

class VerificationEngine:
    def __init__(self, hardware_file="export/hardware_output.txt", golden_file="export/golden_output.txt"):
        self.hw_path = hardware_file
        self.gold_path = golden_file
        self.hw_data = []
        self.gold_data = []
        
        # Architecture parameters for detailed analysis
        self.NUM_FILTERS = 8
        self.OUT_H = 13
        self.OUT_W = 13
        self.TOTAL_EXPECTED = self.NUM_FILTERS * self.OUT_H * self.OUT_W

    def load_data(self):
        print(f"\n[INFO] Loading Verification Files...")
        try:
            with open(self.hw_path, 'r') as f:
                self.hw_data = [int(line.strip()) for line in f if line.strip()]
            print(f"  -> Loaded Hardware Output: {len(self.hw_data)} values")
            
            with open(self.gold_path, 'r') as f:
                self.gold_data = [int(line.strip()) for line in f if line.strip()]
            print(f"  -> Loaded Golden Model:    {len(self.gold_data)} values")
            
        except FileNotFoundError as e:
            print(f"\n[CRITICAL ERROR] File missing: {e.filename}")
            print("  Please run the simulation first to generate 'hardware_output.txt'")
            sys.exit(1)

    def print_header(self):
        print("\n" + "="*70)
        print("                 CNN HARDWARE VERIFICATION REPORT")
        print("="*70)

    def analyze(self):
        self.print_header()
        
        # 1. Length Check
        len_hw = len(self.hw_data)
        len_gold = len(self.gold_data)
        
        if len_hw != len_gold:
            print(f"\n[WARNING] Length Mismatch!")
            print(f"  Expected: {len_gold}")
            print(f"  Received: {len_hw}")
            print(f"  Missing/Extra: {len_hw - len_gold} values")
            if len_hw < len_gold:
                print("  (Simulation likely stopped early or RESET logic issue)")
        
        limit = min(len_hw, len_gold)
        errors = 0
        max_diff = 0
        mse_sum = 0
        
        # 2. Value Comparison
        print(f"\n[ANALYSIS] Comparing first {limit} values...")
        
        # We will track errors per filter to see if a specific channel is broken
        filter_errors = [0] * self.NUM_FILTERS
        
        for i in range(limit):
            hw = self.hw_data[i]
            gold = self.gold_data[i]
            diff = abs(hw - gold)
            
            if diff > 0:
                errors += 1
                max_diff = max(max_diff, diff)
                mse_sum += diff ** 2
                
                # Identify which filter this belongs to
                # Index map: [Filter 0 (169 vals)] [Filter 1 (169 vals)] ...
                filter_idx = i // (self.OUT_H * self.OUT_W)
                if filter_idx < self.NUM_FILTERS:
                    filter_errors[filter_idx] += 1

                # Print first few errors for debug
                if errors <= 5:
                    print(f"  [MISMATCH] Idx {i:4d} | HW: {hw:6d} | Gold: {gold:6d} | Diff: {hw-gold}")

        # 3. Statistical Summary
        mse = mse_sum / limit if limit > 0 else 0
        accuracy = 100.0 * (limit - errors) / limit if limit > 0 else 0
        
        print("\n" + "-"*70)
        print("  STATISTICAL SUMMARY")
        print("-"*70)
        print(f"  Total Data Points:   {limit}")
        print(f"  Perfect Matches:     {limit - errors}")
        print(f"  Mismatches:          {errors}")
        print(f"  Accuracy:            {accuracy:.2f}%")
        print(f"  Max Absolute Error:  {max_diff}")
        print(f"  Mean Squared Error:  {mse:.4f}")
        
        # 4. Filter Breakdown
        print("\n" + "-"*70)
        print("  FILTER CHANNEL DIAGNOSTICS")
        print("-"*70)
        print(f"  {'Filter ID':<10} | {'Status':<10} | {'Errors':<10} | {'Integrity'}")
        
        for f_id, err_count in enumerate(filter_errors):
            status = "PASS" if err_count == 0 else "FAIL"
            # Visual integrity bar
            total_pixels = self.OUT_H * self.OUT_W
            integrity = 100.0 * (total_pixels - err_count) / total_pixels
            bar_len = int(integrity / 10)
            bar = "█" * bar_len + "." * (10 - bar_len)
            
            print(f"  #{f_id:<9} | {status:<10} | {err_count:<10} | [{bar}] {integrity:.1f}%")

        # 5. Visual Sample (First Filter)
        self.print_visual_sample()

        # 6. Final Verdict
        print("\n" + "="*70)
        if errors == 0 and len_hw == len_gold:
            print("  FINAL VERDICT: [ PASS ]  Hardware is Bit-Accurate.")
        else:
            print("  FINAL VERDICT: [ FAIL ]  Design needs debugging.")
        print("="*70 + "\n")

    def print_visual_sample(self):
        """Prints the FULL 13x13 output grids one below the other"""
        print("\n" + "="*80)
        print(f"  FULL VISUAL ALIGNMENT (Filter 0: {self.OUT_H}x{self.OUT_W})")
        print("=" * 80)
        
        # ---------------------------------------------------------
        # 1. Print Golden Model Grid
        # ---------------------------------------------------------
        print("\n  [1] GOLDEN MODEL (Expected Pattern):")
        print("  " + "-"*(self.OUT_W * 5))
        
        for r in range(self.OUT_H):
            row_str = "  "
            for c in range(self.OUT_W):
                # Filter 0 linear index
                idx = r * self.OUT_W + c
                
                if idx < len(self.gold_data):
                    row_str += f"{self.gold_data[idx]:4d} "
                else:
                    row_str += "   ? "
            print(row_str)

        # ---------------------------------------------------------
        # 2. Print Hardware Output Grid
        # ---------------------------------------------------------
        print("\n  [2] HARDWARE OUTPUT (Actual Result):")
        print("  " + "-"*(self.OUT_W * 5))
        
        for r in range(self.OUT_H):
            row_str = "  "
            for c in range(self.OUT_W):
                idx = r * self.OUT_W + c
                
                if idx < len(self.hw_data):
                    val_h = self.hw_data[idx]
                    
                    # Check for mismatch against golden for flagging
                    is_error = False
                    if idx < len(self.gold_data):
                        if val_h != self.gold_data[idx]:
                            is_error = True
                    
                    # Format output
                    if is_error:
                        row_str += f"{val_h:3d}! " # Mark errors with '!'
                    else:
                        row_str += f"{val_h:4d} "
                else:
                    row_str += "   ? "
            print(row_str)
            
        print("\n" + "="*80)
        print("  Legend: '!' indicates a specific pixel mismatch.")

if __name__ == "__main__":
    verifier = VerificationEngine()
    verifier.load_data()
    verifier.analyze()


[INFO] Loading Verification Files...
  -> Loaded Hardware Output: 1352 values
  -> Loaded Golden Model:    1352 values

                 CNN HARDWARE VERIFICATION REPORT

[ANALYSIS] Comparing first 1352 values...

----------------------------------------------------------------------
  STATISTICAL SUMMARY
----------------------------------------------------------------------
  Total Data Points:   1352
  Perfect Matches:     1352
  Mismatches:          0
  Accuracy:            100.00%
  Max Absolute Error:  0
  Mean Squared Error:  0.0000

----------------------------------------------------------------------
  FILTER CHANNEL DIAGNOSTICS
----------------------------------------------------------------------
  Filter ID  | Status     | Errors     | Integrity
  #0         | PASS       | 0          | [██████████] 100.0%
  #1         | PASS       | 0          | [██████████] 100.0%
  #2         | PASS       | 0          | [██████████] 100.0%
  #3         | PASS       | 0          | [██████