In [16]:
import simpy
from random import random
import networkx as nx
import sys

# Protocol:
class P(object):
    TYPE = 0
    HELLO_MSG = 0
    FORWARD_MSG = 1
    SWITCH_NAME = 1
    LSP = 1
    FROM_WHOM = 2
    ACK = 3
    COST = 2

def parseLS(ls, reverse=False):
    if ls is None: raise Exception('No link states.')
    if reverse:
        return (ls[1], ls[0], {'weight': ls[2]})
    else:
        return (ls[0], ls[1], {'weight': ls[2]})
    
def makePrinter(log):
    def p(msgs):
        if log:
            content = ''
            for msg in msgs:
                content += ' ' + str(msg)
            print content
    return p
p = makePrinter(False)

# Define a Port
class Port(object):
    def __init__(self, env, delay=0):
        self.env = env
        self.store = simpy.Store(env)

    def send(self, msg):
        return self.store.put(msg)
    
    def receive(self):
        if self.size() == 0:
            return None 
        return self.store.get()
    
    def size(self):
        return len(self.store.items)

# Define a Link
class Link(object):
    def __init__(self, env, delay=0):
        self.env = env
        self.port1 = Port(env)
        self.port2 = Port(env)
        self.delay = delay
        self.hasConnnected = False
                           
    def connect(self, switch1, switch2):
        if not self.hasConnnected:
            portForSwitch1 = {'in': self.port1, 'out': self.port2, 'delay': self.delay}
            switch1.addPort(portForSwitch1)
            portForSwitch2 = {'in': self.port2, 'out': self.port1, 'delay': self.delay}
            switch2.addPort(portForSwitch2)
            self.hasConnnected = True
        else:
            raise Exception('Already connected to switches')

