In [1]:
!pip install  dgl -f https://data.dgl.ai/wheels/torch-2.3/cu121/repo.html

Looking in links: https://data.dgl.ai/wheels/torch-2.3/cu121/repo.html
Collecting dgl
  Downloading https://data.dgl.ai/wheels/torch-2.3/cu121/dgl-2.2.1%2Bcu121-cp310-cp310-manylinux1_x86_64.whl (199.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 MB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
Collecting torchdata>=0.5.0 (from dgl)
  Downloading torchdata-0.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m52.4 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=2->torchdata>=0.5.0->dgl)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=2->torchdata>=0.5.0->dgl)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=2->torchdata>=

In [11]:
from torch.utils.data import Dataset, DataLoader
from torch import FloatTensor
import dgl
from dgl.data.utils import load_graphs
from tqdm import tqdm
from abc import abstractmethod
from torch.nn.utils.rnn import pad_sequence
import torch



from torch.utils.data import Dataset, DataLoader
import dgl
from torch import FloatTensor, stack

class BaseDataset(Dataset):
    @staticmethod
    def num_classes():
        pass

    def __init__(self, X_train, Y_train):
        """
        self.data is a list of dictionaries with keys graph and label
        """
        assert len(X_train) == len(Y_train), "The number of graphs must match the number of labels"
        self.data = [{"graph": graph, "label": label} for graph, label in zip(X_train, Y_train)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample["graph"], sample["label"]

    def _collate(self, batch):
        graphs, labels = zip(*batch)
        batched_graph = dgl.batch(graphs)
        num_nodes_per_graph = [graph.number_of_nodes() for graph in graphs]


        pad_vector = torch.tensor([6] + [-1]*16, dtype=torch.float32)


        max_length = 20
        padded_labels = []
        for label in labels:
            label_tensor = torch.tensor(label, dtype=torch.float32)
            label_length = label_tensor.shape[0]

            if label_length < max_length:

                padding_count = max_length - label_length

                padding = pad_vector.repeat(padding_count, 1)

                padded_label = torch.cat([label_tensor, padding], dim=0)
            elif label_length > max_length:

                padded_label = label_tensor[:max_length]
            else:
                padded_label = label_tensor

            padded_labels.append(padded_label)


        padded_labels = torch.stack(padded_labels)
        return {"graph": batched_graph, "labels": padded_labels, "num_nodes": num_nodes_per_graph}


    def get_dataloader(self, batch_size, shuffle=True, num_workers=0):
        return DataLoader(
            self,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=self._collate,
            num_workers=num_workers,
            drop_last=True
        )

In [5]:
import torch
from torch import nn
import torch.nn.functional as F
from dgl.nn.pytorch.conv import NNConv
from dgl.nn.pytorch.glob import MaxPooling
from torch.nn import TransformerDecoder, TransformerDecoderLayer
import math


# Convolutional Layers
def _conv1d(in_channels, out_channels, kernel_size=3, padding=0, bias=False):
    """
    Helper function to create a 1D convolutional layer with batchnorm and LeakyReLU activation

    Args:
        in_channels (int): Input channels
        out_channels (int): Output channels
        kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
        padding (int, optional): Padding size on each side. Defaults to 0.
        bias (bool, optional): Whether bias is used. Defaults to False.

    Returns:
        nn.Sequential: Sequential contained the Conv1d, BatchNorm1d and LeakyReLU layers
    """
    return nn.Sequential(
        nn.Conv1d(
            in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias
        ),
        nn.BatchNorm1d(out_channels),
        nn.LeakyReLU(),
    )


def _conv2d(in_channels, out_channels, kernel_size, padding=0, bias=False):
    """
    Helper function to create a 2D convolutional layer with batchnorm and LeakyReLU activation

    Args:
        in_channels (int): Input channels
        out_channels (int): Output channels
        kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
        padding (int, optional): Padding size on each side. Defaults to 0.
        bias (bool, optional): Whether bias is used. Defaults to False.

    Returns:
        nn.Sequential: Sequential contained the Conv2d, BatchNorm2d and LeakyReLU layers
    """
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            padding=padding,
            bias=bias,
        ),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(),
    )


