Removed skip conenction as an option in the operations

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import h5py
from scipy.io import loadmat
import matplotlib.pyplot as plt
import os
from time import time
from collections import defaultdict
import warnings

# ============ Visualization Setup ============
plt.ioff()
plt.rcParams['figure.constrained_layout.use'] = True  # Enable constrained layout
os.makedirs("progress_plots", exist_ok=True)
os.makedirs("architecture_plots", exist_ok=True)
os.makedirs("channel_estimates", exist_ok=True)

def plot_losses(epochs, train_losses, val_losses, test_losses=None):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), constrained_layout=True)
        
        # Training and Validation Loss
        ax1.plot(epochs, train_losses, label='Train Loss')
        ax1.plot(epochs, val_losses, label='Val Loss')
        ax1.set_title('Training & Validation Loss', fontsize=12)
        ax1.set_xlabel('Epochs', fontsize=10)
        ax1.set_ylabel('MSE Loss', fontsize=10)
        if all(y > 0 for y in train_losses + val_losses):
            ax1.set_yscale('log')
        ax1.legend()
        ax1.grid(True)

        # Test Loss if available
        if test_losses:
            ax2.plot(test_losses['epochs'], test_losses['values'], 'r-')
            ax2.set_title('Test Loss Progression', fontsize=12)
            ax2.set_xlabel('Epochs', fontsize=10)
            ax2.set_ylabel('MSE Loss', fontsize=10)
            if all(y > 0 for y in test_losses['values']):
                ax2.set_yscale('log')
            ax2.grid(True)

        plt.savefig(f"progress_plots/losses_{int(time())}.png")
        plt.close()

def plot_architecture(alphas, epoch):
    """Plot the evolution of alpha parameters"""
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        fig = plt.figure(figsize=(18, 16), constrained_layout=True)
        keys = sorted(alphas.keys())
        
        for i, (edge, alpha) in enumerate(alphas.items()):
            ax = fig.add_subplot(4, 2, i+1)  # 4 rows, 2 columns
            ax.bar(range(len(alpha)), alpha, width=0.6)
            ax.set_title(f'Edge {edge} Alpha Values', fontsize=10)
            ax.set_xlabel('Operation', fontsize=8)
            ax.set_ylabel('Weight', fontsize=8)
            ax.set_xticks(range(len(alpha)))
            ax.set_xticklabels(list(OPS.keys()), rotation=60, ha='right', fontsize=7)
            ax.tick_params(axis='y', labelsize=7)
            
        plt.savefig(f"architecture_plots/alpha_epoch_{epoch}.png")
        plt.close()

def plot_channel_estimates(model, test_loader, epoch, num_examples=3):
    """Plot example channel estimates"""
    model.eval()
    with torch.no_grad():
        for i, (real_in, imag_in, real_tar, imag_tar) in enumerate(test_loader):
            if i >= num_examples:
                break
            inputs = torch.cat([real_in.unsqueeze(1), imag_in.unsqueeze(1)], dim=1).to(device)
            preds = model(inputs)
            rpred, ipred = preds.chunk(2, dim=1)
            
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5), constrained_layout=True)
            
            ax1.plot(real_tar[0].numpy(), label='True Real')
            ax1.plot(rpred[0].cpu().numpy(), label='Predicted Real')
            ax1.set_title('Real Component', fontsize=10)
            ax1.legend()
            
            ax2.plot(imag_tar[0].numpy(), label='True Imag')
            ax2.plot(ipred[0].cpu().numpy(), label='Predicted Imag')
            ax2.set_title('Imaginary Component', fontsize=10)
            ax2.legend()
            
            plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
            plt.close()

