In [1]:
import torch
from torch.utils.data import DataLoader
from datasets.wss import read_ansys_csv
import os
from models.field_vae.base import SurfaceFieldAutoEncoder

In [None]:
from datasets.wss import WSSPeakDataset
from losses.base import KLSurfaceField

batch_size = 32
train_split = 0.8

root_dir = '/media/yaplab/HDD_Storage/wenhao/datasets/AneuG_CFD/peak_wss'
dataset = WSSPeakDataset(root_dir, encode_size=16800, decode_size=3600)
train_size = int(train_split * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
loss_module = KLSurfaceField(kl_weight=0.00001)

In [3]:
device = torch.device("cuda:0")
num_latents = 128
feature_dim = 1
embed_dim = 16
num_freqs = 8
width = 768 // 2
heads = 6
num_encoder_layers = 6
num_decoder_layers = 12

SurfaceFieldVAE = SurfaceFieldAutoEncoder(device=device,
                                          num_latents=num_latents,
                                          feature_dim=feature_dim,
                                          embed_dim=embed_dim,
                                          num_freqs=num_freqs,
                                          width=width,
                                          heads=heads,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers)
SurfaceFieldVAE.to(device)

SurfaceFieldAutoEncoder(
  (fourier_embedder): FourierEmbedder()
  (encoder): CrossAttentionEncoder(
    (fourier_embedder): FourierEmbedder()
    (input_proj): Linear(in_features=52, out_features=384, bias=True)
    (cross_attn): ResidualCrossAttentionBlock(
      (attn): MultiheadCrossAttention(
        (c_q): Linear(in_features=384, out_features=384, bias=True)
        (c_kv): Linear(in_features=384, out_features=768, bias=True)
        (c_proj): Linear(in_features=384, out_features=384, bias=True)
        (attention): QKVMultiheadCrossAttention()
      )
      (ln_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (ln_2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (c_fc): Linear(in_features=384, out_features=1536, bias=True)
        (c_proj): Linear(in_features=1536, out_features=384, bias=True)
        (gelu): GELU(approximate='none')
      )
      (ln_3): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
    )
    (self_attn):

In [4]:
import wandb
log_wandb = True
meta = "debug"
if log_wandb:
    wandb.login()
    run = wandb.init(project="geodiffusion",
                     name=meta)
    
optimizer = torch.optim.AdamW(SurfaceFieldVAE.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2500, gamma=0.5)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mwhding[0m ([33mwhding-imperial-college-london[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
for epoch in range(100000+1):
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        data = {key: value.to(device) for key, value in data.items()}
        coords, feats, recon_coords, recon_feats_true = data['coords'], data['feats'], data['recon_coords'], data['recon_feats']
        recon_feats_true = recon_feats_true.squeeze(-1)
        recon_feats_pred, center_pos, posterior = SurfaceFieldVAE(coords, feats, recon_coords, sample_posterior=True)
        loss, loss_log = loss_module(posterior, recon_feats_pred, recon_feats_true)
        loss.backward()
        optimizer.step()
    
    if epoch % 100 == 0:
        recon_loss_test = 0.0
        for j, data in enumerate(test_loader):
            data = {key: value.to(device) for key, value in data.items()}
            coords, feats, recon_coords, recon_feats_true = data['coords'], data['feats'], data['recon_coords'], data['recon_feats']
            recon_feats_true = recon_feats_true.squeeze(-1)
            recon_feats_pred, center_pos, posterior = SurfaceFieldVAE(coords, feats, recon_coords, sample_posterior=True)
            loss_test, loss_log_test = loss_module(posterior, recon_feats_pred, recon_feats_true)
            recon_loss_test += loss_test.item() / len(test_loader)
        print(f'Epoch: {epoch}, Test Loss: {recon_loss_test}')
    
    log_dict = {'recon_loss': loss_log['recon_loss'], 'kl_loss': loss_log['kl_loss'], 'test_loss': recon_loss_test}
    print(log_dict)
    if log_wandb:
        wandb.log(log_dict, step=epoch)
    
    scheduler.step()
wandb.finish

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.78 GiB. GPU 0 has a total capacty of 23.69 GiB of which 1.68 GiB is free. Process 3079653 has 330.00 MiB memory in use. Including non-PyTorch memory, this process has 21.67 GiB memory in use. Of the allocated memory 15.33 GiB is allocated by PyTorch, and 6.03 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF