In [7]:
import torch
from time import time
from dgl.nn.pytorch import GraphConv
from torch_geometric.data import Data
from torch_geometric.nn import radius_graph
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import sklearn.metrics as metrics
from torch_geometric.nn import MessagePassing

In [8]:
BATCH_SIZE = 32

## transformations
transform = transforms.Compose([transforms.ToTensor()])

## download and load training dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

## download and load testing dataset
testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2,drop_last= True)

In [9]:
class MPNN(MessagePassing):
    def __init__(self):
        super(MPNN,self).__init__(aggr = 'mean')
        self.l1 = nn.Linear(1,5)
    
    def forward(self, x , edge_index):
        return self.propagate(edge_index, x = x)
    
    def message(self, x_j):
        return self.l1(x_j)

In [10]:
class Model_MPNN(nn.Module):
    def __init__(self):
        super(Model_MPNN, self).__init__()
        self.batch_size = 32
        self.num_nodes = 26 * 26 * 32
        self.x_dis = torch.arange(self.batch_size*self.num_nodes).view(self.batch_size*self.num_nodes,-1).float()
        self.batch = torch.arange(self.batch_size).view(-1,1).repeat(1,self.num_nodes).view(-1)
        self.edge_index = radius_graph(self.x_dis, 2 , self.batch, loop = True).to("cuda")
        # 28x28x1 => 26x26x32
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)
        
        self.mp = MPNN()
        self.d2 = nn.Linear(26 * 26 * 32 * 5, 10)

    def forward(self, x):
        # 32x1x28x28 => 32x32x26x26
        x = self.conv1(x)
        x = F.relu(x)
         
        # flatten => 32 x (32*26*26)
#         x = x.flatten(start_dim = 1)
        x = x.view(-1, 1)
        x = self.mp(x, self.edge_index)
        x = x.view(32,-1)
        
    

        # logits => 32x10
        logits = self.d2(x)
        
        return logits

In [11]:
learning_rate = 0.001
num_epochs = 5

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Model_MPNN()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [12]:
start = time()
for epoch in range(num_epochs):
    train_running_loss = 0.0
    train_acc = 0.0

    ## training step
    for i, (images, labels) in enumerate(trainloader):
        
        images = images.to(device)
        labels = labels.to(device)

        ## forward + backprop + loss
        logits = model(images)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()

        ## update model params
        optimizer.step()

        train_running_loss += loss.detach().item()
        train_acc += (torch.argmax(logits, 1).flatten() == labels).type(torch.float).mean().item()
    
    print('Epoch: %d | Loss: %.4f | Train Accuracy: %.3f' \
          %(epoch, train_running_loss / i, train_acc/i))
print("Total time", time() - start)

Epoch: 0 | Loss: 3.0400 | Train Accuracy: 0.836
Epoch: 1 | Loss: 0.3099 | Train Accuracy: 0.959
Epoch: 2 | Loss: 0.0991 | Train Accuracy: 0.976
Epoch: 3 | Loss: 0.0640 | Train Accuracy: 0.981
Epoch: 4 | Loss: 0.0537 | Train Accuracy: 0.984
Total time 90.54811882972717


In [13]:
import numpy as np
model_params = filter(lambda p:p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_params])

In [14]:
print(params)

1081940


In [15]:
test_acc = 0.0
for i, (images, labels) in enumerate(testloader, 0):
    images = images.to(device)
    labels = labels.to(device)
    outputs = model(images)
    test_acc += (torch.argmax(outputs, 1).flatten() == labels).type(torch.float).mean().item()
    preds = torch.argmax(outputs, 1).flatten().cpu().numpy()
        
print('Test Accuracy: %.3f'%(test_acc/i))

Test Accuracy: 0.986


In [16]:
class Model_vanilla(nn.Module):
    def __init__(self):
        super(Model_vanilla, self).__init__()

        # 28x28x1 => 26x26x32
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)
        self.d1 = nn.Linear(26 * 26 * 32, 128)
        self.d2 = nn.Linear(128, 10)

    def forward(self, x):
        # 32x1x28x28 => 32x32x26x26
        x = self.conv1(x)
        x = F.relu(x)

        # flatten => 32 x (32*26*26)
        x = x.flatten(start_dim = 1)
        #x = x.view(32, -1)

        # 32 x (32*26*26) => 32x128
        x = self.d1(x)
        x = F.relu(x)

        # logits => 32x10
        logits = self.d2(x)
    
        return logits

In [17]:
learning_rate = 0.001
num_epochs = 5

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = Model_vanilla()
net = net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

In [18]:
from time import time
for epoch in range(num_epochs):
    train_running_loss = 0.0
    train_acc = 0.0
    start = time()
    ## training step
    for i, (images, labels) in enumerate(trainloader):
        
        images = images.to(device)
        labels = labels.to(device)

        ## forward + backprop + loss
        logits = net(images)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()

        ## update model params
        optimizer.step()

        train_running_loss += loss.detach().item()
        train_acc += (torch.argmax(logits, 1).flatten() == labels).type(torch.float).mean().item()
    
    print('Epoch: %d | Loss: %.4f | Train Accuracy: %.3f' \
          %(epoch, train_running_loss / i, train_acc/i))
print("Total time", time()-start)

Epoch: 0 | Loss: 0.2144 | Train Accuracy: 0.937
Epoch: 1 | Loss: 0.0655 | Train Accuracy: 0.980
Epoch: 2 | Loss: 0.0409 | Train Accuracy: 0.988
Epoch: 3 | Loss: 0.0287 | Train Accuracy: 0.992
Epoch: 4 | Loss: 0.0185 | Train Accuracy: 0.995
Total time 7.177600622177124


In [19]:
import numpy as np
net_params = filter(lambda p:p.requires_grad, net.parameters())
params = sum([np.prod(p.size()) for p in net_params])
print(params)

2770634


In [20]:
test_acc = 0.0
for i, (images, labels) in enumerate(testloader, 0):
    images = images.to(device)
    labels = labels.to(device)
    outputs = net(images)
    test_acc += (torch.argmax(outputs, 1).flatten() == labels).type(torch.float).mean().item()
    preds = torch.argmax(outputs, 1).flatten().cpu().numpy()
        
print('Test Accuracy: %.3f'%(test_acc/i))

Test Accuracy: 0.987
