In [None]:
# python /home/uz1/projects/ImageNet-Datasets-Downloader/downloader.py \
#     -data_root /home/uz1/data \
#     -number_of_classes 10 \
#     -images_per_class 200

In [1]:
import os.path as osp
from math import ceil

import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset, GNNBenchmarkDataset
import torch_geometric.transforms as T
from torch_geometric.data import DenseDataLoader
from torch_geometric.nn import DenseGCNConv as GCNConv, dense_diff_pool,DynamicEdgeConv

max_nodes = 60


In [2]:
64//4

16

In [3]:
import numpy as np
import PIL as pl
from torch_geometric.data import Data

import torch_geometric
from torch_geometric.transforms import BaseTransform
from skimage import future
from torch_scatter import scatter_min


class ImgPixelsToGraph(BaseTransform):
    def __init__(self) -> None:
        super().__init__()

    def __call__(self, img) -> Data:
        # print(img.shape)
        c, h, w = img.shape
        img = img.reshape(h * w, -1)
        edge_index, pos = grid(h//2, w//2)
        adj = to_dense_adj(edge_index).squeeze()
        return Data(img, adj=adj, edge_index=edge_index, pos=pos)


class ImgToGraph(BaseTransform):
    r"""Converts an image to a superpixel representation using the
    :meth:`skimage.segmentation.slic` algorithm, resulting in a
    :obj:`torch_geometric.data.Data` object holding the centroids of
    superpixels in :obj:`pos` and their mean color in :obj:`x`
    (functional name: :obj:`to_slic`).

    This transform can be used with any :obj:`torchvision` dataset.

    Example::

        from torchvision.datasets import MNIST
        import torchvision.transforms as T
        from torch_geometric.transforms import ToSLIC

        transform = T.Compose([T.ToTensor(), ToSLIC(n_segments=75)])
        dataset = MNIST('/tmp/MNIST', download=True, transform=transform)

    Args:
        add_seg (bool, optional): If set to `True`, will add the segmentation
            result to the data object. (default: :obj:`False`)
        add_img (bool, optional): If set to `True`, will add the input image
            to the data object. (default: :obj:`False`)
        **kwargs (optional): Arguments to adjust the output of the SLIC
            algorithm. See the `SLIC documentation
            <https://scikit-image.org/docs/dev/api/skimage.segmentation.html
            #skimage.segmentation.slic>`_ for an overview.
    """
    def __init__(self, add_seg=False, add_img=False, **kwargs):
        self.add_seg = add_seg
        self.add_img = add_img
        self.kwargs = kwargs

    def __call__(self, img, mask, n_seg=250):
        segments_slic = segmentation.slic(img,
                                          n_segments=n_seg,
                                          compactness=10,
                                          sigma=1,
                                          start_label=0)

        seg = torch.from_numpy(segments_slic)
        rag = rag_mean_band(img[:, :, :],
                            segments_slic,
                            connectivity=2,
                            mode='similarity',
                            sigma=255.0,
                            ch=img.shape[2])

        img = torch.from_numpy(img)

        mask[mask != 0] = 1
        mask = torch.from_numpy(mask)[:, :, :1]
        h, w, c = img.shape
        # pinta ll shapes
        #   print(seg.shape,img.shape,mask.shape)
        x = scatter_mean(img.view(h * w, c), seg.view(h * w), dim=0)

        pos_y = torch.arange(h, dtype=torch.float)
        pos_y = pos_y.view(-1, 1).repeat(1, w).view(h * w)
        pos_x = torch.arange(w, dtype=torch.float)
        pos_x = pos_x.view(1, -1).repeat(h, 1).view(h * w)

        pos = torch.stack([pos_x, pos_y], dim=-1)
        pos = scatter_mean(pos, seg.view(h * w), dim=0)

        edge_index = np.asarray([[n1, n2] for (n1, n2) in rag.edges
                                 ]).reshape(2, -1)  #connectivity coodinates
        weights = np.asarray([w[2]['weight'] for w in rag.edges.data()])
        x = np.asarray([n[1]['mean color'] for n in rag.nodes.items()])
        node_class = scatter_min(mask.view(h * w), seg.reshape(h * w),
                                 dim=0)[0]

        #   lc = future.graph.show_rag(seg, rag, img[:,:,:3])

        #   pos= np.asarray([n[1]['centroid'] for n in rag.nodes.items()])

        data = Data(x=torch.from_numpy(x),
                    pos=pos,
                    edge_index=torch.tensor(edge_index),
                    edge_weight=torch.tensor(weights).unsqueeze(1),
                    y=node_class[:])

        return data


In [4]:
from typing import Optional
from sklearn.utils import shuffle
from torchvision import datasets
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torch_geometric.utils import grid, to_dense_adj
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import SubsetRandomSampler
class ImageFolderGraph(datasets.ImageFolder):
    def __init__(self,
                 root: str,
                 transform:  None,
                 ):
        super().__init__(root, transform,)

        # self.samples = shuffle(imagenet.samples)

    def __getitem__(self, index: int) ->  Data:
        data, y = super().__getitem__(index)
        data.y = y
        return data


transfrom = T.Compose([T.Resize((64, 64)), T.ToTensor(), ImgPixelsToGraph()])
imagenet = ImageFolderGraph("/home/uz1/data/imagenet_images",
                                transform=transfrom)
n = (len(imagenet) ) // 10


sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)
y= [y[1] for y in imagenet.samples]
for train_index, test_index in sss.split(imagenet,y):