# Define a switch
class Switch(object):
    
    UNIT = 0.1
    
    def __init__(self, env, name, reporter): #measureF=None):
        self.name = name
        # self.routingTable = RoutingTable(measureF)
        self.ports = []
        self.reporter = reporter
        self.env = env
        self.graph = nx.Graph()
        self.graph.add_node(self.name)
        # self.buf = {}
        self.paths = None
    
    def launch(self):
        self.env.process(self.run())
    
    def addPort(self, port):
        port['#'] = len(self.ports)
        self.ports.append(port)

    def _send(self, port, msg):
        yield port['out'].send((msg, self.env.now))
                           
    def _fecth(self, port):
        return port['in'].receive()

    def __str__(self):
        return self.name
                           
    def _getNeighboringEdges(self):
        neiIter = nx.edges_iter(self.graph, [self.name])
        edges = []
        for e in neiIter:
            edge = (e[0], e[1], self.graph[e[0]][e[1]]['weight']) 
            edges.append(edge)
        return edges
    
    def run(self):
        thens = {}
        msgs  = {}
        for port in self.ports:
            msg = [P.HELLO_MSG, self.name, port['delay']]
            p((port['#'], 'from', self.name))
            yield self.env.process(self._send(port, msg))
            
            thens[str(port)] = None
            msgs[str(port)] = None
            
        while True:
            yield self.env.timeout(Switch.UNIT)
            if len(self.ports) == 0: continue
            for port in self.ports:
                if thens[str(port)] is None:
                    do_fetch = self._fecth(port)
                    if do_fetch is None: continue
                    (msgs[str(port)], thens[str(port)]) = yield do_fetch
                
                if self.env.now - thens[str(port)] < port['delay']: continue
                else: thens[str(port)] = None

                newMsg = msgs[str(port)]
                
                if newMsg[P.TYPE] is P.HELLO_MSG:
                    # ADD THE SWITCH TO TOPOLOGY
                    switch = newMsg[P.SWITCH_NAME]
                    delay  = newMsg[P.COST]
                    self.graph.add_edge(self.name, switch, weight=delay)
                    p(('Received HELLO from:', switch))
                    lsps = self._getNeighboringEdges()
                    ack_num = random()
                    
                    # forward link states: msg_type, routes, from_whom, ack_num
                    msg = [P.FORWARD_MSG, lsps, self.name, ack_num]
                    
                    for port_ in self.ports:
                        if port_['#'] != port['#']:
                            yield self.env.process(self._send(port_, msg))
                    
                    p(('link states of %s: %s' % (self.name, lsps),))
                    self.reporter[self.name] = self.env.now
                elif newMsg[P.TYPE] is P.FORWARD_MSG:
                    source = newMsg[P.FROM_WHOM]
                    
                    # ignore it.. otherwise there will be loops
                    if source == self.name: pass
                    else:
                        p((self.env.now, 'Received FORWARD msg:', newMsg))
                        lsps = newMsg[P.LSP]
                        # check whether the link states make difference
                        currentEdges = self.graph.edges(data=True)
                        
                        makeDiff = False
                        for ls in lsps:
                            if parseLS(ls) in currentEdges or parseLS(ls, reverse=True) in currentEdges: pass
                            else:
                                self.graph.add_weighted_edges_from([ls])
                                makeDiff = True
                        
                        if makeDiff:
                            # print self.env.now, self.name, 'make difference ..'
                            # if it makes difference, flood, excluding the receiving port
                            self.reporter[self.name] = self.env.now
                            for port_ in self.ports:
                                if port_['#'] != port['#']:
                                    yield self.env.process(self._send(port_, newMsg))

    def path_delay(self, path):
        pathCost = 0
        for i in xrange(len(path)-1):
            pathCost += self.graph.get_edge_data(path[i], path[i+1])['weight']
        return pathCost

    def swapTimes(self, path, st=-1, line=-1):

        def findLine(A, B):
            lset = []
            for i in xrange(len(lines.LINES)):
                stations = map(lambda s: s[0], lines.LINES[i])
                if A in stations and B in stations:
                    lset.append(i+1)
            return lset
        
        if len(path) == 2:
            cands = findLine(path[0], path[1])
            if len(cands) == 0: return (sys.maxint, -1)
            else: 
                if line in cands: return (st, line)
                else:
                    return (st+1, cands[0])

        # else:
        # for more than two stations ahead..
        cands = findLine(path[0], path[1])
        path = path[1:]

        if len(cands) == 0: 
            return (sys.maxint, -1)
        else:
            if line not in cands:
                st += 1
                res = [self.swapTimes(path+[], st=st, line=c) for c in cands]
                (st, line) = min(res, key=lambda e: e[0])
                return (st, line)
            else:
                oldLine = line
                (st, line) = min([self.swapTimes(path+[], st=st, line=c) for c in cands], \
                                 key=lambda e: e[0])
                return (st, line)
            ######################

                
                                    
    def query_route(self, dst, tp='delay'):
        if dst == self.name: return None
        
        if tp == 'delay':
            self.paths = nx.single_source_dijkstra_path(self.graph, self.name, weight='weight')
            path = self.paths[dst]
            pathCost = self.path_delay(path)
            return (path, pathCost)
        
        elif tp == 'swap':
            pathCand = None
            minSwap = sys.maxint
            for path in nx.all_simple_paths(self.graph, self.name, dst):
                # print path
                (st, line) = self.swapTimes(path)
                if minSwap > st:
                    minSwap = st
                    pathCand = path
                # if swap times are the same, then compare their delay
                elif minSwap == st:
                    if self.path_delay(path) < self.path_delay(pathCand):
                        pathCand = path
                    
            if pathCand is not None: 
                return (pathCand, minSwap)
            else: 
                print 'Unreachable.'
                return
                    
                    

class Map(object):
    def __init__(self):
        self.env = simpy.Environment()
        self.stations = {}
        self.lines = []
        self.convRecords = {}

    def load(self, statPairs):
        for i in xrange(len(statPairs)-1):
        	self.addStation(statPairs[i][0],  \
                            statPairs[i+1][0],\
                            statPairs[i][1])
        
    def addStation(self, stationA, stationB, delay):
        if (stationA, stationB) in self.lines or (stationB, stationA) in self.lines: return

        self.lines.append((stationA, stationB))
        
        if stationA not in self.stations:
            self.stations[stationA] = Switch(self.env, stationA, self.convRecords)
            self.convRecords[stationA] = 0
        stationA = self.stations[stationA]
        
        if stationB not in self.stations:
            self.stations[stationB] = Switch(self.env, stationB, self.convRecords)
            self.convRecords[stationB] = 0
        stationB = self.stations[stationB]

        l = Link(self.env, delay=delay)    
        l.connect(stationA, stationB)

    def launch(self, maxTime=1000):
        for station in self.stations.values():
            station.launch()
        self.env.run(until=maxTime)