def plot_learning_rates(lr_history):
    """Plot learning rate evolution"""
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5), constrained_layout=True)
        
        ax1.plot(lr_history['w_lr'], label='Weight LR')
        ax1.set_title('Weight Learning Rate', fontsize=10)
        ax1.set_xlabel('Epoch', fontsize=8)
        ax1.set_ylabel('Learning Rate', fontsize=8)
        
        ax2.plot(lr_history['alpha_lr'], label='Alpha LR')
        ax2.set_title('Alpha Learning Rate', fontsize=10)
        ax2.set_xlabel('Epoch', fontsize=8)
        ax2.set_ylabel('Learning Rate', fontsize=8)
        
        plt.savefig(f"progress_plots/learning_rates_{int(time())}.png")
        plt.close()

def plot_gradient_norms(grad_norms):
    """Plot gradient norms over time"""
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5), constrained_layout=True)
        
        for name, norms in grad_norms['weight'].items():
            ax1.plot(norms, label=name)
        ax1.set_title('Weight Gradient Norms', fontsize=10)
        ax1.set_xlabel('Epoch', fontsize=8)
        ax1.set_ylabel('Gradient Norm', fontsize=8)
        if any(n > 0 for n in norms):
            ax1.set_yscale('log')
        ax1.legend(fontsize=7)
        
        for name, norms in grad_norms['alpha'].items():
            ax2.plot(norms, label=name)
        ax2.set_title('Alpha Gradient Norms', fontsize=10)
        ax2.set_xlabel('Epoch', fontsize=8)
        ax2.set_ylabel('Gradient Norm', fontsize=8)
        if any(n > 0 for n in norms):
            ax2.set_yscale('log')
        ax2.legend(fontsize=7)
        
        plt.savefig(f"progress_plots/gradient_norms_{int(time())}.png")
        plt.close()

# ============ GPU Setup ============
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============ Load .mat Files with Fallback ============
def load_mat_file(file_path, var_name):
    try:
        with h5py.File(file_path, 'r') as f:
            data = f[var_name][:]
        print(f"Loaded {file_path} using h5py.")
    except OSError:
        print(f"Failed to load {file_path} with h5py. Trying scipy.io.loadmat...")
        data = loadmat(file_path)[var_name]
        print(f"Loaded {file_path} using scipy.io.loadmat.")
    return data

# ============ Load Data ============
def load_data():
    data_y = load_mat_file('../yDL_10dB_40k_150pilots_ipjp.mat', 'yDL')
    data_psi = load_mat_file('../PsiDL_10dB_40k_150pilots_ipjp.mat', 'PsiDL')
    data_h = load_mat_file('../hDL_10dB_40k_150pilots_ipjp.mat', 'hDL')
    data_sigma2 = load_mat_file('../sigma2DL_10dB_40k_150pilots_ipjp.mat', 'sigma2DL')

    y_complex_np = data_y[..., 0] + 1j * data_y[..., 1]
    h_complex_np = data_h[..., 0] + 1j * data_h[..., 1]
    psi_real_np = data_psi[0].transpose(2, 1, 0)
    psi_imag_np = data_psi[1].transpose(2, 1, 0)
    
    return y_complex_np, psi_real_np, psi_imag_np, h_complex_np

y_complex_np, psi_real_np, psi_imag_np, h_complex_np = load_data()

# ============ Compute LS Estimate in Batches ============
num_samples = 40000
batch_size = 1000
h_LS_complex = torch.empty((num_samples, 512), dtype=torch.cfloat)

for start in range(0, num_samples, batch_size):
    end = min(start + batch_size, num_samples)
    print(f"Processing LS batch {start//batch_size + 1}/{(num_samples//batch_size)}")
    
    psi_real_batch = psi_real_np[start:end]
    psi_imag_batch = psi_imag_np[start:end]
    y_batch = y_complex_np[start:end]
    
    psi_batch = torch.tensor(psi_real_batch + 1j*psi_imag_batch, 
                           dtype=torch.cfloat).to(device)
    y_batch = torch.tensor(y_batch, dtype=torch.cfloat).to(device)

    psi_H = torch.conj(psi_batch.transpose(-2, -1))
    A = torch.matmul(psi_H, psi_batch)
    B = torch.matmul(psi_H, y_batch.unsqueeze(-1))
    h_LS_batch = torch.linalg.solve(A, B).squeeze(-1)
    
    h_LS_complex[start:end] = h_LS_batch.cpu()

