In [1]:
!pip install dgl -f https://data.dgl.ai/wheels/cu121/repo.html

Looking in links: https://data.dgl.ai/wheels/cu121/repo.html


In [2]:
import random

import numpy as np
import torch
from scipy.spatial.transform import Rotation


def bounding_box_uvgrid(inp: torch.Tensor):
    pts = inp[..., :3].reshape((-1, 3))
    mask = inp[..., 6].reshape(-1)
    point_indices_inside_faces = mask == 1
    pts = pts[point_indices_inside_faces, :]
    return bounding_box_pointcloud(pts)


def bounding_box_pointcloud(pts: torch.Tensor):
    x = pts[:, 0]
    y = pts[:, 1]
    z = pts[:, 2]
    box = [[x.min(), y.min(), z.min()], [x.max(), y.max(), z.max()]]
    return torch.tensor(box)


def center_and_scale_uvgrid(inp: torch.Tensor, return_center_scale=False):
    bbox = bounding_box_uvgrid(inp)
    diag = bbox[1] - bbox[0]
    scale = 2.0 / max(diag[0], diag[1], diag[2])
    center = 0.5 * (bbox[0] + bbox[1])
    inp[..., :3] -= center
    inp[..., :3] *= scale
    if return_center_scale:
        return inp, center, scale
    return inp


def get_random_rotation():
    """Get a random rotation in 90 degree increments along the canonical axes"""
    axes = [
        np.array([1, 0, 0]),
        np.array([0, 1, 0]),
        np.array([0, 0, 1]),
    ]
    angles = [0.0, 90.0, 180.0, 270.0]
    axis = random.choice(axes)
    angle_radians = np.radians(random.choice(angles))
    return Rotation.from_rotvec(angle_radians * axis)


def rotate_uvgrid(inp, rotation):
    """Rotate the node features in the graph by a given rotation"""
    Rmat = torch.tensor(rotation.as_matrix()).float()
    orig_size = inp[..., :3].size()
    inp[..., :3] = torch.mm(inp[..., :3].view(-1, 3), Rmat).view(
        orig_size
    )  # Points
    inp[..., 3:6] = torch.mm(inp[..., 3:6].view(-1, 3), Rmat).view(
        orig_size
    )  # Normals/tangents
    return inp


INVALID_FONTS = [
    "Bokor",
    "Lao Muang Khong",
    "Lao Sans Pro",
    "MS Outlook",
    "Catamaran Black",
    "Dubai",
    "HoloLens MDL2 Assets",
    "Lao Muang Don",
    "Oxanium Medium",
    "Rounded Mplus 1c",
    "Moul Pali",
    "Noto Sans Tamil",
    "Webdings",
    "Armata",
    "Koulen",
    "Yinmar",
    "Ponnala",
    "Noto Sans Tamil",
    "Chenla",
    "Lohit Devanagari",
    "Metal",
    "MS Office Symbol",
    "Cormorant Garamond Medium",
    "Chiller",
    "Give You Glory",
    "Hind Vadodara Light",
    "Libre Barcode 39 Extended",
    "Myanmar Sans Pro",
    "Scheherazade",
    "Segoe MDL2 Assets",
    "Siemreap",
    "Signika SemiBold" "Taprom",
    "Times New Roman TUR",
    "Playfair Display SC Black",
    "Poppins Thin",
    "Raleway Dots",
    "Raleway Thin",
    "Segoe MDL2 Assets",
    "Segoe MDL2 Assets",
    "Spectral SC ExtraLight",
    "Txt",
    "Uchen",
    "Yinmar",
    "Almarai ExtraBold",
    "Fasthand",
    "Exo",
    "Freckle Face",
    "Montserrat Light",
    "Inter",
    "MS Reference Specialty",
    "MS Outlook",
    "Preah Vihear",
    "Sitara",
    "Barkerville Old Face",
    "Bodoni MT" "Bokor",
    "Fasthand",
    "HoloLens MDL2 Assests",
    "Libre Barcode 39",
    "Lohit Tamil",
    "Marlett",
    "MS outlook",
    "MS office Symbol Semilight",
    "MS office symbol regular",
    "Ms office symbol extralight",
    "Ms Reference speciality",
    "Segoe MDL2 Assets",
    "Siemreap",
    "Sitara",
    "Symbol",
    "Wingdings",
    "Metal",
    "Ponnala",
    "Webdings",
    "Souliyo Unicode",
    "Aguafina Script",
    "Yantramanav Black",
    # "Yaldevi",
    # Taprom,
    # "Zhi Mang Xing",
    # "Taviraj",
    # "SeoulNamsan EB",
]


def valid_font(filename):
    for name in INVALID_FONTS:
        if name.lower() in str(filename).lower():
            return False
    return True

In [3]:
from torch.utils.data import Dataset, DataLoader
from torch import FloatTensor
import dgl
from dgl.data.utils import load_graphs
from tqdm import tqdm
from abc import abstractmethod
from torch.nn.utils.rnn import pad_sequence
import torch



from torch.utils.data import Dataset, DataLoader
import dgl
from torch import FloatTensor, stack