import lines

stationMap = Map()
for line in lines.LINES:
    stationMap.load(line)
stationMap.launch(maxTime=120)

convergeTime = max(stationMap.convRecords.items(), key=lambda e: e[1])[1]

for k in stationMap.stations:
    print k, len(stationMap.stations[k].graph.edges())

print 'When to converge:', convergeTime


YouYiXiLu 100
YanChangLu 100
JiangWanZhen 100
TongHeXinChun 100
ChiFengLu 100
LinPingLu 100
XinZhuang 100
DaMuQiaoLu 100
HuaXiaDongLu 100
ChuangXinZhongLu 100
ChuanSha 100
SongHongLu 100
LouShanGuanLu 100
PuDongGuoJiJiChang 100
XinZhaLu 100
ZhangJiangGaoKe 100
BaoShanLu 100
YanAnXiLu 100
ShanXiNanLu 100
HanZhongLu 100
HongKouZuQiuChang 100
ZhenPingLu 100
LongCaoLu 100
PuDongDaDao 100
CaoYangLu 100
JinJiangLeYuan 100
FuJinLu 100
WaiHuanLu 100
ShiJiDadao 100
WenShuiLu 100
NanPuDaQiao 100
RenMinGuangChang 100
HaiLunLu 100
CaoBaoLu 100
DongBaoXingLu 100
LuJiaZui 100
GongFuXinChun 100
ShiJiGongYuan 100
SongFaLu 100
HengShanLu 100
GongKangLu 100
BeiXinJing 100
ShangHaiNanZhan 100
SongBinLu 100
XuJingDong 100
NanJingDongLu 100
LuBanLu 100
JinKeLu 100
LingKongLu 100
ZhangHuaBang 100
HaiTianSanLu 100
YangShuPuLu 100
XiZangNanLu 100
ShangHaiTiYuGuan 100
WeiNingLu 100
ShangHaiMaXiCheng 100
YuanDongDaDao 100
ShangHaiKeJiGuan 100
HongQiaoLu 100
ShuiChanLu 100
TangQiao 100
YinGaoXiLu 100
DongChangLu

In [17]:
path = stationMap.stations['LianHuaLu'].query_route('LinPingLu', tp='swap')[0]
print path
print stationMap.stations['LinPingLu'].path_delay(path)

['LianHuaLu', 'JinJiangLeYuan', 'ShangHaiNanZhan', 'CaoBaoLu', 'ShangHaiTiYuGuan', 'XuJiaHui', 'HengShanLu', 'ChangShuLu', 'ShanXiNanLu', 'HuangPiNanLu', 'RenMinGuangChang', 'XinZhaLu', 'HanZhongLu', 'ShangHaiHuoCheZhan', 'BaoShanLu', 'HaiLunLu', 'LinPingLu']
39


In [81]:
def findLine(A, B):
    lset = []
    for i in xrange(len(lines.LINES)):
        stations = map(lambda s: s[0], lines.LINES[i])
        if A in stations and B in stations:
            lset.append(i+1)
    return lset

def swapTimes(path, st=-1, line=-1):
    
    if len(path) == 2:
        cands = findLine(path[0], path[1])
        if len(cands) == 0: return (sys.maxint, -1)
        else: 
            if line in cands: return (st, line)
            else:
                return (st+1, cands[0])
    
    # else:
    # for more than two stations ahead..
    cands = findLine(path[0], path[1])
    path = path[1:]
    
    if len(cands) == 0: 
        return (sys.maxint, -1)
    else:
        if line not in cands:
            st += 1
            res = [swapTimes(path+[], st=st, line=c) for c in cands]
            (st, line) = min(res, key=lambda e: e[0])
            return (st, line)
        else:
            oldLine = line
            (st, line) = min([swapTimes(path+[], st=st, line=c) for c in cands], \
                             key=lambda e: e[0])
            return (st, line)
    
print swapTimes(path[0])

(1, 2)
