# üéØ 6G PA DPD Training Notebook
## CWGAN-GP + A-SPSA for 29th LSI Design Contest

This notebook trains a **Time-Delay Neural Network (TDNN)** for Digital Predistortion using:
- **CWGAN-GP**: Conditional Wasserstein GAN with Gradient Penalty
- **Spectral Loss**: EVM + ACPR optimization
- **QAT**: Quantization-Aware Training for FPGA deployment

### Architecture
```
Input (18) ‚Üí FC(32) ‚Üí LeakyReLU ‚Üí FC(16) ‚Üí LeakyReLU ‚Üí FC(2) ‚Üí Tanh ‚Üí Output
```

**Target**: PYNQ-Z1 / ZCU104 FPGA with HDMI loopback demo

## 1Ô∏è‚É£ Setup Environment

In [None]:
# Clone repository (run once)
!git clone https://github.com/YOUR_USERNAME/6g-pa-gan-dpd.git 2>/dev/null || echo "Already cloned"
%cd 6g-pa-gan-dpd

# Install minimal dependencies (most are pre-installed on Colab)
!pip install -q h5py pyyaml tqdm

In [None]:
# Verify imports
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import yaml

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")

## 2Ô∏è‚É£ Load Configuration

In [None]:
# Load config
with open('config/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Display key parameters
print("=" * 50)
print("Configuration Summary")
print("=" * 50)
print(f"Sample Rate: {config['system']['sample_rate_mhz']} MHz")
print(f"TDNN Architecture: {config['model']['generator']['input_dim']} ‚Üí "
      f"{config['model']['generator']['hidden_dims']} ‚Üí "
      f"{config['model']['generator']['output_dim']}")
print(f"Quantization: {config['quantization']['weight_bits']}-bit weights")
print(f"Batch Size: {config['training']['batch_size']}")
print(f"Learning Rate: {config['training']['learning_rate']}")

## 3Ô∏è‚É£ Define Models

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class TDNNGenerator(nn.Module):
    """
    Time-Delay Neural Network Generator for DPD.
    Memory-aware architecture with envelope features.
    """
    def __init__(self, input_dim=18, hidden_dims=[32, 16], output_dim=2, 
                 quantize=False, num_bits=16):
        super().__init__()
        self.input_dim = input_dim
        self.quantize = quantize
        self.num_bits = num_bits
        
        # Build layers
        layers = []
        in_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(nn.LeakyReLU(0.25))  # 0.25 for easy shift in HW
            in_dim = hidden_dim
            
        self.features = nn.Sequential(*layers)
        self.output = nn.Linear(in_dim, output_dim)
        self.tanh = nn.Tanh()
        
        self._init_weights()
        
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, a=0.25)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
                    
    def forward(self, x):
        # x shape: [batch, seq, input_dim] or [batch, input_dim]
        if x.dim() == 3:
            batch, seq, _ = x.shape
            x = x.reshape(-1, self.input_dim)
            reshape_back = True
        else:
            reshape_back = False
            batch = x.shape[0]
            seq = 1
            
        # Forward pass
        h = self.features(x)
        out = self.tanh(self.output(h))
        
        if reshape_back:
            out = out.reshape(batch, seq, -1)
            
        return out

# Count parameters
model = TDNNGenerator(input_dim=18, hidden_dims=[32, 16], output_dim=2)
total_params = sum(p.numel() for p in model.parameters())
print(f"TDNN Generator: {total_params:,} parameters")

for name, param in model.named_parameters():
    print(f"  {name}: {list(param.shape)} = {param.numel()}")

In [None]:
class Discriminator(nn.Module):
    """CWGAN-GP Critic with spectral normalization."""
    def __init__(self, input_dim=4, hidden_dims=[64, 32, 16]):
        super().__init__()
        
        layers = []
        in_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.append(nn.utils.spectral_norm(nn.Linear(in_dim, hidden_dim)))
            layers.append(nn.LeakyReLU(0.2))
            in_dim = hidden_dim
            
        self.features = nn.Sequential(*layers)
        self.output = nn.utils.spectral_norm(nn.Linear(in_dim, 1))
        
    def forward(self, x, condition):
        # Concatenate PA output and condition
        combined = torch.cat([x, condition], dim=-1)
        h = self.features(combined)
        return self.output(h)

disc = Discriminator()
print(f"Discriminator: {sum(p.numel() for p in disc.parameters()):,} parameters")