class BaseDataset(Dataset):
    @staticmethod
    def num_classes():
        pass

    def __init__(self, X_train, Y_train):
        """
        self.data is a list of dictionaries with keys graph and label
        """
        assert len(X_train) == len(Y_train), "The number of graphs must match the number of labels"
        self.data = [{"graph": graph, "label": label} for graph, label in zip(X_train, Y_train)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample["graph"], sample["label"]  # Returns a tuple of the sample graph and its corresponding label

    def _collate(self, batch):
        graphs, labels = zip(*batch)
        batched_graph = dgl.batch(graphs)
        #print("batched_graph", batched_graph)
        num_nodes_per_graph = [graph.number_of_nodes() for graph in graphs]

        # Create a default padding command vector
        pad_vector = torch.tensor([6] + [-1]*16, dtype=torch.float32)

        # Prepare labels with padding
        padded_labels = []
        for label in labels:
            label_length = label.shape[0]
            if label_length < 60:
                # Calculate how many padding vectors are needed
                padding_count = 60 - label_length
                # Create a tensor of padding vectors
                padding = pad_vector.repeat(padding_count, 1)
                # Concatenate the original label with the padding
                padded_label = torch.cat([torch.tensor(label, dtype=torch.float32), padding], dim=0)
            else:
                padded_label = torch.tensor(label, dtype=torch.float32)
            padded_labels.append(padded_label)

        # Stack all the padded labels into a single tensor
        padded_labels = stack(padded_labels)
        return {"graph": batched_graph, "labels": padded_labels, "num_nodes": num_nodes_per_graph}
        # return {"graph": batched_graph, "labels": padded_labels}


    def get_dataloader(self, batch_size, shuffle=True, num_workers=0):
        return DataLoader(
            self,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=self._collate,
            num_workers=num_workers,  # Can be set to non-zero on Linux
            drop_last=True
        )

In [4]:
import torch
from torch import nn
import torch.nn.functional as F
from dgl.nn.pytorch.conv import NNConv
from dgl.nn.pytorch.glob import MaxPooling
from torch.nn import TransformerDecoder, TransformerDecoderLayer
import math


# Convolutional Layers
def _conv1d(in_channels, out_channels, kernel_size=3, padding=0, bias=False):
    """
    Helper function to create a 1D convolutional layer with batchnorm and LeakyReLU activation

    Args:
        in_channels (int): Input channels
        out_channels (int): Output channels
        kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
        padding (int, optional): Padding size on each side. Defaults to 0.
        bias (bool, optional): Whether bias is used. Defaults to False.

    Returns:
        nn.Sequential: Sequential contained the Conv1d, BatchNorm1d and LeakyReLU layers
    """
    return nn.Sequential(
        nn.Conv1d(
            in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias
        ),
        nn.BatchNorm1d(out_channels),
        nn.LeakyReLU(),
    )


def _conv2d(in_channels, out_channels, kernel_size, padding=0, bias=False):
    """
    Helper function to create a 2D convolutional layer with batchnorm and LeakyReLU activation

    Args:
        in_channels (int): Input channels
        out_channels (int): Output channels
        kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
        padding (int, optional): Padding size on each side. Defaults to 0.
        bias (bool, optional): Whether bias is used. Defaults to False.

    Returns:
        nn.Sequential: Sequential contained the Conv2d, BatchNorm2d and LeakyReLU layers
    """
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            padding=padding,
            bias=bias,
        ),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(),
    )


def _fc(in_features, out_features, bias=False):
    return nn.Sequential(
        nn.Linear(in_features, out_features, bias=bias),
        nn.BatchNorm1d(out_features),
        nn.LeakyReLU(),
    )

class _MLP(nn.Module):
    """"""

    def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
        """
        MLP with linear output
        Args:
            num_layers (int): The number of linear layers in the MLP
            input_dim (int): Input feature dimension
            hidden_dim (int): Hidden feature dimensions for all hidden layers
            output_dim (int): Output feature dimension

        Raises:
            ValueError: If the given number of layers is <1
        """
        super(_MLP, self).__init__()
        self.linear_or_not = True  # default is linear model
        self.num_layers = num_layers
        self.output_dim = output_dim

        if num_layers < 1:
            raise ValueError("Number of layers should be positive!")
        elif num_layers == 1:
            # Linear model
            self.linear = nn.Linear(input_dim, output_dim)
        else:
            # Multi-layer model
            self.linear_or_not = False
            self.linears = torch.nn.ModuleList()
            self.batch_norms = torch.nn.ModuleList()

            self.linears.append(nn.Linear(input_dim, hidden_dim))
            for layer in range(num_layers - 2):
                self.linears.append(nn.Linear(hidden_dim, hidden_dim))
            self.linears.append(nn.Linear(hidden_dim, output_dim))

            # TODO: this could move inside the above loop
            for layer in range(num_layers - 1):
                self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))

    def forward(self, x):
        if self.linear_or_not:
            # If linear model
            return self.linear(x)
        else:
            # If MLP
            h = x
            for i in range(self.num_layers - 1):
                h = F.relu(self.batch_norms[i](self.linears[i](h)))
            return self.linears[-1](h)

class UVNetCurveEncoder(nn.Module):
    def __init__(self, in_channels=6, output_dims=64):
        """
        This is the 1D convolutional network that extracts features from the B-rep edge
        geometry described as 1D UV-grids (see Section 3.2, Curve & surface convolution
        in paper)

        Args:
            in_channels (int, optional): Number of channels in the edge UV-grids. By default
                                         we expect 3 channels for point coordinates and 3 for
                                         curve tangents. Defaults to 6.
            output_dims (int, optional): Output curve embedding dimension. Defaults to 64.
        """
        super(UVNetCurveEncoder, self).__init__()
        self.in_channels = in_channels
        self.conv1 = _conv1d(in_channels, 64, kernel_size=3, padding=1, bias=False)
        self.conv2 = _conv1d(64, 128, kernel_size=3, padding=1, bias=False)
        self.conv3 = _conv1d(128, 256, kernel_size=3, padding=1, bias=False)
        self.final_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = _fc(256, output_dims, bias=False)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, (nn.Linear, nn.Conv1d)):
            torch.nn.init.kaiming_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, x):
        assert x.size(1) == self.in_channels
        batch_size = x.size(0)
        x = x.float()
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.final_pool(x)
        x = x.view(batch_size, -1)
        x = self.fc(x)
        return x


