In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from keras.datasets import mnist
from collections import namedtuple
import networkx as nx
import random
import itertools

T = 2000 # the size of training dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
dataset=x_train[:T]
# transfer the images into binary images(0 and 1)

for i in range(T):
  a = dataset[i]
  a[a<=127] = 0
  a[a>127] = 1
  dataset[i] = a
# plt.imshow(dataset[0], cmap="gray") 
dataset.shape

(2000, 28, 28)

In [None]:
class Node:
  def __init__(self, id, value, is_factor, is_observed):
    ''' create a node
    Args:
    id: the node id (e.g. x1)
    is_factor: a boolean variable, true: is factor; false: is node
    is_observed: a boolean variable, true: is observed; false: is unobserved
    '''
    self.id = id
    self.value = value
    self.is_factor_node = is_factor
    self.is_observed = is_observed
    self.outgoing_edges = []
    self.incoming_edges = []

  def __hash__(self):
    return hash(self.id)

  def __eq__(self, other):
    return self.id == other.id

  def __str__(self):
    return self.id

In [None]:
class Edge:
  def __init__(self, from_node, to_node):
    '''
    1. create a edge with direction(from_node -> to_node)
    2. each directed edge corresponds to one message(each iteration), it means the message sent by the 
    edge origin node to the edge end node; they are stored in list messgaes[]
    
    :param from_node: the origin node of this edge
    :param to_node: the end node of this edge  
    '''
    self.from_node = from_node
    self.to_node = to_node
    self.messages = []


  def __hash__(self):
    return hash(self.from_node + self.to_node)

  def __eq__(self, other):
    return (self.from_node, self.to_node) == (other.from_node, other.to_node)

  def __str__(self):
    return str(self.from_node) + "-->" + str(self.to_node)

In [None]:
class Graph:

  def __init__(self, graph_map, value_map, hidden_list):
    '''
    create a graph according to the graph_map and observed_list
    :param graph_map: the dictionary of graph
    :param value_map: the dictionary of observed values
    :param hidden_list: a list which contains all unobserved nodes
    :return:
    '''
    nodes = []
    self.graph_map = graph_map
    self.value_map = value_map
    self.hidden_list = hidden_list
                            
    # create the graph
    self.nodes = {}
    self.edges = {}
    for node, _ in self.graph_map.items():
      n = None
      if node.startswith("x"):
        if node in hidden_list:
          n = Node(node, None, False, False)
        else:
          n = Node(node, value_map[node], False, True)
      elif node.startswith("f"):
        n = Node(node, None, True, None)
      self.nodes[n.id] = n
    for node, connections in self.graph_map.items():
      n = self.nodes[node]
      for connection in connections:
        edge = None
        if self.nodes.get(connection):
          edge = Edge(n, self.nodes[connection])
        n.outgoing_edges.append(edge)
        self.nodes[connection].incoming_edges.append(edge)
        self.edges[str(edge)] = edge

  def __str__(self):
    result = ""
    for id, node in self.nodes.items():
      result += str(node) + " (" + str(node.is_factor_node) + ") :\n"
      for edge in node.outgoing_edges:
        result += str(edge) + "\n"
      result += "\n"
    return result

In [None]:
def factor_function(x1, x2, x3, x4, a, b, c):
  if (x1+x2+x3+x4)==4 or (x1+x2+x3+x4)==0:
    return a
  elif (x1+x2+x3+x4)==3 or (x1+x2+x3+x4)==1:
    return b
  elif (x1+x2+x3+x4)==2:
    return c

In [None]:
def can_calculate(from_node, to_node, i):
  '''
  check whether the message in such path can be calculated
  '''
  for edge in from_node.incoming_edges:
    if edge.from_node.id != to_node.id and edge.messages[i] == None:
      return False
  return True

def Message_f_to_x(incoming_messages, to_node_value, a, b, c):
  '''
  1. calculate the message from a factor to a x node
  2. each factor connects 4 x nodes, three of them are summed
  '''
  sum = 0
  
  configs = list(itertools.product([0, 1], repeat= 3))
  #configs=[(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)]
  for i in range(len(configs)):
    xi = configs[i][0]
    xj = configs[i][1]
    xk = configs[i][2]
    if incoming_messages[0][xi]!=None and incoming_messages[1][xj]!=None and incoming_messages[2][xk]!=None:
      sum += factor_function(xi, xj, xk, to_node_value, a, b, c) * incoming_messages[0][xi] * incoming_messages[1][xj] * incoming_messages[2][xk]
    # print("i=",i,"sum=",sum)
  return sum


