In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Generate synthetic dataset (y, psi, omega) assuming 10000 rows
num_samples = 10000
n, d = 150, 512  # Dimensions given in problem

y = torch.randn(num_samples, n, 1)
psi = torch.randn(num_samples, n, d)
omega = torch.randn(num_samples, n, 1)

# Compute Least Squares Estimate h_LS using Moore-Penrose pseudo-inverse
h_LS = torch.linalg.pinv(psi) @ (y - omega)  # Shape: (num_samples, d, 1)
h_LS = h_LS.squeeze(-1)  # Shape: (num_samples, d)

# Prepare dataset
train_dataset = TensorDataset(h_LS, h_LS)  # Input and output are both h_LS
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Define candidate operations for the search space
OPS = {
    'conv_3x3': lambda C: nn.Conv2d(C, C, kernel_size=3, stride=1, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv2d(C, C, kernel_size=5, stride=1, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(nn.Conv2d(C, C, 1, stride=1, bias=False), nn.BatchNorm2d(C)),
    'zero': lambda C: nn.ZeroPad2d(0),
}

# Mixed operation layer with probability-based selection
class MixedOp(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.ops = nn.ModuleList([op(C) for op in OPS.values()])
        self.alphas = nn.Parameter(torch.randn(len(self.ops)))  # Learnable probabilities
    
    def forward(self, x):
        weights = F.softmax(self.alphas, dim=0)  # Normalize with softmax
        return sum(w * op(x) for w, op in zip(weights, self.ops))

# Searchable Neural Network with architecture selection
class SearchNetwork(nn.Module):
    def __init__(self, C, num_layers=3):
        super().__init__()
        self.layers = nn.ModuleList([MixedOp(C) for _ in range(num_layers)])
        self.output_layer = nn.Conv2d(C, C, kernel_size=1, stride=1, bias=False)
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.output_layer(x)

# Training function with truncated reverse-mode AD for NAS
def train(model, train_loader, arch_optimizer, model_optimizer, criterion, unroll_steps=1):
    model.train()
    for x, y in train_loader:
        x, y = x.unsqueeze(1), y.unsqueeze(1)  # Add channel dimension for CNN layers
        x, y = x.to(device), y.to(device)
        
        # Compute loss and update model parameters
        model_optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        model_optimizer.step()
        
        # Architecture update using truncated differentiation
        arch_optimizer.zero_grad()
        with torch.no_grad():
            temp_model = SearchNetwork(d).to(device)
            temp_model.load_state_dict(model.state_dict())
        for _ in range(unroll_steps):
            temp_output = temp_model(x)
            temp_loss = criterion(temp_output, y)
            temp_loss.backward()
        arch_optimizer.step()

# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_layers = 3
num_epochs = 20

# Initialize model and optimizers
model = SearchNetwork(d).to(device)
arch_optimizer = optim.Adam(model.parameters(), lr=0.003)
model_optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Train model with NAS
for epoch in range(num_epochs):
    train(model, train_loader, arch_optimizer, model_optimizer, criterion)
    print(f"Epoch {epoch+1}/{num_epochs} completed.")

# Extract final architecture based on highest probability operations
def extract_final_architecture(model):
    final_ops = []
    for layer in model.layers:
        best_op_idx = torch.argmax(layer.alphas).item()
        best_op_name = list(OPS.keys())[best_op_idx]
        final_ops.append(best_op_name)
    return final_ops

final_architecture = extract_final_architecture(model)
print("Final Architecture:", final_architecture)

# Save the best model
torch.save(model.state_dict(), "best_model.pth")

print("Neural architecture search completed.")

with real and imaginary separate

In [None]:
import numpy as np    
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Generate synthetic dataset with complex values
num_samples = 10000
n, d = 150, 512  # Dimensions given in problem

y_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)
psi_complex = torch.randn(num_samples, n, d, dtype=torch.cfloat)
omega_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)

# Compute Least Squares Estimate h_LS using Moore-Penrose pseudo-inverse
h_LS_complex = torch.linalg.pinv(psi_complex) @ (y_complex - omega_complex)  # Shape: (num_samples, d, 1)
h_LS_complex = h_LS_complex.squeeze(-1)  # Shape: (num_samples, d)

# Split complex numbers into real and imaginary parts
h_LS_real = h_LS_complex.real  # Shape: (num_samples, d)
h_LS_imag = h_LS_complex.imag  # Shape: (num_samples, d)

# Prepare dataset
train_dataset = TensorDataset(h_LS_real, h_LS_imag, h_LS_real, h_LS_imag)  # Inputs and outputs for training
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Define candidate operations for the search space
OPS = {
    'conv_3x3': lambda C: nn.Conv2d(C, C, kernel_size=3, stride=1, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv2d(C, C, kernel_size=5, stride=1, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(nn.Conv2d(C, C, 1, stride=1, bias=False), nn.BatchNorm2d(C)),
    'zero': lambda C: nn.ZeroPad2d(0),
}

# Mixed operation layer with probability-based selection
class MixedOp(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.ops = nn.ModuleList([op(C) for op in OPS.values()])
        self.alphas = nn.Parameter(torch.randn(len(self.ops)))  # Learnable probabilities
    
    def forward(self, x):
        weights = F.softmax(self.alphas, dim=0)  # Normalize with softmax
        return sum(w * op(x) for w, op in zip(weights, self.ops))

# Searchable Neural Network with architecture selection
class SearchNetwork(nn.Module):
    def __init__(self, C, num_layers=3):
        super().__init__()
        self.layers = nn.ModuleList([MixedOp(C) for _ in range(num_layers)])
        self.output_layer = nn.Conv2d(C, C, kernel_size=1, stride=1, bias=False)
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.output_layer(x)

# Training function with truncated reverse-mode AD for NAS
def train(model, train_loader, arch_optimizer, model_optimizer, criterion, unroll_steps=1):
    model.train()
    for real_in, imag_in, real_out, imag_out in train_loader:
        real_in, imag_in = real_in.unsqueeze(1), imag_in.unsqueeze(1)  # Add channel dimension
        real_out, imag_out = real_out.unsqueeze(1), imag_out.unsqueeze(1)
        real_in, imag_in, real_out, imag_out = real_in.to(device), imag_in.to(device), real_out.to(device), imag_out.to(device)
        
        # Compute loss and update model parameters
        model_optimizer.zero_grad()
        real_pred, imag_pred = model(real_in), model(imag_in)
        loss = criterion(real_pred, real_out) + criterion(imag_pred, imag_out)
        loss.backward()
        model_optimizer.step()
        
        # Architecture update using truncated differentiation
        arch_optimizer.zero_grad()
        with torch.no_grad():
            temp_model = SearchNetwork(d).to(device)
            temp_model.load_state_dict(model.state_dict())
        for _ in range(unroll_steps):
            real_temp_pred, imag_temp_pred = temp_model(real_in), temp_model(imag_in)
            temp_loss = criterion(real_temp_pred, real_out) + criterion(imag_temp_pred, imag_out)
            temp_loss.backward()
        arch_optimizer.step()

# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_layers = 3
num_epochs = 20

# Initialize model and optimizers
model = SearchNetwork(d).to(device)
arch_optimizer = optim.Adam(model.parameters(), lr=0.003)
model_optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Train model with NAS
for epoch in range(num_epochs):
    train(model, train_loader, arch_optimizer, model_optimizer, criterion)
    print(f"Epoch {epoch+1}/{num_epochs} completed.")

# Extract final architecture based on highest probability operations
def extract_final_architecture(model):
    final_ops = []
    for layer in model.layers:
        best_op_idx = torch.argmax(layer.alphas).item()
        best_op_name = list(OPS.keys())[best_op_idx]
        final_ops.append(best_op_name)
    return final_ops

final_architecture = extract_final_architecture(model)
print("Final Architecture:", final_architecture)

# Save the best model
torch.save(model.state_dict(), "best_model.pth")
print("Neural architecture search completed.")

with the initial 2 convolutional layers that extract features from the input and the decoder module at the end

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Generate synthetic dataset with complex values
num_samples = 10000
n, d = 150, 512  # Dimensions

y_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)
psi_complex = torch.randn(num_samples, n, d, dtype=torch.cfloat)
omega_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)

# Compute Least Squares Estimate h_LS
h_LS_complex = torch.linalg.pinv(psi_complex) @ (y_complex - omega_complex)  # Shape: (num_samples, d, 1)
h_LS_complex = h_LS_complex.squeeze(-1)  # Shape: (num_samples, d)

# Split into real and imaginary components
h_LS_real = h_LS_complex.real  # Shape: (num_samples, d)
h_LS_imag = h_LS_complex.imag  # Shape: (num_samples, d)

# Prepare dataset
train_dataset = TensorDataset(h_LS_real, h_LS_imag, h_LS_real, h_LS_imag)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Feature Extraction Layers
class FeatureExtractor(nn.Module):
    def __init__(self, input_type):
        super().__init__()
        if input_type == 'vector':
            self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
            self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
        else:
            self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
            self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x

# NAS Candidate Operations
OPS = {
    'conv_3x3': lambda C: nn.Conv2d(C, C, kernel_size=3, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv2d(C, C, kernel_size=5, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(nn.Conv2d(C, C, 1, bias=False), nn.BatchNorm2d(C)),
    'zero': lambda C: nn.ZeroPad2d(0),
}

# Mixed Operation Layer
class MixedOp(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.ops = nn.ModuleList([op(C) for op in OPS.values()])
        self.alphas = nn.Parameter(torch.randn(len(self.ops)))

    def forward(self, x):
        weights = F.softmax(self.alphas, dim=0)
        return sum(w * op(x) for w, op in zip(weights, self.ops))

# Searchable Neural Network
class SearchNetwork(nn.Module):
    def __init__(self, C, num_layers=3):
        super().__init__()
        self.layers = nn.ModuleList([MixedOp(C) for _ in range(num_layers)])
        self.output_layer = nn.Conv2d(C, C, kernel_size=1, bias=False)
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.output_layer(x)

# Decoder Module
class Decoder(nn.Module):
    def __init__(self, input_type):
        super().__init__()
        self.fc = nn.Linear(512, 512)
        
        if input_type == 'vector':
            self.sep_conv1 = nn.Conv1d(32, 16, kernel_size=3, padding=1, groups=16)
            self.sep_conv2 = nn.Conv1d(16, 1, kernel_size=3, padding=1, groups=1)
            self.final_conv = nn.Conv1d(1, 1, kernel_size=3, padding=1)
        else:
            self.sep_conv1 = nn.Conv2d(32, 16, kernel_size=3, padding=1, groups=16)
            self.sep_conv2 = nn.Conv2d(16, 1, kernel_size=3, padding=1, groups=1)
            self.final_conv = nn.Conv2d(1, 1, kernel_size=3, padding=1)
        
    def forward(self, x):
        x = F.relu(self.fc(x))
        x = F.relu(self.sep_conv1(x))
        x = F.relu(self.sep_conv2(x))
        return self.final_conv(x)

# Complete Model
class FullModel(nn.Module):
    def __init__(self, input_type):
        super().__init__()
        self.feature_extractor = FeatureExtractor(input_type)
        self.search_network = SearchNetwork(32)
        self.decoder = Decoder(input_type)
    
    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.search_network(x)
        return self.decoder(x)

# Training Function
def train(model, train_loader, arch_optimizer, model_optimizer, criterion, unroll_steps=1):
    model.train()
    for real_in, imag_in, real_out, imag_out in train_loader:
        real_in, imag_in = real_in.unsqueeze(1), imag_in.unsqueeze(1)  # Add channel dimension
        real_out, imag_out = real_out.unsqueeze(1), imag_out.unsqueeze(1)
        real_in, imag_in, real_out, imag_out = real_in.to(device), imag_in.to(device), real_out.to(device), imag_out.to(device)
        
        model_optimizer.zero_grad()
        real_pred, imag_pred = model(real_in), model(imag_in)
        loss = criterion(real_pred, real_out) + criterion(imag_pred, imag_out)
        loss.backward()
        model_optimizer.step()
        
        arch_optimizer.zero_grad()
        arch_optimizer.step()

# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 20

# Initialize model
model = FullModel('vector').to(device)
arch_optimizer = optim.Adam(model.parameters(), lr=0.003)
model_optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Train model
for epoch in range(num_epochs):
    train(model, train_loader, arch_optimizer, model_optimizer, criterion)
    print(f"Epoch {epoch+1}/{num_epochs} completed.")

torch.save(model.state_dict(), "best_model.pth")
print("Training completed.")

with the DAG denoise cells with 4 nodes in each cell and the sequence of denoise cells have 10 cells.

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Generate synthetic dataset with complex values
num_samples = 10000
n, d = 150, 512  # Dimensions

y_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)
psi_complex = torch.randn(num_samples, n, d, dtype=torch.cfloat)
omega_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)

# Compute Least Squares Estimate h_LS
h_LS_complex = torch.linalg.pinv(psi_complex) @ (y_complex - omega_complex)
h_LS_complex = h_LS_complex.squeeze(-1)

# Split into real and imaginary components
h_LS_real = h_LS_complex.real
h_LS_imag = h_LS_complex.imag

# Prepare dataset
train_dataset = TensorDataset(h_LS_real, h_LS_imag, h_LS_real, h_LS_imag)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Feature Extraction Layers
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x

# NAS Candidate Operations
OPS = {
    'conv_3x3': lambda C: nn.Conv1d(C, C, kernel_size=3, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv1d(C, C, kernel_size=5, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(nn.Conv1d(C, C, 1, bias=False), nn.BatchNorm1d(C)),
    'zero': lambda C: nn.ZeroPad1d(0),
}

# Denoise Cell
class DenoiseCell(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.ops = nn.ModuleList([op(C) for op in OPS.values()])
        self.alphas = nn.Parameter(torch.randn(len(self.ops), 3))  # 3 edges per cell

    def forward(self, inputs):
        assert len(inputs) == 2, "Each denoise cell must take two inputs."
        node_outputs = [inputs[0], inputs[1], torch.zeros_like(inputs[0]), torch.zeros_like(inputs[0])]
        
        for i in range(3):  # Three edges per cell
            weights = F.softmax(self.alphas[:, i], dim=0)
            node_outputs[i + 1] = sum(w * op(node_outputs[i]) for w, op in zip(weights, self.ops))
        
        return node_outputs[-1]

# Sequence of 10 Denoise Cells
class DenoiseModule(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.cells = nn.ModuleList([DenoiseCell(C) for _ in range(10)])

    def forward(self, x):
        out1, out2 = x, x  # First two cells take same input
        outputs = [out1, out2]
        
        for i in range(10):
            out = self.cells[i]([outputs[-2], outputs[-1]])
            outputs.append(out)
        
        return outputs[-1]

# Decoder Module
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(512, 512)
        self.sep_conv1 = nn.Conv1d(32, 16, kernel_size=3, padding=1, groups=16)
        self.sep_conv2 = nn.Conv1d(16, 1, kernel_size=3, padding=1, groups=1)
        self.final_conv = nn.Conv1d(1, 1, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = F.relu(self.fc(x))
        x = F.relu(self.sep_conv1(x))
        x = F.relu(self.sep_conv2(x))
        return self.final_conv(x)

# Complete Model
class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.denoise_module = DenoiseModule(32)
        self.decoder = Decoder()
    
    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.denoise_module(x)
        return self.decoder(x)

# Training Function
def train(model, train_loader, arch_optimizer, model_optimizer, criterion):
    model.train()
    for real_in, imag_in, real_out, imag_out in train_loader:
        real_in, imag_in = real_in.unsqueeze(1), imag_in.unsqueeze(1)  # Add channel dimension
        real_out, imag_out = real_out.unsqueeze(1), imag_out.unsqueeze(1)
        real_in, imag_in, real_out, imag_out = real_in.to(device), imag_in.to(device), real_out.to(device), imag_out.to(device)
        
        model_optimizer.zero_grad()
        real_pred, imag_pred = model(real_in), model(imag_in)
        loss = criterion(real_pred, real_out) + criterion(imag_pred, imag_out)
        loss.backward()
        model_optimizer.step()
        
        arch_optimizer.zero_grad()
        arch_optimizer.step()

# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 20

# Initialize model
model = FullModel().to(device)
arch_optimizer = optim.Adam(model.parameters(), lr=0.003)
model_optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Train model
for epoch in range(num_epochs):
    train(model, train_loader, arch_optimizer, model_optimizer, criterion)
    print(f"Epoch {epoch+1}/{num_epochs} completed.")

torch.save(model.state_dict(), "best_model.pth")
print("Training completed.")

with 6 edges in each denoise cell

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Generate synthetic dataset with complex values
num_samples = 10000
n, d = 150, 512  # Dimensions

y_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)
psi_complex = torch.randn(num_samples, n, d, dtype=torch.cfloat)
omega_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)

# Compute Least Squares Estimate h_LS
h_LS_complex = torch.linalg.pinv(psi_complex) @ (y_complex - omega_complex)
h_LS_complex = h_LS_complex.squeeze(-1)

# Split into real and imaginary components
h_LS_real = h_LS_complex.real
h_LS_imag = h_LS_complex.imag

# Prepare dataset
train_dataset = TensorDataset(h_LS_real, h_LS_imag, h_LS_real, h_LS_imag)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Feature Extraction Layers
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x

# NAS Candidate Operations
OPS = {
    'conv_3x3': lambda C: nn.Conv1d(C, C, kernel_size=3, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv1d(C, C, kernel_size=5, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(nn.Conv1d(C, C, 1, bias=False), nn.BatchNorm1d(C)),
    'zero': lambda C: nn.ZeroPad1d(0),
}

# Denoise Cell with DAG Structure
class DenoiseCell(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.ops = nn.ModuleList([op(C) for op in OPS.values()])
        self.alphas = nn.Parameter(torch.randn(len(self.ops), 6))  # 6 edges in DAG

    def forward(self, inputs):
        assert len(inputs) == 2, "Each denoise cell must take two inputs."
        node_outputs = [inputs[0], inputs[1], torch.zeros_like(inputs[0]), torch.zeros_like(inputs[0])]
        
        edges = [(0, 2), (0, 3), (0, 4), (2, 3), (2, 4), (3, 4)]
        for edge_idx, (src, dest) in enumerate(edges):
            weights = F.softmax(self.alphas[:, edge_idx], dim=0)
            node_outputs[dest] += sum(w * op(node_outputs[src]) for w, op in zip(weights, self.ops))
        
        return node_outputs[4]  # Output of last node

# Sequence of 10 Denoise Cells
class DenoiseModule(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.cells = nn.ModuleList([DenoiseCell(C) for _ in range(10)])

    def forward(self, x):
        out1, out2 = x, x
        outputs = [out1, out2]
        
        for i in range(10):
            out = self.cells[i]([outputs[-2], outputs[-1]])
            outputs.append(out)
        
        return outputs[-1]

# Decoder Module
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(512, 512)
        self.sep_conv1 = nn.Conv1d(32, 16, kernel_size=3, padding=1, groups=16)
        self.sep_conv2 = nn.Conv1d(16, 1, kernel_size=3, padding=1, groups=1)
        self.final_conv = nn.Conv1d(1, 1, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = F.relu(self.fc(x))
        x = F.relu(self.sep_conv1(x))
        x = F.relu(self.sep_conv2(x))
        return self.final_conv(x)

# Complete Model
class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.denoise_module = DenoiseModule(32)
        self.decoder = Decoder()
    
    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.denoise_module(x)
        return self.decoder(x)

# Training Function
def train(model, train_loader, arch_optimizer, model_optimizer, criterion):
    model.train()
    for real_in, imag_in, real_out, imag_out in train_loader:
        real_in, imag_in = real_in.unsqueeze(1), imag_in.unsqueeze(1)
        real_out, imag_out = real_out.unsqueeze(1), imag_out.unsqueeze(1)
        real_in, imag_in, real_out, imag_out = real_in.to(device), imag_in.to(device), real_out.to(device), imag_out.to(device)
        
        model_optimizer.zero_grad()
        real_pred, imag_pred = model(real_in), model(imag_in)
        loss = criterion(real_pred, real_out) + criterion(imag_pred, imag_out)
        loss.backward()
        model_optimizer.step()
        
        arch_optimizer.zero_grad()
        arch_optimizer.step()

# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 20

# Initialize model
model = FullModel().to(device)
arch_optimizer = optim.Adam(model.parameters(), lr=0.003)
model_optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Train model
for epoch in range(num_epochs):
    train(model, train_loader, arch_optimizer, model_optimizer, criterion)
    print(f"Epoch {epoch+1}/{num_epochs} completed.")

torch.save(model.state_dict(), "best_model.pth")
print("Training completed.")

with base 2 denoise cells sorted

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Generate synthetic dataset with complex values
num_samples = 10000
n, d = 150, 512  # Dimensions

y_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)
psi_complex = torch.randn(num_samples, n, d, dtype=torch.cfloat)
omega_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)

# Compute Least Squares Estimate h_LS
h_LS_complex = torch.linalg.pinv(psi_complex) @ (y_complex - omega_complex)
h_LS_complex = h_LS_complex.squeeze(-1)

# Split into real and imaginary components
h_LS_real = h_LS_complex.real
h_LS_imag = h_LS_complex.imag

# Prepare dataset
train_dataset = TensorDataset(h_LS_real, h_LS_imag, h_LS_real, h_LS_imag)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Feature Extraction Layers
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x

# NAS Candidate Operations
OPS = {
    'conv_3x3': lambda C: nn.Conv1d(C, C, kernel_size=3, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv1d(C, C, kernel_size=5, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(nn.Conv1d(C, C, 1, bias=False), nn.BatchNorm1d(C)),
    'zero': lambda C: nn.ZeroPad1d(0),
}

# Denoise Cell with DAG Structure
class DenoiseCell(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.ops = nn.ModuleList([op(C) for op in OPS.values()])
        self.alphas = nn.Parameter(torch.randn(len(self.ops), 6))  # 6 edges in DAG

    def forward(self, inputs):
        assert len(inputs) == 2, "Each denoise cell must take two inputs."
        node_outputs = [inputs[0], inputs[1], torch.zeros_like(inputs[0]), torch.zeros_like(inputs[0])]
        
        edges = [(0, 2), (0, 3), (0, 4), (2, 3), (2, 4), (3, 4)]
        for edge_idx, (src, dest) in enumerate(edges):
            weights = F.softmax(self.alphas[:, edge_idx], dim=0)
            node_outputs[dest] += sum(w * op(node_outputs[src]) for w, op in zip(weights, self.ops))
        
        return node_outputs[4]  # Output of last node

# Sequence of 10 Denoise Cells
class DenoiseModule(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.cells = nn.ModuleList([DenoiseCell(C) for _ in range(10)])

    def forward(self, x):
        out1, out2 = x, x  # Initial input duplicated for first denoise cell
        outputs = [out1, out2]
        
        for i in range(10):
            if i == 0:
                # First cell: use feature extraction output twice
                out = self.cells[i]([out1, out1])
            elif i == 1:
                # Second cell: use output of first cell and original feature extraction output
                out = self.cells[i]([outputs[1], out1])
            else:
                # Remaining cells: use last 2 outputs
                out = self.cells[i]([outputs[-2], outputs[-1]])
            
            outputs.append(out)
        
        return outputs[-1]

# Decoder Module
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(512, 512)
        self.sep_conv1 = nn.Conv1d(32, 16, kernel_size=3, padding=1, groups=16)
        self.sep_conv2 = nn.Conv1d(16, 1, kernel_size=3, padding=1, groups=1)
        self.final_conv = nn.Conv1d(1, 1, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = F.relu(self.fc(x))
        x = F.relu(self.sep_conv1(x))
        x = F.relu(self.sep_conv2(x))
        return self.final_conv(x)

# Complete Model
class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.denoise_module = DenoiseModule(32)
        self.decoder = Decoder()
    
    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.denoise_module(x)
        return self.decoder(x)

# Training Function
def train(model, train_loader, arch_optimizer, model_optimizer, criterion):
    model.train()
    for real_in, imag_in, real_out, imag_out in train_loader:
        real_in, imag_in = real_in.unsqueeze(1), imag_in.unsqueeze(1)
        real_out, imag_out = real_out.unsqueeze(1), imag_out.unsqueeze(1)
        real_in, imag_in, real_out, imag_out = real_in.to(device), imag_in.to(device), real_out.to(device), imag_out.to(device)
        
        model_optimizer.zero_grad()
        real_pred, imag_pred = model(real_in), model(imag_in)
        loss = criterion(real_pred, real_out) + criterion(imag_pred, imag_out)
        loss.backward()
        model_optimizer.step()
        
        arch_optimizer.zero_grad()
        arch_optimizer.step()

# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 20

# Initialize model
model = FullModel().to(device)
arch_optimizer = optim.Adam(model.parameters(), lr=0.003)
model_optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Train model
for epoch in range(num_epochs):
    train(model, train_loader, arch_optimizer, model_optimizer, criterion)
    print(f"Epoch {epoch+1}/{num_epochs} completed.")

torch.save(model.state_dict(), "best_model.pth")
print("Training completed.")

with aggregation done in cases where there at mutiple inputs in the nodes

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Generate synthetic dataset with complex values
num_samples = 10000
n, d = 150, 512  # Dimensions

y_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)
psi_complex = torch.randn(num_samples, n, d, dtype=torch.cfloat)
omega_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)

# Compute Least Squares Estimate h_LS
h_LS_complex = torch.linalg.pinv(psi_complex) @ (y_complex - omega_complex)
h_LS_complex = h_LS_complex.squeeze(-1)

# Split into real and imaginary components
h_LS_real = h_LS_complex.real
h_LS_imag = h_LS_complex.imag

# Prepare dataset
train_dataset = TensorDataset(h_LS_real, h_LS_imag, h_LS_real, h_LS_imag)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Feature Extraction Layers
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x

# NAS Candidate Operations
OPS = {
    'conv_3x3': lambda C: nn.Conv1d(C, C, kernel_size=3, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv1d(C, C, kernel_size=5, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(nn.Conv1d(C, C, 1, bias=False), nn.BatchNorm1d(C)),
    'zero': lambda C: nn.ZeroPad1d(0),
}

# Denoise Cell with DAG Structure
class DenoiseCell(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.C = C
        self.ops = nn.ModuleList([op(C) for op in OPS.values()])
        self.alphas = nn.Parameter(torch.randn(len(self.ops), 6))  # 6 edges in DAG

        # 1x1 convolution to reduce dimensions after concatenation for node 3
        self.conv1x1_node3 = nn.Conv1d(3 * C, C, kernel_size=1, bias=False)

    def forward(self, inputs):
        assert len(inputs) == 2, "Each denoise cell must take two inputs."
        # Concatenation for nodes 0 and 2
        node0 = torch.cat(inputs, dim=1)
        node2 = torch.cat(inputs, dim=1)
        
        node_outputs = [
            F.relu(node0),  # Node 0: concatenation of 2 inputs
            F.relu(node2),  # Node 1: concatenation of 2 inputs
            torch.zeros_like(inputs[0]),  # Node 2
            torch.zeros_like(inputs[0]),  # Node 3
            torch.zeros_like(inputs[0]),  # Node 4
        ]
        
        edges = [(0, 2), (0, 3), (0, 4), (2, 3), (2, 4), (3, 4)]
        for edge_idx, (src, dest) in enumerate(edges):
            weights = F.softmax(self.alphas[:, edge_idx], dim=0)
            aggregated_output = sum(w * op(node_outputs[src]) for w, op in zip(weights, self.ops))

            if dest == 3:
                # Node 3: Concatenate 3 inputs and reduce dimensions with 1x1 convolution
                node_outputs[dest] = F.relu(self.conv1x1_node3(torch.cat([node_outputs[0], node_outputs[1], aggregated_output], dim=1)))
            elif dest == 2:
                node_outputs[dest] += aggregated_output
            else:
                node_outputs[dest] += aggregated_output
        
        return node_outputs[4]  # Output of last node

# Sequence of 10 Denoise Cells
class DenoiseModule(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.cells = nn.ModuleList([DenoiseCell(C) for _ in range(10)])

    def forward(self, x):
        out1, out2 = x, x  # Initial input duplicated for first denoise cell
        outputs = [out1, out2]
        
        for i in range(10):
            if i == 0:
                # First cell: use feature extraction output twice
                out = self.cells[i]([out1, out1])
            elif i == 1:
                # Second cell: use output of first cell and original feature extraction output
                out = self.cells[i]([outputs[1], out1])
            else:
                # Remaining cells: use last 2 outputs
                out = self.cells[i]([outputs[-2], outputs[-1]])
            
            outputs.append(out)
        
        return outputs[-1]

# Decoder Module
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(512, 512)
        self.sep_conv1 = nn.Conv1d(32, 16, kernel_size=3, padding=1, groups=16)
        self.sep_conv2 = nn.Conv1d(16, 1, kernel_size=3, padding=1, groups=1)
        self.final_conv = nn.Conv1d(1, 1, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = F.relu(self.fc(x))
        x = F.relu(self.sep_conv1(x))
        x = F.relu(self.sep_conv2(x))
        return self.final_conv(x)

# Complete Model
class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.denoise_module = DenoiseModule(32)
        self.decoder = Decoder()
    
    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.denoise_module(x)
        return self.decoder(x)

# Training Function
def train(model, train_loader, arch_optimizer, model_optimizer, criterion):
    model.train()
    for real_in, imag_in, real_out, imag_out in train_loader:
        print(f"real_in.shape = {real_in.shape}")
        print(f"imag_in.shape = {imag_in.shape}")
        real_in, imag_in = real_in.unsqueeze(1), imag_in.unsqueeze(1)
        real_out, imag_out = real_out.unsqueeze(1), imag_out.unsqueeze(1)
        real_in, imag_in, real_out, imag_out = real_in.to(device), imag_out.to(device), real_out.to(device), imag_out.to(device)
        
        model_optimizer.zero_grad()
        real_pred = model(real_in)
        imag_pred = model(imag_in)
        loss = criterion(real_pred, real_out) + criterion(imag_pred, imag_out)
        loss.backward()
        model_optimizer.step()
        
        arch_optimizer.zero_grad()
        arch_optimizer.step()

# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device = {device}")
num_epochs = 20

# Initialize model
model = FullModel().to(device)
arch_optimizer = optim.Adam(model.parameters(), lr=0.003)
model_optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Train model
for epoch in range(num_epochs):
    train(model, train_loader, arch_optimizer, model_optimizer, criterion)
    print(f"Epoch {epoch+1}/{num_epochs} completed.")

torch.save(model.state_dict(), "best_model.pth")
print("Training completed.")

TRIAL

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Generate synthetic dataset with complex values
num_samples = 10000
n, d = 150, 512  # Dimensions

# Generate test samples
test_samples = 2000
train_samples = num_samples - test_samples

# Create full dataset
y_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)
psi_complex = torch.randn(num_samples, n, d, dtype=torch.cfloat)
omega_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)

# Compute Least Squares Estimate h_LS
h_LS_complex = torch.linalg.pinv(psi_complex) @ (y_complex - omega_complex)
h_LS_complex = h_LS_complex.squeeze(-1)

# Split into real and imaginary components
h_LS_real = h_LS_complex.real
h_LS_imag = h_LS_complex.imag

# Split into training and testing sets
train_real, test_real = h_LS_real[:train_samples], h_LS_real[train_samples:]
train_imag, test_imag = h_LS_imag[:train_samples], h_LS_imag[train_samples:]

# Prepare datasets
train_dataset = TensorDataset(train_real, train_imag, train_real, train_imag)
test_dataset = TensorDataset(test_real, test_imag, test_real, test_imag)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Feature Extraction Layers
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x

# NAS Candidate Operations
OPS = {
    'conv_3x3': lambda C: nn.Conv1d(C, C, kernel_size=3, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv1d(C, C, kernel_size=5, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(nn.Conv1d(C, C, 1, bias=False), nn.BatchNorm1d(C)),
    'zero': lambda C: nn.ZeroPad1d(0),
}

# Denoise Cell with DAG Structure
class DenoiseCell(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.C = C
        self.ops = nn.ModuleList([op(C) for op in OPS.values()])
        self.alphas = nn.Parameter(torch.randn(len(self.ops), 6))
        self.conv1x1_node3 = nn.Conv1d(8 * C, C, kernel_size=1, bias=False)

    def pad_and_concat(self, inputs):
        max_size = max(inp.shape[2] for inp in inputs)
        padded_inputs = [
            F.pad(inp, (0, max_size - inp.shape[2])) if inp.shape[2] < max_size else inp
            for inp in inputs
        ]
        return torch.cat(padded_inputs, dim=1)

    def forward(self, inputs):
        assert len(inputs) == 2, "Each denoise cell must take two inputs."
        
        node_outputs = [None] * 4
        node_outputs[0] = F.relu(self.pad_and_concat(inputs))
        node_outputs[1] = F.relu(node_outputs[0])
        node_outputs[2] = F.relu(self.pad_and_concat([node_outputs[0], node_outputs[1]]))
        concat_features = self.pad_and_concat([node_outputs[0], node_outputs[1], node_outputs[2]])
        node_outputs[3] = F.relu(self.conv1x1_node3(concat_features))
        return node_outputs[3]

# Sequence of 10 Denoise Cells
class DenoiseModule(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.cells = nn.ModuleList([DenoiseCell(C) for _ in range(10)])

    def forward(self, x):
        outputs = [x, x]
        for i in range(10):
            out = self.cells[i]([outputs[0], outputs[0]] if i == 0 else [outputs[-2], outputs[-1]])
            outputs.append(out)
        return outputs[-1]

# Decoder Module
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(32, 512)
        self.sep_conv1 = nn.Conv1d(512, 16, kernel_size=3, padding=1)
        self.sep_conv2 = nn.Conv1d(16, 1, kernel_size=3, padding=1)
        self.final_conv = nn.Conv1d(1, 1, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = F.relu(self.fc(x))
        x = x.permute(0, 2, 1)
        x = F.relu(self.sep_conv1(x))
        x = F.relu(self.sep_conv2(x))
        return self.final_conv(x).squeeze(1)

# Complete Model
class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.denoise_module = DenoiseModule(32)
        self.decoder = Decoder()
    
    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.denoise_module(x)
        return self.decoder(x)

# Evaluation Function
def evaluate(model, test_loader, criterion):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for real_in, imag_in, real_out, imag_out in test_loader:
            real_in = real_in.unsqueeze(1).to(device)
            imag_in = imag_in.unsqueeze(1).to(device)
            real_out = real_out.to(device)
            imag_out = imag_out.to(device)
            
            real_pred = model(real_in)
            imag_pred = model(imag_in)
            
            loss = criterion(real_pred, real_out) + criterion(imag_pred, imag_out)
            total_loss += loss.item()
    
    avg_loss = total_loss / len(test_loader)
    print(f"\nTest Loss: {avg_loss:.4f}")
    return avg_loss

# Training Function
def train(model, train_loader, test_loader, arch_optimizer, model_optimizer, criterion):
    best_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        for real_in, imag_in, real_out, imag_out in train_loader:
            real_in = real_in.unsqueeze(1).to(device)
            imag_in = imag_in.unsqueeze(1).to(device)
            real_out = real_out.to(device)
            imag_out = imag_out.to(device)
            
            model_optimizer.zero_grad()
            
            real_pred = model(real_in)
            imag_pred = model(imag_in)
            
            loss = criterion(real_pred, real_out) + criterion(imag_pred, imag_out)
            loss.backward()
            model_optimizer.step()
            arch_optimizer.step()
            arch_optimizer.zero_grad()
        
        # Evaluate after each epoch
        current_loss = evaluate(model, test_loader, criterion)
        if current_loss < best_loss:
            best_loss = current_loss
            torch.save(model.state_dict(), "best_model.pth")
            print("Saved new best model")
        
        print(f"Epoch {epoch+1}/{num_epochs} completed.")

# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device = {device}")
num_epochs = 20

# Initialize model
model = FullModel().to(device)
arch_optimizer = optim.Adam(model.parameters(), lr=0.003)
model_optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Train and evaluate
train(model, train_loader, test_loader, arch_optimizer, model_optimizer, criterion)

# Load best model and final evaluation
model.load_state_dict(torch.load("best_model.pth"))
print("\nFinal Evaluation with Best Model:")
final_loss = evaluate(model, test_loader, criterion)
print(f"Best Model Test Loss: {final_loss:.4f}")

print("Training and Evaluation completed.")

WITH EARLY STOPPING

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Generate synthetic dataset with complex values
num_samples = 10000
n, d = 150, 512  # Dimensions

# Generate test samples
test_samples = 2000
train_samples = num_samples - test_samples

# Create full dataset
y_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)
psi_complex = torch.randn(num_samples, n, d, dtype=torch.cfloat)
omega_complex = torch.randn(num_samples, n, 1, dtype=torch.cfloat)

# Compute Least Squares Estimate h_LS
h_LS_complex = torch.linalg.pinv(psi_complex) @ (y_complex - omega_complex)
h_LS_complex = h_LS_complex.squeeze(-1)

# Split into real and imaginary components
h_LS_real = h_LS_complex.real
h_LS_imag = h_LS_complex.imag

# Split into training and testing sets
train_real, test_real = h_LS_real[:train_samples], h_LS_real[train_samples:]
train_imag, test_imag = h_LS_imag[:train_samples], h_LS_imag[train_samples:]

# Prepare datasets
train_dataset = TensorDataset(train_real, train_imag, train_real, train_imag)
test_dataset = TensorDataset(test_real, test_imag, test_real, test_imag)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Feature Extraction Layers
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x

# NAS Candidate Operations
OPS = {
    'conv_3x3': lambda C: nn.Conv1d(C, C, kernel_size=3, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv1d(C, C, kernel_size=5, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(nn.Conv1d(C, C, 1, bias=False), nn.BatchNorm1d(C)),
    'zero': lambda C: nn.ZeroPad1d(0),
}

# Denoise Cell with DAG Structure
class DenoiseCell(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.C = C
        self.ops = nn.ModuleList([op(C) for op in OPS.values()])
        self.alphas = nn.Parameter(torch.randn(len(self.ops), 6))
        self.conv1x1_node3 = nn.Conv1d(8 * C, C, kernel_size=1, bias=False)

    def pad_and_concat(self, inputs):
        max_size = max(inp.shape[2] for inp in inputs)
        padded_inputs = [
            F.pad(inp, (0, max_size - inp.shape[2])) if inp.shape[2] < max_size else inp
            for inp in inputs
        ]
        return torch.cat(padded_inputs, dim=1)

    def forward(self, inputs):
        assert len(inputs) == 2, "Each denoise cell must take two inputs."
        
        node_outputs = [None] * 4
        node_outputs[0] = F.relu(self.pad_and_concat(inputs))
        node_outputs[1] = F.relu(node_outputs[0])
        node_outputs[2] = F.relu(self.pad_and_concat([node_outputs[0], node_outputs[1]]))
        concat_features = self.pad_and_concat([node_outputs[0], node_outputs[1], node_outputs[2]])
        node_outputs[3] = F.relu(self.conv1x1_node3(concat_features))
        return node_outputs[3]

# Sequence of 10 Denoise Cells
class DenoiseModule(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.cells = nn.ModuleList([DenoiseCell(C) for _ in range(10)])

    def forward(self, x):
        outputs = [x, x]
        for i in range(10):
            out = self.cells[i]([outputs[0], outputs[0]] if i == 0 else [outputs[-2], outputs[-1]])
            outputs.append(out)
        return outputs[-1]

# Decoder Module
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(32, 512)
        self.sep_conv1 = nn.Conv1d(512, 16, kernel_size=3, padding=1)
        self.sep_conv2 = nn.Conv1d(16, 1, kernel_size=3, padding=1)
        self.final_conv = nn.Conv1d(1, 1, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = F.relu(self.fc(x))
        x = x.permute(0, 2, 1)
        x = F.relu(self.sep_conv1(x))
        x = F.relu(self.sep_conv2(x))
        return self.final_conv(x).squeeze(1)

# Complete Model
class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.denoise_module = DenoiseModule(32)
        self.decoder = Decoder()
    
    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.denoise_module(x)
        return self.decoder(x)

# Evaluation Function
def evaluate(model, test_loader, criterion):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for real_in, imag_in, real_out, imag_out in test_loader:
            real_in = real_in.unsqueeze(1).to(device)
            imag_in = imag_in.unsqueeze(1).to(device)
            real_out = real_out.to(device)
            imag_out = imag_out.to(device)
            
            real_pred = model(real_in)
            imag_pred = model(imag_in)
            
            loss = criterion(real_pred, real_out) + criterion(imag_pred, imag_out)
            total_loss += loss.item()
    
    avg_loss = total_loss / len(test_loader)
    return avg_loss

# Training Function with Early Stopping
def train(model, train_loader, test_loader, arch_optimizer, model_optimizer, criterion):
    best_loss = float('inf')
    epochs_no_improve = 0
    patience = 10  # Number of epochs to wait after last improvement
    
    for epoch in range(num_epochs):
        model.train()
        for real_in, imag_in, real_out, imag_out in train_loader:
            real_in = real_in.unsqueeze(1).to(device)
            imag_in = imag_in.unsqueeze(1).to(device)
            real_out = real_out.to(device)
            imag_out = imag_out.to(device)
            
            model_optimizer.zero_grad()
            
            real_pred = model(real_in)
            imag_pred = model(imag_in)
            
            loss = criterion(real_pred, real_out) + criterion(imag_pred, imag_out)
            loss.backward()
            model_optimizer.step()
            arch_optimizer.step()
            arch_optimizer.zero_grad()
        
        # Evaluate after each epoch
        current_loss = evaluate(model, test_loader, criterion)
        print(f"\nEpoch {epoch+1}/{num_epochs} - Test Loss: {current_loss:.4f}")
        
        # Early stopping check
        if current_loss < best_loss:
            best_loss = current_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), "best_model.pth")
            print("Saved new best model")
        else:
            epochs_no_improve += 1
            print(f"No improvement for {epochs_no_improve}/{patience} epochs")
            if epochs_no_improve >= patience:
                print(f"\nEarly stopping triggered after {epoch+1} epochs!")
                break

# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device = {device}")
num_epochs = 1000

# Initialize model
model = FullModel().to(device)
arch_optimizer = optim.Adam(model.parameters(), lr=0.003)
model_optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Train and evaluate
train(model, train_loader, test_loader, arch_optimizer, model_optimizer, criterion)

# Load best model and final evaluation
model.load_state_dict(torch.load("best_model.pth"))
print("\nFinal Evaluation with Best Model:")
final_loss = evaluate(model, test_loader, criterion)
print(f"Best Model Test Loss: {final_loss:.4f}")

print("Training and Evaluation completed.")

ACTUAL DATA INSTEAD OF SYNTHETIC DATA

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import h5py
from scipy.io import loadmat

# ============ GPU Setup ============
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============ Load .mat Files with Fallback ============
def load_mat_file(file_path, var_name):
    try:
        with h5py.File(file_path, 'r') as f:
            data = f[var_name][:]
        print(f"Loaded {file_path} using h5py.")
    except OSError:
        print(f"Failed to load {file_path} with h5py. Trying scipy.io.loadmat...")
        data = loadmat(file_path)[var_name]
        print(f"Loaded {file_path} using scipy.io.loadmat.")
    return data

# Load yDL
y_dl = load_mat_file('yDL_10dB_40k_150pilots_ipjp.mat', 'yDL')
y_complex = torch.tensor(y_dl[..., 0] + 1j * y_dl[..., 1], dtype=torch.cfloat)

# Load PsiDL
psi_dl = load_mat_file('PsiDL_10dB_40k_150pilots_ipjp.mat', 'PsiDL')
psi_real = psi_dl[0].transpose(2, 1, 0)
psi_imag = psi_dl[1].transpose(2, 1, 0)
psi_complex = torch.tensor(psi_real + 1j * psi_imag, dtype=torch.cfloat)

# Load hDL
h_dl = load_mat_file('hDL_10dB_40k_150pilots_ipjp.mat', 'hDL')
h_complex = torch.tensor(h_dl[..., 0] + 1j * h_dl[..., 1], dtype=torch.cfloat, device=device)

# Load sigma2DL
sigma2_dl = load_mat_file('sigma2DL_10dB_40k_150pilots_ipjp.mat', 'sigma2DL')
sigma2 = torch.tensor(sigma2_dl, dtype=torch.float32, device=device)

# ============ Compute LS Estimate in Batches ============
num_samples = 40000
batch_size = 1000
h_LS_complex = torch.empty((num_samples, 512), dtype=torch.cfloat)

for start in range(0, num_samples, batch_size):
    end = min(start + batch_size, num_samples)
    print(f"Processing batch {start//batch_size + 1}/{num_samples//batch_size + 1}")
    psi_batch = psi_complex[start:end].to(device)
    y_batch = y_complex[start:end].to(device)
    h_LS_batch = torch.linalg.lstsq(psi_batch, y_batch.unsqueeze(-1)).solution.squeeze(-1)
    h_LS_complex[start:end] = h_LS_batch.cpu()

# Split data into real/imaginary parts
h_LS_real, h_LS_imag = h_LS_complex.real, h_LS_complex.imag
h_real, h_imag = h_complex.real.cpu(), h_complex.imag.cpu()

# ============ Dataset Preparation ============
train_samples, val_samples, test_samples = 32000, 4000, 4000

train_dataset = TensorDataset(h_LS_real[:train_samples], h_LS_imag[:train_samples],
                              h_real[:train_samples], h_imag[:train_samples])
val_dataset = TensorDataset(h_LS_real[train_samples:train_samples+val_samples],
                            h_LS_imag[train_samples:train_samples+val_samples],
                            h_real[train_samples:train_samples+val_samples],
                            h_imag[train_samples:train_samples+val_samples])
test_dataset = TensorDataset(h_LS_real[train_samples+val_samples:],
                             h_LS_imag[train_samples+val_samples:],
                             h_real[train_samples+val_samples:],
                             h_imag[train_samples+val_samples:])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# ============ Neural Architecture Components ============
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(2, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x

OPS = {
    'conv_3x3': lambda C: nn.Conv1d(C, C, 3, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv1d(C, C, 5, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(
        nn.Conv1d(C, C, 1, bias=False), nn.BatchNorm1d(C)),
    'zero': lambda C: nn.ZeroPad1d(0)
}

class DenoiseCell(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.C = C
        self.num_edges = 8  # Correct number of edges for the DAG
        self.num_ops = len(OPS)
        
        # Architecture parameters
        self.alphas = nn.Parameter(torch.randn(self.num_edges, self.num_ops))
        
        # Operations for each edge
        self.ops = nn.ModuleList()
        for _ in range(self.num_edges):
            self.ops.append(nn.ModuleList([op(C) for op in OPS.values()]))
            
        self.conv1x1 = nn.Conv1d(3*C, C, 1, bias=False)

    def apply_ops(self, x, edge_idx):
        weights = F.softmax(self.alphas[edge_idx], dim=-1)
        return sum(w * op(x) for w, op in zip(weights, self.ops[edge_idx]))

    def forward(self, inputs):
        # Input processing
        in0, in1 = inputs
        
        # Edge 0: input0 -> node0
        # Edge 1: input1 -> node0
        node0 = F.relu(self.apply_ops(in0, 0) + self.apply_ops(in1, 1))
        
        # Edge 2: node0 -> node1
        node1 = F.relu(self.apply_ops(node0, 2))
        
        # Edge 3: node0 -> node2
        # Edge 4: node1 -> node2
        node2 = F.relu(self.apply_ops(node0, 3) + self.apply_ops(node1, 4))
        
        # Edge 5: node0 -> node3
        # Edge 6: node1 -> node3
        # Edge 7: node2 -> node3
        out = torch.cat([
            self.apply_ops(node0, 5),
            self.apply_ops(node1, 6),
            self.apply_ops(node2, 7)
        ], dim=1)
        
        return F.relu(self.conv1x1(out))

class DenoiseModule(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.cells = nn.ModuleList([DenoiseCell(C) for _ in range(10)])

    def forward(self, x):
        states = [x, x]
        for cell in self.cells:
            states.append(cell([states[-2], states[-1]]))
        return states[-1]

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(32, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 2, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return self.conv2(x)

class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.denoiser = DenoiseModule(32)
        self.decoder = Decoder()

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.denoiser(x)
        return self.decoder(x)

# ============ Training Infrastructure ============
def train(model, train_loader, val_loader, optimizer, criterion, epochs=20):
    best_loss = float('inf')
    for epoch in range(epochs):
        model.train()
        for real_in, imag_in, real_tar, imag_tar in train_loader:
            inputs = torch.stack([real_in, imag_in], 1).to(device)
            targets = torch.stack([real_tar, imag_tar], 1).to(device)
            
            optimizer.zero_grad()
            preds = model(inputs)
            loss = criterion(preds[:,0], targets[:,0]) + criterion(preds[:,1], targets[:,1])
            loss.backward()
            optimizer.step()
        
        val_loss = evaluate(model, val_loader, criterion)
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print(f'Epoch {epoch+1}: New best validation loss {val_loss:.4f}')
        else:
            print(f'Epoch {epoch+1}: Validation loss {val_loss:.4f}')
    return best_loss

def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for real_in, imag_in, real_tar, imag_tar in loader:
            inputs = torch.stack([real_in, imag_in], 1).to(device)
            targets = torch.stack([real_tar, imag_tar], 1).to(device)
            preds = model(inputs)
            loss = criterion(preds[:,0], targets[:,0]) + criterion(preds[:,1], targets[:,1])
            total_loss += loss.item()
    return total_loss / len(loader)

# ============ Main Execution ============
if __name__ == "__main__":
    model = FullModel().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.003)
    criterion = nn.MSELoss()

    print("Starting training...")
    best_val = train(model, train_loader, val_loader, optimizer, criterion, 20)
    print(f"Best validation loss: {best_val:.4f}")

    model.load_state_dict(torch.load('best_model.pth'))
    test_loss = evaluate(model, test_loader, criterion)
    print(f"Final test loss: {test_loss:.4f}")

To know what is inside yDL_10dB_40k_150pilots_ipjp.mat file.

In [None]:
import scipy.io

def print_mat_contents(file_path):
    """
    Load and print the contents of a .mat file.
    
    :param file_path: Path to the .mat file
    """
    mat_data = scipy.io.loadmat(file_path)
    
    print(f"Contents of {file_path}:\n")
    for key, value in mat_data.items():
        if not key.startswith("__"):  # Ignore metadata entries
            print(f"{key}: {value.shape} \n{value}\n")

# Example usage
file_path = "yDL_10dB_40k_150pilots_ipjp.mat"  # Replace with the actual path
print_mat_contents(file_path)

To know what is inside hDL_10dB_40k_150pilots_ipjp.mat file.

In [None]:
import scipy.io

def print_mat_contents(file_path):
    """
    Load and print the contents of a .mat file.
    
    :param file_path: Path to the .mat file
    """
    mat_data = scipy.io.loadmat(file_path)
    
    print(f"Contents of {file_path}:\n")
    for key, value in mat_data.items():
        if not key.startswith("__"):  # Ignore metadata entries
            print(f"{key}: {value.shape} \n{value}\n")

# Example usage
file_path = "hDL_10dB_40k_150pilots_ipjp.mat"  # Replace with the actual path
print_mat_contents(file_path)

To know what is PsiDL_10dB_40k_150pilots_ipjp.mat file.

In [None]:
import h5py
import numpy as np

def print_mat_contents(file_path):
    """
    Load and print the contents of a .mat file (MATLAB v7.3 HDF5 format).
    
    :param file_path: Path to the .mat file
    """
    with h5py.File(file_path, 'r') as mat_data:
        print(f"Contents of {file_path}:\n")
        
        def print_dataset(name, obj):
            if isinstance(obj, h5py.Dataset):  # If it's a dataset, print its details
                data = obj[()]  # Load the dataset
                if isinstance(data, np.ndarray):
                    print(f"{name}: shape {data.shape}, dtype {data.dtype} \n{data}\n")
                else:
                    print(f"{name}: {data}\n")
            elif isinstance(obj, h5py.Group):  # If it's a group, just print its name
                print(f"{name}: Group")

        mat_data.visititems(print_dataset)

# Example usage
file_path = "PsiDL_10dB_40k_150pilots_ipjp.mat"  # Replace with the actual path
print_mat_contents(file_path)

To know what is sigma2DL_10dB_40k_150pilots_ipjp.mat file.

In [None]:
import scipy.io

def print_mat_contents(file_path):
    """
    Load and print the contents of a .mat file.
    
    :param file_path: Path to the .mat file
    """
    mat_data = scipy.io.loadmat(file_path)
    
    print(f"Contents of {file_path}:\n")
    for key, value in mat_data.items():
        if not key.startswith("__"):  # Ignore metadata entries
            print(f"{key}: {value.shape} \n{value}\n")

# Example usage
file_path = "sigma2DL_10dB_40k_150pilots_ipjp.mat"  # Replace with the actual path
print_mat_contents(file_path)

Best till now

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import h5py
from scipy.io import loadmat

# ============ GPU Setup ============
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============ Load .mat Files with Fallback ============
def load_mat_file(file_path, var_name):
    try:
        with h5py.File(file_path, 'r') as f:
            data = f[var_name][:]
        print(f"Loaded {file_path} using h5py.")
    except OSError:
        print(f"Failed to load {file_path} with h5py. Trying scipy.io.loadmat...")
        data = loadmat(file_path)[var_name]
        print(f"Loaded {file_path} using scipy.io.loadmat.")
    return data

# Load yDL
y_dl = load_mat_file('yDL_10dB_40k_150pilots_ipjp.mat', 'yDL')
y_complex = torch.tensor(y_dl[..., 0] + 1j * y_dl[..., 1], dtype=torch.cfloat)

# Load PsiDL
psi_dl = load_mat_file('PsiDL_10dB_40k_150pilots_ipjp.mat', 'PsiDL')
psi_real = psi_dl[0].transpose(2, 1, 0)
psi_imag = psi_dl[1].transpose(2, 1, 0)
psi_complex = torch.tensor(psi_real + 1j * psi_imag, dtype=torch.cfloat)

# Load hDL
h_dl = load_mat_file('hDL_10dB_40k_150pilots_ipjp.mat', 'hDL')
h_complex = torch.tensor(h_dl[..., 0] + 1j * h_dl[..., 1], dtype=torch.cfloat, device=device)

# Load sigma2DL
sigma2_dl = load_mat_file('sigma2DL_10dB_40k_150pilots_ipjp.mat', 'sigma2DL')
sigma2 = torch.tensor(sigma2_dl, dtype=torch.float32, device=device)

# ============ Compute LS Estimate in Batches ============
num_samples = 40000
batch_size = 1000
h_LS_complex = torch.empty((num_samples, 512), dtype=torch.cfloat)

for start in range(0, num_samples, batch_size):
    end = min(start + batch_size, num_samples)
    print(f"Processing batch {start//batch_size + 1}/{num_samples//batch_size + 1}")
    psi_batch = psi_complex[start:end].to(device)
    y_batch = y_complex[start:end].to(device)
    h_LS_batch = torch.linalg.lstsq(psi_batch, y_batch.unsqueeze(-1)).solution.squeeze(-1)
    h_LS_complex[start:end] = h_LS_batch.cpu()

# Split data into real/imaginary parts
h_LS_real, h_LS_imag = h_LS_complex.real, h_LS_complex.imag
h_real, h_imag = h_complex.real.cpu(), h_complex.imag.cpu()

# ============ Dataset Preparation ============
train_samples, val_samples, test_samples = 32000, 4000, 4000

train_dataset = TensorDataset(h_LS_real[:train_samples], h_LS_imag[:train_samples],
                              h_real[:train_samples], h_imag[:train_samples])
val_dataset = TensorDataset(h_LS_real[train_samples:train_samples+val_samples],
                            h_LS_imag[train_samples:train_samples+val_samples],
                            h_real[train_samples:train_samples+val_samples],
                            h_imag[train_samples:train_samples+val_samples])
test_dataset = TensorDataset(h_LS_real[train_samples+val_samples:],
                             h_LS_imag[train_samples+val_samples:],
                             h_real[train_samples+val_samples:],
                             h_imag[train_samples+val_samples:])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# ============ Neural Architecture Components ============
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(2, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x

OPS = {
    'conv_3x3': lambda C: nn.Conv1d(C, C, 3, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv1d(C, C, 5, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(
        nn.Conv1d(C, C, 1, bias=False), nn.BatchNorm1d(C)),
    'zero': lambda C: nn.ZeroPad1d(0)
}

class DenoiseCell(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.C = C
        self.num_edges = 8
        self.num_ops = len(OPS)
        
        # Architecture parameters
        self.alphas = nn.Parameter(torch.randn(self.num_edges, self.num_ops))
        
        # Operations for each edge
        self.ops = nn.ModuleList()
        for _ in range(self.num_edges):
            self.ops.append(nn.ModuleList([op(C) for op in OPS.values()]))
            
        self.conv1x1 = nn.Conv1d(3*C, C, 1, bias=False)

    def apply_ops(self, x, edge_idx):
        weights = F.softmax(self.alphas[edge_idx], dim=-1)
        return sum(w * op(x) for w, op in zip(weights, self.ops[edge_idx]))

    def pad_and_concat(self, inputs):
        max_size = max(inp.shape[2] for inp in inputs)
        padded_inputs = [F.pad(inp, (0, max_size - inp.shape[2])) if inp.shape[2] < max_size else inp for inp in inputs]
        return torch.cat(padded_inputs, dim=1)

    def forward(self, inputs):
        # Input processing with padding and concatenation
        in0, in1 = inputs
        padded_inputs = self.pad_and_concat([in0, in1])
        
        # Edge 0: input0 -> node0
        # Edge 1: input1 -> node0
        node0 = F.relu(self.apply_ops(in0, 0) + self.apply_ops(in1, 1))
        
        # Edge 2: node0 -> node1
        node1 = F.relu(self.apply_ops(node0, 2))
        
        # Edge 3: node0 -> node2
        # Edge 4: node1 -> node2
        node2 = F.relu(self.apply_ops(node0, 3) + self.apply_ops(node1, 4))
        
        # Edge 5: node0 -> node3
        # Edge 6: node1 -> node3
        # Edge 7: node2 -> node3
        out = torch.cat([
            self.apply_ops(node0, 5),
            self.apply_ops(node1, 6),
            self.apply_ops(node2, 7)
        ], dim=1)
        
        return F.relu(self.conv1x1(out))

class DenoiseModule(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.cells = nn.ModuleList([DenoiseCell(C) for _ in range(10)])
        self.base_feature = None  # Store initial convolutional features

    def forward(self, x):
        # Store initial features
        self.base_feature = x
        
        # First cell: use base feature twice
        out0 = self.cells[0]([x, x])
        
        # Second cell: use first cell output and base feature
        out1 = self.cells[1]([out0, self.base_feature])
        
        # Subsequent cells: use previous two outputs
        prev_outs = [out0, out1]
        for i in range(2, 10):
            new_out = self.cells[i]([prev_outs[-2], prev_outs[-1]])
            prev_outs.append(new_out)
            # Maintain only last two outputs
            prev_outs = prev_outs[1:] if len(prev_outs) > 2 else prev_outs
        
        return prev_outs[-1]

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(32, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 2, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return self.conv2(x)

class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.denoiser = DenoiseModule(32)
        self.decoder = Decoder()

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.denoiser(x)
        return self.decoder(x)

# ============ Training Infrastructure ============
def train(model, train_loader, val_loader, optimizer, criterion, epochs=20):
    best_loss = float('inf')
    for epoch in range(epochs):
        model.train()
        for real_in, imag_in, real_tar, imag_tar in train_loader:
            inputs = torch.stack([real_in, imag_in], 1).to(device)
            targets = torch.stack([real_tar, imag_tar], 1).to(device)
            
            optimizer.zero_grad()
            preds = model(inputs)
            loss = criterion(preds[:,0], targets[:,0]) + criterion(preds[:,1], targets[:,1])
            loss.backward()
            optimizer.step()
        
        val_loss = evaluate(model, val_loader, criterion)
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print(f'Epoch {epoch+1}: New best validation loss {val_loss:.4f}')
        else:
            print(f'Epoch {epoch+1}: Validation loss {val_loss:.4f}')
    return best_loss

def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for real_in, imag_in, real_tar, imag_tar in loader:
            inputs = torch.stack([real_in, imag_in], 1).to(device)
            targets = torch.stack([real_tar, imag_tar], 1).to(device)
            preds = model(inputs)
            loss = criterion(preds[:,0], targets[:,0]) + criterion(preds[:,1], targets[:,1])
            total_loss += loss.item()
    return total_loss / len(loader)

# ============ Main Execution ============
if __name__ == "__main__":
    model = FullModel().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.003)
    criterion = nn.MSELoss()

    print("Starting training...")
    best_val = train(model, train_loader, val_loader, optimizer, criterion, 20)
    print(f"Best validation loss: {best_val:.4f}")

    model.load_state_dict(torch.load('best_model.pth'))
    test_loss = evaluate(model, test_loader, criterion)
    print(f"Final test loss: {test_loss:.4f}")

Corrected least square estimate
Truncated reverse-mode automatic differentiation applied (actually it is not happening correctly)

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import h5py
from scipy.io import loadmat

# ============ GPU Setup ============
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============ Load .mat Files with Fallback ============
def load_mat_file(file_path, var_name):
    try:
        with h5py.File(file_path, 'r') as f:
            data = f[var_name][:]
        print(f"Loaded {file_path} using h5py.")
    except OSError:
        print(f"Failed to load {file_path} with h5py. Trying scipy.io.loadmat...")
        data = loadmat(file_path)[var_name]
        print(f"Loaded {file_path} using scipy.io.loadmat.")
    return data

# Load datasets
y_dl = load_mat_file('yDL_10dB_40k_150pilots_ipjp.mat', 'yDL')
psi_dl = load_mat_file('PsiDL_10dB_40k_150pilots_ipjp.mat', 'PsiDL')
h_dl = load_mat_file('hDL_10dB_40k_150pilots_ipjp.mat', 'hDL')
sigma2_dl = load_mat_file('sigma2DL_10dB_40k_150pilots_ipjp.mat', 'sigma2DL')

# ============ Data Preparation ============
# Convert to complex tensors
y_complex = torch.tensor(y_dl[..., 0] + 1j * y_dl[..., 1], dtype=torch.cfloat)
psi_complex = torch.tensor(psi_dl[0].transpose(2, 1, 0) + 1j * psi_dl[1].transpose(2, 1, 0), dtype=torch.cfloat)
h_complex = torch.tensor(h_dl[..., 0] + 1j * h_dl[..., 1], dtype=torch.cfloat, device=device)
sigma2 = torch.tensor(sigma2_dl, dtype=torch.float32, device=device)

# ============ Least Squares Calculation ============
num_samples = 40000
batch_size = 1000
h_LS_complex = torch.empty((num_samples, 512), dtype=torch.cfloat)

for start in range(0, num_samples, batch_size):
    end = min(start + batch_size, num_samples)
    print(f"Processing batch {start//batch_size + 1}/{(num_samples//batch_size)+1}")
    
    psi_batch = psi_complex[start:end].to(device)
    y_batch = y_complex[start:end].to(device)
    
    # Compute LS estimate using normal equations
    psi_H = torch.conj(psi_batch.transpose(-2, -1))
    A = torch.matmul(psi_H, psi_batch)
    B = torch.matmul(psi_H, y_batch.unsqueeze(-1))
    h_LS_batch = torch.linalg.solve(A, B).squeeze(-1)
    
    h_LS_complex[start:end] = h_LS_batch.cpu()

# Split into real/imaginary components
h_LS_real, h_LS_imag = h_LS_complex.real, h_LS_complex.imag
h_real, h_imag = h_complex.real.cpu(), h_complex.imag.cpu()

# ============ Dataset Preparation ============
train_samples, val_samples, test_samples = 32000, 4000, 4000

train_dataset = TensorDataset(h_LS_real[:train_samples], h_LS_imag[:train_samples],
                              h_real[:train_samples], h_imag[:train_samples])
val_dataset = TensorDataset(h_LS_real[train_samples:train_samples+val_samples],
                            h_LS_imag[train_samples:train_samples+val_samples],
                            h_real[train_samples:train_samples+val_samples],
                            h_imag[train_samples:train_samples+val_samples])
test_dataset = TensorDataset(h_LS_real[train_samples+val_samples:],
                             h_LS_imag[train_samples+val_samples:],
                             h_real[train_samples+val_samples:],
                             h_imag[train_samples+val_samples:])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# ============ Neural Architecture Components ============
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(2, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

OPS = {
    'conv_3x3': lambda C: nn.Conv1d(C, C, 3, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv1d(C, C, 5, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(
        nn.Conv1d(C, C, 1, bias=False), nn.BatchNorm1d(C)),
    'zero': lambda C: nn.ZeroPad1d(0)
}

class DenoiseCell(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.C = C
        self.num_edges = 8
        self.num_ops = len(OPS)
        
        self.alphas = nn.Parameter(torch.randn(self.num_edges, self.num_ops))
        self.ops = nn.ModuleList([nn.ModuleList([op(C) for op in OPS.values()]) 
                                for _ in range(self.num_edges)])
        self.conv1x1 = nn.Conv1d(3*C, C, 1, bias=False)

    def apply_ops(self, x, edge_idx):
        weights = F.softmax(self.alphas[edge_idx], dim=-1)
        return sum(w * op(x) for w, op in zip(weights, self.ops[edge_idx]))

    def forward(self, inputs):
        in0, in1 = inputs
        node0 = F.relu(self.apply_ops(in0, 0) + self.apply_ops(in1, 1))
        node1 = F.relu(self.apply_ops(node0, 2))
        node2 = F.relu(self.apply_ops(node0, 3) + self.apply_ops(node1, 4))
        out = torch.cat([self.apply_ops(node0, 5),
                        self.apply_ops(node1, 6),
                        self.apply_ops(node2, 7)], dim=1)
        return F.relu(self.conv1x1(out))

class DenoiseModule(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.cells = nn.ModuleList([DenoiseCell(C) for _ in range(10)])
        self.base_feature = None

    def forward(self, x):
        self.base_feature = x
        out0 = self.cells[0]([x, x])
        out1 = self.cells[1]([out0, self.base_feature])
        prev_outs = [out0, out1]
        for i in range(2, 10):
            new_out = self.cells[i]([prev_outs[-2], prev_outs[-1]])
            prev_outs = [prev_outs[-1], new_out]
        return new_out

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(32, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 2, 3, padding=1)

    def forward(self, x):
        return self.conv2(F.relu(self.conv1(x)))

class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.denoiser = DenoiseModule(32)
        self.decoder = Decoder()

    def forward(self, x):
        x = self.feature_extractor(x)
        return self.decoder(self.denoiser(x))

# ============ Training Infrastructure ============
def train_model(model, train_loader, optimizer, criterion, trunc_steps=50):
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    
    for i, (real_in, imag_in, real_tar, imag_tar) in enumerate(train_loader, 1):
        inputs = torch.stack([real_in, imag_in], 1).to(device)
        targets = torch.stack([real_tar, imag_tar], 1).to(device)
        
        preds = model(inputs)
        loss = criterion(preds[:,0], targets[:,0]) + criterion(preds[:,1], targets[:,1])
        loss = loss / trunc_steps  # Scale loss for gradient accumulation
        loss.backward()
        
        if i % trunc_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            total_loss += loss.item() * trunc_steps
    
    # Handle remaining steps
    if len(train_loader) % trunc_steps != 0:
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item() * (len(train_loader) % trunc_steps)
    
    return total_loss / len(train_loader)

def train_alphas(model, val_loader, optimizer, criterion):
    model.eval()
    total_loss = 0
    
    # Freeze model parameters, unfreeze alphas
    for param in model.parameters():
        param.requires_grad = False
    for cell in model.denoiser.cells:
        cell.alphas.requires_grad = True
    
    optimizer.zero_grad()
    for real_in, imag_in, real_tar, imag_tar in val_loader:
        inputs = torch.stack([real_in, imag_in], 1).to(device)
        targets = torch.stack([real_tar, imag_tar], 1).to(device)
        
        preds = model(inputs)
        loss = criterion(preds[:,0], targets[:,0]) + criterion(preds[:,1], targets[:,1])
        total_loss += loss.item()
        
        loss.backward()  # Accumulate gradients for alphas
    
    optimizer.step()
    
    # Restore parameter states
    for param in model.parameters():
        param.requires_grad = True
    
    return total_loss / len(val_loader)

# ============ Main Execution ============
if __name__ == "__main__":
    model = FullModel().to(device)
    criterion = nn.MSELoss()

    # Split parameters into model weights and alphas
    model_params = []
    alpha_params = []
    for name, param in model.named_parameters():
        if 'alphas' in name:
            alpha_params.append(param)
        else:
            model_params.append(param)

    model_optim = optim.Adam(model_params, lr=0.01)
    alpha_optim = optim.Adam(alpha_params, lr=0.003)

    print("Starting training...")
    for epoch in range(20):
        # Train model weights with truncated backprop
        train_loss = train_model(model, train_loader, model_optim, criterion, 50)
        
        # Train architecture parameters
        val_loss = train_alphas(model, val_loader, alpha_optim, criterion)
        
        print(f"Epoch {epoch+1}: "
              f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Final evaluation
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for real_in, imag_in, real_tar, imag_tar in test_loader:
            inputs = torch.stack([real_in, imag_in], 1).to(device)
            targets = torch.stack([real_tar, imag_tar], 1).to(device)
            preds = model(inputs)
            loss = criterion(preds[:,0], targets[:,0]) + criterion(preds[:,1], targets[:,1])
            total_loss += loss.item()
    
    print(f"Final Test Loss: {total_loss/len(test_loader):.4f}")

Truncated RAD removed

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import h5py
from scipy.io import loadmat

# ============ GPU Setup ============
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============ Load .mat Files with Fallback ============
def load_mat_file(file_path, var_name):
    try:
        with h5py.File(file_path, 'r') as f:
            data = f[var_name][:]
        print(f"Loaded {file_path} using h5py.")
    except OSError:
        print(f"Failed to load {file_path} with h5py. Trying scipy.io.loadmat...")
        data = loadmat(file_path)[var_name]
        print(f"Loaded {file_path} using scipy.io.loadmat.")
    return data

# ============ Load Data ============
# Load datasets (40,000 instances)
data_y = load_mat_file('yDL_10dB_40k_150pilots_ipjp.mat', 'yDL')          # shape: (40000, 150, 2)
data_psi = load_mat_file('PsiDL_10dB_40k_150pilots_ipjp.mat', 'PsiDL')      # shape: (2, 512, 150, 40000)
data_h = load_mat_file('hDL_10dB_40k_150pilots_ipjp.mat', 'hDL')            # shape: (40000, 512, 2)
data_sigma2 = load_mat_file('sigma2DL_10dB_40k_150pilots_ipjp.mat', 'sigma2DL')  # shape: (40000, 1)

# Convert yDL and hDL to complex tensors.
y_complex = torch.tensor(data_y[..., 0] + 1j * data_y[..., 1], dtype=torch.cfloat, device=device)
h_complex = torch.tensor(data_h[..., 0] + 1j * data_h[..., 1], dtype=torch.cfloat, device=device)

# For PsiDL, reshape from (2, 512, 150, 40000) to (40000, 150, 512)
psi_real = data_psi[0].transpose(2, 3, 0).squeeze(-1)  # shape: (40000, 150, 512)
psi_imag = data_psi[1].transpose(2, 3, 0).squeeze(-1)  # shape: (40000, 150, 512)
psi_complex = torch.tensor(psi_real + 1j * psi_imag, dtype=torch.cfloat, device=device)

# sigma2 is kept for reference.
sigma2 = torch.tensor(data_sigma2, dtype=torch.float32, device=device)

# ============ Compute LS Estimate Using Normal Equation ============
num_samples = 40000
batch_size = 1000
h_LS_complex = torch.empty((num_samples, 512), dtype=torch.cfloat)

for start in range(0, num_samples, batch_size):
    end = min(start + batch_size, num_samples)
    print(f"Processing LS batch {start//batch_size + 1}/{(num_samples//batch_size)}")
    psi_batch = psi_complex[start:end]  # shape: (batch, 150, 512)
    y_batch = y_complex[start:end]      # shape: (batch, 150)
    
    # Compute Hermitian transpose of psi: (batch, 512, 150)
    psi_H = torch.conj(psi_batch.transpose(-2, -1))
    A = torch.matmul(psi_H, psi_batch)                # (batch, 512, 512)
    B = torch.matmul(psi_H, y_batch.unsqueeze(-1))      # (batch, 512, 1)
    
    # Solve the normal equations A * h = B
    h_LS_batch = torch.linalg.solve(A, B).squeeze(-1)   # (batch, 512)
    h_LS_complex[start:end] = h_LS_batch.cpu()

# Split LS estimate and ground truth into real and imaginary parts (move to CPU)
h_LS_real = h_LS_complex.real  # (40000, 512)
h_LS_imag = h_LS_complex.imag  # (40000, 512)
h_real = h_complex.real.cpu()  # (40000, 512)
h_imag = h_complex.imag.cpu()  # (40000, 512)

# ============ Split Data into Train, Validation, Test ============
train_samples = 32000
val_samples = 4000
test_samples = 4000

train_real_in = h_LS_real[:train_samples]
train_imag_in = h_LS_imag[:train_samples]
train_real_target = h_real[:train_samples]
train_imag_target = h_imag[:train_samples]

val_real_in = h_LS_real[train_samples:train_samples+val_samples]
val_imag_in = h_LS_imag[train_samples:train_samples+val_samples]
val_real_target = h_real[train_samples:train_samples+val_samples]
val_imag_target = h_imag[train_samples:train_samples+val_samples]

test_real_in = h_LS_real[train_samples+val_samples:]
test_imag_in = h_LS_imag[train_samples+val_samples:]
test_real_target = h_real[train_samples+val_samples:]
test_imag_target = h_imag[train_samples+val_samples:]

train_dataset = TensorDataset(train_real_in, train_imag_in, train_real_target, train_imag_target)
val_dataset = TensorDataset(val_real_in, val_imag_in, val_real_target, val_imag_target)
test_dataset = TensorDataset(test_real_in, test_imag_in, test_real_target, test_imag_target)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# ============ Neural Architecture Components ============
# Feature extractor now takes 2 channels as input (real and imag concatenated)
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(2, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

OPS = {
    'conv_3x3': lambda C: nn.Conv1d(C, C, 3, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv1d(C, C, 5, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(nn.Conv1d(C, C, 1, bias=False), nn.BatchNorm1d(C)),
    'zero': lambda C: nn.ZeroPad1d(0)
}

class DenoiseCell(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.C = C
        self.num_edges = 8
        self.num_ops = len(OPS)
        self.alphas = nn.Parameter(torch.randn(self.num_edges, self.num_ops))
        self.ops = nn.ModuleList([nn.ModuleList([op(C) for op in OPS.values()]) 
                                   for _ in range(self.num_edges)])
        self.conv1x1 = nn.Conv1d(3 * C, C, 1, bias=False)

    def pad_and_concat(self, inputs):
        """Zero pad inputs along temporal dimension to the maximum length and concatenate along channels."""
        max_size = max(inp.shape[2] for inp in inputs)
        padded_inputs = [F.pad(inp, (0, max_size - inp.shape[2])) if inp.shape[2] < max_size else inp for inp in inputs]
        return torch.cat(padded_inputs, dim=1)

    def apply_ops(self, x, edge_idx):
        weights = F.softmax(self.alphas[edge_idx], dim=-1)
        return sum(w * op(x) for w, op in zip(weights, self.ops[edge_idx]))

    def forward(self, inputs):
        # Expected inputs: list of 2 tensors.
        assert len(inputs) == 2, "Each denoise cell must take two inputs."
        in0, in1 = inputs
        # Node 0: receives both inputs (pad and concat)
        node0 = self.pad_and_concat([in0, in1])
        node0 = F.relu(node0)  # shape: (batch, 2*C, T)
        # Node 1: receives output from Node 0 (single input)
        node1 = F.relu(node0)  # shape: (batch, 2*C, T)
        # Node 2: receives outputs from Node 0 and Node 1 (pad and concat)
        node2 = self.pad_and_concat([node0, node1])
        node2 = F.relu(node2)  # shape: (batch, 4*C, T)
        # Node 3: receives outputs from Node 0, Node 1, and Node 2 (pad, concat, then 1x1 conv)
        node3_inputs = self.pad_and_concat([node0, node1, node2])
        node3 = F.relu(self.conv1x1(node3_inputs))
        return node3

class DenoiseModule(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.cells = nn.ModuleList([DenoiseCell(C) for _ in range(10)])

    def forward(self, x):
        # For first cell, duplicate input.
        outputs = [x, x]
        for i in range(10):
            if i == 0:
                out = self.cells[i]([outputs[0], outputs[0]])
            elif i == 1:
                out = self.cells[i]([outputs[1], outputs[0]])
            else:
                out = self.cells[i]([outputs[-2], outputs[-1]])
            outputs.append(out)
        return outputs[-1]

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(32, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 2, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return self.conv2(x)

class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.denoiser = DenoiseModule(32)
        self.decoder = Decoder()

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.denoiser(x)
        return self.decoder(x)

# ============ Training and Evaluation Functions ============
def train_model(model, train_loader, val_loader, weight_optimizer, alpha_optimizer, criterion, num_epochs=20):
    for epoch in range(num_epochs):
        model.train()
        for real_in, imag_in, real_tar, imag_tar in train_loader:
            # Prepare inputs and targets: stack real and imaginary parts along channel dimension.
            real_in = real_in.unsqueeze(1)
            imag_in = imag_in.unsqueeze(1)
            input_data = torch.cat([real_in, imag_in], dim=1).to(device)  # shape: (batch, 2, T)
            
            real_tar = real_tar.unsqueeze(1)
            imag_tar = imag_tar.unsqueeze(1)
            target_data = torch.cat([real_tar, imag_tar], dim=1).to(device)  # shape: (batch, 2, T)
            
            weight_optimizer.zero_grad()
            preds = model(input_data)  # expected output shape: (batch, 2, T)
            # Split predictions and targets into real and imaginary parts.
            real_pred, imag_pred = preds.chunk(2, dim=1)
            real_target, imag_target = target_data.chunk(2, dim=1)
            loss_train = criterion(real_pred, real_target) + criterion(imag_pred, imag_target)
            loss_train.backward()
            weight_optimizer.step()
        
        # After weight updates, update architecture parameters using the validation loss.
        model.eval()
        total_val_loss = 0.0
        alpha_optimizer.zero_grad()
        for real_in, imag_in, real_tar, imag_tar in val_loader:
            real_in = real_in.unsqueeze(1)
            imag_in = imag_in.unsqueeze(1)
            v_input = torch.cat([real_in, imag_in], dim=1).to(device)
            real_tar = real_tar.unsqueeze(1).to(device)
            imag_tar = imag_tar.unsqueeze(1).to(device)
            v_target = torch.cat([real_tar, imag_tar], dim=1)
            
            v_preds = model(v_input)
            v_real_pred, v_imag_pred = v_preds.chunk(2, dim=1)
            v_real_target, v_imag_target = v_target.chunk(2, dim=1)
            loss_val = criterion(v_real_pred, v_real_target) + criterion(v_imag_pred, v_imag_target)
            loss_val.backward()
            total_val_loss += loss_val.item()
        alpha_optimizer.step()
        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {loss_train.item():.4f}, Val Loss: {total_val_loss/len(val_loader):.4f}")
    return

def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for real_in, imag_in, real_tar, imag_tar in loader:
            real_in = real_in.unsqueeze(1)
            imag_in = imag_in.unsqueeze(1)
            input_data = torch.cat([real_in, imag_in], dim=1).to(device)
            real_tar = real_tar.unsqueeze(1).to(device)
            imag_tar = imag_tar.unsqueeze(1).to(device)
            target_data = torch.cat([real_tar, imag_tar], dim=1)
            preds = model(input_data)
            real_pred, imag_pred = preds.chunk(2, dim=1)
            real_target, imag_target = target_data.chunk(2, dim=1)
            loss = criterion(real_pred, real_target) + criterion(imag_pred, imag_target)
            total_loss += loss.item()
    avg_loss = total_loss / len(loader)
    print(f"Evaluation Loss: {avg_loss:.4f}")
    return avg_loss

# ============ Get Data Loaders from .mat Files ============
def get_dataloaders_from_mat(batch_size=64):
    # Load .mat files
    data_y = load_mat_file('yDL_10dB_40k_150pilots_ipjp.mat', 'yDL')          # shape: (40000, 150, 2)
    data_psi = load_mat_file('PsiDL_10dB_40k_150pilots_ipjp.mat', 'PsiDL')      # shape: (2, 512, 150, 40000)
    data_h = load_mat_file('hDL_10dB_40k_150pilots_ipjp.mat', 'hDL')            # shape: (40000, 512, 2)
    data_sigma2 = load_mat_file('sigma2DL_10dB_40k_150pilots_ipjp.mat', 'sigma2DL')  # shape: (40000, 1)
    
    # Convert to complex tensors:
    y_complex = torch.tensor(data_y[..., 0] + 1j * data_y[..., 1], dtype=torch.cfloat)
    h_complex = torch.tensor(data_h[..., 0] + 1j * data_h[..., 1], dtype=torch.cfloat, device=device)
    
    # PsiDL: reshape from (2, 512, 150, 40000) to (40000, 150, 512)
    psi_real = data_psi[0].transpose(2, 3, 0).squeeze(-1)
    psi_imag = data_psi[1].transpose(2, 3, 0).squeeze(-1)
    psi_complex = torch.tensor(psi_real + 1j * psi_imag, dtype=torch.cfloat)
    
    # Compute LS estimate using batched least squares (normal equation)
    h_LS_complex = torch.linalg.lstsq(psi_complex, y_complex.unsqueeze(-1)).solution.squeeze(-1)
    
    # Split LS estimate and ground truth into real and imaginary parts (move to CPU)
    h_LS_real = h_LS_complex.real.cpu()
    h_LS_imag = h_LS_complex.imag.cpu()
    h_real = h_complex.real.cpu()
    h_imag = h_complex.imag.cpu()
    
    # Split into training (32000), validation (4000), and testing (4000)
    total_samples = 40000
    train_samples = 32000
    val_samples = 4000
    
    train_real_in = h_LS_real[:train_samples]
    train_imag_in = h_LS_imag[:train_samples]
    train_real_target = h_real[:train_samples]
    train_imag_target = h_imag[:train_samples]
    
    val_real_in = h_LS_real[train_samples:train_samples+val_samples]
    val_imag_in = h_LS_imag[train_samples:train_samples+val_samples]
    val_real_target = h_real[train_samples:train_samples+val_samples]
    val_imag_target = h_imag[train_samples:train_samples+val_samples]
    
    test_real_in = h_LS_real[train_samples+val_samples:]
    test_imag_in = h_LS_imag[train_samples+val_samples:]
    test_real_target = h_real[train_samples+val_samples:]
    test_imag_target = h_imag[train_samples+val_samples:]
    
    train_dataset = TensorDataset(train_real_in, train_imag_in, train_real_target, train_imag_target)
    val_dataset = TensorDataset(val_real_in, val_imag_in, val_real_target, val_imag_target)
    test_dataset = TensorDataset(test_real_in, test_imag_in, test_real_target, test_imag_target)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader

# ============ Main Execution ============
if __name__ == "__main__":
    # Get data loaders from .mat files
    train_loader, val_loader, test_loader = get_dataloaders_from_mat(batch_size=64)
    
    # Initialize model
    model = FullModel().to(device)
    
    # Split parameters into weight parameters and architecture parameters (alphas)
    weight_params = []
    alpha_params = []
    for name, param in model.named_parameters():
        if 'alphas' in name:
            alpha_params.append(param)
        else:
            weight_params.append(param)
    
    weight_optimizer = optim.Adam(weight_params, lr=0.01)
    alpha_optimizer = optim.Adam(alpha_params, lr=0.003)
    criterion = nn.MSELoss()
    
    # Bilevel Training: inner loop (weight updates) and outer loop (alpha updates)
    num_epochs = 20
    print("Starting bilevel training...")
    for epoch in range(num_epochs):
        # Train model weights using training loss
        train_loss = train_model(model, train_loader, val_loader, weight_optimizer, alpha_optimizer, criterion, num_epochs=1)
        
        # Evaluate on validation set
        val_loss = evaluate(model, val_loader, criterion)
        print(f"Epoch {epoch+1}/{num_epochs} - Validation Loss: {val_loss:.4f}")
    
    print("Bilevel training completed.")
    
    # Final evaluation on test set
    print("Final evaluation on test set:")
    test_loss = evaluate(model, test_loader, criterion)
    print(f"Test Loss: {test_loss:.4f}")

Tried to implement DARTS-style approach

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import h5py
from scipy.io import loadmat

# ============ GPU Setup ============
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============ Load .mat Files with Fallback ============
def load_mat_file(file_path, var_name):
    try:
        with h5py.File(file_path, 'r') as f:
            data = f[var_name][:]
        print(f"Loaded {file_path} using h5py.")
    except OSError:
        print(f"Failed to load {file_path} with h5py. Trying scipy.io.loadmat...")
        data = loadmat(file_path)[var_name]
        print(f"Loaded {file_path} using scipy.io.loadmat.")
    return data

# ============ Load Data ============
def load_data():
    data_y = load_mat_file('yDL_10dB_40k_150pilots_ipjp.mat', 'yDL')
    data_psi = load_mat_file('PsiDL_10dB_40k_150pilots_ipjp.mat', 'PsiDL')
    data_h = load_mat_file('hDL_10dB_40k_150pilots_ipjp.mat', 'hDL')
    data_sigma2 = load_mat_file('sigma2DL_10dB_40k_150pilots_ipjp.mat', 'sigma2DL')

    y_complex_np = data_y[..., 0] + 1j * data_y[..., 1]
    h_complex_np = data_h[..., 0] + 1j * data_h[..., 1]
    psi_real_np = data_psi[0].transpose(2, 1, 0)
    psi_imag_np = data_psi[1].transpose(2, 1, 0)
    
    return y_complex_np, psi_real_np, psi_imag_np, h_complex_np

y_complex_np, psi_real_np, psi_imag_np, h_complex_np = load_data()

# ============ Compute LS Estimate in Batches ============
num_samples = 40000
batch_size = 1000
h_LS_complex = torch.empty((num_samples, 512), dtype=torch.cfloat)

for start in range(0, num_samples, batch_size):
    end = min(start + batch_size, num_samples)
    print(f"Processing LS batch {start//batch_size + 1}/{(num_samples//batch_size)}")
    
    psi_real_batch = psi_real_np[start:end]
    psi_imag_batch = psi_imag_np[start:end]
    y_batch = y_complex_np[start:end]
    
    psi_batch = torch.tensor(psi_real_batch + 1j*psi_imag_batch, 
                           dtype=torch.cfloat).to(device)
    y_batch = torch.tensor(y_batch, dtype=torch.cfloat).to(device)

    psi_H = torch.conj(psi_batch.transpose(-2, -1))
    A = torch.matmul(psi_H, psi_batch)
    B = torch.matmul(psi_H, y_batch.unsqueeze(-1))
    h_LS_batch = torch.linalg.solve(A, B).squeeze(-1)
    
    h_LS_complex[start:end] = h_LS_batch.cpu()

h_LS_real, h_LS_imag = h_LS_complex.real, h_LS_complex.imag
h_real = torch.tensor(h_complex_np.real, dtype=torch.float32)
h_imag = torch.tensor(h_complex_np.imag, dtype=torch.float32)

# ============ Dataset Split ============
train_samples, val_samples, test_samples = 32000, 4000, 4000

train_dataset = TensorDataset(
    h_LS_real[:train_samples], h_LS_imag[:train_samples],
    h_real[:train_samples], h_imag[:train_samples]
)

val_dataset = TensorDataset(
    h_LS_real[train_samples:train_samples+val_samples],
    h_LS_imag[train_samples:train_samples+val_samples],
    h_real[train_samples:train_samples+val_samples],
    h_imag[train_samples:train_samples+val_samples]
)

test_dataset = TensorDataset(
    h_LS_real[train_samples+val_samples:],
    h_LS_imag[train_samples+val_samples:],
    h_real[train_samples+val_samples:],
    h_imag[train_samples+val_samples:]
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# ============ Neural Architecture Components ============
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(2, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

OPS = {
    'conv_3x3': lambda C: nn.Conv1d(C, C, 3, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv1d(C, C, 5, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(nn.Conv1d(C, C, 1, bias=False), nn.BatchNorm1d(C)),
    'zero': lambda C: nn.ZeroPad1d(0)
}

class DenoiseCell(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.C = C
        self.num_edges = 8
        self.num_ops = len(OPS)
        self.alphas = nn.Parameter(torch.randn(self.num_edges, self.num_ops))
        self.ops = nn.ModuleList([nn.ModuleList([op(C) for op in OPS.values()]) 
                                   for _ in range(self.num_edges)])
        # Corrected input channels calculation to 8*C (from node0+node1+node2 concatenation)
        self.conv1x1 = nn.Conv1d(8*C, C, 1, bias=False)  # Adjusted from 4*C to 8*C

    def pad_and_concat(self, inputs):
        max_size = max(inp.shape[2] for inp in inputs)
        padded_inputs = [F.pad(inp, (0, max_size - inp.shape[2])) if inp.shape[2] < max_size else inp
                         for inp in inputs]
        return torch.cat(padded_inputs, dim=1)

    def apply_ops(self, x, edge_idx):
        weights = F.softmax(self.alphas[edge_idx], dim=-1)
        return sum(w * op(x) for w, op in zip(weights, self.ops[edge_idx]))

    def forward(self, inputs):
        in0, in1 = inputs
        
        # Node 0: concatenate inputs (32+32=64 channels)
        node0 = self.pad_and_concat([in0, in1])  # 64 channels
        node0 = F.relu(node0)
        
        # Node 1: same as node0 (64 channels)
        node1 = F.relu(node0)
        
        # Node 2: concatenate node0+node1 (64+64=128 channels)
        node2 = self.pad_and_concat([node0, node1])
        node2 = F.relu(node2)
        
        # Node3 inputs: node0+node1+node2 (64+64+128=256 channels)
        node3_inputs = self.pad_and_concat([node0, node1, node2])
        
        # Final 1x1 conv to reduce from 256 (8*C) to 32 (C) channels
        node3 = F.relu(self.conv1x1(node3_inputs))
        return node3

class DenoiseModule(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.cells = nn.ModuleList([DenoiseCell(C) for _ in range(10)])

    def forward(self, x):
        outputs = [x, x]
        for i in range(10):
            if i == 0:
                out = self.cells[i]([outputs[0], outputs[0]])
            elif i == 1:
                out = self.cells[i]([outputs[1], outputs[0]])
            else:
                out = self.cells[i]([outputs[-2], outputs[-1]])
            outputs.append(out)
        return outputs[-1]

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(32, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 2, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return self.conv2(x)

class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.denoiser = DenoiseModule(32)
        self.decoder = Decoder()

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.denoiser(x)
        return self.decoder(x)

# ============ Single-Step Unrolled DARTS-Style Approach ============
def compute_unrolled_weights(model, loss_fn, train_batch, lr_w):
    """
    Perform one gradient step on the training batch to get w' (the unrolled weights).
    We'll create a clone of the model, do forward/backward on train_batch, and
    update the clone's parameters (NOT the original model's) by one step.
    """
    real_in, imag_in, real_tar, imag_tar = train_batch
    real_in = real_in.unsqueeze(1).to(device)
    imag_in = imag_in.unsqueeze(1).to(device)
    inputs = torch.cat([real_in, imag_in], dim=1)

    real_tar = real_tar.unsqueeze(1).to(device)
    imag_tar = imag_tar.unsqueeze(1).to(device)
    targets = torch.cat([real_tar, imag_tar], dim=1)

    # Clone the model to get a new set of parameters
    cloned_model = FullModel().to(device)
    cloned_model.load_state_dict(model.state_dict())

    # Identify alpha vs weight parameters in the clone
    alpha_params = []
    w_params = []
    for n,p in cloned_model.named_parameters():
        if 'alphas' in n:
            alpha_params.append(p)
        else:
            w_params.append(p)

    # Turn off grads for alpha in the clone, since we only want to update w
    for p in alpha_params:
        p.requires_grad = False

    # Forward/backward on the clone
    preds = cloned_model(inputs)
    rpred, ipred = preds.chunk(2, dim=1)
    rtar, itar = targets.chunk(2, dim=1)
    loss = loss_fn(rpred, rtar) + loss_fn(ipred, itar)

    grads = torch.autograd.grad(loss, w_params, create_graph=True)

    # Manually update the clone's w-params: w' = w - lr_w * grads
    # We'll store them in a dictionary for easy use
    w_prime_dict = {}
    for (name, p), g in zip(cloned_model.named_parameters(), grads):
        if 'alphas' not in name:
            w_prime_dict[name] = p - lr_w*g

    return w_prime_dict

def forward_with_weights(model, w_prime_dict, input_batch):
    """
    Forward pass the model using a dictionary of updated weights (w_prime).
    We do this by temporarily substituting the model's w-params with w_prime_dict.
    """
    # Save old param.data
    backup = {}
    for name, param in model.named_parameters():
        backup[name] = param.data

    # Substitute param.data with w_prime
    for name, param in model.named_parameters():
        if name in w_prime_dict:
            param.data = w_prime_dict[name].data

    # Forward pass
    real_in, imag_in, real_tar, imag_tar = input_batch
    real_in = real_in.unsqueeze(1).to(device)
    imag_in = imag_in.unsqueeze(1).to(device)
    inputs = torch.cat([real_in, imag_in], dim=1)

    preds = model(inputs)

    # restore old param.data
    for name, param in model.named_parameters():
        param.data = backup[name]

    return preds

def darts_unrolled_step(model, train_batch, val_batch, w_optimizer, alpha_optimizer, lr_w, criterion):
    """
    Single-step unrolled approach:
      1) compute w' by taking one step of gradient on train batch
      2) do forward on val batch with w' to compute gradient wrt alpha
      3) update alpha
      4) update w (again) with train batch or keep the unrolled step
         depending on the standard approach
    """

    # Step 1: compute unrolled w'
    w_prime_dict = compute_unrolled_weights(model, criterion, train_batch, lr_w)

    # Step 2: forward pass on val with w'
    real_in, imag_in, real_tar, imag_tar = val_batch
    preds_val = forward_with_weights(model, w_prime_dict, (real_in, imag_in, real_tar, imag_tar))
    # compute val loss
    real_pred_val, imag_pred_val = preds_val.chunk(2, dim=1)
    real_tar_val = real_tar.unsqueeze(1).to(device)
    imag_tar_val = imag_tar.unsqueeze(1).to(device)
    targets_val = torch.cat([real_tar_val, imag_tar_val], dim=1)
    rtar_val, itar_val = targets_val.chunk(2, dim=1)

    loss_val = criterion(real_pred_val, rtar_val) + criterion(imag_pred_val, itar_val)

    # Step 3: update alpha
    alpha_optimizer.zero_grad()
    loss_val.backward()
    alpha_optimizer.step()

    # Step 4: do a normal step on w with the train batch
    #   or we can skip if we want w to be replaced by w' directly. 
    #   Here, we do a normal step for clarity.
    real_in_t, imag_in_t, real_tar_t, imag_tar_t = train_batch
    real_in_t = real_in_t.unsqueeze(1).to(device)
    imag_in_t = imag_in_t.unsqueeze(1).to(device)
    inputs_t = torch.cat([real_in_t, imag_in_t], dim=1)
    real_tar_t = real_tar_t.unsqueeze(1).to(device)
    imag_tar_t = imag_tar_t.unsqueeze(1).to(device)
    targets_t = torch.cat([real_tar_t, imag_tar_t], dim=1)

    w_optimizer.zero_grad()
    preds_t = model(inputs_t)
    rpred_t, ipred_t = preds_t.chunk(2, dim=1)
    rtar_t, itar_t = targets_t.chunk(2, dim=1)
    loss_train = criterion(rpred_t, rtar_t) + criterion(ipred_t, itar_t)
    loss_train.backward()
    w_optimizer.step()

    return loss_train.item(), loss_val.item()

def train_darts_style(model, train_loader, val_loader, w_optimizer, alpha_optimizer, criterion, epochs=20):
    """
    Single-step DARTS approach:
      For each mini-batch in training, we pick a mini-batch in validation
      (for simplicity, we cycle or zip).
    """
    from itertools import cycle

    val_iter = cycle(val_loader)  # infinite val loader
    for epoch in range(epochs):
        model.train()
        total_train_loss = 0.0
        total_val_loss = 0.0
        for train_batch in train_loader:
            val_batch = next(val_iter)
            # unrolled step
            train_loss_val, val_loss_val = darts_unrolled_step(
                model, train_batch, val_batch,
                w_optimizer, alpha_optimizer,
                lr_w=0.01,  # can also pass w_optimizer.param_groups[0]['lr']
                criterion=criterion
            )
            total_train_loss += train_loss_val
            total_val_loss += val_loss_val

        avg_train_loss = total_train_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(train_loader)  # we count val losses per train batch
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Proxy Val Loss: {avg_val_loss:.4f}")

def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for real_in, imag_in, real_tar, imag_tar in loader:
            real_in = real_in.unsqueeze(1).to(device)
            imag_in = imag_in.unsqueeze(1).to(device)
            inputs = torch.cat([real_in, imag_in], dim=1)
            real_tar = real_tar.unsqueeze(1).to(device)
            imag_tar = imag_tar.unsqueeze(1).to(device)
            targets = torch.cat([real_tar, imag_tar], dim=1)
            preds = model(inputs)
            rpred, ipred = preds.chunk(2, dim=1)
            rtar, itar = targets.chunk(2, dim=1)
            loss = criterion(rpred, rtar) + criterion(ipred, itar)
            total_loss += loss.item()
    return total_loss / len(loader)

# ============ Main Execution ============
if __name__ == "__main__":
    # Build model
    model = FullModel().to(device)

    # Separate alpha-params from w-params
    alpha_params = [p for n,p in model.named_parameters() if 'alphas' in n]
    w_params = [p for n,p in model.named_parameters() if 'alphas' not in n]
    
    w_optim = optim.Adam(w_params, lr=0.01)
    alpha_optim = optim.Adam(alpha_params, lr=0.003)
    criterion = nn.MSELoss()

    print("Starting DARTS-style training with optimized memory...")
    train_darts_style(model, train_loader, val_loader, w_optim, alpha_optim, criterion, epochs=20)
    
    test_loss = evaluate(model, test_loader, criterion)
    print(f"Final Test Loss: {test_loss:.4f}")

Truncated RAD

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import h5py
from scipy.io import loadmat

# ============ GPU Setup ============
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============ Load .mat Files with Fallback ============
def load_mat_file(file_path, var_name):
    try:
        with h5py.File(file_path, 'r') as f:
            data = f[var_name][:]
        print(f"Loaded {file_path} using h5py.")
    except OSError:
        print(f"Failed to load {file_path} with h5py. Trying scipy.io.loadmat...")
        data = loadmat(file_path)[var_name]
        print(f"Loaded {file_path} using scipy.io.loadmat.")
    return data

# ============ Load Data ============
def load_data():
    data_y = load_mat_file('yDL_10dB_40k_150pilots_ipjp.mat', 'yDL')
    data_psi = load_mat_file('PsiDL_10dB_40k_150pilots_ipjp.mat', 'PsiDL')
    data_h = load_mat_file('hDL_10dB_40k_150pilots_ipjp.mat', 'hDL')
    data_sigma2 = load_mat_file('sigma2DL_10dB_40k_150pilots_ipjp.mat', 'sigma2DL')

    y_complex_np = data_y[..., 0] + 1j * data_y[..., 1]
    h_complex_np = data_h[..., 0] + 1j * data_h[..., 1]
    psi_real_np = data_psi[0].transpose(2, 1, 0)
    psi_imag_np = data_psi[1].transpose(2, 1, 0)
    
    return y_complex_np, psi_real_np, psi_imag_np, h_complex_np

y_complex_np, psi_real_np, psi_imag_np, h_complex_np = load_data()

# ============ Compute LS Estimate in Batches ============
num_samples = 40000
batch_size = 1000
h_LS_complex = torch.empty((num_samples, 512), dtype=torch.cfloat)

for start in range(0, num_samples, batch_size):
    end = min(start + batch_size, num_samples)
    print(f"Processing LS batch {start//batch_size + 1}/{(num_samples//batch_size)}")
    
    psi_real_batch = psi_real_np[start:end]
    psi_imag_batch = psi_imag_np[start:end]
    y_batch = y_complex_np[start:end]
    
    psi_batch = torch.tensor(psi_real_batch + 1j*psi_imag_batch, 
                           dtype=torch.cfloat).to(device)
    y_batch = torch.tensor(y_batch, dtype=torch.cfloat).to(device)

    psi_H = torch.conj(psi_batch.transpose(-2, -1))
    A = torch.matmul(psi_H, psi_batch)
    B = torch.matmul(psi_H, y_batch.unsqueeze(-1))
    h_LS_batch = torch.linalg.solve(A, B).squeeze(-1)
    
    h_LS_complex[start:end] = h_LS_batch.cpu()

h_LS_real, h_LS_imag = h_LS_complex.real, h_LS_complex.imag
h_real = torch.tensor(h_complex_np.real, dtype=torch.float32)
h_imag = torch.tensor(h_complex_np.imag, dtype=torch.float32)

# ============ Dataset Split ============
train_samples, val_samples, test_samples = 32000, 4000, 4000

train_dataset = TensorDataset(
    h_LS_real[:train_samples], h_LS_imag[:train_samples],
    h_real[:train_samples], h_imag[:train_samples]
)

val_dataset = TensorDataset(
    h_LS_real[train_samples:train_samples+val_samples],
    h_LS_imag[train_samples:train_samples+val_samples],
    h_real[train_samples:train_samples+val_samples],
    h_imag[train_samples:train_samples+val_samples]
)

test_dataset = TensorDataset(
    h_LS_real[train_samples+val_samples:],
    h_LS_imag[train_samples+val_samples:],
    h_real[train_samples+val_samples:],
    h_imag[train_samples+val_samples:]
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# ============ Neural Architecture Components ============
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(2, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

OPS = {
    'conv_3x3': lambda C: nn.Conv1d(C, C, 3, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv1d(C, C, 5, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(nn.Conv1d(C, C, 1, bias=False), nn.BatchNorm1d(C)),
    'zero': lambda C: nn.ZeroPad1d(0)
}

class DenoiseCell(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.C = C
        self.num_edges = 8
        self.num_ops = len(OPS)
        self.alphas = nn.Parameter(torch.randn(self.num_edges, self.num_ops))
        self.ops = nn.ModuleList([nn.ModuleList([op(C) for op in OPS.values()]) 
                                   for _ in range(self.num_edges)])
        self.conv1x1 = nn.Conv1d(8*C, C, 1, bias=False)

    def pad_and_concat(self, inputs):
        max_size = max(inp.shape[2] for inp in inputs)
        padded_inputs = [F.pad(inp, (0, max_size - inp.shape[2])) if inp.shape[2] < max_size else inp
                         for inp in inputs]
        return torch.cat(padded_inputs, dim=1)

    def apply_ops(self, x, edge_idx):
        weights = F.softmax(self.alphas[edge_idx], dim=-1)
        return sum(w * op(x) for w, op in zip(weights, self.ops[edge_idx]))

    def forward(self, inputs):
        in0, in1 = inputs
        
        node0 = self.pad_and_concat([in0, in1])
        node0 = F.relu(node0)
        
        node1 = F.relu(node0)
        
        node2 = self.pad_and_concat([node0, node1])
        node2 = F.relu(node2)
        
        node3_inputs = self.pad_and_concat([node0, node1, node2])
        
        node3 = F.relu(self.conv1x1(node3_inputs))
        return node3

class DenoiseModule(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.cells = nn.ModuleList([DenoiseCell(C) for _ in range(10)])
        self.state_history = []

    def forward(self, x):
        outputs = [x, x]
        self.state_history = []
        for i in range(10):
            if i == 0:
                out = self.cells[i]([outputs[0], outputs[0]])
            elif i == 1:
                out = self.cells[i]([outputs[1], outputs[0]])
            else:
                out = self.cells[i]([outputs[-2], outputs[-1]])
            outputs.append(out)
            self.state_history.append({
                'outputs': list(outputs),
                'alpha': list(self.parameters())
            })
        return outputs[-1]

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(32, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 2, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return self.conv2(x)

class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.denoiser = DenoiseModule(32)
        self.decoder = Decoder()

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.denoiser(x)
        return self.decoder(x)

# ============ Truncated RAD Implementation ============
def compute_truncated_grad(model, val_batch, criterion, truncation_steps=3):
    real_in, imag_in, real_tar, imag_tar = val_batch
    real_in = real_in.unsqueeze(1).to(device)
    imag_in = imag_in.unsqueeze(1).to(device)
    inputs = torch.cat([real_in, imag_in], dim=1)
    
    real_tar = real_tar.unsqueeze(1).to(device)
    imag_tar = imag_tar.unsqueeze(1).to(device)
    targets = torch.cat([real_tar, imag_tar], dim=1)

    # Forward pass through unrolled steps
    preds = model(inputs)
    rpred, ipred = preds.chunk(2, dim=1)
    rtar, itar = targets.chunk(2, dim=1)
    loss = criterion(rpred, rtar) + criterion(ipred, itar)

    # Get computation history
    history = model.denoiser.state_history
    T = len(history)
    M = min(truncation_steps, T)
    
    # Initialize gradients
    alpha_list = [p for n, p in model.named_parameters() if 'alphas' in n]
    alpha_indices = {p: idx for idx, p in enumerate(alpha_list)}
    grad_alpha = [torch.zeros_like(p) for p in alpha_list]
    
    # Initialize lambda with proper gradient handling
    if T == 0:
        return grad_alpha, loss.item()
    
    outputs = history[-1]['outputs']
    lambda_t = torch.autograd.grad(
        loss, outputs, 
        retain_graph=True, 
        allow_unused=True
    )
    
    # Replace None in lambda_t with zeros
    lambda_t = list(lambda_t)
    for i in range(len(lambda_t)):
        if lambda_t[i] is None:
            lambda_t[i] = torch.zeros_like(outputs[i])
    
    # Reverse through truncated steps
    for t in reversed(range(max(0, T-M), T)):
        state = history[t]
        current_outputs = state['outputs']
        current_alpha = state['alpha']

        # Compute gradients with allow_unused=True
        A = torch.autograd.grad(
            current_outputs, current_alpha, 
            grad_outputs=lambda_t, 
            retain_graph=True, 
            allow_unused=True
        )
        B = torch.autograd.grad(
            current_outputs, current_outputs, 
            grad_outputs=lambda_t, 
            retain_graph=True, 
            allow_unused=True
        )
        
        # Update gradients and lambda with None checks
        for g_a, a in zip(A, current_alpha):
            if g_a is not None and a in alpha_indices:
                grad_alpha[alpha_indices[a]] += g_a.detach()
        # Update lambda_t for next iteration
        lambda_t = [b.detach() if b is not None else None for b in B]
        # Replace None in lambda_t with zeros for next iteration
        for i in range(len(lambda_t)):
            if lambda_t[i] is None:
                lambda_t[i] = torch.zeros_like(current_outputs[i])

    return grad_alpha, loss.item()

def truncated_rad_step(model, train_batch, val_batch, w_optimizer, alpha_optimizer, 
                      criterion, truncation_steps=3):
    # Train on current batch
    real_in_t, imag_in_t, real_tar_t, imag_tar_t = train_batch
    real_in_t = real_in_t.unsqueeze(1).to(device)
    imag_in_t = imag_in_t.unsqueeze(1).to(device)
    inputs_t = torch.cat([real_in_t, imag_in_t], dim=1)
    
    w_optimizer.zero_grad()
    preds_t = model(inputs_t)
    rpred_t, ipred_t = preds_t.chunk(2, dim=1)
    rtar_t, itar_t = real_tar_t.unsqueeze(1).to(device), imag_tar_t.unsqueeze(1).to(device)
    loss_train = criterion(rpred_t, rtar_t) + criterion(ipred_t, itar_t)
    loss_train.backward()
    w_optimizer.step()

    # Compute truncated gradients for alpha
    grad_alpha, val_loss = compute_truncated_grad(model, val_batch, criterion, truncation_steps)
    
    # Update alpha parameters
    alpha_optimizer.zero_grad()
    for p, g in zip([p for n, p in model.named_parameters() if 'alphas' in n], grad_alpha):
        if g is not None:
            if p.grad is None:
                p.grad = g.to(device)
            else:
                p.grad += g.to(device)
    alpha_optimizer.step()

    return loss_train.item(), val_loss

def train_truncated_rad(model, train_loader, val_loader, w_optimizer, alpha_optimizer, 
                       criterion, epochs=20, truncation_steps=3):
    from itertools import cycle
    val_iter = cycle(val_loader)
    
    for epoch in range(epochs):
        model.train()
        total_train_loss = 0.0
        total_val_loss = 0.0
        
        for train_batch in train_loader:
            val_batch = next(val_iter)
            train_loss, val_loss = truncated_rad_step(
                model, train_batch, val_batch,
                w_optimizer, alpha_optimizer,
                criterion, truncation_steps
            )
            total_train_loss += train_loss
            total_val_loss += val_loss
        
        avg_train_loss = total_train_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for real_in, imag_in, real_tar, imag_tar in loader:
            real_in = real_in.unsqueeze(1).to(device)
            imag_in = imag_in.unsqueeze(1).to(device)
            inputs = torch.cat([real_in, imag_in], dim=1)
            real_tar = real_tar.unsqueeze(1).to(device)
            imag_tar = imag_tar.unsqueeze(1).to(device)
            targets = torch.cat([real_tar, imag_tar], dim=1)
            preds = model(inputs)
            rpred, ipred = preds.chunk(2, dim=1)
            rtar, itar = targets.chunk(2, dim=1)
            loss = criterion(rpred, rtar) + criterion(ipred, itar)
            total_loss += loss.item()
    return total_loss / len(loader)

# ============ Main Execution ============
if __name__ == "__main__":
    model = FullModel().to(device)
    
    # Separate parameters
    alpha_params = [p for n, p in model.named_parameters() if 'alphas' in n]
    w_params = [p for n, p in model.named_parameters() if 'alphas' not in n]
    
    # Optimizers
    w_optim = optim.Adam(w_params, lr=0.01)
    alpha_optim = optim.Adam(alpha_params, lr=0.003)
    criterion = nn.MSELoss()

    print("Starting training with Truncated RAD...")
    train_truncated_rad(model, train_loader, val_loader, w_optim, alpha_optim, 
                       criterion, epochs=2000, truncation_steps=50)
    
    test_loss = evaluate(model, test_loader, criterion)
    print(f"Final Test Loss: {test_loss:.4f}")

Using device: cuda
Failed to load yDL_10dB_40k_150pilots_ipjp.mat with h5py. Trying scipy.io.loadmat...
Loaded yDL_10dB_40k_150pilots_ipjp.mat using scipy.io.loadmat.
Loaded PsiDL_10dB_40k_150pilots_ipjp.mat using h5py.
Failed to load hDL_10dB_40k_150pilots_ipjp.mat with h5py. Trying scipy.io.loadmat...
Loaded hDL_10dB_40k_150pilots_ipjp.mat using scipy.io.loadmat.
Failed to load sigma2DL_10dB_40k_150pilots_ipjp.mat with h5py. Trying scipy.io.loadmat...
Loaded sigma2DL_10dB_40k_150pilots_ipjp.mat using scipy.io.loadmat.
Processing LS batch 1/40
Processing LS batch 2/40
Processing LS batch 3/40
Processing LS batch 4/40
Processing LS batch 5/40
Processing LS batch 6/40
Processing LS batch 7/40
Processing LS batch 8/40
Processing LS batch 9/40
Processing LS batch 10/40
Processing LS batch 11/40
Processing LS batch 12/40
Processing LS batch 13/40
Processing LS batch 14/40
Processing LS batch 15/40
Processing LS batch 16/40
Processing LS batch 17/40
Processing LS batch 18/40
Processing LS b

With visualizations(+Skip)

In [8]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import h5py
from scipy.io import loadmat
import matplotlib.pyplot as plt
import os
from time import time
from collections import defaultdict
import warnings

# ============ Visualization Setup ============
plt.ioff()
plt.rcParams['figure.constrained_layout.use'] = True  # Enable constrained layout
os.makedirs("progress_plots", exist_ok=True)
os.makedirs("architecture_plots", exist_ok=True)
os.makedirs("channel_estimates", exist_ok=True)

def plot_losses(epochs, train_losses, val_losses, test_losses=None):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), constrained_layout=True)
        
        # Training and Validation Loss
        ax1.plot(epochs, train_losses, label='Train Loss')
        ax1.plot(epochs, val_losses, label='Val Loss')
        ax1.set_title('Training & Validation Loss', fontsize=12)
        ax1.set_xlabel('Epochs', fontsize=10)
        ax1.set_ylabel('MSE Loss', fontsize=10)
        if all(y > 0 for y in train_losses + val_losses):
            ax1.set_yscale('log')
        ax1.legend()
        ax1.grid(True)

        # Test Loss if available
        if test_losses:
            ax2.plot(test_losses['epochs'], test_losses['values'], 'r-')
            ax2.set_title('Test Loss Progression', fontsize=12)
            ax2.set_xlabel('Epochs', fontsize=10)
            ax2.set_ylabel('MSE Loss', fontsize=10)
            if all(y > 0 for y in test_losses['values']):
                ax2.set_yscale('log')
            ax2.grid(True)

        plt.savefig(f"progress_plots/losses_{int(time())}.png")
        plt.close()

def plot_architecture(alphas, epoch):
    """Plot the evolution of alpha parameters"""
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        fig = plt.figure(figsize=(18, 16), constrained_layout=True)
        keys = sorted(alphas.keys())
        
        for i, (edge, alpha) in enumerate(alphas.items()):
            ax = fig.add_subplot(4, 2, i+1)  # 4 rows, 2 columns
            ax.bar(range(len(alpha)), alpha, width=0.6)
            ax.set_title(f'Edge {edge} Alpha Values', fontsize=10)
            ax.set_xlabel('Operation', fontsize=8)
            ax.set_ylabel('Weight', fontsize=8)
            ax.set_xticks(range(len(alpha)))
            ax.set_xticklabels(list(OPS.keys()), rotation=60, ha='right', fontsize=7)
            ax.tick_params(axis='y', labelsize=7)
            
        plt.savefig(f"architecture_plots/alpha_epoch_{epoch}.png")
        plt.close()

def plot_channel_estimates(model, test_loader, epoch, num_examples=3):
    """Plot example channel estimates"""
    model.eval()
    with torch.no_grad():
        for i, (real_in, imag_in, real_tar, imag_tar) in enumerate(test_loader):
            if i >= num_examples:
                break
            inputs = torch.cat([real_in.unsqueeze(1), imag_in.unsqueeze(1)], dim=1).to(device)
            preds = model(inputs)
            rpred, ipred = preds.chunk(2, dim=1)
            
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5), constrained_layout=True)
            
            ax1.plot(real_tar[0].numpy(), label='True Real')
            ax1.plot(rpred[0].cpu().numpy(), label='Predicted Real')
            ax1.set_title('Real Component', fontsize=10)
            ax1.legend()
            
            ax2.plot(imag_tar[0].numpy(), label='True Imag')
            ax2.plot(ipred[0].cpu().numpy(), label='Predicted Imag')
            ax2.set_title('Imaginary Component', fontsize=10)
            ax2.legend()
            
            plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
            plt.close()

def plot_learning_rates(lr_history):
    """Plot learning rate evolution"""
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5), constrained_layout=True)
        
        ax1.plot(lr_history['w_lr'], label='Weight LR')
        ax1.set_title('Weight Learning Rate', fontsize=10)
        ax1.set_xlabel('Epoch', fontsize=8)
        ax1.set_ylabel('Learning Rate', fontsize=8)
        
        ax2.plot(lr_history['alpha_lr'], label='Alpha LR')
        ax2.set_title('Alpha Learning Rate', fontsize=10)
        ax2.set_xlabel('Epoch', fontsize=8)
        ax2.set_ylabel('Learning Rate', fontsize=8)
        
        plt.savefig(f"progress_plots/learning_rates_{int(time())}.png")
        plt.close()

def plot_gradient_norms(grad_norms):
    """Plot gradient norms over time"""
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5), constrained_layout=True)
        
        for name, norms in grad_norms['weight'].items():
            ax1.plot(norms, label=name)
        ax1.set_title('Weight Gradient Norms', fontsize=10)
        ax1.set_xlabel('Epoch', fontsize=8)
        ax1.set_ylabel('Gradient Norm', fontsize=8)
        if any(n > 0 for n in norms):
            ax1.set_yscale('log')
        ax1.legend(fontsize=7)
        
        for name, norms in grad_norms['alpha'].items():
            ax2.plot(norms, label=name)
        ax2.set_title('Alpha Gradient Norms', fontsize=10)
        ax2.set_xlabel('Epoch', fontsize=8)
        ax2.set_ylabel('Gradient Norm', fontsize=8)
        if any(n > 0 for n in norms):
            ax2.set_yscale('log')
        ax2.legend(fontsize=7)
        
        plt.savefig(f"progress_plots/gradient_norms_{int(time())}.png")
        plt.close()

# ============ GPU Setup ============
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============ Load .mat Files with Fallback ============
def load_mat_file(file_path, var_name):
    try:
        with h5py.File(file_path, 'r') as f:
            data = f[var_name][:]
        print(f"Loaded {file_path} using h5py.")
    except OSError:
        print(f"Failed to load {file_path} with h5py. Trying scipy.io.loadmat...")
        data = loadmat(file_path)[var_name]
        print(f"Loaded {file_path} using scipy.io.loadmat.")
    return data

# ============ Load Data ============
def load_data():
    data_y = load_mat_file('yDL_10dB_40k_150pilots_ipjp.mat', 'yDL')
    data_psi = load_mat_file('PsiDL_10dB_40k_150pilots_ipjp.mat', 'PsiDL')
    data_h = load_mat_file('hDL_10dB_40k_150pilots_ipjp.mat', 'hDL')
    data_sigma2 = load_mat_file('sigma2DL_10dB_40k_150pilots_ipjp.mat', 'sigma2DL')

    y_complex_np = data_y[..., 0] + 1j * data_y[..., 1]
    h_complex_np = data_h[..., 0] + 1j * data_h[..., 1]
    psi_real_np = data_psi[0].transpose(2, 1, 0)
    psi_imag_np = data_psi[1].transpose(2, 1, 0)
    
    return y_complex_np, psi_real_np, psi_imag_np, h_complex_np

y_complex_np, psi_real_np, psi_imag_np, h_complex_np = load_data()

# ============ Compute LS Estimate in Batches ============
num_samples = 40000
batch_size = 1000
h_LS_complex = torch.empty((num_samples, 512), dtype=torch.cfloat)

for start in range(0, num_samples, batch_size):
    end = min(start + batch_size, num_samples)
    print(f"Processing LS batch {start//batch_size + 1}/{(num_samples//batch_size)}")
    
    psi_real_batch = psi_real_np[start:end]
    psi_imag_batch = psi_imag_np[start:end]
    y_batch = y_complex_np[start:end]
    
    psi_batch = torch.tensor(psi_real_batch + 1j*psi_imag_batch, 
                           dtype=torch.cfloat).to(device)
    y_batch = torch.tensor(y_batch, dtype=torch.cfloat).to(device)

    psi_H = torch.conj(psi_batch.transpose(-2, -1))
    A = torch.matmul(psi_H, psi_batch)
    B = torch.matmul(psi_H, y_batch.unsqueeze(-1))
    h_LS_batch = torch.linalg.solve(A, B).squeeze(-1)
    
    h_LS_complex[start:end] = h_LS_batch.cpu()

h_LS_real, h_LS_imag = h_LS_complex.real, h_LS_complex.imag
h_real = torch.tensor(h_complex_np.real, dtype=torch.float32)
h_imag = torch.tensor(h_complex_np.imag, dtype=torch.float32)

# ============ Dataset Split ============
train_samples, val_samples, test_samples = 32000, 4000, 4000

train_dataset = TensorDataset(
    h_LS_real[:train_samples], h_LS_imag[:train_samples],
    h_real[:train_samples], h_imag[:train_samples]
)

val_dataset = TensorDataset(
    h_LS_real[train_samples:train_samples+val_samples],
    h_LS_imag[train_samples:train_samples+val_samples],
    h_real[train_samples:train_samples+val_samples],
    h_imag[train_samples:train_samples+val_samples]
)

test_dataset = TensorDataset(
    h_LS_real[train_samples+val_samples:],
    h_LS_imag[train_samples+val_samples:],
    h_real[train_samples+val_samples:],
    h_imag[train_samples+val_samples:]
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# ============ Neural Architecture Components ============
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(2, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

OPS = {
    'conv_3x3': lambda C: nn.Conv1d(C, C, 3, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv1d(C, C, 5, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'skip_connection': lambda C: nn.Sequential(nn.Conv1d(C, C, 1, bias=False), nn.BatchNorm1d(C)),
    'zero': lambda C: nn.ZeroPad1d(0)
}

class DenoiseCell(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.C = C
        self.num_edges = 8
        self.num_ops = len(OPS)
        self.alphas = nn.Parameter(torch.randn(self.num_edges, self.num_ops))
        self.ops = nn.ModuleList([nn.ModuleList([op(C) for op in OPS.values()]) 
                                   for _ in range(self.num_edges)])
        self.conv1x1 = nn.Conv1d(8*C, C, 1, bias=False)

    def pad_and_concat(self, inputs):
        max_size = max(inp.shape[2] for inp in inputs)
        padded_inputs = [F.pad(inp, (0, max_size - inp.shape[2])) if inp.shape[2] < max_size else inp
                         for inp in inputs]
        return torch.cat(padded_inputs, dim=1)

    def apply_ops(self, x, edge_idx):
        weights = F.softmax(self.alphas[edge_idx], dim=-1)
        return sum(w * op(x) for w, op in zip(weights, self.ops[edge_idx]))

    def forward(self, inputs):
        in0, in1 = inputs
        
        node0 = self.pad_and_concat([in0, in1])
        node0 = F.relu(node0)
        
        node1 = F.relu(node0)
        
        node2 = self.pad_and_concat([node0, node1])
        node2 = F.relu(node2)
        
        node3_inputs = self.pad_and_concat([node0, node1, node2])
        
        node3 = F.relu(self.conv1x1(node3_inputs))
        return node3

class DenoiseModule(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.cells = nn.ModuleList([DenoiseCell(C) for _ in range(10)])
        self.state_history = []

    def forward(self, x):
        outputs = [x, x]
        self.state_history = []
        for i in range(10):
            if i == 0:
                out = self.cells[i]([outputs[0], outputs[0]])
            elif i == 1:
                out = self.cells[i]([outputs[1], outputs[0]])
            else:
                out = self.cells[i]([outputs[-2], outputs[-1]])
            outputs.append(out)
            self.state_history.append({
                'outputs': list(outputs),
                'alpha': list(self.parameters())
            })
        return outputs[-1]

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(32, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 2, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return self.conv2(x)

class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.denoiser = DenoiseModule(32)
        self.decoder = Decoder()

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.denoiser(x)
        return self.decoder(x)

# ============ Evaluation Function ============
def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for real_in, imag_in, real_tar, imag_tar in loader:
            inputs = torch.cat([real_in.unsqueeze(1), imag_in.unsqueeze(1)], dim=1).to(device)
            targets = torch.cat([real_tar.unsqueeze(1), imag_tar.unsqueeze(1)], dim=1).to(device)
            
            preds = model(inputs)
            loss = criterion(preds, targets)
            total_loss += loss.item()
    return total_loss / len(loader)

# ============ Truncated RAD Implementation ============
def compute_truncated_grad(model, val_batch, criterion, truncation_steps=3):
    real_in, imag_in, real_tar, imag_tar = val_batch
    real_in = real_in.unsqueeze(1).to(device)
    imag_in = imag_in.unsqueeze(1).to(device)
    inputs = torch.cat([real_in, imag_in], dim=1)
    
    real_tar = real_tar.unsqueeze(1).to(device)
    imag_tar = imag_tar.unsqueeze(1).to(device)
    targets = torch.cat([real_tar, imag_tar], dim=1)

    # Forward pass through unrolled steps
    preds = model(inputs)
    rpred, ipred = preds.chunk(2, dim=1)
    rtar, itar = targets.chunk(2, dim=1)
    loss = criterion(rpred, rtar) + criterion(ipred, itar)

    # Get computation history
    history = model.denoiser.state_history
    T = len(history)
    M = min(truncation_steps, T)
    
    # Initialize gradients
    alpha_list = [p for n, p in model.named_parameters() if 'alphas' in n]
    alpha_indices = {p: idx for idx, p in enumerate(alpha_list)}
    grad_alpha = [torch.zeros_like(p) for p in alpha_list]
    
    # Initialize lambda with proper gradient handling
    if T == 0:
        return grad_alpha, loss.item()
    
    outputs = history[-1]['outputs']
    lambda_t = torch.autograd.grad(
        loss, outputs, 
        retain_graph=True, 
        allow_unused=True
    )
    
    # Replace None in lambda_t with zeros
    lambda_t = list(lambda_t)
    for i in range(len(lambda_t)):
        if lambda_t[i] is None:
            lambda_t[i] = torch.zeros_like(outputs[i])
    
    # Reverse through truncated steps
    for t in reversed(range(max(0, T-M), T)):
        state = history[t]
        current_outputs = state['outputs']
        current_alpha = state['alpha']

        # Compute gradients with allow_unused=True
        A = torch.autograd.grad(
            current_outputs, current_alpha, 
            grad_outputs=lambda_t, 
            retain_graph=True, 
            allow_unused=True
        )
        B = torch.autograd.grad(
            current_outputs, current_outputs, 
            grad_outputs=lambda_t, 
            retain_graph=True, 
            allow_unused=True
        )
        
        # Update gradients and lambda with None checks
        for g_a, a in zip(A, current_alpha):
            if g_a is not None and a in alpha_indices:
                grad_alpha[alpha_indices[a]] += g_a.detach()
        # Update lambda_t for next iteration
        lambda_t = [b.detach() if b is not None else None for b in B]
        # Replace None in lambda_t with zeros for next iteration
        for i in range(len(lambda_t)):
            if lambda_t[i] is None:
                lambda_t[i] = torch.zeros_like(current_outputs[i])

    return grad_alpha, loss.item()

def truncated_rad_step(model, train_batch, val_batch, w_optimizer, alpha_optimizer, 
                      criterion, truncation_steps=3):
    # Train on current batch
    real_in_t, imag_in_t, real_tar_t, imag_tar_t = train_batch
    real_in_t = real_in_t.unsqueeze(1).to(device)
    imag_in_t = imag_in_t.unsqueeze(1).to(device)
    inputs_t = torch.cat([real_in_t, imag_in_t], dim=1)
    
    w_optimizer.zero_grad()
    preds_t = model(inputs_t)
    rpred_t, ipred_t = preds_t.chunk(2, dim=1)
    rtar_t, itar_t = real_tar_t.unsqueeze(1).to(device), imag_tar_t.unsqueeze(1).to(device)
    loss_train = criterion(rpred_t, rtar_t) + criterion(ipred_t, itar_t)
    loss_train.backward()
    w_optimizer.step()

    # Compute truncated gradients for alpha
    grad_alpha, val_loss = compute_truncated_grad(model, val_batch, criterion, truncation_steps)
    
    # Update alpha parameters
    alpha_optimizer.zero_grad()
    for p, g in zip([p for n, p in model.named_parameters() if 'alphas' in n], grad_alpha):
        if g is not None:
            if p.grad is None:
                p.grad = g.to(device)
            else:
                p.grad += g.to(device)
    alpha_optimizer.step()

    return loss_train.item(), val_loss

# ============ Modified Training Function ============
def train_truncated_rad(model, train_loader, val_loader, test_loader, w_optimizer, 
                       alpha_optimizer, criterion, epochs=20, truncation_steps=3,
                       test_interval=50):
    from itertools import cycle
    val_iter = cycle(val_loader)
    
    # Initialize tracking variables
    train_losses = []
    val_losses = []
    test_losses = {'epochs': [], 'values': []}
    all_epochs = []
    start_time = time()
    
    # For learning rate tracking
    lr_history = {'w_lr': [], 'alpha_lr': []}
    
    # For gradient norm tracking
    grad_norms = {
        'weight': defaultdict(list),
        'alpha': defaultdict(list)
    }
    
    # For architecture visualization
    alpha_history = []
    
    for epoch in range(epochs):
        epoch_start = time()
        model.train()
        total_train_loss = 0.0
        total_val_loss = 0.0
        
        # Store current learning rates
        lr_history['w_lr'].append(w_optimizer.param_groups[0]['lr'])
        lr_history['alpha_lr'].append(alpha_optimizer.param_groups[0]['lr'])
        
        for train_batch in train_loader:
            val_batch = next(val_iter)
            train_loss, val_loss = truncated_rad_step(
                model, train_batch, val_batch,
                w_optimizer, alpha_optimizer,
                criterion, truncation_steps
            )
            total_train_loss += train_loss
            total_val_loss += val_loss
        
        # Record gradient norms
        with torch.no_grad():
            for name, param in model.named_parameters():
                if param.grad is not None:
                    norm = param.grad.norm().item()
                    if 'alphas' in name:
                        grad_norms['alpha'][name].append(norm)
                    else:
                        grad_norms['weight'][name].append(norm)
        
        # Store alpha values for visualization
        if (epoch+1) % 10 == 0:
            alphas = {}
            for name, param in model.named_parameters():
                if 'alphas' in name:
                    edge_num = name.split('.')[1]
                    alphas[edge_num] = F.softmax(param, dim=-1).detach().cpu().numpy()[0]
            alpha_history.append((epoch+1, alphas))
        
        # Calculate epoch metrics
        avg_train_loss = total_train_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(train_loader)
        
        # Store metrics with small offset to avoid log(0)
        train_losses.append(avg_train_loss + 1e-12)
        val_losses.append(avg_val_loss + 1e-12)
        all_epochs.append(epoch+1)
        
        # Periodic testing and visualization
        if (epoch+1) % test_interval == 0 or (epoch+1) == epochs:
            test_loss = evaluate(model, test_loader, criterion)
            test_losses['epochs'].append(epoch+1)
            test_losses['values'].append(test_loss + 1e-12)
            print(f"Test Loss @ Epoch {epoch+1}: {test_loss:.4f}")
            
            # Plot channel estimates
            plot_channel_estimates(model, test_loader, epoch+1)
        
        # Print and plot
        print(f"Epoch {epoch+1}/{epochs} | Time: {time()-epoch_start:.1f}s | "
              f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        
        if (epoch+1) % 10 == 0:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                plot_losses(all_epochs, train_losses, val_losses, test_losses)
                plot_learning_rates(lr_history)
                plot_gradient_norms(grad_norms)
                
                for epoch_num, alphas in alpha_history:
                    plot_architecture(alphas, epoch_num)
                alpha_history = []
    
    # Final plots
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        plot_losses(all_epochs, train_losses, val_losses, test_losses)
        plot_learning_rates(lr_history)
        plot_gradient_norms(grad_norms)
    print(f"Total training time: {(time()-start_time)/3600:.2f} hours")
    
    return train_losses, val_losses, test_losses

# ============ Main Execution ============
if __name__ == "__main__":
    model = FullModel().to(device)
    
    # Separate parameters
    alpha_params = [p for n, p in model.named_parameters() if 'alphas' in n]
    w_params = [p for n, p in model.named_parameters() if 'alphas' not in n]
    
    # Optimizers
    w_optim = optim.Adam(w_params, lr=0.01)
    alpha_optim = optim.Adam(alpha_params, lr=0.003)
    criterion = nn.MSELoss()

    print("Starting training with Truncated RAD...")
    train_losses, val_losses, test_losses = train_truncated_rad(
        model, train_loader, val_loader, test_loader,
        w_optim, alpha_optim, criterion,
        epochs=20, truncation_steps=3, test_interval=50
    )
    
    # Final evaluation
    test_loss = evaluate(model, test_loader, criterion)
    print(f"Final Test Loss: {test_loss:.4f}")
    
    # Save final model
    torch.save(model.state_dict(), "final_model.pth")
    print("Model saved as final_model.pth")

Using device: cuda
Failed to load yDL_10dB_40k_150pilots_ipjp.mat with h5py. Trying scipy.io.loadmat...
Loaded yDL_10dB_40k_150pilots_ipjp.mat using scipy.io.loadmat.
Loaded PsiDL_10dB_40k_150pilots_ipjp.mat using h5py.
Failed to load hDL_10dB_40k_150pilots_ipjp.mat with h5py. Trying scipy.io.loadmat...
Loaded hDL_10dB_40k_150pilots_ipjp.mat using scipy.io.loadmat.
Failed to load sigma2DL_10dB_40k_150pilots_ipjp.mat with h5py. Trying scipy.io.loadmat...
Loaded sigma2DL_10dB_40k_150pilots_ipjp.mat using scipy.io.loadmat.
Processing LS batch 1/40
Processing LS batch 2/40
Processing LS batch 3/40
Processing LS batch 4/40
Processing LS batch 5/40
Processing LS batch 6/40
Processing LS batch 7/40
Processing LS batch 8/40
Processing LS batch 9/40
Processing LS batch 10/40
Processing LS batch 11/40
Processing LS batch 12/40
Processing LS batch 13/40
Processing LS batch 14/40
Processing LS batch 15/40
Processing LS batch 16/40
Processing LS batch 17/40
Processing LS batch 18/40
Processing LS b

  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")


Epoch 20/20 | Time: 36.1s | Train Loss: 0.8956 | Val Loss: 0.8920
Total training time: 0.13 hours
Final Test Loss: 0.4505
Model saved as final_model.pth
