# Test: Train CNN to find material params for a single number of layers

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.optim as optim
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
# Load the dataset
data = torch.load("data/Synthetic_data_100k_1to5.pt", weights_only=False)

# Extract components
synthetic_data = data["synthetic_data"]
material_params = data["material_params"]
num_layers = data["num_layers"]

# Create a mask for entries with exactly 3 layers
mask = num_layers == 3

# Apply the mask to filter the data
synthetic_data_3layers = synthetic_data[mask]
material_params_3layers = [params for i, params in enumerate(material_params) if mask[i]]
num_layers_3layers = num_layers[mask]

# Confirm the filtering
print("Filtered dataset for 3-layer samples:")
print(f"Number of samples: {len(synthetic_data_3layers)}")
print(f"Shape of synthetic_data: {synthetic_data_3layers.shape}")
print(f"Shape of num_layers: {num_layers_3layers.shape}")
print(f"Example number of layers: {num_layers_3layers[:10]}")



Filtered dataset for 3-layer samples:
Number of samples: 19993
Shape of synthetic_data: torch.Size([19993, 1024])
Shape of num_layers: torch.Size([19993])
Example number of layers: tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3])


In [3]:
def convert_to_nkd(params_per_sample):
    """
    Convert a list of layers [(n+ik, d), ...] into a flat tensor [n1, k1, d1, ..., n3, k3, d3]
    """
    nkd = []
    for complex_n, d in params_per_sample:
        n = complex_n.real
        k = complex_n.imag
        nkd.extend([n, k, d])
    return torch.tensor(nkd, dtype=torch.float32)

# Convert the list to a tensor of shape [num_samples, 9]
targets_nkd = torch.stack([convert_to_nkd(params) for params in material_params_3layers])

print(material_params_3layers[0])
print(targets_nkd[0])

[((4.474795207839568-0.013104228373755605j), 0.0006001358074770559), ((1.4531826326434016-0.009297615570658752j), 0.0006093181017557209), ((5.400491992529762-0.041721082325605646j), 0.0007296340700163872)]
tensor([ 4.4748e+00, -1.3104e-02,  6.0014e-04,  1.4532e+00, -9.2976e-03,
         6.0932e-04,  5.4005e+00, -4.1721e-02,  7.2963e-04])


In [4]:
class ConditionedCNN(nn.Module):
    def __init__(self, input_length=1024, output_dim=9):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(64, 128, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
        )
        self.fc = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        x = self.cnn(x).squeeze(-1)  # shape: [batch, 128]
        return self.fc(x)
    

In [5]:
def train_conditioned_model(model, train_loader, val_loader, num_epochs=10, lr=1e-3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    train_loss_hist = []
    val_loss_hist = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        # Training loop with tqdm progress bar
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as pbar:
            for inputs, targets in pbar:
                inputs = inputs.to(device)
                targets = targets.to(device)

                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                pbar.set_postfix(loss=running_loss / (pbar.n + 1))  # Running average

        train_loss = running_loss / len(train_loader)
        train_loss_hist.append(train_loss)

        # Validation loop
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = model(inputs)
                val_loss += criterion(outputs, targets).item()
        val_loss /= len(val_loader)
        val_loss_hist.append(val_loss)

        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f} | Val Loss = {val_loss:.4f}")

    return train_loss_hist, val_loss_hist


In [6]:
# Prepare DataLoader for this structure:
from torch.utils.data import TensorDataset, DataLoader, random_split
# Assuming synthetic_data_3layers has shape [N, 1024]
# Add channel dimension for Conv1D
inputs = synthetic_data_3layers.unsqueeze(1)  # → [N, 1, 1024]
dataset = TensorDataset(inputs, targets_nkd)
train_ds, val_ds = random_split(dataset, [int(0.8*len(dataset)), len(dataset) - int(0.8*len(dataset))])

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64)

# Initialize and train
model = ConditionedCNN()
train_loss_hist, val_loss_hist = train_conditioned_model(model, train_loader, val_loader)

Epoch 1/10: 100%|██████████| 250/250 [00:44<00:00,  5.56batch/s, loss=0.965]


Epoch 1: Train Loss = 0.9651 | Val Loss = 0.6710


Epoch 2/10: 100%|██████████| 250/250 [00:39<00:00,  6.41batch/s, loss=0.696]


Epoch 2: Train Loss = 0.6955 | Val Loss = 0.6584


Epoch 3/10: 100%|██████████| 250/250 [00:39<00:00,  6.40batch/s, loss=0.676]


Epoch 3: Train Loss = 0.6760 | Val Loss = 0.6300


Epoch 4/10: 100%|██████████| 250/250 [00:39<00:00,  6.40batch/s, loss=0.634]


Epoch 4: Train Loss = 0.6341 | Val Loss = 0.5939


Epoch 5/10: 100%|██████████| 250/250 [00:38<00:00,  6.46batch/s, loss=0.604]


Epoch 5: Train Loss = 0.6038 | Val Loss = 0.5590


Epoch 6/10: 100%|██████████| 250/250 [00:38<00:00,  6.50batch/s, loss=0.574]


Epoch 6: Train Loss = 0.5740 | Val Loss = 0.5133


Epoch 7/10: 100%|██████████| 250/250 [00:38<00:00,  6.52batch/s, loss=0.53] 


Epoch 7: Train Loss = 0.5299 | Val Loss = 0.4893


Epoch 8/10: 100%|██████████| 250/250 [00:39<00:00,  6.41batch/s, loss=0.511]


Epoch 8: Train Loss = 0.5108 | Val Loss = 0.4635


Epoch 9/10: 100%|██████████| 250/250 [00:37<00:00,  6.68batch/s, loss=0.488]


Epoch 9: Train Loss = 0.4879 | Val Loss = 0.4513


Epoch 10/10: 100%|██████████| 250/250 [00:38<00:00,  6.50batch/s, loss=0.473]


Epoch 10: Train Loss = 0.4727 | Val Loss = 0.4479
