# OFDM-GAN-SR Training on Google Colab

This notebook trains the CWGAN-GP model for OFDM signal reconstruction.

In [None]:
# Clone repository
!git clone https://github.com/orpheus016/ofdm-gan-sr.git
%cd ofdm-gan-sr

In [None]:
# Install dependencies
!pip install -q -r requirements.txt

In [None]:
# Test model architectures before training
import torch
from models.generator import UNetGenerator
from models.discriminator import Discriminator

print("=" * 60)
print("MODEL VALIDATION TEST")
print("=" * 60)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n✓ Device: {device}")
if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Initialize models
print("\n[1/5] Initializing Generator...")
generator = UNetGenerator(
    input_channels=2,
    output_channels=2,
    base_channels=32,
    depth=5
).to(device)

print("[2/5] Initializing Discriminator...")
discriminator = Discriminator(
    input_channels=4,  # 2 (candidate) + 2 (condition)
    base_channels=32,
    num_layers=6
).to(device)

# Count parameters
gen_params = sum(p.numel() for p in generator.parameters())
disc_params = sum(p.numel() for p in discriminator.parameters())
print(f"\n✓ Generator parameters: {gen_params:,}")
print(f"✓ Discriminator parameters: {disc_params:,}")

# Test forward pass
print("\n[3/5] Testing Generator forward pass...")
batch_size = 4
noisy_signal = torch.randn(batch_size, 2, 1024).to(device)

try:
    fake_signal = generator(noisy_signal)
    assert fake_signal.shape == (batch_size, 2, 1024), \
        f"Generator output shape mismatch: {fake_signal.shape}"
    print(f"✓ Generator output shape: {tuple(fake_signal.shape)}")
except Exception as e:
    print(f"❌ Generator test FAILED: {e}")
    raise

# Test discriminator
print("\n[4/5] Testing Discriminator forward pass...")
clean_signal = torch.randn(batch_size, 2, 1024).to(device)
condition = noisy_signal  # Use noisy as condition

try:
    # Concatenate along channel dimension
    disc_input = torch.cat([fake_signal.detach(), condition], dim=1)
    score = discriminator(disc_input)
    print(f"✓ Discriminator output shape: {tuple(score.shape)}")
    print(f"✓ Score range: [{score.min().item():.3f}, {score.max().item():.3f}]")
except Exception as e:
    print(f"❌ Discriminator test FAILED: {e}")
    raise

# Memory usage
print("\n[5/5] Checking memory usage...")
if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated(device) / 1e9
    reserved = torch.cuda.memory_reserved(device) / 1e9
    print(f"✓ GPU Memory: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved")

print("\n" + "=" * 60)
print("✅ ALL TESTS PASSED - Models are ready for training!")
print("=" * 60)

In [None]:
# Start training with fixed configuration
!python train.py --synthetic --epochs 100 --lr 0.0001 --batch_size 16

In [None]:
# Optional: Monitor training with TensorBoard
%load_ext tensorboard
%tensorboard --logdir runs/

In [None]:
# Optional: Save checkpoints to Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create backup directory
!mkdir -p /content/drive/MyDrive/ofdm-gan-checkpoints

# Copy checkpoints (run this after training completes or periodically)
!cp -r checkpoints/* /content/drive/MyDrive/ofdm-gan-checkpoints/
!cp -r runs/* /content/drive/MyDrive/ofdm-gan-checkpoints/runs/

print("✓ Checkpoints saved to Google Drive")

In [None]:
!python train.py --synthetic --epochs 100 --lr 0.0001