class UVNetSurfaceEncoder(nn.Module):
    def __init__(
        self,
        in_channels=7,
        output_dims=64,
    ):
        """
        This is the 2D convolutional network that extracts features from the B-rep face
        geometry described as 2D UV-grids (see Section 3.2, Curve & surface convolution
        in paper)

        Args:
            in_channels (int, optional): Number of channels in the edge UV-grids. By default
                                         we expect 3 channels for point coordinates and 3 for
                                         surface normals and 1 for the trimming mask. Defaults
                                         to 7.
            output_dims (int, optional): Output surface embedding dimension. Defaults to 64.
        """
        super(UVNetSurfaceEncoder, self).__init__()
        self.in_channels = in_channels
        self.conv1 = _conv2d(in_channels, 64, 3, padding=1, bias=False)
        self.conv2 = _conv2d(64, 128, 3, padding=1, bias=False)
        self.conv3 = _conv2d(128, 256, 3, padding=1, bias=False)
        self.final_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = _fc(256, output_dims, bias=False)
        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            torch.nn.init.kaiming_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, x):
        assert x.size(1) == self.in_channels
        batch_size = x.size(0)
        x = x.float()
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.final_pool(x)
        x = x.view(batch_size, -1)
        x = self.fc(x)
        return x

class _EdgeConv(nn.Module):
    def __init__(
        self,
        edge_feats,
        out_feats,
        node_feats,
        num_mlp_layers=2,
        hidden_mlp_dim=64,
    ):
        """
        This module implements Eq. 2 from the paper where the edge features are
        updated using the node features at the endpoints.

        Args:
            edge_feats (int): Input edge feature dimension
            out_feats (int): Output feature deimension
            node_feats (int): Input node feature dimension
            num_mlp_layers (int, optional): Number of layers used in the MLP. Defaults to 2.
            hidden_mlp_dim (int, optional): Hidden feature dimension in the MLP. Defaults to 64.
        """
        super(_EdgeConv, self).__init__()
        self.proj = _MLP(1, node_feats, hidden_mlp_dim, edge_feats)
        self.mlp = _MLP(num_mlp_layers, edge_feats, hidden_mlp_dim, out_feats)
        self.batchnorm = nn.BatchNorm1d(out_feats)
        self.eps = torch.nn.Parameter(torch.FloatTensor([0.0]))

    def forward(self, graph, nfeat, efeat):
        src, dst = graph.edges()
        proj1, proj2 = self.proj(nfeat[src]), self.proj(nfeat[dst])
        agg = proj1 + proj2
        h = self.mlp((1 + self.eps) * efeat + agg)
        h = F.leaky_relu(self.batchnorm(h))
        return h


class _NodeConv(nn.Module):
    def __init__(
        self,
        node_feats,
        out_feats,
        edge_feats,
        num_mlp_layers=2,
        hidden_mlp_dim=64,
    ):
        """
        This module implements Eq. 1 from the paper where the node features are
        updated using the neighboring node and edge features.

        Args:
            node_feats (int): Input edge feature dimension
            out_feats (int): Output feature deimension
            node_feats (int): Input node feature dimension
            num_mlp_layers (int, optional): Number of layers used in the MLP. Defaults to 2.
            hidden_mlp_dim (int, optional): Hidden feature dimension in the MLP. Defaults to 64.
        """
        super(_NodeConv, self).__init__()
        self.gconv = NNConv(
            in_feats=node_feats,
            out_feats=out_feats,
            edge_func=nn.Linear(edge_feats, node_feats * out_feats),
            aggregator_type="sum",
            bias=False,
        )
        self.batchnorm = nn.BatchNorm1d(out_feats)
        self.mlp = _MLP(num_mlp_layers, node_feats, hidden_mlp_dim, out_feats)
        self.eps = torch.nn.Parameter(torch.FloatTensor([0.0]))

    def forward(self, graph, nfeat, efeat):
        h = (1 + self.eps) * nfeat
        h = self.gconv(graph, h, efeat)
        h = self.mlp(h)
        h = F.leaky_relu(self.batchnorm(h))
        return h


class UVNetGraphEncoder(nn.Module):
    def __init__(
        self,
        input_dim,
        input_edge_dim,
        output_dim,
        hidden_dim=64,
        learn_eps=True,
        num_layers=3,
        num_mlp_layers=2,
    ):
        """
        This is the graph neural network used for message-passing features in the
        face-adjacency graph.

        Args:
            input_dim ([type]): [description]
            input_edge_dim ([type]): [description]
            output_dim ([type]): [description]
            hidden_dim (int, optional): [description]. Defaults to 64.
            learn_eps (bool, optional): [description]. Defaults to True.
            num_layers (int, optional): [description]. Defaults to 3.
            num_mlp_layers (int, optional): [description]. Defaults to 2.
        """
        super(UVNetGraphEncoder, self).__init__()
        self.num_layers = num_layers
        self.learn_eps = learn_eps

        # List of layers for node and edge feature message passing
        self.node_conv_layers = torch.nn.ModuleList()
        self.edge_conv_layers = torch.nn.ModuleList()

        for layer in range(self.num_layers - 1):
            node_feats = input_dim if layer == 0 else hidden_dim
            edge_feats = input_edge_dim if layer == 0 else hidden_dim
            self.node_conv_layers.append(
                _NodeConv(
                    node_feats=node_feats,
                    out_feats=hidden_dim,
                    edge_feats=edge_feats,
                    num_mlp_layers=num_mlp_layers,
                    hidden_mlp_dim=hidden_dim,
                ),
            )
            self.edge_conv_layers.append(
                _EdgeConv(
                    edge_feats=edge_feats,
                    out_feats=hidden_dim,
                    node_feats=node_feats,
                    num_mlp_layers=num_mlp_layers,
                    hidden_mlp_dim=hidden_dim,
                )
            )

        # Linear function for graph poolings of output of each layer
        # which maps the output of different layers into a prediction score
        self.linears_prediction = torch.nn.ModuleList()

        for layer in range(num_layers):
            if layer == 0:
                self.linears_prediction.append(nn.Linear(input_dim, output_dim))
            else:
                self.linears_prediction.append(nn.Linear(hidden_dim, output_dim))

        self.drop1 = nn.Dropout(0.3)
        self.drop = nn.Dropout(0.5)
        self.pool = MaxPooling()

    def forward(self, g, h, efeat):
        hidden_rep = [h]
        he = efeat

        for i in range(self.num_layers - 1):
            # Update node features
            h = self.node_conv_layers[i](g, h, he)
            # Update edge features
            he = self.edge_conv_layers[i](g, h, he)
            hidden_rep.append(h)
        # print(f'hidden_rep is {type(hidden_rep)}')
        # Use the node embeddings from the last layer
        node_embeddings = hidden_rep[-1]
        # print(f'node_embeddings_0 = {node_embeddings.shape}')
        node_embeddings = self.drop1(node_embeddings)
        # print(f'node_embeddings_1 = {node_embeddings.shape}')

        # Optional: Perform pooling to get a graph-level representation
        graph_representation = 0
        for i, h in enumerate(hidden_rep):
            pooled_h = self.pool(g, h)
            graph_representation += self.drop(self.linears_prediction[i](pooled_h))

        # print(f'node_embeddings_2.shape = {node_embeddings.shape}')
        # print(f'graph_representation.shape = {graph_representation.shape}')

        return node_embeddings, graph_representation