# test_loader = DenseDataLoader(test_dataset, batch_size=8)
# val_loader = DenseDataLoader(val_dataset, batch_size=8)
# train_loader = DenseDataLoader(train_dataset, batch_size=8)

    val_dataset = SubsetRandomSampler(test_index)
# test_dataset = SubsetRandomSampler(test_dataset)
    train_dataset = SubsetRandomSampler(train_index)
train_loader = DenseDataLoader(imagenet, batch_size=16,sampler=train_dataset)
# test_loader= DenseDataLoader(imagenet, batch_size=16,sampler=test_dataset)
val_loader = DenseDataLoader(imagenet, batch_size=16,sampler=val_dataset)




In [5]:
len(val_loader),len(train_loader),len(train_dataset.indices)

(26, 100, 1600)

In [6]:
imagenet.num_features = imagenet[0].x.shape[1]
imagenet.num_nodes = imagenet[0].x.shape[0]
imagenet.num_classes = 10

In [7]:
imagenet.num_features=128

In [8]:
imagenet.num_nodes

4096

In [9]:
dataset = imagenet


class GNN(torch.nn.Module):
    def __init__(self,
                 in_channels,
                 hidden_channels,
                 out_channels,
                 normalize=False,
                 lin=True):
        super(GNN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()

        self.convs.append(GCNConv(in_channels, hidden_channels, normalize))
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))

        self.convs.append(GCNConv(hidden_channels, hidden_channels, normalize))
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))

        self.convs.append(GCNConv(hidden_channels, out_channels, normalize))
        self.bns.append(torch.nn.BatchNorm1d(out_channels))

    def forward(self, x, adj, mask=None):
        batch_size, num_nodes, in_channels = x.size()

        for step in range(len(self.convs)):
            x = F.relu(self.convs[step](x, adj, mask))
            # print("in frwd ", x.shape)
            x = self.bns[step](x.permute(0, 2, 1))
            # print("after bn",x.shape)
            x = x.permute(0, 2, 1)

        return x