def Message_x_to_f(incoming_messages, to_node_value):
  '''
  calculate the message from a x node to a factor 
  :return: the product of all incoming messages
  '''

  prod = 1
  for i in range(len(incoming_messages)):
    prod *= incoming_messages[i][to_node_value]
  return prod

In [None]:
def calculate_message(graph, from_node, to_node, i, a, b, c):
  '''
  1. calculate the message sent by from_node to to_node
  2. if from_node is a node, calculate the product of all incoming messages except: to_node -> from_node
  3. if from_node is a factor, calculate the sum of factor_i * all incoming messages except: to_node -> from_node
  :param i: corresponds to the iteration, also the length of messages of each edge
  :return: the new graph, its message of edge(from_node->to_node) is added a new one/ updated
  '''

  if can_calculate(from_node, to_node, i):
    incoming_messages = []
          
    for edge in from_node.incoming_edges:
      if edge.from_node.id != to_node.id: # all incoming edges except to_node -> from_node
        incoming_messages.append(edge.messages[i]) # store the messages of these edges
        
    if from_node.is_factor_node: # if it's factor node
      # calculate the message(from_node->to_node)
      if to_node.is_observed:
        message = [None, None]
        message[to_node.value] = Message_f_to_x(incoming_messages, to_node.value, a, b, c)
      else:
        message = [Message_f_to_x(incoming_messages, 0, a, b, c), Message_f_to_x(incoming_messages, 1, a, b, c)]
    else: # if it's variable node
      if from_node.is_observed:
        message = [None, None]
        # print(incoming_messages, int(from_node.value.item()))
        message[from_node.value] = Message_x_to_f(incoming_messages, from_node.value)
      else:
        # print(incoming_messages)
        message = [Message_x_to_f(incoming_messages, 0), Message_x_to_f(incoming_messages, 1)]
      # message = [Message_x_to_f(incoming_messages, 0), Message_x_to_f(incoming_messages, 1)]
    # store this new message in the graph
    edge_id = from_node.id + "-->" + to_node.id
    graph.edges[edge_id].messages.append(message)
  
  else:
    edge_id = from_node.id + "-->" + to_node.id
    msg = graph.edges[edge_id].messages[i]
    graph.edges[edge_id].messages.append(msg)
  return graph

In [None]:
def calculate_belief(graph, hidden_nodes):
  """
  1. calculate the belief for all unobserved nodes
  2. then calculate which value (0 or 1) of x_i can make its belief to be maximum
  3. then get a list of the values of all nodes in hidden_nodes[]
  """
  results = []
  for id, node in graph.nodes.items():
    if id in hidden_nodes:
      belief = [1, 1]
      # belief.requires_grad = True 
      for edge in node.incoming_edges: # all incoming messages
        factor = edge.from_node
        print(factor.incoming_edges.from_node)
        message = edge.messages[-1] # the latest message
        # node = 0
        belief[0] *= message[0]
        # node = 1
        belief[1] *= message[1]
      
      '''
      if belief[0] > belief[1], then x_i should be 0,
      vice versa
      '''
      # if belief[0] >= belief[1]:
      #   results.append(0)
      # else:
      #   results.append(1)
      '''
      normalize the belief, and save the vector
      '''
  
      if sum(belief)==0:
        results.append([0.5, 0.5])
      else:
        results.append([belief[0]/sum(belief), belief[1]/sum(belief)])
  return results

In [None]:
def passing(graph, square, a, b, c):
  '''
  belief propagation/passing messages:
  1: iterate the messages passing for 20 times
  2: for each iteration:
    1) using function calculate_message:
      calculate the message sent by each node/factor to its all neighbors
      calculate the message sent by (x->f) first, then by (f->x)
    2) using function calculate_belief: 
      in this new graph(each edge with a list of messages), calculate the beliefs,
      and choose one sequence of the unobserved nodes that has max belief
  '''
  for i in range(ITERATION):
    # get all nodes in the graph
    for id, node in graph.nodes.items():
      for edge in node.outgoing_edges:
        graph = calculate_message(graph, node, edge.to_node, i, a, b, c)

    vectors = calculate_belief(graph, square)
    print("iter=",i," the vectors is:", vectors)
  return vectors

