In [7]:
import torch

def custom_scatter(index_tensor, value_tensor, n):
    result = torch.zeros(n)
    result.scatter_add_(0, index_tensor, value_tensor)
    return result

index_tensor = torch.tensor([1, 3, 4, 3])
value_tensor = torch.tensor([0.5, 1.0, 0.2, 1.0], requires_grad = True)
n = 5

result = custom_scatter(index_tensor, value_tensor, n).sum()
print(result)

tensor(2.7000, grad_fn=<SumBackward0>)


tensor([1., 1., 1., 1.])

# Petries

In [90]:
class Ark():
    def __init__(s, size):
        s.size = size
    
    def __str__(s):
        return "Ark: " + str(s.size)
    
    def __repr__(s):
        return str(s)

class PetriNet():
    def __init__(s):
        s.places = set()
        s.transitions = set()
        s.p_inputs = {} #dict of set
        s.p_outputs = {} #dict of set
        s.t_inputs = {} #dict of set
        s.t_outputs = {} #dict of set
        s.arks = []
    
    def inputs(s, name):
        if name in s.places:
            return s.p_inputs[name]
        elif name in s.transitions:
            return s.t_inputs[name]
        else:
            raise ValueError()
        
    def outputs(s, name):
        if name in s.places:
            return s.p_outputs[name]
        elif name in s.transitions:
            return s.t_outputs[name]
        else:
            raise ValueError()
    
    def add_place(s, name):
        if name in s.places:
            raise ValueError("Making new place. Name " + name + " already in places")
        elif name in s.transitions:
            raise ValueError("Making new place. Name " + name + " already in transitions")
        else:
            s.places.add(name)
            s.p_inputs[name] = set()
            s.p_outputs[name] = set()
    
    def add_transition(s, name):
        if name in s.places:
            raise ValueError("Making new transition. Name " + name + " already in places")
        elif name in s.transitions:
            raise ValueError("Making new transition. Name " + name + " already in transitions")
        else:
            s.transitions.add(name)
            s.t_inputs[name] = set()
            s.t_outputs[name] = set()
    
    def node_type(s, name):
        if name in s.places:
            return 'place'
        elif name in s.transitions:
            return 'trans'
        else:
            return 'none'
    
    def add_ark(s, name1, name2, size = 1):
        type1 = s.node_type(name1)
        type2 = s.node_type(name2)
        
        if type1 == 'place' and type2 == 'trans':
            s.arks.append(Ark(size))
            idx = len(s.arks) - 1
            s.p_outputs[name1].add(idx)
            s.t_inputs[name2].add(idx)
        elif type1 == 'trans' and type2 == 'place':
            s.arks.append(Ark(size))
            idx = len(s.arks) - 1
            s.t_outputs[name1].add(idx)
            s.p_inputs[name2].add(idx)
        else:
            raise ValueError("Trying to make ark with invalid types " + name1 + " " + type1 + " -> " + name2 + " " + type2)
    
    def seek_ark(s, idx):
        inname = 'none'
        outname = 'none'
        for p in s.places:
            if idx in s.p_outputs[p]:
                inname = p
                break
            if idx in s.p_inputs[p]:
                outname = p
                break
        for t in s.transitions:
            if idx in s.t_outputs[t]:
                inname = t
                break
            if idx in s.t_inputs[t]:
                outname = t
                break
        
        return inname + ' -> ' + outname
    
    def dumb_print(s):
        res = ''
        res+= 'Places:\n'
        for p in s.places:
            res+= p + '\n'
        res+= '\n'
        res+= 'Transitions:\n'
        for t in s.transitions:
            res+= t + '\n'
        res+= '\n'
        res+= 'Arks: \n'
        for idx, ark in enumerate(s.arks):
            res+= str(ark) + " " + s.seek_ark(idx) + '\n'
        return res
    
    def __str__(s):
        return s.dumb_print()
    
    def __repr__(s):
        return str(s)

In [97]:
net = PetriNet()

In [98]:
net.add_place('Copper Ore')
net.add_place('Copper Plate')

In [99]:
net.add_transition('Copper Furnace')

In [100]:
net.add_ark('Copper Ore', 'Copper Furnace')
net.add_ark('Copper Furnace', 'Copper Plate')

In [101]:
net

Places:
Copper Plate
Copper Ore

Transitions:
Copper Furnace

Arks: 
Ark: 1 Copper Ore -> Copper Furnace
Ark: 1 Copper Furnace -> Copper Plate