In [1]:
import random

import numpy as np
import torch
from sklearn.model_selection import train_test_split
import torch.nn.functional as F

from models.can.can import CAN
from models.utils.sparse import from_sparse

torch.manual_seed(0)
np.random.seed(0)

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
import pickle

with open("data/qm9_test_cell_complex.pkl", "rb") as f:
    cc_list = pickle.load(f)


In [4]:
from cell_loader import CCDataset
dataset = CCDataset(cc_list)

In [5]:
x_0_list = [data[0] for data in dataset]
x_1_list = [data[1] for data in dataset]
y_list = [random.choice([0, 1]) for _ in range(30)]

lower_neighborhood_list = []
upper_neighborhood_list = []
adjacency_0_list = []

for cell_complex in cc_list:
    adjacency_0 = cell_complex.adjacency_matrix(rank=0)
    adjacency_0 = torch.from_numpy(adjacency_0.todense()).to_sparse()
    adjacency_0_list.append(adjacency_0)

    lower_neighborhood_t = cell_complex.down_laplacian_matrix(rank=1, signed=False)
    lower_neighborhood_t = from_sparse(lower_neighborhood_t)
    lower_neighborhood_list.append(lower_neighborhood_t)

    try:
        upper_neighborhood_t = cell_complex.up_laplacian_matrix(rank=1, signed=False)
        upper_neighborhood_t = from_sparse(upper_neighborhood_t)
    except:
        upper_neighborhood_t = np.zeros(
            (lower_neighborhood_t.shape[0], lower_neighborhood_t.shape[0])
        )
        upper_neighborhood_t = torch.from_numpy(upper_neighborhood_t).to_sparse()

    upper_neighborhood_list.append(upper_neighborhood_t)

In [6]:
in_channels_0 = x_0_list[0].shape[-1]
in_channels_1 = x_1_list[0].shape[-1]
#in_channels_2 = 5

In [7]:
in_channels_0, in_channels_1

(4, 3)

In [8]:
model = CAN(
    in_channels_0,
    in_channels_1,
    16,
    dropout=0.5,
    heads=3,
    num_classes=2,
    n_layers=2,
    att_lift=True,
)
model = model.to(device)

In [9]:
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.001)
model