class DGNN(torch.nn.Module):
    def __init__(self,
                 in_channels,
                 hidden_channels,
                 out_channels,
                 normalize='batch',
                 lin=True):
        super(DGNN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()

        self.convs.append(EdgeConv2d(in_channels, hidden_channels, norm=normalize))

        self.convs.append(EdgeConv2d(hidden_channels, hidden_channels,norm=normalize))

        self.convs.append(EdgeConv2d(hidden_channels, out_channels, norm=normalize))

    def forward(self, x, adj, mask=None):
        # batch_size, num_nodes, in_channels = x.size()

        for step in range(len(self.convs)):
            x = F.relu(self.convs[step](x, adj, mask))
            # print("in frwd ", x.shape)
            # x = x.permute(0, 2, 1)

        return x

class DiffPool(torch.nn.Module):
    def __init__(self):
        super(DiffPool, self).__init__()

        num_nodes = ceil(0.25 * max_nodes)
        self.gnn1_pool = GNN(dataset.num_features, 64, num_nodes)
        self.gnn1_embed = GNN(dataset.num_features, 64, 64)

        num_nodes = ceil(0.25 * num_nodes)
        self.gnn2_pool = GNN(64, 64, num_nodes)
        self.gnn2_embed = GNN(64, 64, 64, lin=False)

        self.gnn3_embed = GNN(64, 64, 64, lin=False)

        self.lin1 = torch.nn.Linear(64, 64)
        self.lin2 = torch.nn.Linear(64, dataset.num_classes)

        self.stem = Stem(64,3,128)

    def forward(self, x, adj, mask=None,return_clusters=False):
        x_temp = x
        #add stem downsampling
        b,n,f = x.shape 
        x = self.stem(x.view(b,3,int(n**.5),-1))
        x_stem=x
        x=x.reshape(b,-1,128)
        # print(x.shape,adj.shape)
        s1 = self.gnn1_pool(x, adj, mask)
        x = self.gnn1_embed(x, adj, mask)

        x, adj, l1, e1 = dense_diff_pool(x, adj, s1, mask)
        #x_1 = s_0.t() @ z_0
        #adj_1 = s_0.t() @ adj_0 @ s_0

        s2 = self.gnn2_pool(x, adj)
        x = self.gnn2_embed(x, adj)

        x, adj, l2, e2 = dense_diff_pool(x, adj, s2)

        x = self.gnn3_embed(x, adj)

        x = x.mean(dim=1)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        if return_clusters : return x_stem, x_temp, s1,s2
        return F.log_softmax(x, dim=-1), l1 + l2, e1 + e2



class DynDiffPool(torch.nn.Module):
    def __init__(self):
        super(DynDiffPool, self).__init__()

        num_nodes = ceil(0.25 * max_nodes)
        self.gnn1_pool = DGNN(dataset.num_features, 64, num_nodes)
        self.gnn1_embed = DGNN(dataset.num_features, 64, 64)

        num_nodes = ceil(0.25 * num_nodes)
        self.gnn2_pool = DGNN(64, 64, num_nodes)
        self.gnn2_embed = DGNN(64, 64, 64, lin=False)

        self.gnn3_embed = DGNN(64, 64, 64, lin=False)

        self.lin1 = torch.nn.Linear(64, 64)
        self.lin2 = torch.nn.Linear(64, dataset.num_classes)

        self.stem = Stem(64,3,128)

    def forward(self, x, adj, mask=None,return_clusters=False):
        x_temp = x
        #add stem downsampling
        b,n,f = x.shape 
        x = self.stem(x.view(b,3,int(n**.5),-1))
        x_stem=x
        x=x.reshape(b,128,-1,1)
        ad = dense_knn_matrix(x,16)
        # print(x.shape,adj.shape)
        s1 = self.gnn1_pool(x, ad, mask)
        x = self.gnn1_embed(x, ad, mask)

        x, adj, l1, e1 = dense_diff_pool(x.reshape(b,-1,64), adj, s1.reshape(b,n,-1), mask)
        #x_1 = s_0.t() @ z_0
        #adj_1 = s_0.t() @ adj_0 @ s_0

        x=x.reshape(b,128,-1,1)
        ad = dense_knn_matrix(x,16)

        s2 = self.gnn2_pool(x, ad)
        x = self.gnn2_embed(x, ad)


        x, adj, l2, e2 = dense_diff_pool(x.reshape(b,-1,64), adj, s2.reshape(b,n,-1), mask)

        x = self.gnn3_embed(x, adj)

        x = x.mean(dim=1)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        if return_clusters : return x_stem, x_temp, s1,s2
        return F.log_softmax(x, dim=-1), l1 + l2, e1 + e2


In [10]:
class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
import torch.nn as nn
def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
    # activation layer

    act = act.lower()
    if act == 'relu':
        layer = nn.ReLU(inplace)
    elif act == 'leakyrelu':
        layer = nn.LeakyReLU(neg_slope, inplace)
    elif act == 'prelu':
        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
    elif act == 'gelu':
        layer = nn.GELU()
    elif act == 'hswish':
        layer = nn.Hardswish(inplace)
    else:
        raise NotImplementedError('activation layer [%s] is not found' % act)
    return layer

class Stem(nn.Module):
    """ Image to Visual Embedding
    Overlap: https://arxiv.org/pdf/2106.13797.pdf
    """
    def __init__(self, img_size=224, in_dim=3, out_dim=768, act='relu'):
        super().__init__()        
        self.convs = nn.Sequential(
            nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_dim),
            act_layer(act),
            nn.Conv2d(out_dim, out_dim, 3, stride=1, padding=1),
            nn.BatchNorm2d(out_dim),
        )

    def forward(self, x):
        x = self.convs(x)
        # print(x.shape)
        return x

