In [1]:
from IPython.display import HTML, display

In [6]:
class Connect4:
    def get_col_row(self, col, row):
        pos = col * 7 + row
        mask = 1 << pos
        if self.data[1] & mask:
            return 2
        return int(bool(self.data[0] & mask))
    
    def iswin(self):
        bitboard = self.data[1-self.turn%2]
        bound = 1 << 49 # 49 = 7*(6+1)  
        # horizontal: 0x204081 = 1|(1<<7)|(1<<14)|(1<<21)
        # vertical: 0xf = 1|(1<<1)|(1<<2)|(1<<3)
        # up-right: 0x1010101 = 1|(1<<8)|(1<<16)|(1<<24)
        # down-right: 0x208208 = (1<<3)|(1<<9)|(1<<15)|(1<<21)
        for mask in [0x204081, 0xf,  0x1010101, 0x208208]:
            while mask < bound:
                if mask & bitboard == mask:
                    return True
                mask <<= 1
        return False
    
    def set_col_row(self, col, row, value):
        # assert value in [0,1,2]
        pos = col * 7 + row
        mask = 1 << pos
        neg_mask = 0xFFFFFFFF^mask       
        if value == 1 or value ==2:
            self.data[value-1] |= mask
            self.data[2-value] &= neg_mask
        else:
            self.data[0] &= neg_mask
            self.data[1] &= neg_mask
            
    def __init__(self, data=None):
        self.data = [0, 0] if data is None else data
        self.turn = 0
        self.history = []
    
    def remove(self, col):
        shift = col*7
        mask = (((self.data[0]|self.data[1]) >> shift) &0x3f) +1
        mask >>= 1
        neg_mask = 0xFFFFFFFF^mask
        self.data[0] &= neg_mask
        self.data[1] &= neg_mask
        
    def move(self, col, test=False):
        assert 0<= col <7
        shift = col*7
        mask = (((self.data[0]|self.data[1]) >> shift) &0x3f) +1
        if mask >= 64:
            return None
        if not test:
            self.data[self.turn%2] |= (mask<<shift)
            self.turn += 1
        return self
    
    def board_data(self):
        for i in range(7):
            for j in range(6):
                c = self.get_col_row(i,j)
                if c!=0:
                    yield i,j,c
                    
    def _repr_html_(self):
        def pos(i):
            return int(7+(220-6.5)*i/8)
        imgstr = "<img src='img/%s.png' width='23px' height='23px' style='position: absolute; top: %spx; left: %spx;margin-top: 0' />"
        header = """<div style="width: 200px; height:180px;position: relative;background: blue">"""
        header += "\n".join(imgstr%('empty', pos(5-j), pos(i)) for i in range(7) for j in range(6))
        return header +"\n".join(imgstr%('red_coin' if c==1 else 'yellow_coin', pos(5-j), pos(i)) for (i,j,c) in self.board_data()) +"</div>"
    
    def display(self):
        display(HTML(self._repr_html_()))
    
    def __repr__(self):
        row_str = lambda j: "".join(".ox"[self.get_col_row(i,j)] for i in range(7))
        return "\n".join(row_str(j) for j in range(5,-1,-1))

In [7]:
game = Connect4()
game.move(3).move(4).move(4).display()
game.move(2).move(1).move(2).display()
print(repr(game))

.......
.......
.......
.......
..x.o..
.oxox..


In [26]:
def test_moves(name, moves, answers, display=False):
    print("Test::"+name)
    game = Connect4()
    for i, m in enumerate(moves):
        if game.move(m) is None:
            assert answers[i] is None
            continue
        if display:
            print(i, m, game.iswin(), answers[i])
            game.display()
        assert game.iswin() == answers[i]
    return game

# test
test_data = [
    ("Overflow1", [3]*7, [False]*6+[None]), 
    ("Overflow2", [6]*7, [False]*6+[None]), 
    ("Overflow3", [0,1,2]*7+[3], [False]*18+[None]*3+[False]),
    ("Vertical1", [1,2]*4, [False]*6+[True]*2),
    ("Vertical2", [1,2]*3+[2,1]*3, [False]*12),
    ("Vertical3", [6]*3+[5,6]*3, [False]*8+[True]),
    ("Horizontal1", [0,0,1,1,2,2,3,3], [False]*6+[True]*2),
    ("Horizontal2", [0,1,2,3,4,5,6]*2+[1,2,3,4,5,6,0]*2+[0,1,2,3,4,5,6]+[6,0,5,1,4,2,3], [False]*41+[True]),
    ("Diagonal1", [0,1,2,3,4,5,6]*3+[0,1], [False]*21+[True]*2),
    ("Diagonal2", [0,1,2,3,4,5,6]*3+[1,2], [False]*21+[False]*2),
    ("Diagonal3", [0,1,2,3,4]*3+[2,3,4], [False]*15+[False,True,True]),
    ("Diagonal4", [0,1,2,3,4,5,6]*3+[2,3], [False]*21+[True]*2),]
for data in test_data:
    test_moves(*data).display()

Test::Overflow1


Test::Overflow2


Test::Overflow3


Test::Vertical1


Test::Vertical2


Test::Vertical3


Test::Horizontal1


Test::Horizontal2


Test::Diagonal1


Test::Diagonal2


Test::Diagonal3


Test::Diagonal4