h_LS_real, h_LS_imag = h_LS_complex.real, h_LS_complex.imag
h_real = torch.tensor(h_complex_np.real, dtype=torch.float32)
h_imag = torch.tensor(h_complex_np.imag, dtype=torch.float32)

# ============ Dataset Split ============
train_samples, val_samples, test_samples = 32000, 4000, 4000

train_dataset = TensorDataset(
    h_LS_real[:train_samples], h_LS_imag[:train_samples],
    h_real[:train_samples], h_imag[:train_samples]
)

val_dataset = TensorDataset(
    h_LS_real[train_samples:train_samples+val_samples],
    h_LS_imag[train_samples:train_samples+val_samples],
    h_real[train_samples:train_samples+val_samples],
    h_imag[train_samples:train_samples+val_samples]
)

test_dataset = TensorDataset(
    h_LS_real[train_samples+val_samples:],
    h_LS_imag[train_samples+val_samples:],
    h_real[train_samples+val_samples:],
    h_imag[train_samples+val_samples:]
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# ============ Neural Architecture Components ============
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(2, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 32, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

OPS = {
    'conv_3x3': lambda C: nn.Conv1d(C, C, 3, padding=1, bias=False),
    'conv_5x5': lambda C: nn.Conv1d(C, C, 5, padding=2, bias=False),
    'identity': lambda C: nn.Identity(),
    'zero': lambda C: nn.ZeroPad1d(0)
}

class DenoiseCell(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.C = C
        self.num_edges = 8
        self.num_ops = len(OPS)  # Now automatically 4 instead of 5
        self.alphas = nn.Parameter(torch.randn(self.num_edges, self.num_ops))
        self.ops = nn.ModuleList([nn.ModuleList([op(C) for op in OPS.values()]) 
                                   for _ in range(self.num_edges)])
        self.conv1x1 = nn.Conv1d(8*C, C, 1, bias=False)

    def pad_and_concat(self, inputs):
        max_size = max(inp.shape[2] for inp in inputs)
        padded_inputs = [F.pad(inp, (0, max_size - inp.shape[2])) if inp.shape[2] < max_size else inp
                         for inp in inputs]
        return torch.cat(padded_inputs, dim=1)

    def apply_ops(self, x, edge_idx):
        weights = F.softmax(self.alphas[edge_idx], dim=-1)
        return sum(w * op(x) for w, op in zip(weights, self.ops[edge_idx]))

    def forward(self, inputs):
        in0, in1 = inputs
        
        node0 = self.pad_and_concat([in0, in1])
        node0 = F.relu(node0)
        
        node1 = F.relu(node0)
        
        node2 = self.pad_and_concat([node0, node1])
        node2 = F.relu(node2)
        
        node3_inputs = self.pad_and_concat([node0, node1, node2])
        
        node3 = F.relu(self.conv1x1(node3_inputs))
        return node3

class DenoiseModule(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.cells = nn.ModuleList([DenoiseCell(C) for _ in range(10)])
        self.state_history = []

    def forward(self, x):
        outputs = [x, x]
        self.state_history = []
        for i in range(10):
            if i == 0:
                out = self.cells[i]([outputs[0], outputs[0]])
            elif i == 1:
                out = self.cells[i]([outputs[1], outputs[0]])
            else:
                out = self.cells[i]([outputs[-2], outputs[-1]])
            outputs.append(out)
            self.state_history.append({
                'outputs': list(outputs),
                'alpha': list(self.parameters())
            })
        return outputs[-1]

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(32, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 2, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return self.conv2(x)

class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.denoiser = DenoiseModule(32)
        self.decoder = Decoder()

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.denoiser(x)
        return self.decoder(x)

# ============ Evaluation Function ============
def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for real_in, imag_in, real_tar, imag_tar in loader:
            inputs = torch.cat([real_in.unsqueeze(1), imag_in.unsqueeze(1)], dim=1).to(device)
            targets = torch.cat([real_tar.unsqueeze(1), imag_tar.unsqueeze(1)], dim=1).to(device)
            
            preds = model(inputs)
            loss = criterion(preds, targets)
            total_loss += loss.item()
    return total_loss / len(loader)

# ============ Truncated RAD Implementation ============
def compute_truncated_grad(model, val_batch, criterion, truncation_steps=3):
    real_in, imag_in, real_tar, imag_tar = val_batch
    real_in = real_in.unsqueeze(1).to(device)
    imag_in = imag_in.unsqueeze(1).to(device)
    inputs = torch.cat([real_in, imag_in], dim=1)
    
    real_tar = real_tar.unsqueeze(1).to(device)
    imag_tar = imag_tar.unsqueeze(1).to(device)
    targets = torch.cat([real_tar, imag_tar], dim=1)

    # Forward pass through unrolled steps
    preds = model(inputs)
    rpred, ipred = preds.chunk(2, dim=1)
    rtar, itar = targets.chunk(2, dim=1)
    loss = criterion(rpred, rtar) + criterion(ipred, itar)

    # Get computation history
    history = model.denoiser.state_history
    T = len(history)
    M = min(truncation_steps, T)
    
    # Initialize gradients
    alpha_list = [p for n, p in model.named_parameters() if 'alphas' in n]
    alpha_indices = {p: idx for idx, p in enumerate(alpha_list)}
    grad_alpha = [torch.zeros_like(p) for p in alpha_list]
    
    # Initialize lambda with proper gradient handling
    if T == 0:
        return grad_alpha, loss.item()
    
    outputs = history[-1]['outputs']
    lambda_t = torch.autograd.grad(
        loss, outputs, 
        retain_graph=True, 
        allow_unused=True
    )
    
    # Replace None in lambda_t with zeros
    lambda_t = list(lambda_t)
    for i in range(len(lambda_t)):
        if lambda_t[i] is None:
            lambda_t[i] = torch.zeros_like(outputs[i])
    
    # Reverse through truncated steps
    for t in reversed(range(max(0, T-M), T)):
        state = history[t]
        current_outputs = state['outputs']
        current_alpha = state['alpha']

        # Compute gradients with allow_unused=True
        A = torch.autograd.grad(
            current_outputs, current_alpha, 
            grad_outputs=lambda_t, 
            retain_graph=True, 
            allow_unused=True
        )
        B = torch.autograd.grad(
            current_outputs, current_outputs, 
            grad_outputs=lambda_t, 
            retain_graph=True, 
            allow_unused=True
        )
        
        # Update gradients and lambda with None checks
        for g_a, a in zip(A, current_alpha):
            if g_a is not None and a in alpha_indices:
                grad_alpha[alpha_indices[a]] += g_a.detach()
        # Update lambda_t for next iteration
        lambda_t = [b.detach() if b is not None else None for b in B]
        # Replace None in lambda_t with zeros for next iteration
        for i in range(len(lambda_t)):
            if lambda_t[i] is None:
                lambda_t[i] = torch.zeros_like(current_outputs[i])

    return grad_alpha, loss.item()

def truncated_rad_step(model, train_batch, val_batch, w_optimizer, alpha_optimizer, 
                      criterion, truncation_steps=3):
    # Train on current batch
    real_in_t, imag_in_t, real_tar_t, imag_tar_t = train_batch
    real_in_t = real_in_t.unsqueeze(1).to(device)
    imag_in_t = imag_in_t.unsqueeze(1).to(device)
    inputs_t = torch.cat([real_in_t, imag_in_t], dim=1)
    
    w_optimizer.zero_grad()
    preds_t = model(inputs_t)
    rpred_t, ipred_t = preds_t.chunk(2, dim=1)
    rtar_t, itar_t = real_tar_t.unsqueeze(1).to(device), imag_tar_t.unsqueeze(1).to(device)
    loss_train = criterion(rpred_t, rtar_t) + criterion(ipred_t, itar_t)
    loss_train.backward()
    w_optimizer.step()

    # Compute truncated gradients for alpha
    grad_alpha, val_loss = compute_truncated_grad(model, val_batch, criterion, truncation_steps)
    
    # Update alpha parameters
    alpha_optimizer.zero_grad()
    for p, g in zip([p for n, p in model.named_parameters() if 'alphas' in n], grad_alpha):
        if g is not None:
            if p.grad is None:
                p.grad = g.to(device)
            else:
                p.grad += g.to(device)
    alpha_optimizer.step()

    return loss_train.item(), val_loss

# ============ Modified Training Function ============
def train_truncated_rad(model, train_loader, val_loader, test_loader, w_optimizer, 
                       alpha_optimizer, criterion, epochs=20, truncation_steps=3,
                       test_interval=50):
    from itertools import cycle
    val_iter = cycle(val_loader)
    
    # Initialize tracking variables
    train_losses = []
    val_losses = []
    test_losses = {'epochs': [], 'values': []}
    all_epochs = []
    start_time = time()
    
    # For learning rate tracking
    lr_history = {'w_lr': [], 'alpha_lr': []}
    
    # For gradient norm tracking
    grad_norms = {
        'weight': defaultdict(list),
        'alpha': defaultdict(list)
    }
    
    # For architecture visualization
    alpha_history = []
    
    for epoch in range(epochs):
        epoch_start = time()
        model.train()
        total_train_loss = 0.0
        total_val_loss = 0.0
        
        # Store current learning rates
        lr_history['w_lr'].append(w_optimizer.param_groups[0]['lr'])
        lr_history['alpha_lr'].append(alpha_optimizer.param_groups[0]['lr'])
        
        for train_batch in train_loader:
            val_batch = next(val_iter)
            train_loss, val_loss = truncated_rad_step(
                model, train_batch, val_batch,
                w_optimizer, alpha_optimizer,
                criterion, truncation_steps
            )
            total_train_loss += train_loss
            total_val_loss += val_loss
        
        # Record gradient norms
        with torch.no_grad():
            for name, param in model.named_parameters():
                if param.grad is not None:
                    norm = param.grad.norm().item()
                    if 'alphas' in name:
                        grad_norms['alpha'][name].append(norm)
                    else:
                        grad_norms['weight'][name].append(norm)
        
        # Store alpha values for visualization
        if (epoch+1) % 10 == 0:
            alphas = {}
            for name, param in model.named_parameters():
                if 'alphas' in name:
                    edge_num = name.split('.')[1]
                    alphas[edge_num] = F.softmax(param, dim=-1).detach().cpu().numpy()[0]
            alpha_history.append((epoch+1, alphas))
        
        # Calculate epoch metrics
        avg_train_loss = total_train_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(train_loader)
        
        # Periodic testing and visualization
        if (epoch+1) % test_interval == 0 or (epoch+1) == epochs:
            test_loss = evaluate(model, test_loader, criterion)
            test_losses['epochs'].append(epoch+1)
            test_losses['values'].append(test_loss + 1e-12)
            print(f"Test Loss @ Epoch {epoch+1}: {test_loss:.4f}")
            
            # Plot channel estimates
            plot_channel_estimates(model, test_loader, epoch+1)
        
        # Print and plot
        print(f"Epoch {epoch+1}/{epochs} | Time: {time()-epoch_start:.1f}s | "
              f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        
        if (epoch+1) % 10 == 0:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                plot_losses(all_epochs, train_losses, val_losses, test_losses)
                plot_learning_rates(lr_history)
                plot_gradient_norms(grad_norms)
                
                for epoch_num, alphas in alpha_history:
                    plot_architecture(alphas, epoch_num)
                alpha_history = []
    
    # Final plots
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        plot_losses(all_epochs, train_losses, val_losses, test_losses)
        plot_learning_rates(lr_history)
        plot_gradient_norms(grad_norms)
    print(f"Total training time: {(time()-start_time)/3600:.2f} hours")
    
    return train_losses, val_losses, test_losses

# ============ Main Execution ============
if __name__ == "__main__":
    model = FullModel().to(device)
    
    # Separate parameters
    alpha_params = [p for n, p in model.named_parameters() if 'alphas' in n]
    w_params = [p for n, p in model.named_parameters() if 'alphas' not in n]
    
    # Optimizers
    w_optim = optim.Adam(w_params, lr=0.01)
    alpha_optim = optim.Adam(alpha_params, lr=0.003)
    criterion = nn.MSELoss()

    print("Starting training with Truncated RAD...")
    train_losses, val_losses, test_losses = train_truncated_rad(
        model, train_loader, val_loader, test_loader,
        w_optim, alpha_optim, criterion,
        epochs=2000, truncation_steps=50, test_interval=50
    )
    
    # Final evaluation
    test_loss = evaluate(model, test_loader, criterion)
    print(f"Final Test Loss: {test_loss:.4f}")
    
    # Save final model
    torch.save(model.state_dict(), "final_model.pth")
    print("Model saved as final_model.pth")

Using device: cuda
Failed to load ../yDL_10dB_40k_150pilots_ipjp.mat with h5py. Trying scipy.io.loadmat...
Loaded ../yDL_10dB_40k_150pilots_ipjp.mat using scipy.io.loadmat.
Loaded ../PsiDL_10dB_40k_150pilots_ipjp.mat using h5py.
Failed to load ../hDL_10dB_40k_150pilots_ipjp.mat with h5py. Trying scipy.io.loadmat...
Loaded ../hDL_10dB_40k_150pilots_ipjp.mat using scipy.io.loadmat.
Failed to load ../sigma2DL_10dB_40k_150pilots_ipjp.mat with h5py. Trying scipy.io.loadmat...
Loaded ../sigma2DL_10dB_40k_150pilots_ipjp.mat using scipy.io.loadmat.
Processing LS batch 1/40
Processing LS batch 2/40
Processing LS batch 3/40
Processing LS batch 4/40
Processing LS batch 5/40
Processing LS batch 6/40
Processing LS batch 7/40
Processing LS batch 8/40
Processing LS batch 9/40
Processing LS batch 10/40
Processing LS batch 11/40
Processing LS batch 12/40
Processing LS batch 13/40
Processing LS batch 14/40
Processing LS batch 15/40
Processing LS batch 16/40
Processing LS batch 17/40
Processing LS batch 

  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")


Epoch 50/2000 | Time: 41.2s | Train Loss: 0.8956 | Val Loss: 0.8919
Epoch 51/2000 | Time: 28.1s | Train Loss: 0.8956 | Val Loss: 0.8926
Epoch 52/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 53/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 54/2000 | Time: 28.1s | Train Loss: 0.8956 | Val Loss: 0.8925
Epoch 55/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8920
Epoch 56/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8919
Epoch 57/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 58/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 59/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 60/2000 | Time: 28.0s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 61/2000 | Time: 27.8s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 62/2000 | Time: 28.0s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 63/2000 | Time: 27.9s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 64/2000 | Time: 28.3s | Train Loss: 0.8956

  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")


Epoch 100/2000 | Time: 41.9s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 101/2000 | Time: 28.4s | Train Loss: 0.8956 | Val Loss: 0.8925
Epoch 102/2000 | Time: 28.5s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 103/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 104/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 105/2000 | Time: 28.0s | Train Loss: 0.8956 | Val Loss: 0.8925
Epoch 106/2000 | Time: 28.0s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 107/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 108/2000 | Time: 28.4s | Train Loss: 0.8956 | Val Loss: 0.8920
Epoch 109/2000 | Time: 28.4s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 110/2000 | Time: 28.4s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 111/2000 | Time: 28.4s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 112/2000 | Time: 28.5s | Train Loss: 0.8956 | Val Loss: 0.8926
Epoch 113/2000 | Time: 28.6s | Train Loss: 0.8956 | Val Loss: 0.8919
Epoch 114/2000 | Time: 28.5s | Tra

  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")


Epoch 150/2000 | Time: 41.7s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 151/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 152/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8926
Epoch 153/2000 | Time: 28.6s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 154/2000 | Time: 28.6s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 155/2000 | Time: 28.8s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 156/2000 | Time: 28.7s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 157/2000 | Time: 28.4s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 158/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 159/2000 | Time: 28.4s | Train Loss: 0.8956 | Val Loss: 0.8926
Epoch 160/2000 | Time: 28.1s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 161/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 162/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 163/2000 | Time: 28.8s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 164/2000 | Time: 28.3s | Tra

  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")


Epoch 200/2000 | Time: 42.0s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 201/2000 | Time: 28.4s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 202/2000 | Time: 28.7s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 203/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 204/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 205/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 206/2000 | Time: 28.1s | Train Loss: 0.8956 | Val Loss: 0.8925
Epoch 207/2000 | Time: 28.9s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 208/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8926
Epoch 209/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8920
Epoch 210/2000 | Time: 28.7s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 211/2000 | Time: 28.6s | Train Loss: 0.8956 | Val Loss: 0.8927
Epoch 212/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 213/2000 | Time: 28.6s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 214/2000 | Time: 28.5s | Tra

  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")


Epoch 250/2000 | Time: 42.0s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 251/2000 | Time: 28.0s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 252/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 253/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8925
Epoch 254/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 255/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 256/2000 | Time: 29.1s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 257/2000 | Time: 28.8s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 258/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 259/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 260/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 261/2000 | Time: 28.1s | Train Loss: 0.8956 | Val Loss: 0.8920
Epoch 262/2000 | Time: 28.0s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 263/2000 | Time: 28.1s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 264/2000 | Time: 28.3s | Tra

  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")


Epoch 300/2000 | Time: 41.7s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 301/2000 | Time: 28.1s | Train Loss: 0.8956 | Val Loss: 0.8926
Epoch 302/2000 | Time: 28.7s | Train Loss: 0.8956 | Val Loss: 0.8919
Epoch 303/2000 | Time: 28.6s | Train Loss: 0.8956 | Val Loss: 0.8926
Epoch 304/2000 | Time: 28.9s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 305/2000 | Time: 28.5s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 306/2000 | Time: 27.7s | Train Loss: 0.8956 | Val Loss: 0.8925
Epoch 307/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8920
Epoch 308/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8919
Epoch 309/2000 | Time: 28.1s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 310/2000 | Time: 28.4s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 311/2000 | Time: 28.0s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 312/2000 | Time: 28.4s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 313/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 314/2000 | Time: 28.3s | Tra

  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")


Epoch 350/2000 | Time: 41.2s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 351/2000 | Time: 28.1s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 352/2000 | Time: 28.7s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 353/2000 | Time: 28.0s | Train Loss: 0.8956 | Val Loss: 0.8925
Epoch 354/2000 | Time: 27.5s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 355/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 356/2000 | Time: 27.7s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 357/2000 | Time: 28.1s | Train Loss: 0.8956 | Val Loss: 0.8925
Epoch 358/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 359/2000 | Time: 28.9s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 360/2000 | Time: 27.8s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 361/2000 | Time: 28.4s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 362/2000 | Time: 27.8s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 363/2000 | Time: 28.1s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 364/2000 | Time: 28.1s | Tra

  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")


Epoch 400/2000 | Time: 40.9s | Train Loss: 0.8956 | Val Loss: 0.8927
Epoch 401/2000 | Time: 27.5s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 402/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 403/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 404/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8926
Epoch 405/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 406/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 407/2000 | Time: 27.7s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 408/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 409/2000 | Time: 27.5s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 410/2000 | Time: 27.5s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 411/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8926
Epoch 412/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 413/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 414/2000 | Time: 27.6s | Tra

  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")


Epoch 450/2000 | Time: 41.4s | Train Loss: 0.8956 | Val Loss: 0.8920
Epoch 451/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 452/2000 | Time: 27.7s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 453/2000 | Time: 27.9s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 454/2000 | Time: 27.8s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 455/2000 | Time: 27.9s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 456/2000 | Time: 27.8s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 457/2000 | Time: 28.0s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 458/2000 | Time: 27.8s | Train Loss: 0.8956 | Val Loss: 0.8925
Epoch 459/2000 | Time: 27.7s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 460/2000 | Time: 27.8s | Train Loss: 0.8956 | Val Loss: 0.8926
Epoch 461/2000 | Time: 27.7s | Train Loss: 0.8956 | Val Loss: 0.8920
Epoch 462/2000 | Time: 27.9s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 463/2000 | Time: 27.5s | Train Loss: 0.8956 | Val Loss: 0.8927
Epoch 464/2000 | Time: 27.6s | Tra

  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")


Epoch 500/2000 | Time: 41.3s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 501/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 502/2000 | Time: 27.7s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 503/2000 | Time: 27.5s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 504/2000 | Time: 27.5s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 505/2000 | Time: 27.7s | Train Loss: 0.8956 | Val Loss: 0.8925
Epoch 506/2000 | Time: 27.4s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 507/2000 | Time: 27.7s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 508/2000 | Time: 27.8s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 509/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 510/2000 | Time: 27.4s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 511/2000 | Time: 27.7s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 512/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 513/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8920
Epoch 514/2000 | Time: 27.8s | Tra

  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")


Epoch 550/2000 | Time: 41.2s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 551/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 552/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 553/2000 | Time: 27.4s | Train Loss: 0.8956 | Val Loss: 0.8926
Epoch 554/2000 | Time: 28.6s | Train Loss: 0.8956 | Val Loss: 0.8919
Epoch 555/2000 | Time: 28.2s | Train Loss: 0.8956 | Val Loss: 0.8926
Epoch 556/2000 | Time: 28.4s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 557/2000 | Time: 27.8s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 558/2000 | Time: 27.7s | Train Loss: 0.8956 | Val Loss: 0.8925
Epoch 559/2000 | Time: 27.8s | Train Loss: 0.8956 | Val Loss: 0.8920
Epoch 560/2000 | Time: 27.6s | Train Loss: 0.8956 | Val Loss: 0.8919
Epoch 561/2000 | Time: 27.7s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 562/2000 | Time: 27.7s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 563/2000 | Time: 27.8s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 564/2000 | Time: 28.3s | Tra

  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")
  plt.savefig(f"channel_estimates/epoch_{epoch}_example_{i}.png")


Epoch 600/2000 | Time: 41.5s | Train Loss: 0.8956 | Val Loss: 0.8926
Epoch 601/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 602/2000 | Time: 28.5s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 603/2000 | Time: 28.3s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 604/2000 | Time: 28.4s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 605/2000 | Time: 28.6s | Train Loss: 0.8956 | Val Loss: 0.8925
Epoch 606/2000 | Time: 28.6s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 607/2000 | Time: 28.5s | Train Loss: 0.8956 | Val Loss: 0.8921
Epoch 608/2000 | Time: 28.5s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 609/2000 | Time: 28.6s | Train Loss: 0.8956 | Val Loss: 0.8925
Epoch 610/2000 | Time: 28.5s | Train Loss: 0.8956 | Val Loss: 0.8922
Epoch 611/2000 | Time: 28.6s | Train Loss: 0.8956 | Val Loss: 0.8923
Epoch 612/2000 | Time: 29.8s | Train Loss: 0.8956 | Val Loss: 0.8920
Epoch 613/2000 | Time: 29.5s | Train Loss: 0.8956 | Val Loss: 0.8924
Epoch 614/2000 | Time: 28.8s | Tra