In [15]:
from tqdm import tqdm
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = DiffPool().to(device)
# model = torch.load("/home/uz1/projects/GCN/checkpoint")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

losses = AverageMeter()
val_accc = AverageMeter()
test_accc = AverageMeter()
'''
- stem downsamples from 64x64 to 16x16 - adj matches that
- num features per node is 128 - output of stem 
'''


def train(epoch):
    model.train()
    loss_all = 0

    for i,data in enumerate(tqdm(train_loader,total=len(train_loader))):
        data = data.to(device)
        optimizer.zero_grad()
        # print(data)
        output, _, _ = model(data.x, data.adj)
        loss = F.nll_loss(output, data.y.view(-1))
        loss.backward()
        losses.update(loss)
        loss_all += data.y.size(0) * loss.item()
        optimizer.step()
    return loss_all / len(train_loader.dataset)


@torch.no_grad()
def test(loader):
    model.eval()
    correct = 0

    for data in tqdm(loader,total=len(loader)):
        data = data.to(device)
        pred = model(data.x, data.adj)[0].max(dim=1)[1]
        correct += pred.eq(data.y.view(-1)).sum().item()
        # print(correct,len(loader.sampler.indices))
    return correct / len(loader.sampler.indices)


best_val_acc = test_acc = 0
for epoch in range(1, 151):
    train_loss = train(epoch)
    val_acc = test(val_loader)
    val_accc.update(val_acc)
    if val_acc > best_val_acc:
        best_val_acc = val_acc
    print(
        f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f} ({losses.avg}), '
        f'Val Acc: {val_acc:.4f} ({val_accc.avg})'
    )


  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
torch.save(model,"/home/uz1/projects/GCN/checkpoint")

In [None]:
model = torch.load("/home/uz1/projects/GCN/checkpoint")


In [None]:
model = DiffPool().to(device)
@torch.no_grad()
def test_get_clusters(loader):
    model.eval()
    correct = 0
    s1s=[]
    s2s=[]
    xs=[]
    preds=[]
    for data in tqdm(loader,total=len(loader)):
        data = data.to(device)
        pred,x,s1,s2 = model(data.x, data.adj,return_clusters=True)
        s1s.append(s1)
        s2s.append(s2)
        xs.append(x)
        preds.append(pred)
    return s1s,s2s,xs,preds
s1s,s2s,xs,preds = test_get_clusters(train_loader)

