In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
import torchvision.datasets as datasets
from torchvision import models, transforms
from torchvision.utils import save_image, make_grid
from torch.optim.lr_scheduler import StepLR
from torch import autograd
from torch.autograd import Variable
from tensorboardX import SummaryWriter

from typing import Dict, Tuple
from tqdm import tqdm
import numpy as np
import time
import os
import random
from tabulate import tabulate

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter

%matplotlib inline

if __name__ == "__main__":
    print("Torch version:", torch.__version__)
    print("CUDA available:", torch.cuda.is_available())
    print("CUDA version:", torch.version.cuda)
    print("Number of GPUs:", torch.cuda.device_count())
    print("GPU name:", torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "No GPU detected")



Torch version: 2.7.0+cu126
CUDA available: True
CUDA version: 12.6
Number of GPUs: 1
GPU name: NVIDIA GeForce RTX 4090


### Loading in Dataset

In [None]:
from waveguide_dataset import WaveguideDataset
dataset = WaveguideDataset('train_test_split.h5')

# # Define split sizes (e.g., 80% train, 20% test)
# train_size = int(0.8 * len(dataset))
# test_size = len(dataset) - train_size

# # Split the dataset
# train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# # Create DataLoaders
# train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
# test_loader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=4)
# pbar = tqdm(test_loader)
# i=0
# for c, p, x in pbar:
#     print(c)
#     break


### CNN Modules

In [21]:
class Flatten(nn.Module):
    """
    Flatten function so can include in nn.Sequential(...)
    """
    def forward(self, x):
        return torch.flatten(x, 1)  # flatten all dimensions except batch

class Net(nn.Module):
    """
    Does CNN on input waveguide, concatenates with conditional parameters, then
    does fully connected nn to output final tensor of values
    """
    def __init__(self):
        super(Net, self).__init__()
        
        # CNN feature extractor for the image
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),  # [B, 1, 32, 32] -> [B, 32, 30, 30]
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1),  # -> [B, 64, 28, 28]
            nn.ReLU(),
            nn.MaxPool2d(2),          # -> [B, 64, 14, 14]
            nn.Dropout(0.25),
            Flatten()                 # -> [B, 64*14*14 = 12544]
        )

        # Fully connected layers after combining with 4-dim condition
        self.fc = nn.Sequential(
            nn.Linear(12544 + 4, 128),  # Concatenate condition
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 8)           # Output 8 continuous values
        )

    def forward(self, x_img, x_cond):
        x = self.cnn(x_img)                 # [B, 12544]
        x = torch.cat((x, x_cond), dim=1)   # [B, 12544 + 4]
        output = self.fc(x)                 # [B, 8]
        return output

In [22]:
def train(model, device, train_loader_in, optimizer, loss_fn):
    model.train()
    i = 0
    train_loader = tqdm(train_loader_in)
    for target, p, x in train_loader:
        i += 1
        target, p, x = target.to(device), p.to(device), x.to(device)
        optimizer.zero_grad()
        output = model(x, p)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        train_loader.set_description(f"loss: {loss.item():.4f}")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=1)
# pbar = tqdm(dataloader)
# train(model, device, pbar, optimizer)


In [None]:
def test(model, device, test_loader_in, loss_fn, dataset):
    model.eval()
    test_loss = 0
    samples = []
    test_loader = tqdm(test_loader_in)
    with torch.no_grad():
        for target, p, x in test_loader:
            target, p, x = target.to(device), p.to(device), x.to(device)
            output = model(x,p)
            test_loss += loss_fn(output, target).item() * x.size(0)
            for t, o in zip(target.cpu(), output.cpu()):
                t_unnorm = dataset.denormalize_cond(t)
                o_unnorm = dataset.denormalize_cond(o)
                if len(samples) < 50:
                    samples.append((t_unnorm.numpy(), o_unnorm.numpy()))

    test_loss /= len(test_loader_in.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}\n')

    chosen = random.sample(samples, 8)
    index = []
    targets = []
    outputs = []
    for i, (target, output) in enumerate(chosen):
        index.append(i)
        targets.append([f"{v:.3f}" for v in target.tolist()])
        outputs.append([f"{v:.3f}" for v in output.tolist()])
    
    headers = [' '] + index
    row_1 = ['Target'] + targets
    row_2 = ['Output'] + outputs
    print(tabulate([row_1, row_2], headers=headers, tablefmt='orgtbl'))
    

In [None]:
def main(dataset):
    os.makedirs("models", exist_ok=True)
    batch_size = 64
    test_batch_size = 1000
    lr = 1e-3
    gamma = 0.7
    epochs = 15

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = Net().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
    loss_fn = nn.MSELoss()

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=4)
    # pbar_train = tqdm(train_loader)
    # pbar_test = tqdm(test_loader)

    for epoch in range(epochs):
        print(f'Epoch #{epoch}:')
        train(model, device, train_loader, optimizer, loss_fn)
        test(model, device, test_loader, loss_fn, dataset)
        scheduler.step()
        torch.save(model.state_dict(), 'models/mode_cnn.pth')

if __name__ == '__main__':
    main(dataset)


Epoch #0:


loss: 119.5393:   7%|▋         | 1026/13737 [00:14<02:56, 72.06it/s]


KeyboardInterrupt: 

In [12]:
print(tabulate([['Alice', [1,2,3], 'bruh'], ['Bob', [1,3,4], 'sam']], headers=['Name', 'sam', 'bruh'], tablefmt='orgtbl'))

| Name   | sam       | bruh   |
|--------+-----------+--------|
| Alice  | [1, 2, 3] | bruh   |
| Bob    | [1, 3, 4] | sam    |
