In [34]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GraphConv, dense_diff_pool, dense_mincut_pool, GATConv
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_batch, to_dense_adj

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
loader = DataLoader(dataset, batch_size=32, shuffle=True)

class GATWithMinCutPooling(torch.nn.Module):
    def __init__(
        self,
        num_features,
        num_classes,
        num_hidden=8,
        heads=8,
        dropout=0.6,
        name="GAT",
    ):  
        super(GATWithMinCutPooling, self).__init__()
        self.name = name

        self.conv1 = GATConv(
            in_channels=num_features,
            out_channels=num_hidden,
            heads=heads,
            dropout=dropout,
        )

        self.conv2 = GATConv(
            in_channels=num_hidden * heads,
            out_channels=num_classes,
            heads=1,
            dropout=dropout,
        )
        
        self.num_classes = num_classes
        self.feature_transform = torch.nn.Linear(num_classes, num_features)
        
    def forward(self, data):
        x = F.dropout(data.x, p=0.6, training=self.training)
        x = self.conv1(x, data.edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        out = self.conv2(x, data.edge_index)

        # Convert to dense format
        x_dense, mask = to_dense_batch(x, data.batch)
        adj_dense = to_dense_adj(data.edge_index, data.batch)

        # Assignment matrix computation
        num_clusters = self.num_classes
        transform = torch.nn.Linear(out.size(1), num_clusters)
        s = transform(out)
        s = torch.softmax(s, dim=1)

        # Ensure 's' is properly masked and compatible in size
        s = s[mask.bool()].view(mask.size(0), -1, num_clusters)

        # MinCut pooling
        x_pool, adj_pool, mincut_loss, ortho_loss = dense_mincut_pool(x_dense, adj_dense, s, mask=mask)
        x_pool = x_pool.squeeze(0)

        return out, x_pool, adj_pool, mincut_loss, ortho_loss


In [35]:
def reconstruction_loss(original_x, reconstructed_x):
    # Mean Squared Error is a common choice for reconstruction tasks
    return F.mse_loss(reconstructed_x, original_x)


In [36]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GATWithMinCutPooling(dataset.num_node_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train():
    model.train()
    total_loss = 0
    for data in loader:  # Assuming 'loader' is your DataLoader
        data = data.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        out, x_pool, adj_pool, mincut_loss, ortho_loss = model(data)

        # Compute primary task-specific loss (modify this according to your task)
        # For example, if it's a classification task:
        task_loss = F.nll_loss(out, data.y)  # Assuming you have labels 'data.y'

        # Combine losses
        # You might want to balance these losses with different weights
        loss = task_loss + mincut_loss + ortho_loss

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


# Training loop execution
for epoch in range(1, 51):
    loss = train()
    print(f'Epoch {epoch}, Loss: {loss:.4f}')


IndexError: The shape of the mask [32, 64] at index 0 does not match the shape of the indexed tensor [1023, 6] at index 0