In [None]:
preds[1][1].mean(0).shape

In [None]:
xs[0][0].shape

In [None]:
256**.5

In [None]:
import PIL as pl 
def show_img(x,pil=False):
    if len(x.shape) > 2:
        c,n,w = x.shape
        x = x.mean(0)
        c=1
    else:
        n,c=x.shape
        
        n=int(n**.5)
    if pil:
        img = x.reshape(c,n,-1)
        to_img = T.ToPILImage()
        img = to_img(img)
        return img
    img = x.reshape(n,-1,c)
    return img.cpu().detach().numpy()

In [None]:
s2s[0][0].argmax(1).shape

In [None]:
from skimage import graph, data, io, segmentation, color
from matplotlib import pyplot as plt
from skimage.measure import regionprops


def get_img_labels(s1s,s2s,xs,preds):
    for i,(s1,s2,x,pred) in enumerate(zip(s1s,s2s,xs,preds)):
        # print(x)
        img = show_img(x,True)
        
        pred = show_img(pred) 

        plt.subplot(1,3,1,)
        plt.imshow(pred)
    

        plt.subplot(1,3,2,)
        plt.imshow(img)
        label_shape = int(s1.shape[0] ** .5) 
        labels = s1.argmax(1).reshape(label_shape,label_shape).cpu().detach().numpy()

        label_rgb = color.label2rgb(labels, np.asarray(img), kind='avg')
        regions = regionprops(labels)
        label_img = pl.Image.fromarray(label_rgb)
        plt.subplot(1,3,3)
        plt.imshow(label_img)

        plt.figure(figsize = (50,50)) 
        # print(s1.shape,s2.shape,x.shape)
get_img_labels(s1s[15],s2s[15],xs[15],preds[15])

In [None]:
s1s[0][4].argmax(1)

In [None]:
preds[1].mean(1).shape

## Adding K-means to the Mix

In [11]:
'''
Adding K means to the mix 
-Find the patchs place in the space 
-Cluster based on embed 
craete graph 
vis? - when before after 
classify ? 
'''

'\nAdding K means to the mix \n-Find the patchs place in the space \n-Cluster based on embed \ncraete graph \nvis? - when before after \nclassify ? \n'

In [12]:
from sklearn.cluster import KMeans
import math
# x_em = x.reshape(16,-1,128)
# KMeans(n_clusters=7).fit(x_em[0].cpu().detach().numpy()).labels_.shape
'''
A
'''

def pairwise_distance(x):
    """
    Compute pairwise distance of a point cloud.
    Args:
        x: tensor (batch_size, num_points, num_dims)
    Returns:
        pairwise distance: (batch_size, num_points, num_points)
    """
    with torch.no_grad():
        x_inner = -2*torch.matmul(x, x.transpose(2, 1))
        x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
        return x_square + x_inner + x_square.transpose(2, 1)


def part_pairwise_distance(x, start_idx=0, end_idx=1):
    """
    Compute pairwise distance of a point cloud.
    Args:
        x: tensor (batch_size, num_points, num_dims)
    Returns:
        pairwise distance: (batch_size, num_points, num_points)
    """
    with torch.no_grad():
        x_part = x[:, start_idx:end_idx]
        x_square_part = torch.sum(torch.mul(x_part, x_part), dim=-1, keepdim=True)
        x_inner = -2*torch.matmul(x_part, x.transpose(2, 1))
        x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
        return x_square_part + x_inner + x_square.transpose(2, 1)
