In [1]:
import torch
import torch.nn as nn
import numpy as np
import h5py
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader

# Define paths explicitly here
DATA_FILE = '/users/1/kuma0458/open_channel_ret180/data/filter_calc_input.mat'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Running on {DEVICE}")

Running on cuda


In [2]:
class SineLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))

class DualFilterSiren(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            SineLayer(3, 256, is_first=True, omega_0=30),
            SineLayer(256, 256, is_first=False, omega_0=30),
            SineLayer(256, 256, is_first=False, omega_0=30),
            nn.Linear(256, 2)
        )
    def forward(self, x): return self.net(x)

In [3]:
import h5py
import numpy as np
import torch
import os

# Define paths explicitly
DATA_DIR = '/users/1/kuma0458/open_channel_ret180/data'
DATA_FILENAME = 'filter_calc_input.mat'

def load_matlab_data():
    filepath = os.path.join(DATA_DIR, DATA_FILENAME)
    print(f"Attempting to load data from: {filepath}")
    
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"Could not find data file at: {filepath}")

    with h5py.File(filepath, 'r') as f:
        # Load and Transpose (Fixing MATLAB v7.3 dimension swap)
        if 'phin' not in f.keys():
             raise KeyError("Variable 'phin' not found.")
        
        phin = np.array(f['phin']).transpose()
        W    = np.array(f['W']).transpose()
        Kxn  = np.array(f['Kxn']).transpose()
        Kyn  = np.array(f['Kyn']).transpose()
        Zn   = np.array(f['Zn']).transpose()

    print(f"Original Data Shapes: {phin.shape}")

    # Flatten inputs
    coords_flat = np.stack([Kxn.flatten(), Kyn.flatten(), Zn.flatten()], axis=1)
    targets_flat = (phin.flatten() < 0).astype(np.longlong)
    weights_flat = W.flatten()

    # --- THE FIX: CLEAN UP NaNs and Infs ---
    # 1. Check for NaNs (0/0 errors from normalization)
    if np.isnan(coords_flat).any() or np.isnan(weights_flat).any():
        print("⚠️ Warning: NaNs detected! Replacing with 0.0 to prevent crash.")
        coords_flat = np.nan_to_num(coords_flat, nan=0.0, posinf=0.0, neginf=0.0)
        weights_flat = np.nan_to_num(weights_flat, nan=0.0, posinf=0.0, neginf=0.0)
        
    # 2. Clip weights to be safe (strictly between 0 and 1)
    weights_flat = np.clip(weights_flat, 0.0, 1.0)
    # ----------------------------------------
    
    return (torch.tensor(coords_flat, dtype=torch.float32),
            torch.tensor(targets_flat, dtype=torch.long),
            torch.tensor(weights_flat, dtype=torch.float32))

In [None]:
# Load Data
# ... (Paste load_matlab_data logic here) ...
coords, targets, weights = load_matlab_data() # Ensure this function is defined

# Create DataLoader
# CRITICAL FOR JUPYTER: num_workers=0
dataset = TensorDataset(coords, targets, weights)
dataloader = DataLoader(dataset, batch_size=32768, shuffle=True, num_workers=0)

# Init Model
model = DualFilterSiren().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training Loop with Progress Plotting
loss_history = []

print("Starting training...")
for epoch in range(2000): # Adjust epochs as needed
    epoch_loss = 0
    for batch_coords, batch_targets, batch_weights in dataloader:
        batch_coords, batch_targets, batch_weights = \
            batch_coords.to(DEVICE), batch_targets.to(DEVICE), batch_weights.to(DEVICE)

        optimizer.zero_grad()
        logits = model(batch_coords)
        
        # Weighted Loss
        raw_loss = torch.nn.functional.cross_entropy(logits, batch_targets, reduction='none')
        loss = (raw_loss * batch_weights).mean()
        
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
    # Log Average Loss
    avg = epoch_loss / len(dataloader)
    loss_history.append(avg)
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch} | Loss: {avg:.6f}")

# Plot Loss Curve inline
plt.figure()
plt.plot(loss_history)
plt.yscale('log')
plt.title('Training Convergence')
plt.show()

Attempting to load data from: /users/1/kuma0458/open_channel_ret180/data/filter_calc_input.mat
Original Data Shapes: (192, 255, 319)
Starting training...
Epoch 0 | Loss: 0.004173