def _fc(in_features, out_features, bias=False):
    return nn.Sequential(
        nn.Linear(in_features, out_features, bias=bias),
        nn.BatchNorm1d(out_features),
        nn.LeakyReLU(),
    )

class _MLP(nn.Module):
    """"""

    def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
        """
        MLP with linear output
        Args:
            num_layers (int): The number of linear layers in the MLP
            input_dim (int): Input feature dimension
            hidden_dim (int): Hidden feature dimensions for all hidden layers
            output_dim (int): Output feature dimension

        Raises:
            ValueError: If the given number of layers is <1
        """
        super(_MLP, self).__init__()
        self.linear_or_not = True  # default is linear model
        self.num_layers = num_layers
        self.output_dim = output_dim

        if num_layers < 1:
            raise ValueError("Number of layers should be positive!")
        elif num_layers == 1:
            # Linear model
            self.linear = nn.Linear(input_dim, output_dim)
        else:
            # Multi-layer model
            self.linear_or_not = False
            self.linears = torch.nn.ModuleList()
            self.batch_norms = torch.nn.ModuleList()

            self.linears.append(nn.Linear(input_dim, hidden_dim))
            for layer in range(num_layers - 2):
                self.linears.append(nn.Linear(hidden_dim, hidden_dim))
            self.linears.append(nn.Linear(hidden_dim, output_dim))

            # TODO: this could move inside the above loop
            for layer in range(num_layers - 1):
                self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))

    def forward(self, x):
        if self.linear_or_not:
            # If linear model
            return self.linear(x)
        else:
            # If MLP
            h = x
            for i in range(self.num_layers - 1):
                h = F.relu(self.batch_norms[i](self.linears[i](h)))
            return self.linears[-1](h)

class UVNetCurveEncoder(nn.Module):
    def __init__(self, in_channels=6, output_dims=64):
        """
        This is the 1D convolutional network that extracts features from the B-rep edge
        geometry described as 1D UV-grids (see Section 3.2, Curve & surface convolution
        in paper)

        Args:
            in_channels (int, optional): Number of channels in the edge UV-grids. By default
                                         we expect 3 channels for point coordinates and 3 for
                                         curve tangents. Defaults to 6.
            output_dims (int, optional): Output curve embedding dimension. Defaults to 64.
        """
        super(UVNetCurveEncoder, self).__init__()
        self.in_channels = in_channels
        self.conv1 = _conv1d(in_channels, 64, kernel_size=3, padding=1, bias=False)
        self.conv2 = _conv1d(64, 128, kernel_size=3, padding=1, bias=False)
        self.conv3 = _conv1d(128, 256, kernel_size=3, padding=1, bias=False)
        self.final_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = _fc(256, output_dims, bias=False)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, (nn.Linear, nn.Conv1d)):
            torch.nn.init.kaiming_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, x):
        assert x.size(1) == self.in_channels
        batch_size = x.size(0)
        x = x.float()
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.final_pool(x)
        x = x.view(batch_size, -1)
        x = self.fc(x)
        return x


class UVNetSurfaceEncoder(nn.Module):
    def __init__(
        self,
        in_channels=7,
        output_dims=64,
    ):
        """
        This is the 2D convolutional network that extracts features from the B-rep face
        geometry described as 2D UV-grids (see Section 3.2, Curve & surface convolution
        in paper)

        Args:
            in_channels (int, optional): Number of channels in the edge UV-grids. By default
                                         we expect 3 channels for point coordinates and 3 for
                                         surface normals and 1 for the trimming mask. Defaults
                                         to 7.
            output_dims (int, optional): Output surface embedding dimension. Defaults to 64.
        """
        super(UVNetSurfaceEncoder, self).__init__()
        self.in_channels = in_channels
        self.conv1 = _conv2d(in_channels, 64, 3, padding=1, bias=False)
        self.conv2 = _conv2d(64, 128, 3, padding=1, bias=False)
        self.conv3 = _conv2d(128, 256, 3, padding=1, bias=False)
        self.final_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = _fc(256, output_dims, bias=False)
        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            torch.nn.init.kaiming_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, x):
        assert x.size(1) == self.in_channels
        batch_size = x.size(0)
        x = x.float()
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.final_pool(x)
        x = x.view(batch_size, -1)
        x = self.fc(x)
        return x