In [None]:
class PADigitalTwin(nn.Module):
    """
    PA Digital Twin using Volterra model.
    Models AM-AM, AM-PM, and memory effects.
    """
    def __init__(self, memory_depth=5, nonlin_order=5):
        super().__init__()
        self.memory_depth = memory_depth
        self.nonlin_order = nonlin_order
        
        # Volterra coefficients (learnable or fixed)
        self.register_buffer('alpha1', torch.tensor(0.95))   # Linear gain
        self.register_buffer('alpha3', torch.tensor(-0.12))  # 3rd order
        self.register_buffer('alpha5', torch.tensor(0.03))   # 5th order
        self.register_buffer('beta', torch.tensor(0.15))     # AM-PM
        
        # Memory coefficients
        mem_coef = torch.tensor([1.0, 0.3, 0.1, 0.05, 0.02])
        self.register_buffer('memory_coef', mem_coef / mem_coef.sum())
        
    def forward(self, x, temperature_state=1):
        """
        Args:
            x: Complex input [batch, seq, 2] (I, Q)
            temperature_state: 0=cold, 1=normal, 2=hot
        """
        # Apply temperature drift
        temp_scale = 1.0 + 0.05 * (temperature_state - 1)
        
        # Get I/Q components
        x_i = x[..., 0]
        x_q = x[..., 1]
        
        # Complex magnitude
        mag_sq = x_i**2 + x_q**2
        mag = torch.sqrt(mag_sq + 1e-8)
        
        # AM-AM: Polynomial compression
        gain = self.alpha1 + self.alpha3 * mag_sq + self.alpha5 * mag_sq**2
        gain = gain * temp_scale
        
        # AM-PM: Phase rotation
        phase_shift = self.beta * mag_sq * temp_scale
        cos_phi = torch.cos(phase_shift)
        sin_phi = torch.sin(phase_shift)
        
        # Apply gain and phase
        y_i = gain * (x_i * cos_phi - x_q * sin_phi)
        y_q = gain * (x_i * sin_phi + x_q * cos_phi)
        
        # Memory effects (simplified FIR)
        if x.dim() == 3 and x.shape[1] >= self.memory_depth:
            y_i_mem = F.conv1d(
                y_i.unsqueeze(1), 
                self.memory_coef.view(1, 1, -1),
                padding=self.memory_depth // 2
            ).squeeze(1)[..., :y_i.shape[-1]]
            y_q_mem = F.conv1d(
                y_q.unsqueeze(1),
                self.memory_coef.view(1, 1, -1),
                padding=self.memory_depth // 2
            ).squeeze(1)[..., :y_q.shape[-1]]
        else:
            y_i_mem = y_i
            y_q_mem = y_q
            
        return torch.stack([y_i_mem, y_q_mem], dim=-1)

pa = PADigitalTwin()
print("PA Digital Twin created (Volterra model)")

## 4Ô∏è‚É£ Feature Engineering

In [None]:
def create_memory_features(x, memory_depth=5):
    """
    Create input features for TDNN including memory taps and envelope.
    
    Input: x [batch, seq, 2] (I, Q)
    Output: features [batch, seq, 18]
        - Current I/Q: 2
        - Delayed I/Q (3 taps): 6  
        - Envelope features |x|^2, |x|^4 (5 taps): 10
    """
    batch, seq, _ = x.shape
    
    # Pad for delays
    x_padded = F.pad(x, (0, 0, memory_depth-1, 0), mode='replicate')
    
    features_list = []
    
    # Current I/Q
    features_list.append(x)  # [batch, seq, 2]
    
    # Delayed I/Q (taps 1, 2, 3)
    for d in range(1, 4):
        delayed = x_padded[:, memory_depth-1-d:memory_depth-1-d+seq, :]
        features_list.append(delayed)
    
    # Envelope features: |x[n-k]|^2 and |x[n-k]|^4 for k=0..4
    for d in range(5):
        if d == 0:
            tap = x
        else:
            tap = x_padded[:, memory_depth-1-d:memory_depth-1-d+seq, :]
        
        mag_sq = tap[..., 0]**2 + tap[..., 1]**2  # |x|^2
        mag_4 = mag_sq ** 2  # |x|^4
        
        features_list.append(mag_sq.unsqueeze(-1))
        features_list.append(mag_4.unsqueeze(-1))
    
    # Concatenate all features
    features = torch.cat(features_list, dim=-1)  # [batch, seq, 18]
    
    return features

# Test feature creation
test_x = torch.randn(4, 64, 2)
test_features = create_memory_features(test_x)
print(f"Input shape: {test_x.shape}")
print(f"Features shape: {test_features.shape}")

## 5Ô∏è‚É£ Loss Functions

In [None]:
def compute_evm(reference, signal):
    """Compute Error Vector Magnitude (lower is better)."""
    error = signal - reference
    error_power = (error ** 2).sum(dim=-1).mean()
    ref_power = (reference ** 2).sum(dim=-1).mean()
    evm = torch.sqrt(error_power / (ref_power + 1e-8))
    return evm

def compute_nmse(reference, signal):
    """Compute Normalized Mean Square Error."""
    error = signal - reference
    nmse = (error ** 2).sum() / (reference ** 2).sum()
    return nmse

def compute_acpr(signal, main_bw_ratio=0.8, adj_bw_ratio=0.1):
    """Compute Adjacent Channel Power Ratio."""
    # Complex signal
    sig_complex = signal[..., 0] + 1j * signal[..., 1]
    
    # FFT
    spectrum = torch.fft.fft(sig_complex, dim=-1)
    power = (spectrum.real**2 + spectrum.imag**2)
    
    n_fft = power.shape[-1]
    
    # Main channel (center)
    main_start = int(n_fft * (0.5 - main_bw_ratio/2))
    main_end = int(n_fft * (0.5 + main_bw_ratio/2))
    main_power = power[..., main_start:main_end].sum()
    
    # Adjacent channels
    adj_lower = power[..., :int(n_fft * adj_bw_ratio)].sum()
    adj_upper = power[..., -int(n_fft * adj_bw_ratio):].sum()
    adj_power = adj_lower + adj_upper
    
    acpr = 10 * torch.log10(adj_power / (main_power + 1e-8))
    return acpr

class SpectralLoss(nn.Module):
    """Combined spectral loss: EVM + ACPR + NMSE."""
    def __init__(self, evm_weight=1.0, acpr_weight=0.5, nmse_weight=1.0):
        super().__init__()
        self.evm_weight = evm_weight
        self.acpr_weight = acpr_weight
        self.nmse_weight = nmse_weight
        
    def forward(self, reference, signal):
        evm = compute_evm(reference, signal)
        nmse = compute_nmse(reference, signal)
        acpr = compute_acpr(signal)
        
        # ACPR is in dB (negative), we want to minimize it (make more negative)
        # So we add it (or use -acpr to maximize suppression)
        loss = (self.evm_weight * evm + 
                self.nmse_weight * nmse - 
                self.acpr_weight * acpr / 50)  # Normalize ACPR contribution
        
        return loss, {'evm': evm.item(), 'nmse': nmse.item(), 'acpr': acpr.item()}

spectral_loss = SpectralLoss()
print("Spectral loss function created")