class CADEncoder(nn.Module):
  def __init__(self, crv_emb_dim=64, srf_emb_dim=64, graph_emb_dim=128, dropout=0.3):
    super(CADEncoder, self).__init__()
    self.curv_encoder = UVNetCurveEncoder(in_channels=10, output_dims=crv_emb_dim) # in_channels originally 6
    self.surf_encoder = UVNetSurfaceEncoder(in_channels=10, output_dims=srf_emb_dim)
    self.graph_encoder = UVNetGraphEncoder(srf_emb_dim, crv_emb_dim, graph_emb_dim)

  def forward(self, batched_graph):
    input_crv_feat = batched_graph.edata["x"]
    input_srf_feat = batched_graph.ndata["x"]
    hidden_crv_feat = self.curv_encoder(input_crv_feat)
    hidden_srf_feat = self.surf_encoder(input_srf_feat)
    node_emb, graph_emb = self.graph_encoder(batched_graph, hidden_srf_feat, hidden_crv_feat)
    return node_emb, graph_emb


class PositionalEncodingLUT(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=250):
        super(PositionalEncodingLUT, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(0, max_len, dtype=torch.long).unsqueeze(1)
        self.register_buffer('position', position)

        self.pos_embed = nn.Embedding(max_len, d_model)

        self._init_embeddings()

    def _init_embeddings(self):
        nn.init.kaiming_normal_(self.pos_embed.weight, mode="fan_in")

    def forward(self, x):
        batch_size, seq_length, _ = x.shape

        # Get positional encodings for the sequence length
        pos = self.position[:seq_length]  # This will have shape [seq_length, 1]

        # Expand positional encodings to cover the whole batch
        pos = pos.expand(-1, batch_size).contiguous()  # Reshape to [seq_length, batch_size]
        pos = pos.transpose(0, 1)  # Transpose to [batch_size, seq_length]

        # Retrieve positional embeddings
        pos_embeddings = self.pos_embed(pos)  # This should now be [batch_size, seq_length, d_model]

        # Element-wise addition of embeddings to input x
        x = x + pos_embeddings

        # Apply dropout and return
        return self.dropout(x)

class CADEmbedding(nn.Module):
    """Embedding: positional embed + command embed + parameter embed + group embed (optional)"""
    def __init__(self, n_commands=7, d_model=64, n_args=16, args_dim=257, seq_len=60):
        super(CADEmbedding, self).__init__()
        self.command_embed = nn.Embedding(n_commands, d_model)
        self.arg_embed = nn.Embedding(args_dim, d_model, padding_idx=0)
        self.embed_fcn = nn.Linear(d_model * n_args, d_model)
        self.pos_encoding = PositionalEncodingLUT(d_model, max_len=seq_len+2)

    def forward(self, commands, args):
        S, N = commands.shape
        src = self.command_embed(commands.long()) + \
              self.embed_fcn(self.arg_embed((args + 1).long()).view(S, N, -1))
        src = self.pos_encoding(src)
        return src

class FusionModule(nn.Module):
    def __init__(self, latent_size, command_embedding_size=64, hidden_size=128, output_size=64):
        super(FusionModule, self).__init__()
        self.fc1 = nn.Linear(latent_size + command_embedding_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, latent, command_embedding):
        # print("latent dimension", latent.shape)
        # print("command embedding", command_embedding.shape)
        command_embedding = command_embedding.unsqueeze(1)
        # print('command embedding 1', command_embedding.shape)
        combined = torch.cat((latent, command_embedding), dim=2)
        # print('combined shape', combined.shape)
        x = self.fc1(combined)
        # print('x1 shape', x.shape)
        x = F.relu(x)
        x = self.fc2(x)
        # print('x2 shape', x.shape)
        x = F.relu(x)
        # print('fusion output', x.shape)
        return x



from torch import Tensor
from typing import Optional

# class CausalTransformerDecoder(nn.TransformerDecoder):
#     def forward(
#         self,
#         tgt: Tensor,
#         memory: Optional[Tensor] = None,
#         cache: Optional[Tensor] = None,
#         memory_mask: Optional[Tensor] = None,
#         tgt_key_padding_mask: Optional[Tensor] = None,
#         memory_key_padding_mask: Optional[Tensor] = None,
#     ) -> Tensor:
#         print("tgt in TransformerDecoder ",tgt.shape)
#         print("memory in TransformerDecoder ",memory.shape)

#         if self.training:
#             if cache is not None:
#                 raise ValueError("cache parameter should be None in training mode")
#             for mod in self.layers:
#                 tgt = mod(
#                     tgt,
#                     memory,
#                     tgt_mask=None,
#                     memory_mask=memory_mask,
#                     tgt_key_padding_mask=tgt_key_padding_mask,
#                     memory_key_padding_mask=memory_key_padding_mask,
#                 )
#             return tgt

#         new_token_cache = []

#         for i, mod in enumerate(self.layers):
#             tgt = mod(tgt, memory)
#             new_token_cache.append(tgt)
#             if cache is not None:
#                 tgt = torch.cat([cache[i], tgt], dim=0)

#         if cache is not None:
#             new_cache = torch.cat([cache, torch.stack(new_token_cache, dim=0)], dim=1)
#         else:
#             new_cache = torch.stack(new_token_cache, dim=0)

#         # Return only the last token's prediction and the new cache
#         print("tgt[-1:] from the transformer decoder block ", tgt[-1:])
#         print("new_cache from the transformer decoder block ", new_cache)
#         return tgt[-1:], new_cache




# class CausalTransformerDecoderLayer(nn.TransformerDecoderLayer):
#     def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu"):
#         super().__init__(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation)

#     def forward(
#         self,
#         tgt: Tensor,
#         memory: Optional[Tensor] = None,
#         tgt_mask: Optional[Tensor] = None,
#         memory_mask: Optional[Tensor] = None,
#         tgt_key_padding_mask: Optional[Tensor] = None,
#         memory_key_padding_mask: Optional[Tensor] = None,
#     ) -> Tensor:
#         if self.training:
#             # In training mode, follow the standard procedure including masking
#             print("tgt", tgt.shape)
#             print("memory", memory.shape)
#             returned = super().forward(
#                 tgt,
#                 memory,
#                 tgt_mask=tgt_mask,
#                 memory_mask=memory_mask,
#                 tgt_key_padding_mask=tgt_key_padding_mask,
#                 memory_key_padding_mask=memory_key_padding_mask,
#             )
#             print("returned: ", returned.shape)
#             return returned
#         else:
#             # In evaluation mode, proceed with the autoregressive manner
#             tgt_last_tok = tgt[-1:, :, :]  # Handle the last token from the sequence only

#             # Perform self-attention on the last token
#             tgt_last_tok = self.self_attn(
#                 tgt_last_tok,
#                 tgt,
#                 tgt,
#                 attn_mask=None,
#                 key_padding_mask=tgt_key_padding_mask,
#             )[0] + tgt_last_tok
#             tgt_last_tok = self.norm1(tgt_last_tok)

#             # Perform cross-attention with the memory (encoder's output)
#             if memory is not None:
#                 tgt_last_tok = self.multihead_attn(
#                     tgt_last_tok,
#                     memory,
#                     memory,
#                     attn_mask=memory_mask,
#                     key_padding_mask=memory_key_padding_mask,
#                 )[0] + tgt_last_tok
#                 tgt_last_tok = self.norm2(tgt_last_tok)

#             # Pass through the final feed-forward network
#             tgt_last_tok = self.linear2(self.dropout(self.activation(self.linear1(tgt_last_tok)))) + tgt_last_tok
#             tgt_last_tok = self.norm3(tgt_last_tok)

class CausalTransformerDecoderLayer(nn.TransformerDecoderLayer):
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        if self.training:
            return super().forward(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)
        else:
            tgt_last_tok = tgt[-1:, :, :]
            tgt_last_tok = self.self_attn(tgt_last_tok, tgt, tgt)[0] + tgt_last_tok
            tgt_last_tok = self.norm1(tgt_last_tok)
            if memory is not None:
                tgt_last_tok = self.multihead_attn(tgt_last_tok, memory, memory)[0] + tgt_last_tok
                tgt_last_tok = self.norm2(tgt_last_tok)
            tgt_last_tok = self.linear2(self.dropout(self.activation(self.linear1(tgt_last_tok)))) + tgt_last_tok
            tgt_last_tok = self.norm3(tgt_last_tok)
            return tgt_last_tok


class CausalTransformerDecoder(nn.TransformerDecoder):
    def forward(self, tgt, memory=None, cache=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # print("tgt in TransformerDecoder ", tgt.shape)
        # print("memory in TransformerDecoder ", memory.shape)
        if self.training:
            if cache is not None:
                raise ValueError("cache parameter should be None in training mode")
            for mod in self.layers:
                tgt = mod(tgt, memory, tgt_mask=None, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
            return tgt
        else:
            new_token_cache = []
            for i, mod in enumerate(self.layers):
                tgt = mod(tgt, memory)
                new_token_cache.append(tgt)
                if cache is not None:
                    tgt = torch.cat([cache[i], tgt], dim=0)
            if cache is not None:
                new_cache = torch.cat([cache, torch.stack(new_token_cache, dim=0)], dim=1)
            else:
                new_cache = torch.stack(new_token_cache, dim=0)
            # print("tgt[-1:] from the transformer decoder block ", tgt[-1:])
            # print("new_cache from the transformer decoder block ", new_cache)
            return tgt[-1:], new_cache



    @staticmethod
    def _generate_causal_mask(sz, device):
        mask = torch.full((sz, sz), float('-inf'))
        mask = torch.triu(mask, diagonal=1)
        return mask.to(device)


    @staticmethod
    def _generate_causal_mask(sz, device):
        """
        Generates a causal mask to hide future tokens for autoregressive tasks.
        """
        mask = torch.full((sz, sz), float('-inf'))
        mask = torch.triu(mask, diagonal=1)
        return mask.to(device)


# class CADDecoder(nn.Module):
#     def __init__(self, d_model=64, nhead=8, num_decoder_layers=4, dim_feedforward=2048, dropout=0.1, activation="relu", num_commands=7, max_seq_len=5000, num_parameters=16, param_cat=257):
#         super(CADDecoder, self).__init__()
#         self.param_cat = param_cat
#         # Embedding
#         self.cad_command_embedding = CADEmbedding(d_model=d_model, n_commands=num_commands, n_args=num_parameters, seq_len=max_seq_len)

#         # Fusion module
#         self.fusion_module = FusionModule(latent_size=128)

#         # Decoder
#         decoder_layer = CausalTransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation)
#         self.transformer_decoder = CausalTransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

#         # Final output
#         self.output_layer1 = nn.Linear(d_model, num_commands)  # t_i
#         self.output_layer2 = nn.Linear(d_model, num_parameters * param_cat)  # p_i

#     def forward(self, node_embeddings, graph_embeddings, command_seq, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
#         # Command embedding
#         print("command seq ", command_seq.shape)
#         commands_count = command_seq.shape[1]
#         commands = command_seq[:, :, 0]
#         args = command_seq[:, :, 1:]
#         construct_embed = self.cad_command_embedding(commands, args)

#         # Fusion module (Modified the Fusion Module)
#         fusion_outputs = None
#         fusion_outputs = []
#         for i in range(commands_count):
#             fusion_output = self.fusion_module(graph_embeddings, construct_embed[:, i, :])
#             fusion_outputs.append(fusion_output)

#         fusion_outputs = torch.cat(fusion_outputs, dim=1)

#         # Reshape node embeddings to [number of nodes, batch size, embedding dimension]
#         batch_size, num_nodes, embed_dim = node_embeddings.shape
#         node_embeddings = node_embeddings.transpose(0, 1)  # [number of nodes, batch size, embedding dimension]

#         # Decoder
#         print("fusion_outputs", fusion_outputs.shape)
#         decoder_output = self.transformer_decoder(
#             tgt=fusion_outputs.transpose(0, 1),  # Ensure tgt shape is [sequence length, batch size, embedding dimension]
#             memory=node_embeddings,
#             memory_mask=memory_mask,
#             tgt_key_padding_mask=tgt_key_padding_mask,
#             memory_key_padding_mask=memory_key_padding_mask
#         )
#         decoder_output = decoder_output.permute(1, 0, 2)
#         print("decoder_output: ",decoder_output.shape)
#         # Final output layers
#         output1 = self.output_layer1(decoder_output)
#         print('output1.1 =', output1.shape)
#         # output1 = output1.transpose(0,1)
#         print('output1.2 =', output1.shape)
#         output1 = F.softmax(output1, dim=-1)  # t_i
#         print('output1.3 =', output1.shape)

#         output2 = self.output_layer2(decoder_output)
#         print('output2.1 =', output2.shape)
#         output2 = output2.view(batch_size, -1, self.param_cat)
#         print('output2.2 =', output2.shape)


#         output2 = F.softmax(output2, dim=2)  # p_i
#         print("output1 ",output1.shape)
#         print("output2 ", output2.shape)
#         return output1, output2

class CADDecoder(nn.Module):
    def __init__(self, d_model=64, nhead=8, num_decoder_layers=4, dim_feedforward=2048, dropout=0.1, activation="relu", num_commands=7, max_seq_len=5000, num_parameters=16, param_cat=257):
        super(CADDecoder, self).__init__()
        self.param_cat = param_cat
        # Embedding
        self.cad_command_embedding = CADEmbedding(d_model=d_model, n_commands=num_commands, n_args=num_parameters, seq_len=max_seq_len)

        # Fusion module
        self.fusion_module = FusionModule(latent_size=128)

        # Decoder
        decoder_layer = CausalTransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation)
        self.transformer_decoder = CausalTransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        # Final output
        self.output_layer1 = nn.Linear(d_model, num_commands)  # t_i
        self.output_layer2 = nn.Linear(d_model, num_parameters * param_cat)  # p_i

    def forward(self, node_embeddings, graph_embeddings, command_seq, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # Command embedding
        # print("command seq ", command_seq.shape)
        commands_count = command_seq.shape[1]
        commands = command_seq[:, :, 0]
        args = command_seq[:, :, 1:]
        construct_embed = self.cad_command_embedding(commands, args)

        # Fusion module
        fusion_outputs = []
        # print('commands count=', commands_count)
        for i in range(commands_count):
            fusion_output = self.fusion_module(graph_embeddings, construct_embed[:, i, :])
            # print('fusion_output shape =', fusion_output.shape)
            fusion_outputs.append(fusion_output)
        # print('fusion_outputs b4 len =', len(fusion_outputs))
        # print('fusion_outputs b4 =', fusion_outputs[0].shape)
        fusion_outputs = torch.cat(fusion_outputs, dim=1)  # [batch_size, commands_count, d_model]
        # print('fusion_outputs af =', fusion_outputs.shape)

        # Reshape node embeddings to [number of nodes, batch size, embedding dimension]
        batch_size, num_nodes, embed_dim = node_embeddings.shape
        node_embeddings = node_embeddings.transpose(0, 1)  # [number of nodes, batch size, embedding dimension]

        # Initialize decoder outputs
        all_decoder_outputs = []

        # Autoregressive decoding
        for t in range(commands_count):
            if t == 0:
                current_input = fusion_outputs[:, :1, :]  # [batch_size, 1, d_model]
                # print('t==0, current_input shape =', current_input.shape)
            else:
                current_input = torch.cat((fusion_outputs[:, :t, :], last_output), dim=1)
            #     print('t==else, current_input shape =', current_input.shape)
            # print('current_input shape =', current_input.shape)

            current_input = current_input.transpose(0, 1)  # [sequence_length, batch_size, d_model]

            decoder_output = self.transformer_decoder(
                tgt=current_input,
                memory=node_embeddings,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask
            )

            decoder_output = decoder_output.transpose(0, 1)  # [batch_size, sequence_length, d_model]

            # Get the last output token
            last_output = decoder_output[:, -1, :].unsqueeze(1)  # [batch_size, 1, d_model]

            all_decoder_outputs.append(last_output)

        all_decoder_outputs = torch.cat(all_decoder_outputs, dim=1)  # [batch_size, commands_count, d_model]

        # Final output layers
        output1 = self.output_layer1(all_decoder_outputs)
        # print('output1.1 =', output1.shape)
        output1 = F.softmax(output1[:, -1:, :], dim=-1)  # t_i
        # print('output1.2 =', output1.shape)

        output2 = self.output_layer2(all_decoder_outputs)
        # print('output2.1 =', output2.shape)
        output2 = output2.view(batch_size, commands_count, -1, self.param_cat)[:, -1, :, :]  # Reshape and select last
        # print('output2.2 =', output2.shape)
        output2 = F.softmax(output2, dim=2)  # p_i
        # print("output1 ", output1.shape)
        # print("output2 ", output2.shape)
        return output1, output2


class CADParser(nn.Module):
    def __init__(self):
        super(CADParser, self).__init__()

        # encoder
        self.cad_encoder = CADEncoder()

        # decoder
        self.cad_decoder = CADDecoder()

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [12]:
import dgl
import numpy as np
import torch

from torch import optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import os
import torch.nn.functional as F
# from basedataset import BaseDataset

from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader

# from models import CADParser



graphs, label_dict = dgl.load_graphs("/content/drive/My Drive/DeepCADDataset/all_graphs.bin")
# print(graphs)
# print(label_dict)

npz = np.load("/content/drive/My Drive/DeepCADDataset/all_npz.npz")


seednumber=2024
torch.manual_seed(seednumber)
torch.cuda.manual_seed(seednumber)
np.random.seed(seednumber)

# graphs, label_dict = dgl.load_graphs("data/all_graphs.bin")
# print("number of graphs: ",len(graphs))
# print("detail of one of the graph: ",graphs[0])
# npz = np.load("data/all_npz.npz")
# print("type: ",type(npz))
# print("length: ",len(npz))
# print(npz)
# print(npz['vec_0'])

# import dgl

max_nodes = 0

for graph in graphs:
    num_nodes = graph.number_of_nodes()
    if num_nodes > max_nodes:
        max_nodes = num_nodes

print("Maximum number of nodes in any graph:", max_nodes)

Y = []

# Iterate over the sorted keys to maintain the order
for key in npz.keys():
    # print(key)
    Y.append(npz[key])

X = graphs

from sklearn.model_selection import train_test_split

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, random_state=42) # train size = 6353

import torch
from torch import optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import os

torch.set_printoptions(threshold=10_000)
# Path where checkpoints are stored
checkpoint_dir = 'checkpoint'
loss_log_file_path = os.path.join(checkpoint_dir, 'loss_log.txt')

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_epochs = 100
initial_learning_rate = 1e-3
batch_size = 8  # Adjust according to your GPU memory
warmup_epochs = 10
root_dir = ""

# Initialize the dataset and data loader
dataset = BaseDataset(X_train, Y_train)
data_loader = dataset.get_dataloader(batch_size)

# Initialize the model
model = CADParser().to(device)

# Loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=initial_learning_rate)

# Gradual warmup and learning rate decay
scheduler = LambdaLR(
    optimizer,
    lr_lambda=lambda epoch: 0.9**(epoch // 30) * min((epoch + 1) / warmup_epochs, 1)
)

# Function to find the latest checkpoint file
def find_latest_checkpoint(checkpoint_dir):
    checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
    if checkpoint_files:
        latest_file = max(checkpoint_files, key=lambda x: int(x.strip('model_epoch_').strip('.pth')))
        return os.path.join(checkpoint_dir, latest_file)
    return None
start_epoch = 0

# Load the latest checkpoint if it exists
# latest_checkpoint = find_latest_checkpoint(checkpoint_dir)

# if latest_checkpoint:
#     print(f"Loading checkpoint '{latest_checkpoint}'")
#     checkpoint = torch.load(latest_checkpoint, map_location=device)
#     model.load_state_dict(checkpoint['model_state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#     scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
#     start_epoch = checkpoint['epoch']
#     print(f"Resuming training from epoch {start_epoch}")

loss_list = []

X_all_node_counts = []
for batch in data_loader:
    num_nodes_batch = batch['num_nodes']
    X_all_node_counts.append(num_nodes_batch)

# Training loop
for epoch in range(start_epoch, num_epochs):
    model.train()
    total_loss = 0
    print_out = True
    for batch_idx, batch in enumerate(data_loader):
        graphs, sequences = batch['graph'].to(device), batch['labels'].to(device)
        num_nodes_per_graph = batch['num_nodes']
        # print("num_nodes_per_graph", num_nodes_per_graph)

        optimizer.zero_grad()
        batch_loss = 0

        node_embeddings, graph_embeddings = model.cad_encoder(graphs)
        decoder_input_seq = sequences[:, 0:1, :]  # Start with the first vector (START token)

        batch_num_nodes = node_embeddings.shape[0]
        # print("len num_nodes_per_graph", len(num_nodes_per_graph))
        padded_node_embeddings = torch.zeros(len(num_nodes_per_graph), 66, 64).to(device)
        padded_node_embeddings.to(device)
        graph_embeddings.to(device)
        # decoder_input_seq.to(device)

        start_idx = 0
        for i, num_nodes in enumerate(num_nodes_per_graph):
            end_idx = start_idx + num_nodes
            padded_node_embeddings[i, :num_nodes] = node_embeddings[start_idx:end_idx]
            start_idx = end_idx
        # padded_node_embeddings: [batch size * padded node numbers * embd dim 64]

        graph_embeddings = torch.unsqueeze(graph_embeddings, 1).to(device) # transform from [batch size * embd dim] to [batch size * 1 * embd dim 128]


        for t in range(1, sequences.size(1)):
            # sequences: [batch size * sequence length (padding 60) * 17 len raw vector]
            # print(f"-------------------------{t}-----------------------------")
            # print("decoder input seq: ", decoder_input_seq.shape) # [batch size * sequence length [t] * len-17 raw vector]
            decoder_output_t_i, decoder_output_p_i = model.cad_decoder(padded_node_embeddings, graph_embeddings, decoder_input_seq) # two probability vectors: t_i, v_i
            # print("decoder_output_t_i ", decoder_output_t_i.shape) # [batch size * sequence length [t] * len-17 one hot vector]
            # print("decoder_output_p_i ", decoder_output_p_i.shape) # [batch size * sequence length [t] * len-257 one hot vector]

            gt_t = sequences[:, :t, :] # ground truth seq in raw 17-len vector form

            # command_type_t = sequences[:, :t, 0] # get command type
            command_type_t = sequences[:, t, 0]
            # print('command type true', command_type_t)
            command_type_t = command_type_t.long()
            # print("command_type_t: ", command_type_t.shape)
            # param_t = sequences[:, :t, 1:] # get parameters
            param_t = sequences[:, t, 1:]
            # print('param true', param_t)
            param_t_mapped = param_t + 1
            # param_t_mapped = param_t_mapped.long()  # Ensure param_t_mapped is of type Long
            # param_t_mapped = torch.flatten(param_t_mapped, start_dim=1, end_dim=2)
            # print("param_t_mapped: ", param_t_mapped.shape)

            # decoder_output_t_i = decoder_output_t_i.view(-1, 7)  # Flatten to [batch_size*sequence_length, num_classes]
            # command_type_t = command_type_t.view(-1)  # Flatten to [batch_size*sequence_length]
            # decoder_output_p_i = decoder_output_p_i.view(-1, 257)  # Flatten to [batch_size*sequence_length*num_parameters, num_classes]
            # param_t_mapped = param_t_mapped.view(-1)  # Flatten to [batch_size*sequence_length*num_parameters]
            true_t_i = F.one_hot(command_type_t, num_classes=7)
            # print("true_t_i", true_t_i.shape)
            true_p_i = F.one_hot(param_t_mapped.long(), num_classes=257)
            # print("true_p_i ", true_p_i.shape)

            # Calculate losses
            t_i_loss = criterion(decoder_output_t_i.squeeze(), true_t_i.float())
            p_i_loss = criterion(decoder_output_p_i, true_p_i.float())

            # print(f't_i_loss: {t_i_loss.item()}')
            # print(f'p_i_loss: {p_i_loss.item()}')


            loss = t_i_loss + p_i_loss # compare decoder output with the true next output

            # convert softmax distribution back into valid token
            _, command_type_pred_next = torch.max(decoder_output_t_i, dim=2, keepdim=True)

            _, command_args_pred_next = torch.max(decoder_output_p_i, dim=2, keepdim = True)

            command_args_pred_next = command_args_pred_next - 1
            # print("command_type_pred_next", command_type_pred_next.shape)
            # print("command_args_pred_next", command_args_pred_next.shape)
            next_token_pred = torch.cat((command_type_pred_next, command_args_pred_next), dim=1)

            next_token_pred = next_token_pred.transpose(1,2)
            # print("decoder_input_seq", decoder_input_seq.shape)
            next_token_pred = next_token_pred.view(batch_size, -1, 17)
            # print("next_token_pred ", next_token_pred.shape)
            decoder_input_seq = torch.cat((decoder_input_seq, next_token_pred), dim=1)
            batch_loss += loss

        if print_out:
            print('batch_idx:', batch_idx)
            # print('pred:', decoder_input_seq[0])
            # print('true:', sequences[0])

        # Normalize loss by seq length
        batch_loss /= (sequences.size(1) - 1)
        batch_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += batch_loss.item()

        scheduler.step()  # Update the learning rate
        # print(f'batch_idx = {batch_idx}')
        # print(f'epoch {epoch+1}, training batch_loss = {batch_loss}')
        torch.cuda.empty_cache()



    loss_list.append(total_loss / 49)
    with open(loss_log_file_path, 'a') as f:
        f.write(str(total_loss / 49))
    print(f'----------------------EPOCH {epoch+1}, TOTAL LOSS = {total_loss / 49}')
        # if epoch % 10 == 0:
        #     print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(data_loader)}], Loss: {loss.item():.4f}')

    # Save checkpoint at the end of each epoch
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss_list,
    }
    torch.cuda.empty_cache()
    if (epoch+1) % 10 == 0:
        torch.save(checkpoint, f'{checkpoint_dir}/model_epoch_{epoch+1}.pth')

print('Finished Training')

import matplotlib.pyplot as plt

# Create a plot of the losses
plt.figure(figsize=(10, 5))
plt.plot(loss_list, label='Loss per Epoch')
plt.title('Training Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Save the plot to a PNG file
plt.savefig('plot/loss_plot.png')
plt.close()  # Close the plot explicitly after saving to free up memory


Maximum number of nodes in any graph: 66


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 22.17 GiB of which 10.88 MiB is free. Process 28898 has 22.15 GiB memory in use. Of the allocated memory 21.88 GiB is allocated by PyTorch, and 18.36 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [14]:
torch.cuda.empty_cache()