class _EdgeConv(nn.Module):
    def __init__(
        self,
        edge_feats,
        out_feats,
        node_feats,
        num_mlp_layers=2,
        hidden_mlp_dim=64,
    ):
        """
        This module implements Eq. 2 from the paper where the edge features are
        updated using the node features at the endpoints.

        Args:
            edge_feats (int): Input edge feature dimension
            out_feats (int): Output feature deimension
            node_feats (int): Input node feature dimension
            num_mlp_layers (int, optional): Number of layers used in the MLP. Defaults to 2.
            hidden_mlp_dim (int, optional): Hidden feature dimension in the MLP. Defaults to 64.
        """
        super(_EdgeConv, self).__init__()
        self.proj = _MLP(1, node_feats, hidden_mlp_dim, edge_feats)
        self.mlp = _MLP(num_mlp_layers, edge_feats, hidden_mlp_dim, out_feats)
        self.batchnorm = nn.BatchNorm1d(out_feats)
        self.eps = torch.nn.Parameter(torch.FloatTensor([0.0]))

    def forward(self, graph, nfeat, efeat):
        src, dst = graph.edges()
        proj1, proj2 = self.proj(nfeat[src]), self.proj(nfeat[dst])
        agg = proj1 + proj2
        h = self.mlp((1 + self.eps) * efeat + agg)
        h = F.leaky_relu(self.batchnorm(h))
        return h


class _NodeConv(nn.Module):
    def __init__(
        self,
        node_feats,
        out_feats,
        edge_feats,
        num_mlp_layers=2,
        hidden_mlp_dim=64,
    ):
        """
        This module implements Eq. 1 from the paper where the node features are
        updated using the neighboring node and edge features.

        Args:
            node_feats (int): Input edge feature dimension
            out_feats (int): Output feature deimension
            node_feats (int): Input node feature dimension
            num_mlp_layers (int, optional): Number of layers used in the MLP. Defaults to 2.
            hidden_mlp_dim (int, optional): Hidden feature dimension in the MLP. Defaults to 64.
        """
        super(_NodeConv, self).__init__()
        self.gconv = NNConv(
            in_feats=node_feats,
            out_feats=out_feats,
            edge_func=nn.Linear(edge_feats, node_feats * out_feats),
            aggregator_type="sum",
            bias=False,
        )
        self.batchnorm = nn.BatchNorm1d(out_feats)
        self.mlp = _MLP(num_mlp_layers, node_feats, hidden_mlp_dim, out_feats)
        self.eps = torch.nn.Parameter(torch.FloatTensor([0.0]))

    def forward(self, graph, nfeat, efeat):
        h = (1 + self.eps) * nfeat
        h = self.gconv(graph, h, efeat)
        h = self.mlp(h)
        h = F.leaky_relu(self.batchnorm(h))
        return h


