In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn.functional as F
from tqdm import tqdm
from scipy.io import loadmat
import numpy as np


In [2]:
class TopologyDataset(torch.utils.data.Dataset):
    def __init__(self, input_folder, output_folder):
        self.input_files = sorted([f for f in os.listdir(input_folder) if f.endswith('.mat')])
        self.output_files = sorted([f for f in os.listdir(output_folder) if f.endswith('.mat')])
        self.input_folder = input_folder
        self.output_folder = output_folder
        assert len(self.input_files) == len(self.output_files), "Input/output mismatch"

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        X = loadmat(os.path.join(self.input_folder, self.input_files[idx]))['X']  # H x W x 3
        Y = loadmat(os.path.join(self.output_folder, self.output_files[idx]))['Y']  # H x W

        X = torch.tensor(X.transpose(2, 0, 1), dtype=torch.float32)   # 3 x H x W
        Y = torch.tensor(Y[np.newaxis, :, :], dtype=torch.float32)    # 1 x H x W
        return X, Y


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(2, 2)

        # Down part
        for feature in features:
            self.downs.append(self._block(in_channels, feature))
            in_channels = feature

        # Up part
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
            )
            self.ups.append(self._block(feature * 2, feature))

        self.bottleneck = self._block(features[-1], features[-1] * 2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def _block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx // 2]
            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:])
            x = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx + 1](x)

        return self.final_conv(x)


In [4]:
# Paths
train_input = 'data/train/input/'
train_output = 'data/train/output/'

# Dataset & Dataloader
dataset = TopologyDataset(train_input, train_output)
loader = DataLoader(dataset, batch_size=2, shuffle=True)

# Device and model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = UNet(in_channels=3, out_channels=1).to(device)

# Optimizer and loss
optimizer = Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()  # or nn.L1Loss(), nn.BCEWithLogitsLoss()


In [None]:
num_epochs = 50
model.train()

for epoch in range(num_epochs):
    running_loss = 0.0
    for X, Y in tqdm(loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        X, Y = X.to(device), Y.to(device)
        optimizer.zero_grad()
        preds = model(X)
        loss = criterion(preds, Y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {running_loss/len(loader):.6f}")


Epoch 1/50: 100%|███████████████████████████████████████████████████████████████████████| 50/50 [00:21<00:00,  2.31it/s]


Epoch [1/50] - Loss: 0.028349


Epoch 2/50: 100%|███████████████████████████████████████████████████████████████████████| 50/50 [00:15<00:00,  3.20it/s]


Epoch [2/50] - Loss: 0.006374


Epoch 3/50: 100%|███████████████████████████████████████████████████████████████████████| 50/50 [00:16<00:00,  3.10it/s]


Epoch [3/50] - Loss: 0.005331


Epoch 4/50:  44%|███████████████████████████████▏                                       | 22/50 [00:06<00:09,  3.11it/s]

In [None]:
torch.save(model.state_dict(), 'unet_topopt.pth')

In [None]:
model.eval()

# Paths
val_input = 'data/val/input/'
val_output = 'data/val/output/'

# Dataset & Dataloader
dataset_val = TopologyDataset(val_input, val_output)


with torch.no_grad():
    X, Y_true = dataset_val[-1]
    pred = model(X.unsqueeze(0).to(device))
    pred = pred.squeeze().cpu().numpy()

In [None]:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import PolyCollection
import scipy.io

# ======================
# Load mesh
# ======================
data = scipy.io.loadmat('Lbracket2d.mat')
V = data['V']       # nverts x 2
F = data['F'] - 1   # nelems x nodes, convert 1-based -> 0-based

# ======================
# Prepare predicted field
# ======================
# pred: should be nelems x 1 or flattened
# if pred.shape != (n_elems,), flatten
pred_flat = pred.flatten()

# Normalize prediction to [0,1] for colormap
pred_norm = (pred_flat - np.min(pred_flat)) / (np.max(pred_flat) - np.min(pred_flat) + 1e-12)

# Colormap
cmap = plt.cm.jet
colors = cmap(pred_norm)

# ======================
# Create faces list
# ======================
faces = [V[face, :] for face in F]

# ======================
# Plot
# ======================
fig, ax = plt.subplots(figsize=(6,6))
mesh = PolyCollection(faces, facecolors=colors, edgecolors='k')  # show mesh edges
ax.add_collection(mesh)

# Formatting
ax.set_aspect('equal')
ax.axis('off')
ax.autoscale_view()

plt.show()


In [None]:
V = data['V']       # nverts x 2
V

In [None]:
F = data['F'] - 1   # nelems x nodes, convert 1-based -> 0-based
F