
### Example of backpropagating error through pattern matcher queries.

This example shows how to
maximize probability of producing correct sum of X and Y
where X and Y are digits from mnist dataset


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from module import CogModule, CogModel, get_value
from module import InputModule, set_value

try:
    from opencog.utilities import tmp_atomspace
    from opencog.scheme_wrapper import *
    from opencog.atomspace import AtomSpace, types, PtrValue
    from opencog.atomspace import create_child_atomspace
    from opencog.type_constructors import *
    from opencog.utilities import initialize_opencog, finalize_opencog
    from opencog.bindlink import execute_atom
except RuntimeWarning as e:
    pass

**train function**

In [None]:
def train(model, device, train_loader, optimizer, epoch, log_interval, scheduler):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        model.zero_grad()
        output = model.process(data, target)
        loss = - torch.log(output)
        loss.backward()
        optimizer.step()
        scheduler.step()
        if batch_idx % log_interval == 0:
            for group in optimizer.param_groups:
                lr = group['lr']
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f},\t lr: '.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()), lr)
            print(model.inh_weights)


def exponential_lr(decay_rate, global_step, decay_steps, staircase=False):
    if staircase:
        return decay_rate ** (global_step // decay_steps)
    return decay_rate ** (global_step / decay_steps)


def main():
    atomspace = AtomSpace()
    initialize_opencog(atomspace)
    device = 'cpu'
    epoch = 20
    batch_size = 2
    lr = 0.0001
    decay_rate = 0.9
    decay_steps = 10000
    train_loader = torch.utils.data.DataLoader(
       datasets.MNIST('/tmp/mnist', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
        batch_size=batch_size, shuffle=True)
    model = MnistModel(atomspace).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    l = lambda step: exponential_lr(decay_rate, step, decay_steps,staircase=True)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=l)
    for i in range(epoch):
        train(model, device, train_loader, optimizer, i + 1, 300, scheduler)
        torch.save(model.state_dict(),"mnist_cnn.pt")

### CogModule

CogModule is python class that attaches itself to an Atom.  
It has number of helper methods to generate queries in Atomese

In [None]:
class MnistNet(CogModule):
    def __init__(self, atom):
        super().__init__(atom)
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        result = F.softmax(x, dim=1)
        return result

class SumProb(CogModule):
    def forward(self, x, y):
        return x * y

class ProbOfDigit(CogModule):
    def forward(self, probs, i, p_in_range):
        return probs[0][i] * p_in_range


class TorchSum(CogModule):
    def forward(self, *args):
        result = sum(args)
        if result > 1:
            raise RuntimeError("probability exceeds 1")
        return result

### CogModel

CogModel is class for holding all learnable parameters. 
Also it provides methods to execute queries.
Such wrappers are necessary to deal with caching. 

In [None]:
class MnistModel(CogModel):
    def __init__(self, atomspace):
        super().__init__()
        self.atomspace = atomspace
        self.mnist = MnistNet(ConceptNode("mnist"))
        self.sum_prob = SumProb(ConceptNode("SumProb"))
        self.digit_prob = ProbOfDigit(ConceptNode("ProbOfDigit"))
        self.torch_sum = TorchSum(ConceptNode("TorchSum"))
        self.inh_weights = torch.nn.Parameter(torch.Tensor([0.3] * 10))
        #  create NumberNodes
        #  attach tensor representing p(number in range)
        for i in range(10):
            NumberNode(str(i)).set_value(PredicateNode("cogNet"), PtrValue(i))
            inh1 = InheritanceLink(NumberNode(str(i)), ConceptNode("range"))
            set_value(inh1, self.inh_weights[i])

    def process(self, data, label):
        """
        Accepts batch with features and labels,
        returns probability of labels
        """
        with tmp_atomspace() as atomspace:
            #  compute possible pairs of NumberNodes
            pairs = self.get_all_pairs(label, atomspace)

            # setup input images
            inp1 = InputModule(ConceptNode("img1"), data[0].reshape([1,1, 28, 28]))
            inp2 = InputModule(ConceptNode("img2"), data[1].reshape([1,1, 28, 28]))
            return self.p_correct_answer(pairs, inp1, inp2)


    def p_correct_answer(self, pairs, inp1, inp2):
        """
        compute probability of earch pair
        compute total probability - sum of pairs
        """
        lst = []
        p_digit = lambda mnist, digit, inh: self.digit_prob.execute(mnist, digit, inh)
        for pair in pairs.out:
            p_digit1 = p_digit(self.mnist.execute(inp1.execute()),
                    pair.out[0],
                    InheritanceLink(pair.out[0], ConceptNode("range")))
            p_digit2 = p_digit(self.mnist.execute(inp2.execute()),
                    pair.out[1],
                    InheritanceLink(pair.out[1], ConceptNode("range")))
            sum_expr = self.sum_prob.execute(p_digit1, p_digit2)
            lst.append(sum_expr)
        sum_query = self.torch_sum.execute(*lst)
        result = self.execute_atom(sum_query)
        return result

    def get_all_pairs(self, label, atomspace):
        """
        Calculate all suitable pairs of digits for given label
        """
        label = str(int(label.sum()))
        var_x = atomspace.add_node(types.VariableNode, "X")
        var_y = atomspace.add_node(types.VariableNode, "Y")
        vardecl = VariableList(TypedVariableLink(var_x, TypeNode("NumberNode")), TypedVariableLink(var_y, TypeNode("NumberNode")))
        eq = EqualLink(PlusLink(var_x, var_y), NumberNode(label))
        inh1 = InheritanceLink(var_x, ConceptNode("range"))
        inh2 = InheritanceLink(var_y, ConceptNode("range"))
        bindlink = BindLink(vardecl, AndLink(inh1, inh2, eq), ListLink(var_x, var_y))
        return execute_atom(atomspace, bindlink)

In [None]:
main()