class UVNetGraphEncoder(nn.Module):
    def __init__(
        self,
        input_dim,
        input_edge_dim,
        output_dim,
        hidden_dim=64,
        learn_eps=True,
        num_layers=3,
        num_mlp_layers=2,
    ):
        """
        This is the graph neural network used for message-passing features in the
        face-adjacency graph.

        Args:
            input_dim ([type]): [description]
            input_edge_dim ([type]): [description]
            output_dim ([type]): [description]
            hidden_dim (int, optional): [description]. Defaults to 64.
            learn_eps (bool, optional): [description]. Defaults to True.
            num_layers (int, optional): [description]. Defaults to 3.
            num_mlp_layers (int, optional): [description]. Defaults to 2.
        """
        super(UVNetGraphEncoder, self).__init__()
        self.num_layers = num_layers
        self.learn_eps = learn_eps

        # List of layers for node and edge feature message passing
        self.node_conv_layers = torch.nn.ModuleList()
        self.edge_conv_layers = torch.nn.ModuleList()

        for layer in range(self.num_layers - 1):
            node_feats = input_dim if layer == 0 else hidden_dim
            edge_feats = input_edge_dim if layer == 0 else hidden_dim
            self.node_conv_layers.append(
                _NodeConv(
                    node_feats=node_feats,
                    out_feats=hidden_dim,
                    edge_feats=edge_feats,
                    num_mlp_layers=num_mlp_layers,
                    hidden_mlp_dim=hidden_dim,
                ),
            )
            self.edge_conv_layers.append(
                _EdgeConv(
                    edge_feats=edge_feats,
                    out_feats=hidden_dim,
                    node_feats=node_feats,
                    num_mlp_layers=num_mlp_layers,
                    hidden_mlp_dim=hidden_dim,
                )
            )

        # Linear function for graph poolings of output of each layer
        # which maps the output of different layers into a prediction score
        self.linears_prediction = torch.nn.ModuleList()

        for layer in range(num_layers):
            if layer == 0:
                self.linears_prediction.append(nn.Linear(input_dim, output_dim))
            else:
                self.linears_prediction.append(nn.Linear(hidden_dim, output_dim))

        self.drop1 = nn.Dropout(0.3)
        self.drop = nn.Dropout(0.5)
        self.pool = MaxPooling()

    def forward(self, g, h, efeat):
        hidden_rep = [h]
        he = efeat

        for i in range(self.num_layers - 1):

            h = self.node_conv_layers[i](g, h, he)
            he = self.edge_conv_layers[i](g, h, he)
            hidden_rep.append(h)
        node_embeddings = hidden_rep[-1]
        node_embeddings = self.drop1(node_embeddings)

        graph_representation = 0
        for i, h in enumerate(hidden_rep):
            pooled_h = self.pool(g, h)
            graph_representation += self.drop(self.linears_prediction[i](pooled_h))


        return node_embeddings, graph_representation

class CADEncoder(nn.Module):
  def __init__(self, crv_emb_dim=64, srf_emb_dim=64, graph_emb_dim=128, dropout=0.3):
    super(CADEncoder, self).__init__()
    self.curv_encoder = UVNetCurveEncoder(in_channels=10, output_dims=crv_emb_dim) # in_channels originally 6
    self.surf_encoder = UVNetSurfaceEncoder(in_channels=10, output_dims=srf_emb_dim)
    self.graph_encoder = UVNetGraphEncoder(srf_emb_dim, crv_emb_dim, graph_emb_dim)

  def forward(self, batched_graph):
    input_crv_feat = batched_graph.edata["x"]
    input_srf_feat = batched_graph.ndata["x"]
    hidden_crv_feat = self.curv_encoder(input_crv_feat)
    hidden_srf_feat = self.surf_encoder(input_srf_feat)
    node_emb, graph_emb = self.graph_encoder(batched_graph, hidden_srf_feat, hidden_crv_feat)
    return node_emb, graph_emb