def dense_knn_matrix(x, k=16, relative_pos=None):
    """Get KNN based on the pairwise distance.
    Args:
        x: (batch_size, num_dims, num_points, 1)
        k: int
    Returns:
        nearest neighbors: (batch_size, num_points, k) (batch_size, num_points, k)
    """
    with torch.no_grad():
        x = x.transpose(2, 1).squeeze(-1)
        batch_size, n_points, n_dims = x.shape
        ### memory efficient implementation ###
        n_part = 10000
        if n_points > n_part:
            nn_idx_list = []
            groups = math.ceil(n_points / n_part)
            for i in range(groups):
                start_idx = n_part * i
                end_idx = min(n_points, n_part * (i + 1))
                dist = part_pairwise_distance(x.detach(), start_idx, end_idx)
                if relative_pos is not None:
                    dist += relative_pos[:, start_idx:end_idx]
                _, nn_idx_part = torch.topk(-dist, k=k)
                nn_idx_list += [nn_idx_part]
            nn_idx = torch.cat(nn_idx_list, dim=1)
        else:
            dist = pairwise_distance(x.detach())
            if relative_pos is not None:
                dist += relative_pos
            _, nn_idx = torch.topk(-dist, k=k) # b, n, k
        ######
        center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1)
    return torch.stack((nn_idx, center_idx), dim=0)

In [13]:
from torch import nn
from torch.nn import Sequential as Seq, Linear as Lin, Conv2d
def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
    # activation layer

    act = act.lower()
    if act == 'relu':
        layer = nn.ReLU(inplace)
    elif act == 'leakyrelu':
        layer = nn.LeakyReLU(neg_slope, inplace)
    elif act == 'prelu':
        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
    elif act == 'gelu':
        layer = nn.GELU()
    elif act == 'hswish':
        layer = nn.Hardswish(inplace)
    else:
        raise NotImplementedError('activation layer [%s] is not found' % act)
    return layer


def norm_layer(norm, nc):
    # normalization layer 2d
    norm = norm.lower()
    if norm == 'batch':
        layer = nn.BatchNorm2d(nc, affine=True)
    elif norm == 'instance':
        layer = nn.InstanceNorm2d(nc, affine=False)
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm)
    return layer
class BasicConv(Seq):
    def __init__(self, channels, act='relu', norm=None, bias=True, drop=0.):
        m = []
        for i in range(1, len(channels)):
            m.append(Conv2d(channels[i - 1], channels[i], 1, bias=bias, groups=4))
            if norm is not None and norm.lower() != 'none':
                m.append(norm_layer(norm, channels[-1]))
            if act is not None and act.lower() != 'none':
                m.append(act_layer(act))
            if drop > 0:
                m.append(nn.Dropout2d(drop))

        super(BasicConv, self).__init__(*m)

        self.reset_parameters()

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
def batched_index_select(x, idx):
    r"""fetches neighbors features from a given neighbor idx

    Args:
        x (Tensor): input feature Tensor
                :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`.
        idx (Tensor): edge_idx
                :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`.
    Returns:
        Tensor: output neighbors features
            :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`.
    """
    batch_size, num_dims, num_vertices_reduced = x.shape[:3]
    _, num_vertices, k = idx.shape
    idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced
    idx = idx + idx_base
    idx = idx.contiguous().view(-1)

    x = x.transpose(2, 1)
    feature = x.contiguous().view(batch_size * num_vertices_reduced, -1)[idx, :]
    feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous()
    return feature

class EdgeConv2d(nn.Module):
    """
    Edge convolution layer (with activation, batch normalization) for dense data type
    """
    def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
        super(EdgeConv2d, self).__init__()
        self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias)

    def forward(self, x, edge_index, y=None):
        x_i = batched_index_select(x, edge_index[1])
        if y is not None:
            x_j = batched_index_select(y, edge_index[0])
        else:
            x_j = batched_index_select(x, edge_index[0])
        max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True)
        return max_value

In [14]:
DGNN(128,64,16)

DGNN(
  (convs): ModuleList(
    (0): EdgeConv2d(
      (nn): BasicConv(
        (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), groups=4)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
    (1): EdgeConv2d(
      (nn): BasicConv(
        (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), groups=4)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
    (2): EdgeConv2d(
      (nn): BasicConv(
        (0): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1), groups=4)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
  )
  (bns): ModuleList()
)