CAN(
  (lift_layer): MultiHeadLiftLayer(
    (lifts): LiftLayer()
  )
  (layers): ModuleList(
    (0): CANLayer(
      (lower_att): MultiHeadCellAttention_v2(
        (att_activation): LeakyReLU(negative_slope=0.2)
        (lin_src): Linear(in_features=7, out_features=48, bias=False)
        (lin_dst): Linear(in_features=7, out_features=48, bias=False)
      )
      (upper_att): MultiHeadCellAttention_v2(
        (att_activation): LeakyReLU(negative_slope=0.2)
        (lin_src): Linear(in_features=7, out_features=48, bias=False)
        (lin_dst): Linear(in_features=7, out_features=48, bias=False)
      )
      (lin): Linear(in_features=7, out_features=48, bias=False)
      (aggregation): Aggregation()
    )
    (1): CANLayer(
      (lower_att): MultiHeadCellAttention_v2(
        (att_activation): LeakyReLU(negative_slope=0.2)
        (lin_src): Linear(in_features=48, out_features=48, bias=False)
        (lin_dst): Linear(in_features=48, out_features=48, bias=False)
      )
      (uppe

In [10]:
test_size = 0.3
x_1_train, x_1_test = train_test_split(x_1_list, test_size=test_size, shuffle=False)
x_0_train, x_0_test = train_test_split(x_0_list, test_size=test_size, shuffle=False)
lower_neighborhood_train, lower_neighborhood_test = train_test_split(
    lower_neighborhood_list, test_size=test_size, shuffle=False
)
upper_neighborhood_train, upper_neighborhood_test = train_test_split(
    upper_neighborhood_list, test_size=test_size, shuffle=False
)
adjacency_0_train, adjacency_0_test = train_test_split(
    adjacency_0_list, test_size=test_size, shuffle=False
)
y_train, y_test = train_test_split(y_list, test_size=test_size, shuffle=False)

In [11]:
test_interval = 1
num_epochs = 4
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    num_samples = 0
    correct = 0
    model.train()
    for x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood, y in zip(
        x_0_train,
        x_1_train,
        adjacency_0_train,
        lower_neighborhood_train,
        upper_neighborhood_train,
        y_train,
    ):
        x_0 = x_0.float().to(device)
        x_1, y = x_1.float().to(device), torch.tensor(y, dtype=torch.long).to(device)
        adjacency = adjacency.float().to(device)
        lower_neighborhood, upper_neighborhood = lower_neighborhood.float().to(
            device
        ), upper_neighborhood.float().to(device)
        opt.zero_grad()
        y_hat = model(x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood)
        #print('notice')
        loss = crit(y_hat, y)
        #print(loss)
        correct += (y_hat.argmax() == y).sum().item()
        num_samples += 1
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())
    train_acc = correct / num_samples
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {train_acc:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            num_samples = 0
            correct = 0
            for x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood, y in zip(
                x_0_test,
                x_1_test,
                adjacency_0_test,
                lower_neighborhood_test,
                upper_neighborhood_test,
                y_test,
            ):
                x_0 = x_0.float().to(device)
                x_1, y = x_1.float().to(device), torch.tensor(y, dtype=torch.long).to(
                    device
                )
                adjacency = adjacency.float().to(device)
                lower_neighborhood, upper_neighborhood = lower_neighborhood.float().to(
                    device
                ), upper_neighborhood.float().to(device)
                y_hat = model(
                    x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood
                )
                #print(y_hat)
                correct += (y_hat.argmax() == y).sum().item()
                num_samples += 1
            test_acc = correct / num_samples
            print(f"Test_acc: {test_acc:.4f}", flush=True)

Epoch: 1 loss: 0.7065 Train_acc: 0.5714
Test_acc: 0.5556
Epoch: 2 loss: 0.7037 Train_acc: 0.5714
Test_acc: 0.5556
Epoch: 3 loss: 0.6979 Train_acc: 0.5714
Test_acc: 0.5556
Epoch: 4 loss: 0.6853 Train_acc: 0.5714
Test_acc: 0.5556


In [12]:
from models.can.can_layer import MultiHeadCellAttention_v2

mh = MultiHeadCellAttention_v2(in_channels=3, out_channels = 32, 
                               heads = 3, concat= True,
                               att_activation= torch.nn.ReLU(), aggr_func='sum',
                               dropout=0.5)

In [13]:
i = 17
print(upper_neighborhood_list[i])
print(cc_list[i].cells)
print(cc_list[i].number_of_edges())
print(mh(x_1_list[i], upper_neighborhood_list[i]).shape)
#print(mh(x_1_list[i], lower_neighborhood_list[i]).shape)

tensor(indices=tensor([[5, 5, 5, 7, 7, 7, 9, 9, 9],
                       [5, 7, 9, 5, 7, 9, 5, 7, 9]]),
       values=tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.]),
       size=(18, 18), nnz=9, layout=torch.sparse_coo)
CellView([Cell((2, 3, 4))])
18
torch.Size([18, 96])


In [14]:
(target_index_i,source_index_j,) = upper_neighborhood_list[i].indices() 
print(target_index_i.shape)
print(source_index_j.shape)
print(mh(x_1_list[i], upper_neighborhood_list[i]).shape)
print(upper_neighborhood_list[i].shape[0])

torch.Size([9])
torch.Size([9])
torch.Size([18, 96])
18


In [15]:
class CustomTensor:
    def __init__(self, tensor, n_edges=None):
        self.tensor = tensor
        self.n_edges = n_edges

    def __getattr__(self, name):
        # Delegate attribute access to the tensor
        return getattr(self.tensor, name)

    # If you need to support specific tensor methods, you might need to explicitly define them
    def to(self, *args, **kwargs):
        # Example of delegating a method call to the tensor
        self.tensor = self.tensor.to(*args, **kwargs)
        return self

# Usage
your_tensor = torch.tensor([1, 2, 3])
n_edges = 18
custom_tensor = CustomTensor(your_tensor, n_edges=n_edges)

# Tensor operations
result = custom_tensor.tensor + torch.tensor([1, 1, 1])

# Accessing custom attribute
print("n_edges:", custom_tensor.n_edges)


n_edges: 18


In [16]:
import torch

def scatter_add_(self, index, src, num_edges=None):
    if num_edges is None:
        num_edges = src.numel()

    # Check if dimensions match and num_edges is within the valid range
    if self.dim() != src.dim() or self.dim() != index.dim():
        raise ValueError("All tensors must have the same number of dimensions")
    if num_edges > src.numel():
        raise ValueError("num_edges is larger than the number of elements in src")

    # Iterate and add values
    for n in range(num_edges):
        # Convert flat index to 3D index
        i, j, k = np.unravel_index(n, src.shape)
        self[index[i][j][k]][j][k] += src[i][j][k]

    return self

# Example usage
self_tensor = torch.zeros(5, 3, 3)  # Adjust the size as needed
index_tensor = torch.randint(0, 5, (2, 3, 3))
src_tensor = torch.randn(2, 3, 3)

scatter_add_(self_tensor, index_tensor, src_tensor).shape


torch.Size([5, 3, 3])