class PositionalEncodingLUT(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=250):
        super(PositionalEncodingLUT, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(0, max_len, dtype=torch.long).unsqueeze(1)
        self.register_buffer('position', position)

        self.pos_embed = nn.Embedding(max_len, d_model)

        self._init_embeddings()

    def _init_embeddings(self):
        nn.init.kaiming_normal_(self.pos_embed.weight, mode="fan_in")

    def forward(self, x):
        batch_size, seq_length, _ = x.shape

        pos = self.position[:seq_length]  # This will have shape [seq_length, 1]

        pos = pos.expand(-1, batch_size).contiguous()  # Reshape to [seq_length, batch_size]
        pos = pos.transpose(0, 1)  # Transpose to [batch_size, seq_length]

        pos_embeddings = self.pos_embed(pos)  # This should now be [batch_size, seq_length, d_model]

        x = x + pos_embeddings

        return self.dropout(x)

class CADEmbedding(nn.Module):
    """Embedding: positional embed + command embed + parameter embed + group embed (optional)"""
    def __init__(self, n_commands=7, d_model=64, n_args=16, args_dim=257, seq_len=60):
        super(CADEmbedding, self).__init__()
        self.command_embed = nn.Embedding(n_commands, d_model)
        self.arg_embed = nn.Embedding(args_dim, d_model, padding_idx=0)
        self.embed_fcn = nn.Linear(d_model * n_args, d_model)
        self.pos_encoding = PositionalEncodingLUT(d_model, max_len=seq_len+2)

    def forward(self, commands, args):
        S, N = commands.shape
        src = self.command_embed(commands.long()) + \
              self.embed_fcn(self.arg_embed((args + 1).long()).view(S, N, -1))
        src = self.pos_encoding(src)
        return src

class FusionModule(nn.Module):
    def __init__(self, latent_size, command_embedding_size=64, hidden_size=128, output_size=64):
        super(FusionModule, self).__init__()
        self.fc1 = nn.Linear(latent_size + command_embedding_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, latent, command_embedding):
        command_embedding = command_embedding.unsqueeze(1)
        combined = torch.cat((latent, command_embedding), dim=2)
        x = self.fc1(combined)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        return x



from torch import Tensor
from typing import Optional



class CausalTransformerDecoderLayer(nn.TransformerDecoderLayer):
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        if self.training:
            return super().forward(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)
        else:
            tgt_last_tok = tgt[-1:, :, :]
            tgt_last_tok = self.self_attn(tgt_last_tok, tgt, tgt)[0] + tgt_last_tok
            tgt_last_tok = self.norm1(tgt_last_tok)
            if memory is not None:
                tgt_last_tok = self.multihead_attn(tgt_last_tok, memory, memory)[0] + tgt_last_tok
                tgt_last_tok = self.norm2(tgt_last_tok)
            tgt_last_tok = self.linear2(self.dropout(self.activation(self.linear1(tgt_last_tok)))) + tgt_last_tok
            tgt_last_tok = self.norm3(tgt_last_tok)
            return tgt_last_tok


class CausalTransformerDecoder(nn.TransformerDecoder):
    def forward(self, tgt, memory=None, cache=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # print("tgt in TransformerDecoder ", tgt.shape)
        # print("memory in TransformerDecoder ", memory.shape)
        if self.training:
            if cache is not None:
                raise ValueError("cache parameter should be None in training mode")
            for mod in self.layers:
                tgt = mod(tgt, memory, tgt_mask=None, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
            return tgt
        else:
            new_token_cache = []
            for i, mod in enumerate(self.layers):
                tgt = mod(tgt, memory)
                new_token_cache.append(tgt)
                if cache is not None:
                    tgt = torch.cat([cache[i], tgt], dim=0)
            if cache is not None:
                new_cache = torch.cat([cache, torch.stack(new_token_cache, dim=0)], dim=1)
            else:
                new_cache = torch.stack(new_token_cache, dim=0)
            # print("tgt[-1:] from the transformer decoder block ", tgt[-1:])
            # print("new_cache from the transformer decoder block ", new_cache)
            return tgt[-1:], new_cache



    @staticmethod
    def _generate_causal_mask(sz, device):
        mask = torch.full((sz, sz), float('-inf'))
        mask = torch.triu(mask, diagonal=1)
        return mask.to(device)


    @staticmethod
    def _generate_causal_mask(sz, device):
        """
        Generates a causal mask to hide future tokens for autoregressive tasks.
        """
        mask = torch.full((sz, sz), float('-inf'))
        mask = torch.triu(mask, diagonal=1)
        return mask.to(device)


class CADDecoder(nn.Module):
    def __init__(self, d_model=64, nhead=8, num_decoder_layers=4, dim_feedforward=2048, dropout=0.1, activation="relu", num_commands=7, max_seq_len=5000, num_parameters=16, param_cat=257):
        super(CADDecoder, self).__init__()
        self.param_cat = param_cat
        # Embedding
        self.cad_command_embedding = CADEmbedding(d_model=d_model, n_commands=num_commands, n_args=num_parameters, seq_len=max_seq_len)

        # Fusion module
        self.fusion_module = FusionModule(latent_size=128)

        # Decoder
        decoder_layer = CausalTransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation)
        self.transformer_decoder = CausalTransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        # Final output
        self.output_layer1 = nn.Linear(d_model, num_commands)  # t_i
        self.output_layer2 = nn.Linear(d_model, num_parameters * param_cat)  # p_i

    def forward(self, node_embeddings, graph_embeddings, command_seq, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        commands_count = command_seq.shape[1]
        commands = command_seq[:, :, 0]
        args = command_seq[:, :, 1:]
        construct_embed = self.cad_command_embedding(commands, args)

        # Fusion module
        fusion_outputs = []
        for i in range(commands_count):
            fusion_output = self.fusion_module(graph_embeddings, construct_embed[:, i, :])
            fusion_outputs.append(fusion_output)
        fusion_outputs = torch.cat(fusion_outputs, dim=1)  # [batch_size, commands_count, d_model]

        # Reshape node embeddings to [number of nodes, batch size, embedding dimension]
        batch_size, num_nodes, embed_dim = node_embeddings.shape
        node_embeddings = node_embeddings.transpose(0, 1)  # [number of nodes, batch size, embedding dimension]

        all_decoder_outputs = []

        # Autoregressive decoding
        for t in range(commands_count):
            if t == 0:
                current_input = fusion_outputs[:, :1, :]  # [batch_size, 1, d_model]
            else:
                current_input = torch.cat((fusion_outputs[:, :t, :], last_output), dim=1)

            current_input = current_input.transpose(0, 1)  # [sequence_length, batch_size, d_model]

            decoder_output = self.transformer_decoder(
                tgt=current_input,
                memory=node_embeddings,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask
            )

            decoder_output = decoder_output.transpose(0, 1)  # [batch_size, sequence_length, d_model]

            # Get the last output token
            last_output = decoder_output[:, -1, :].unsqueeze(1)  # [batch_size, 1, d_model]

            all_decoder_outputs.append(last_output)

        all_decoder_outputs = torch.cat(all_decoder_outputs, dim=1)  # [batch_size, commands_count, d_model]

        # Final output layers
        output1 = self.output_layer1(all_decoder_outputs)
        output1 = F.softmax(output1[:, -1:, :], dim=-1)  # t_i

        output2 = self.output_layer2(all_decoder_outputs)
        output2 = output2.view(batch_size, commands_count, -1, self.param_cat)[:, -1, :, :]  # Reshape and select last
        output2 = F.softmax(output2, dim=2)  # p_i
        return output1, output2


class CADParser(nn.Module):
    def __init__(self):
        super(CADParser, self).__init__()

        # encoder
        self.cad_encoder = CADEncoder()

        # decoder
        self.cad_decoder = CADDecoder()
    def forward(graphs, sequences):


DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
import dgl
import numpy as np

# Load data
graphs, label_dict = dgl.load_graphs("/content/drive/My Drive/DeepCADDataset/all_graphs.bin")
npz = np.load("/content/drive/My Drive/DeepCADDataset/all_npz.npz")
# Prepare data
max_nodes = max(graph.number_of_nodes() for graph in graphs)
print("Maximum number of nodes in any graph:", max_nodes)

# Filter graphs and corresponding npz entries
filtered_graphs = []
filtered_npz_keys = []
for i in range(len(graphs)):
    if graphs[i].number_of_nodes() <= 20:
        filtered_graphs.append(graphs[i])
        filtered_npz_keys.append(list(npz.keys())[i])

# Count the number of filtered graphs
count = len(filtered_graphs)
print("Number of graphs with <= 20 nodes:", count)


Maximum number of nodes in any graph: 66
Number of graphs with <= 20 nodes: 6703


In [7]:
import dgl
import numpy as np
import torch
from torch import optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import os
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.cuda.amp import autocast, GradScaler

# Assuming BaseDataset and CADParser are already defined

# Load data
graphs, label_dict = dgl.load_graphs("/content/drive/My Drive/DeepCADDataset/all_graphs.bin")
npz = np.load("/content/drive/My Drive/DeepCADDataset/all_npz.npz")

# Set seed for reproducibility
seednumber = 2024
torch.manual_seed(seednumber)
torch.cuda.manual_seed(seednumber)
np.random.seed(seednumber)

# Prepare data
max_nodes = max(graph.number_of_nodes() for graph in graphs)
print("Maximum number of nodes in any graph:", max_nodes)

Y = [npz[key] for key in npz.keys()]
X = graphs



Maximum number of nodes in any graph: 66


In [8]:
filtered_X = []
filtered_Y = []
for i in range(len(X)):
    if X[i].number_of_nodes() <= 20:
        filtered_X.append(X[i])
        filtered_Y.append(Y[i])

# Count the number of filtered graphs
count = len(filtered_X)
print("Number of graphs with <= 20 nodes:", count)

Number of graphs with <= 20 nodes: 6703


In [9]:
X_train, X_test, Y_train, Y_test = train_test_split(filtered_X, filtered_Y, test_size=0.1, random_state=42)

In [None]:

# Set device and other configurations
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_dir = 'checkpoint'
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# Define hyperparameters
num_epochs = 10
initial_learning_rate = 1e-4
batch_size = 96
warmup_epochs = 10
torch.set_printoptions(threshold=10_000)

# Initialize dataset, dataloader, model, criterion, optimizer, scheduler
dataset = BaseDataset(X_train, Y_train)
data_loader = dataset.get_dataloader(batch_size)
model = CADParser().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=initial_learning_rate)
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.9 ** (epoch // 30) * min((epoch + 1) / warmup_epochs, 1))
scaler = GradScaler()  # For mixed precision training

def print_memory_usage(tag=""):
    print(f"{tag} - Allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MiB, Reserved: {torch.cuda.memory_reserved()/1024**2:.2f} MiB")

start_epoch = 0
loss_list = []
# Training loop
for epoch in range(start_epoch, num_epochs):
    model.train()
    total_loss = 0
    for batch_idx, batch in enumerate(data_loader):
        graphs, sequences = batch['graph'].to(device), batch['labels'].to(device)
        # print_memory_usage(f"After loading batch {batch_idx}")

        num_nodes_per_graph = batch['num_nodes']

        optimizer.zero_grad()
        batch_loss = 0

        node_embeddings, graph_embeddings = model.cad_encoder(graphs)
        # print_memory_usage(f"After cad_encoder {batch_idx}")

        del graphs  # Release memory of graphs
        torch.cuda.empty_cache()  # Clear unused memory

        decoder_input_seq = sequences[:, 0:1, :].to(device)

        batch_num_nodes = node_embeddings.shape[0]
        padded_node_embeddings = torch.zeros(len(num_nodes_per_graph), 66, 64, device=device)
        # print_memory_usage(f"After creating padded_node_embeddings {batch_idx}")

        start_idx = 0
        for i, num_nodes in enumerate(num_nodes_per_graph):
            end_idx = start_idx + num_nodes
            padded_node_embeddings[i, :num_nodes] = node_embeddings[start_idx:end_idx]
            start_idx = end_idx

        graph_embeddings = torch.unsqueeze(graph_embeddings, 1).to(device)
        # print_memory_usage(f"After processing node_embeddings {batch_idx}")

        for t in range(1, sequences.size(1)):
            torch.cuda.empty_cache()

            # print_memory_usage("1")
            gt_t = sequences[:, :t, :]
            # print_memory_usage("2")
            command_type_t = sequences[:, t, 0].long()
            # print_memory_usage("3")
            param_t = sequences[:, t, 1:]
            param_t_mapped = param_t + 1
            # print_memory_usage("4")
            true_t_i = F.one_hot(command_type_t, num_classes=7).float()
            true_p_i = F.one_hot(param_t_mapped.long(), num_classes=257).float()
            # print_memory_usage("5")

            with autocast():  # Use mixed precision
                decoder_output_t_i, decoder_output_p_i = model.cad_decoder(padded_node_embeddings, graph_embeddings, decoder_input_seq)

            t_i_loss = criterion(decoder_output_t_i.squeeze(), true_t_i)
            p_i_loss = criterion(decoder_output_p_i, true_p_i)
            decoder_output_t_i = decoder_output_t_i.detach()  # Detach to free computation graph
            decoder_output_p_i = decoder_output_p_i.detach()  # Detach to free computation graph
            torch.cuda.empty_cache()
            # print_memory_usage("6")
            loss = t_i_loss + p_i_loss
            batch_loss += loss

            # print(f"t_i_loss.requires_grad: {t_i_loss.requires_grad}")
            # print(f"p_i_loss.requires_grad: {p_i_loss.requires_grad}")
            # print(f"loss.requires_grad: {loss.requires_grad}")
            # print(f"batch_loss.requires_grad: {batch_loss.requires_grad}")

            _, command_type_pred_next = torch.max(decoder_output_t_i, dim=2, keepdim=True)
            _, command_args_pred_next = torch.max(decoder_output_p_i, dim=2, keepdim=True)
            # print_memory_usage("7")
            command_args_pred_next = command_args_pred_next - 1
            next_token_pred = torch.cat((command_type_pred_next, command_args_pred_next), dim=1)
            next_token_pred = next_token_pred.transpose(1, 2).view(batch_size, -1, 17)
            decoder_input_seq = torch.cat((decoder_input_seq, next_token_pred), dim=1)
            # print_memory_usage("8")

            # print_memory_usage(f"After processing timestep {t} of batch {batch_idx}")

            # Free the memory used by previous iterations
            del decoder_output_t_i, decoder_output_p_i, gt_t, command_type_t, param_t, param_t_mapped, true_t_i, true_p_i
            torch.cuda.empty_cache()
            # print_memory_usage("end")

        batch_loss /= (sequences.size(1) - 1)
        batch_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += batch_loss.item()
        print("batch loss: ", batch_loss.item())

        scheduler.step()
        torch.cuda.empty_cache()  # Clear unused memory after each batch
        print_memory_usage(f"After processing batch {batch_idx}")

    print(f'----------------------EPOCH {epoch + 1}, TOTAL LOSS = {total_loss / 96}')
    if epoch % 10 == 0:
      torch.save({
          'epoch': epoch + 1,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'scheduler_state_dict': scheduler.state_dict(),
          'loss': total_loss / 96,
      }, f'{checkpoint_dir}/model_epoch_{epoch + 1}.pth')
    # print_memory_usage(f"After saving checkpoint for epoch {epoch + 1}")
    loss_list.append(total_loss)

print('Finished Training')

import matplotlib.pyplot as plt
plt.figure(figsize=(10, 5))
plt.plot(loss_list, label='Loss per Epoch')
plt.title('Training Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.savefig('plot/loss_plot.png')
plt.close()



  return F.conv2d(input, weight, bias, self.stride,


batch loss:  2.1503868103027344
After processing batch 0 - Allocated: 90.37 MiB, Reserved: 610.00 MiB
batch loss:  2.152381420135498
After processing batch 1 - Allocated: 90.20 MiB, Reserved: 644.00 MiB
batch loss:  2.13252592086792
After processing batch 2 - Allocated: 90.30 MiB, Reserved: 596.00 MiB
batch loss:  2.1202807426452637
After processing batch 3 - Allocated: 90.40 MiB, Reserved: 588.00 MiB
batch loss:  2.088874578475952
After processing batch 4 - Allocated: 90.25 MiB, Reserved: 610.00 MiB
batch loss:  2.0595810413360596
After processing batch 5 - Allocated: 90.80 MiB, Reserved: 616.00 MiB
batch loss:  2.0069761276245117
After processing batch 6 - Allocated: 66.56 MiB, Reserved: 434.00 MiB
batch loss:  2.0009541511535645
After processing batch 7 - Allocated: 67.35 MiB, Reserved: 508.00 MiB
batch loss:  1.954213261604309
After processing batch 8 - Allocated: 66.74 MiB, Reserved: 470.00 MiB
batch loss:  1.9417903423309326
After processing batch 9 - Allocated: 66.69 MiB, Reserv

In [35]:
torch.cuda.empty_cache()