In [None]:
def remove_square(x, y, len):
  '''
  define which nodes are hidden/removed, 
  :param x and y: the location of left-top corner of this square
  :param len: the length of this square
  :return: a list of removed nodes
  '''
  list = []
  index = 0
  for i in range(28):
    for j in range(28):
      index += 1
      if x<=(i+1)<=(x+len-1) and y<=(j+1)<=(y+len-1):
        list.append('x{}'.format(index))
  return list
print("for testing the functions:")
lis = remove_square(6,17,5)
print(lis)

for testing the functions:
['x157', 'x158', 'x159', 'x160', 'x161', 'x185', 'x186', 'x187', 'x188', 'x189', 'x213', 'x214', 'x215', 'x216', 'x217', 'x241', 'x242', 'x243', 'x244', 'x245', 'x269', 'x270', 'x271', 'x272', 'x273']


In [None]:
def create_graphMap(graph_size=28):
  '''
  create a dictionary, the key is each node/factor, the values are its neighbers
  :param graph_size: the length of the graph(MNIST: 28)
  :return: the dictionary
  '''
  map = {}
  index=0
  for i in range(graph_size):
    for j in range(graph_size):
      index += 1
      map['x{}'.format(index)] = []
      if i>0:
        if j>0:
          map['x{}'.format(index)].append('f{}'.format((i-1)*(graph_size-1)+j))
        if j<27:
          map['x{}'.format(index)].append('f{}'.format((i-1)*(graph_size-1)+j+1))
      if i<27:
        if j>0:
          map['x{}'.format(index)].append('f{}'.format(i*(graph_size-1)+j))
        if j<27:
          map['x{}'.format(index)].append('f{}'.format(i*(graph_size-1)+j+1))
  # for item in map:
  #   print(item,":",map[item])
  index=0
  for i in range(graph_size-1):
    for j in range(graph_size-1):
      index += 1
      map['f{}'.format(index)] = []
      map['f{}'.format(index)].append('x{}'.format(i*graph_size+j+1))
      map['f{}'.format(index)].append('x{}'.format(i*graph_size+j+1+1))
      map['f{}'.format(index)].append('x{}'.format((i+1)*graph_size+j+1))
      map['f{}'.format(index)].append('x{}'.format((i+1)*graph_size+j+1+1))
  # for item in map:
  #   print(item,":",map[item])
  return map

def create_valueMap(image, re_list, image_size=28):
  '''
  create a dictionary:
  the key is each observed node;
  the value is the corresponding value in the training image
  '''
  index = 0
  map = {}
  for i in range(image_size):
    for j in range(image_size):
      index += 1
      node = 'x{}'.format(index)
      if node not in re_list:        
        map[node]=image[i][j]
  return map
        
print("for testing the functions:")
map1 = create_graphMap()        
print(map1) 
map2 = create_valueMap(dataset[0],lis)   
print(map2) 


