In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import networkx as nx
import dgl
from dgl.nn import GATConv
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

Using backend: pytorch


In [2]:
train_num = 10000
test_num = 1000
length = 6
tot_len = 15

In [3]:
def proc_dataset(dataset):
    n = len(dataset)
    split_data = torch.zeros( (n,tot_len), dtype = torch.long )
    for ii in range(n):
        for i in range(tot_len):
            split_data[ii,i] = torch.sum(dataset[ii,i:i+length])
    return split_data

In [4]:
class VPVRDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, num, length):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        
        self.X = torch.randint(0,9,(num, tot_len) )
        self.y = []
        labels = torch.ones(length)
        for ii in range(num):
            y = torch.sum(self.X[ii,self.X[ii][0]:self.X[ii][0]+length]) 
            self.y.append(y % 10)
            
        self.z = proc_dataset(self.X).unsqueeze(dim = 1)
            
        self.G = dgl.from_networkx(nx.complete_graph(tot_len))
        
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.z[idx], self.G

In [5]:
def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    X,y,z,graphs = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return torch.cat(X), torch.tensor(y, dtype = torch.long), torch.cat(z), batched_graph

In [6]:
trainset = VPVRDataset(train_num, length)
testset = VPVRDataset(test_num, length)

In [7]:
print(trainset[1][2])

tensor([[16, 20, 26, 22, 19, 21, 28, 30, 30, 32, 27, 23, 15,  9,  3]])


In [8]:
train_loader = DataLoader(trainset, batch_size=64, shuffle=True,num_workers=1,collate_fn=collate)
test_loader = DataLoader(testset, batch_size=64, shuffle=False,num_workers=1,collate_fn=collate)

In [9]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.emb1 = nn.Embedding(10, 32)
        self.x2 = torch.autograd.Variable(torch.randn([tot_len,32]),requires_grad=True)
        
        self.lin1 = nn.Linear(64,64)
        self.gat = GATConv(64,64,1)
        self.gat2 = GATConv(64,64,1)
        
        self.sum_pool = dgl.nn.SumPooling()
        
        self.lin2 = nn.Linear(128,32)
        self.lin3 = nn.Linear(32,1)

    def forward(self, g, x):
        batch_size = x.shape[0]//tot_len
        
        x1 = self.emb1(x)
        x2 = self.x2.repeat((batch_size, 1))
        
        x = torch.cat([x1, x2], dim = 1)
        x = self.lin1(x)
        
        x1 = self.gat(g, x).squeeze()
        x2 = self.gat2(g, x).squeeze()
        x = torch.cat([x1, x2], dim = 1)
        
        x = self.lin2(x)
        x = F.relu(x)
        mask = self.lin3(x)
        mask = mask.reshape(batch_size, tot_len)
        mask = F.softmax(mask, dim = 1).unsqueeze(-1)
        
        return mask

generator = Generator()

In [14]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.emb = nn.Embedding(60, 64)
        
        self.fc1 = nn.Linear(64,128)
        self.fc2 = nn.Linear(128,10)

    def forward(self, z, mask):
        #x_emb = F.one_hot(z, num_classes=10)
        x_emb = self.emb(z)
        x = x_emb * mask
        
        x = torch.sum(x, 1)
        
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        output = F.log_softmax(x, dim = 1)
        return output
discriminator = Discriminator()

In [15]:
def train(gen, dis, train_loader, optimizer, epoch):
    gen.train()
    dis.train()
    running_loss = 0.0
    correct = 0.0
    
    n = len(train_loader.dataset)
    for batch_idx, (data, target, z, graph) in enumerate(train_loader):
        optimizer.zero_grad()
        mask = gen(graph, data)
        output = dis(z, mask)
        
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    
    print('Loss/train:', running_loss/n, 'Accuracy/train:', correct/n)

In [16]:
def test(gen, dis, test_loader, epoch):
    gen.eval()
    dis.eval()
    test_loss = 0
    correct = 0
    n = len(test_loader.dataset)
    with torch.no_grad():
        for data, target, z, graph in test_loader:
            mask = gen(graph, data)
            output = dis(z, mask)
        
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()


    print('Loss/test:', test_loss/n, 'Accuracy/test:', correct/n)

In [17]:
optimizer = optim.Adam(list(generator.parameters())+list(discriminator.parameters()), lr=1e-3)
for epoch in range(1, 30):
    train(generator, discriminator, train_loader, optimizer, epoch)
    test(generator, discriminator, test_loader, epoch)

Loss/train: 0.029911385667324066 Accuracy/train: 0.3278
Loss/test: 1.407561149597168 Accuracy/test: 0.509
Loss/train: 0.01793164694905281 Accuracy/train: 0.5747
Loss/test: 0.9192896652221679 Accuracy/test: 0.655
Loss/train: 0.013748155176639557 Accuracy/train: 0.6613
Loss/test: 0.7940436477661132 Accuracy/test: 0.672
Loss/train: 0.012653006130456924 Accuracy/train: 0.6857
Loss/test: 0.7684544048309326 Accuracy/test: 0.684
Loss/train: 0.011475351110100746 Accuracy/train: 0.7103
Loss/test: 0.7421910438537598 Accuracy/test: 0.709
Loss/train: 0.01043642218708992 Accuracy/train: 0.7367
Loss/test: 0.799012071609497 Accuracy/test: 0.679
Loss/train: 0.011169840887188911 Accuracy/train: 0.7091
Loss/test: 0.7363619060516358 Accuracy/test: 0.692
Loss/train: 0.009783775967359543 Accuracy/train: 0.7444
Loss/test: 0.6332630386352539 Accuracy/test: 0.729
Loss/train: 0.010664560198783875 Accuracy/train: 0.7391
Loss/test: 0.7300996646881104 Accuracy/test: 0.706
Loss/train: 0.01134096976518631 Accuracy/