In [None]:
def compute_gradient_penalty(disc, real, fake, condition, device):
    """WGAN-GP gradient penalty."""
    alpha = torch.rand(real.size(0), 1, 1, device=device)
    interpolates = alpha * real + (1 - alpha) * fake
    interpolates.requires_grad_(True)
    
    d_interp = disc(interpolates, condition)
    
    gradients = torch.autograd.grad(
        outputs=d_interp,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interp),
        create_graph=True,
        retain_graph=True
    )[0]
    
    gradients = gradients.reshape(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    
    return gradient_penalty

## 6Ô∏è‚É£ Generate Synthetic Training Data

In [None]:
def generate_ofdm_signal(batch_size, seq_len, num_subcarriers=64):
    """
    Generate OFDM-like signal for training.
    """
    # Random QAM symbols
    qam_symbols = (torch.randint(0, 4, (batch_size, num_subcarriers)) * 2 - 3) + \
                  1j * (torch.randint(0, 4, (batch_size, num_subcarriers)) * 2 - 3)
    qam_symbols = qam_symbols / torch.abs(qam_symbols).max()
    
    # IFFT to get time domain
    time_signal = torch.fft.ifft(qam_symbols, n=seq_len, dim=-1)
    
    # Stack I/Q
    signal = torch.stack([time_signal.real, time_signal.imag], dim=-1)
    
    # Normalize
    signal = signal / (signal.abs().max() + 1e-8) * 0.8
    
    return signal.float()

# Generate test batch
test_signal = generate_ofdm_signal(8, 256)
print(f"Generated signal shape: {test_signal.shape}")
print(f"Signal range: [{test_signal.min():.3f}, {test_signal.max():.3f}]")

In [None]:
# Visualize signal and PA distortion
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Time domain
ax = axes[0, 0]
ax.plot(test_signal[0, :100, 0].numpy(), label='I')
ax.plot(test_signal[0, :100, 1].numpy(), label='Q')
ax.set_title('Input Signal (Time Domain)')
ax.legend()

# Constellation
ax = axes[0, 1]
ax.scatter(test_signal[0, :, 0].numpy(), test_signal[0, :, 1].numpy(), 
           alpha=0.5, s=5)
ax.set_title('Input Constellation')
ax.set_xlabel('I')
ax.set_ylabel('Q')
ax.axis('equal')

# PA output
pa_output = pa(test_signal)
ax = axes[1, 0]
ax.plot(pa_output[0, :100, 0].detach().numpy(), label='I')
ax.plot(pa_output[0, :100, 1].detach().numpy(), label='Q')
ax.set_title('PA Output (Distorted)')
ax.legend()

ax = axes[1, 1]
ax.scatter(pa_output[0, :, 0].detach().numpy(), 
           pa_output[0, :, 1].detach().numpy(), 
           alpha=0.5, s=5)
ax.set_title('PA Output Constellation (Distorted)')
ax.set_xlabel('I')
ax.set_ylabel('Q')
ax.axis('equal')

plt.tight_layout()
plt.show()

# Compute distortion metrics
evm = compute_evm(test_signal, pa_output)
print(f"\nPA Distortion Metrics (without DPD):")
print(f"  EVM: {evm.item()*100:.2f}%")
print(f"  EVM (dB): {20*np.log10(evm.item()):.2f} dB")

## 7Ô∏è‚É£ Training Loop

In [None]:
# Training configuration
NUM_EPOCHS = 100
BATCH_SIZE = 32
SEQ_LEN = 256
LR_G = 1e-4
LR_D = 1e-4
N_CRITIC = 5  # Train discriminator N times per generator update
GP_WEIGHT = 10.0
SPECTRAL_WEIGHT = 0.5

# Initialize models
generator = TDNNGenerator(input_dim=18, hidden_dims=[32, 16], output_dim=2).to(device)
discriminator = Discriminator(input_dim=4, hidden_dims=[64, 32, 16]).to(device)
pa_model = PADigitalTwin().to(device)

# Optimizers
opt_g = torch.optim.Adam(generator.parameters(), lr=LR_G, betas=(0.5, 0.9))
opt_d = torch.optim.Adam(discriminator.parameters(), lr=LR_D, betas=(0.5, 0.9))

# Loss
spectral_loss_fn = SpectralLoss(evm_weight=1.0, acpr_weight=0.5, nmse_weight=1.0)

print("Training setup complete!")
print(f"  Generator params: {sum(p.numel() for p in generator.parameters()):,}")
print(f"  Discriminator params: {sum(p.numel() for p in discriminator.parameters()):,}")

In [None]:
# Training history
history = {
    'd_loss': [], 'g_loss': [], 'evm': [], 'acpr': [], 'w_distance': []
}

# Training loop
pbar = tqdm(range(NUM_EPOCHS), desc='Training')

for epoch in pbar:
    epoch_d_loss = 0
    epoch_g_loss = 0
    epoch_evm = 0
    n_batches = 100  # Batches per epoch
    
    for batch_idx in range(n_batches):
        # Generate training data
        x_input = generate_ofdm_signal(BATCH_SIZE, SEQ_LEN).to(device)
        
        # Create memory features
        features = create_memory_features(x_input)
        
        # ==================
        # Train Discriminator
        # ==================
        for _ in range(N_CRITIC):
            opt_d.zero_grad()
            
            # Generator output (predistorted signal)
            with torch.no_grad():
                dpd_output = generator(features)
            
            # PA output
            pa_output = pa_model(dpd_output)
            
            # Real = ideal output (input signal)
            # Fake = PA output after DPD
            real_score = discriminator(x_input, x_input)
            fake_score = discriminator(pa_output, x_input)
            
            # Wasserstein loss
            d_loss = fake_score.mean() - real_score.mean()
            
            # Gradient penalty
            gp = compute_gradient_penalty(discriminator, x_input, pa_output, x_input, device)
            
            # Total discriminator loss
            d_total = d_loss + GP_WEIGHT * gp
            d_total.backward()
            opt_d.step()
        
        # ==================
        # Train Generator
        # ==================
        opt_g.zero_grad()
        
        # Generate predistorted signal
        dpd_output = generator(features)
        
        # PA output
        pa_output = pa_model(dpd_output)
        
        # Adversarial loss (want discriminator to think PA output is real)
        fake_score = discriminator(pa_output, x_input)
        g_adv_loss = -fake_score.mean()
        
        # Spectral loss (EVM, ACPR)
        g_spectral_loss, spectral_info = spectral_loss_fn(x_input, pa_output)
        
        # Total generator loss
        g_total = g_adv_loss + SPECTRAL_WEIGHT * g_spectral_loss
        g_total.backward()
        opt_g.step()
        
        # Accumulate metrics
        epoch_d_loss += d_loss.item()
        epoch_g_loss += g_total.item()
        epoch_evm += spectral_info['evm']
    
    # Average metrics
    epoch_d_loss /= n_batches
    epoch_g_loss /= n_batches
    epoch_evm /= n_batches
    
    # Record history
    history['d_loss'].append(epoch_d_loss)
    history['g_loss'].append(epoch_g_loss)
    history['evm'].append(epoch_evm)
    history['w_distance'].append(-epoch_d_loss)
    
    # Update progress bar
    pbar.set_postfix({
        'D_loss': f'{epoch_d_loss:.4f}',
        'G_loss': f'{epoch_g_loss:.4f}',
        'EVM': f'{epoch_evm*100:.2f}%'
    })

In [None]:
# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

ax = axes[0, 0]
ax.plot(history['d_loss'], label='Discriminator')
ax.plot(history['g_loss'], label='Generator')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Loss')
ax.legend()

ax = axes[0, 1]
ax.plot(history['w_distance'])
ax.set_xlabel('Epoch')
ax.set_ylabel('Wasserstein Distance')
ax.set_title('Wasserstein Distance')

ax = axes[1, 0]
ax.plot([e*100 for e in history['evm']])
ax.set_xlabel('Epoch')
ax.set_ylabel('EVM (%)')
ax.set_title('Error Vector Magnitude')

ax = axes[1, 1]
ax.plot([20*np.log10(e) for e in history['evm']])
ax.set_xlabel('Epoch')
ax.set_ylabel('EVM (dB)')
ax.set_title('EVM in dB (lower is better)')

plt.tight_layout()
plt.show()

## 8Ô∏è‚É£ Evaluate Results

In [None]:
# Final evaluation
generator.eval()

with torch.no_grad():
    # Generate test signal
    test_input = generate_ofdm_signal(1, 1024).to(device)
    test_features = create_memory_features(test_input)
    
    # Without DPD
    pa_no_dpd = pa_model(test_input)
    
    # With DPD
    dpd_output = generator(test_features)
    pa_with_dpd = pa_model(dpd_output)

# Compute metrics
evm_no_dpd = compute_evm(test_input, pa_no_dpd)
evm_with_dpd = compute_evm(test_input, pa_with_dpd)

print("="*50)
print("Final Evaluation Results")
print("="*50)
print(f"\nWithout DPD:")
print(f"  EVM: {evm_no_dpd.item()*100:.2f}% ({20*np.log10(evm_no_dpd.item()):.2f} dB)")
print(f"\nWith DPD:")
print(f"  EVM: {evm_with_dpd.item()*100:.2f}% ({20*np.log10(evm_with_dpd.item()):.2f} dB)")
print(f"\nImprovement: {(evm_no_dpd.item() - evm_with_dpd.item())*100:.2f}% absolute")
print(f"             {20*np.log10(evm_no_dpd.item()/evm_with_dpd.item()):.2f} dB")

In [None]:
# Visualize results
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

test_input_np = test_input[0].cpu().numpy()
pa_no_dpd_np = pa_no_dpd[0].cpu().numpy()
pa_with_dpd_np = pa_with_dpd[0].cpu().numpy()

# Constellation plots
ax = axes[0, 0]
ax.scatter(test_input_np[:, 0], test_input_np[:, 1], alpha=0.5, s=5)
ax.set_title('Input Signal')
ax.set_xlabel('I')
ax.set_ylabel('Q')
ax.axis('equal')

ax = axes[0, 1]
ax.scatter(pa_no_dpd_np[:, 0], pa_no_dpd_np[:, 1], alpha=0.5, s=5, c='red')
ax.set_title(f'Without DPD (EVM={evm_no_dpd.item()*100:.1f}%)')
ax.set_xlabel('I')
ax.set_ylabel('Q')
ax.axis('equal')

ax = axes[0, 2]
ax.scatter(pa_with_dpd_np[:, 0], pa_with_dpd_np[:, 1], alpha=0.5, s=5, c='green')
ax.set_title(f'With DPD (EVM={evm_with_dpd.item()*100:.1f}%)')
ax.set_xlabel('I')
ax.set_ylabel('Q')
ax.axis('equal')

# Spectrum plots
def plot_spectrum(ax, signal, title):
    sig_complex = signal[:, 0] + 1j * signal[:, 1]
    spectrum = np.fft.fftshift(np.fft.fft(sig_complex))
    power_db = 20 * np.log10(np.abs(spectrum) + 1e-10)
    freqs = np.fft.fftshift(np.fft.fftfreq(len(sig_complex)))
    ax.plot(freqs, power_db)
    ax.set_title(title)
    ax.set_xlabel('Normalized Frequency')
    ax.set_ylabel('Power (dB)')
    ax.set_ylim(-60, 10)

plot_spectrum(axes[1, 0], test_input_np, 'Input Spectrum')
plot_spectrum(axes[1, 1], pa_no_dpd_np, 'Without DPD Spectrum')
plot_spectrum(axes[1, 2], pa_with_dpd_np, 'With DPD Spectrum')

plt.tight_layout()
plt.show()

## 9Ô∏è‚É£ Export Weights for FPGA

In [None]:
def quantize_weights(weight, num_bits=16):
    """Quantize weights to fixed-point Q1.15 format."""
    scale = 2 ** (num_bits - 1) - 1
    weight_clipped = torch.clamp(weight, -1.0, 1.0 - 1/scale)
    weight_quantized = torch.round(weight_clipped * scale) / scale
    return weight_quantized

def export_weights_hex(model, filepath):
    """Export model weights to Verilog $readmemh format."""
    all_weights = []
    
    for name, param in model.named_parameters():
        w = param.detach().cpu()
        w_quant = quantize_weights(w)
        w_int = (w_quant * 32767).to(torch.int16)
        all_weights.extend(w_int.flatten().tolist())
        print(f"  {name}: {list(w.shape)} = {w.numel()} params")
    
    with open(filepath, 'w') as f:
        f.write(f"// TDNN Generator weights - {len(all_weights)} total\n")
        f.write(f"// Format: Q1.15 signed fixed-point\n\n")
        for w in all_weights:
            if w < 0:
                w = (1 << 16) + w
            f.write(f"{w:04X}\n")
    
    print(f"\nExported {len(all_weights)} weights to {filepath}")

# Export
print("Exporting weights...")
export_weights_hex(generator, 'weights_trained.hex')

In [None]:
# Save checkpoint for later use
checkpoint = {
    'epoch': NUM_EPOCHS,
    'generator_state_dict': generator.state_dict(),
    'discriminator_state_dict': discriminator.state_dict(),
    'optimizer_g_state_dict': opt_g.state_dict(),
    'optimizer_d_state_dict': opt_d.state_dict(),
    'history': history,
    'config': {
        'input_dim': 18,
        'hidden_dims': [32, 16],
        'output_dim': 2,
    }
}

torch.save(checkpoint, 'dpd_trained.pt')
print("Checkpoint saved to dpd_trained.pt")

# Download files (for Colab)
try:
    from google.colab import files
    files.download('weights_trained.hex')
    files.download('dpd_trained.pt')
    print("Files downloaded!")
except:
    print("Not running on Colab - files saved locally")

## üéØ Summary

### Trained Model
- **Architecture**: TDNN 18‚Üí32‚Üí16‚Üí2
- **Parameters**: 1,170 (fits in ~2.4KB)
- **Quantization**: Q1.15 (16-bit fixed-point)

### Next Steps
1. Download `weights_trained.hex`
2. Copy to `rtl/weights/` folder
3. Run RTL simulation: `cd rtl && make sim_all`
4. Build FPGA bitstream in Vivado
5. Run HDMI demo on PYNQ-Z1/ZCU104

### Files Generated
- `weights_trained.hex` - Verilog weight file
- `dpd_trained.pt` - PyTorch checkpoint