for testing the functions:
{'x1': ['f1'], 'x2': ['f1', 'f2'], 'x3': ['f2', 'f3'], 'x4': ['f3', 'f4'], 'x5': ['f4', 'f5'], 'x6': ['f5', 'f6'], 'x7': ['f6', 'f7'], 'x8': ['f7', 'f8'], 'x9': ['f8', 'f9'], 'x10': ['f9', 'f10'], 'x11': ['f10', 'f11'], 'x12': ['f11', 'f12'], 'x13': ['f12', 'f13'], 'x14': ['f13', 'f14'], 'x15': ['f14', 'f15'], 'x16': ['f15', 'f16'], 'x17': ['f16', 'f17'], 'x18': ['f17', 'f18'], 'x19': ['f18', 'f19'], 'x20': ['f19', 'f20'], 'x21': ['f20', 'f21'], 'x22': ['f21', 'f22'], 'x23': ['f22', 'f23'], 'x24': ['f23', 'f24'], 'x25': ['f24', 'f25'], 'x26': ['f25', 'f26'], 'x27': ['f26', 'f27'], 'x28': ['f27'], 'x29': ['f1', 'f28'], 'x30': ['f1', 'f2', 'f28', 'f29'], 'x31': ['f2', 'f3', 'f29', 'f30'], 'x32': ['f3', 'f4', 'f30', 'f31'], 'x33': ['f4', 'f5', 'f31', 'f32'], 'x34': ['f5', 'f6', 'f32', 'f33'], 'x35': ['f6', 'f7', 'f33', 'f34'], 'x36': ['f7', 'f8', 'f34', 'f35'], 'x37': ['f8', 'f9', 'f35', 'f36'], 'x38': ['f9', 'f10', 'f36', 'f37'], 'x39': ['f10', 'f11', 'f37', 'f

In [None]:
def initialize(image, square, image_size=28):
  '''
  create a graph and initialize the messages of all edges
  :param image: the training image
  :param square: the removed square(a list)
  '''
  graph_map = create_graphMap()
  value_map = create_valueMap(image, square)
  graph = Graph(graph_map, value_map, square)
  
  for id, edge in graph.edges.items():
    from_node = edge.from_node
    to_node = edge.to_node
    mes = [None, None]
    if (from_node.is_factor_node) and to_node.is_observed: 
      value = to_node.value
      mes[value] = 1
      edge.messages.append(mes)      
    elif (not from_node.is_factor_node) and from_node.is_observed:
      value = from_node.value
      mes[value] = 1
      edge.messages.append(mes)
    else:
      mes = [1, 1]
      edge.messages.append(mes)

  return graph

In [None]:
def repair(image, square, x, y, len, image_size=28):
  '''
  repair the image with the inferred square
  :param image: the original image
  :param square: a list output by passing()
  :param square x,y,len: the information of the square
  '''
  index = 0
  for i in range(image_size):
    for j in range(image_size):
      if x<=(i+1)<=(x+len-1) and y<=(j+1)<=(y+len-1):
        image[i][j] = square[index]
        index += 1
  return image


In [None]:
def true_vectors(image, square, image_size=28):
  vectors = []
  index = 0
  for i in range(image_size):
    for j in range(image_size):
      index += 1
      node = 'x{}'.format(index)
      if node in square:
        if image[i][j] == 0:
          vectors.append([1, 0])
        else:
          vectors.append([0, 1])
  return vectors

In [None]:
def training(image, square, a, b, c):
  criterion = nn.CrossEntropyLoss()
  graph = initialize(image, square)
  true_vector = torch.tensor(true_vectors(image, square), requires_grad = True, dtype=torch.float64)
  for epoch in range(20):
    inferred_vector = passing(graph, square, a, b, c)
    loss = criterion(inferred_vector, true_vector)
    loss = Variable(loss, requires_grad = True)
    loss.backward()
    with torch.no_grad():
      print(a.grad)
      a -= a.grad * LR
      b -= b.grad * LR
      c -= c.grad * LR
      a.grad.zero_()
      b.grad.zero_()
      c.grad.zero_()
  return a
    



In [None]:
from torch.autograd import Variable
X = 5
Y = 15
LEN = 4
ITERATION = 1
a = 1
b = 0.7
c = 0.1
LR = 1e-5
image = torch.tensor(dataset[0], requires_grad=True, dtype=torch.float64)
# image = Variable(image, requires_grad=True)
square = remove_square(X, Y, LEN)
a= training(dataset[0], square, a, b, c)
print(a, b, c)

In [None]:
import math
a = torch.ones([1], requires_grad=True)*2.6

b = torch.ones([1], requires_grad=True)*3

if (a + b).item()==5.6:
  x = torch.exp(abs(a+b))
else:
  x = torch.exp(abs(a+b))+1
print(a.item())
x.backward()
print(a.grad)

2.5999999046325684
None


  return self._grad


In [None]:
import torch
a = torch.ones([1,2])*2
b = torch.ones([1,2])*3
c=torch.cat((a, b),0)
print(c)
c=c[1:]
print(c.shape)
print(c)




tensor([[2., 2.],
        [3., 3.]])
torch.Size([1, 2])
tensor([[3., 3.]])


In [None]:
a = torch.ones([1,2])*2
print(torch.sum(a))

tensor(4.)
