In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
import matplotlib
# matplotlib.use("TkAgg")
# %matplotlib tk
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
import dtnnlib as dtnn

In [None]:
# device = torch.device("cuda:0")
device = torch.device("cpu")

In [None]:
num_points = 1000
X1 = np.linspace(-1.5, 1.5, num_points)
X2 = np.linspace(-1.5, 1.5, num_points)
X1, X2 = np.meshgrid(X1, X2)

XX = torch.Tensor(np.c_[X1.reshape(-1), X2.reshape(-1)]).to(device)

In [None]:
torch.manual_seed(741) # 147

## Distance Voronoi

In [None]:
class DistanceVoronoi(nn.Module):
    
    def __init__(self, h, bias=False):
        super().__init__()
        self.dt = dtnn.DistanceTransformBase(2, h)
        self.dt.centers.data = torch.rand(self.dt.centers.shape)*2-1
        self.bias = None
        if bias:
            self._init_bias_()
        
    def _init_bias_(self):
        self.bias = torch.randn(self.dt.centers.shape[0])*0.2
        
    def forward(self, x):
        dists = self.dt(x)
        if self.bias is not None:
            dists = dists + self.bias
#         class_probs = torch.argmin(dists, dim=1)
        return torch.argmin(dists, dim=1)

    def set_centroid(self, index, value):
        self.dt.centers[index] = torch.Tensor(value, dtype=self.dt.centers.dtype)
        pass

In [None]:
out = DistanceVoronoi(10)(XX)
torch.unique(out, return_counts=True)

In [None]:
out.shape

In [None]:
regions = 10
voronoi = DistanceVoronoi(regions)
cls = voronoi(XX)

In [None]:
!mkdir outputs/02_voronoi_diagrams/

In [None]:
plt.figure(figsize=(5, 5))

cents = voronoi.dt.centers.data.cpu()
cent_label = np.arange(0, regions, step=1)

plt.scatter(*cents.t(), c=cent_label, s=100, cmap='tab10', ec='k')
    
# plt.gca().set_prop_cycle(None)
plt.imshow(cls.data.cpu().numpy().reshape(X1.shape), interpolation='nearest',
           extent=(-1.5, 1.5, -1.5, 1.5),
           alpha=0.6, cmap='tab10',
           aspect='auto', origin='lower')
plt.savefig("./outputs/02_voronoi_diagrams/voronoi_distance_nobias.pdf", bbox_inches='tight')

In [None]:
cls

In [None]:
voronoi.dt.centers

#### With output scale/bias

In [None]:
regions = 10
# voronoi = DistanceVoronoi(regions, bias=True)
voronoi._init_bias_()
cls = voronoi(XX)

cents = voronoi.dt.centers.data.cpu()
cent_label = np.arange(0, regions, step=1)

plt.figure(figsize=(5, 5))
plt.scatter(*cents.t(), c=cent_label, s=100, cmap='tab10', ec='k')
# plt.gca().set_prop_cycle(None)
plt.imshow(cls.data.cpu().numpy().reshape(X1.shape), interpolation='nearest',
           extent=(-1.5, 1.5, -1.5, 1.5),
           alpha=0.6, cmap='tab10',
           aspect='auto', origin='lower')
plt.savefig("./outputs/02_voronoi_diagrams/voronoi_distance_bias.pdf", bbox_inches='tight')

#### With weights shift

In [None]:
regions = 10
voronoi.dt.centers.data -= 0.5
cls = voronoi(XX)

cents = voronoi.dt.centers.data.cpu()
cent_label = np.arange(0, regions, step=1)

plt.figure(figsize=(5, 5))
plt.scatter(*cents.t(), c=cent_label, s=100, cmap='tab10', ec='k')
# plt.gca().set_prop_cycle(None)
plt.imshow(cls.data.cpu().numpy().reshape(X1.shape), interpolation='nearest',
           extent=(-1.5, 1.5, -1.5, 1.5),
           alpha=0.6, cmap='tab10',
           aspect='auto', origin='lower')
plt.savefig("./outputs/02_voronoi_diagrams/voronoi_distance_shift.pdf", bbox_inches='tight')

## Linear Voronoi

In [None]:
class LinearVoronoi(nn.Module):
    
    def __init__(self, h, bias=False):
        super().__init__()
        self.lin = nn.Linear(2, h, bias=False)
        self.lin.weight.data = torch.rand(self.lin.weight.shape)*2-1
        self.lin.weight.data /= torch.norm(self.lin.weight.data, dim=1, keepdim=True)
        self.lin.weight.data *= 0.7+0.3*2*(torch.rand_like(self.lin.weight)-0.5)
        self.bias = None
        if bias:
            self._init_bias_()
        
    def _init_bias_(self):
        self.bias = torch.randn(self.lin.weight.shape[0])*0.2
        
    def forward(self, x):
        dists = self.lin(x)
        if self.bias is not None:
            dists = dists + self.bias
#         class_probs = torch.argmax(dists, dim=1)
        return torch.argmax(dists, dim=1)

    def set_centroid(self, index, value):
        self.lin.weight[index] = torch.Tensor(value, dtype=self.lin.weight.dtype)
        pass

In [None]:
regions = 10
voronoi = LinearVoronoi(regions)
cls = voronoi(XX)

In [None]:
torch.unique(cls, return_counts=True)

In [None]:
%matplotlib inline

plt.figure(figsize=(5, 5))

cls = voronoi(XX)
cents = voronoi.lin.weight.data.cpu()
cent_label = np.arange(0, regions, step=1)

# plt.scatter(*cents.t(), c=cent_label, s=100, cmap='tab10', ec='k')
for i, cent in enumerate(cents):
#     print(i)
    c = matplotlib.cm.tab10(i)
    plt.scatter(cent[0], cent[1], facecolor=c, s=50, cmap='tab10', ec='k')
    plt.arrow(0, 0, cent[0], cent[1], 
              head_width=0.05, head_length=0.04, linestyle='solid', linewidth=3,
              alpha=0.8, fc="k", ec=c)
    
plt.gca().set_prop_cycle(None)
plt.imshow(cls.data.cpu().numpy().reshape(X1.shape), interpolation='nearest',
           extent=(-1.5, 1.5, -1.5, 1.5),
           alpha=0.6, cmap='tab10',
           aspect='auto', origin='lower')
plt.savefig("./outputs/02_voronoi_diagrams/voronoi_linear_nobias.pdf", bbox_inches='tight')

In [None]:
voronoi.lin.weight.data

#### With output scale/bias

In [None]:
regions = 10
voronoi._init_bias_()
cls = voronoi(XX)

cents = voronoi.lin.weight.data.cpu()
cent_label = np.arange(0, regions, step=1)

plt.figure(figsize=(5, 5))
for i, cent in enumerate(cents):
    c = matplotlib.cm.tab10(i)
    plt.scatter(cent[0], cent[1], facecolor=c, s=50, cmap='tab10', ec='k')
    plt.arrow(0, 0, cent[0], cent[1], 
              head_width=0.05, head_length=0.04, linestyle='solid', linewidth=3,
              alpha=0.8, fc="k", ec=c)
    
plt.gca().set_prop_cycle(None)
plt.imshow(cls.data.cpu().numpy().reshape(X1.shape), interpolation='nearest',
           extent=(-1.5, 1.5, -1.5, 1.5),
           alpha=0.6, cmap='tab10',
           aspect='auto', origin='lower')
plt.savefig("./outputs/02_voronoi_diagrams/voronoi_linear_bias.pdf", bbox_inches='tight')

#### With weights shift

In [None]:
voronoi.lin.weight.data -= 0.5

In [None]:
plt.figure(figsize=(5, 5))

cls = voronoi(XX)
cents = voronoi.lin.weight.data.cpu()
cent_label = np.arange(0, regions, step=1)

for i, cent in enumerate(cents):
    c = matplotlib.cm.tab10(i)
    plt.scatter(cent[0], cent[1], facecolor=c, s=50, cmap='tab10', ec='k')
    plt.arrow(0, 0, cent[0], cent[1], 
              head_width=0.05, head_length=0.04, linestyle='solid', linewidth=3,
              alpha=0.8, fc="k", ec=c)
    
plt.gca().set_prop_cycle(None)
plt.imshow(cls.data.cpu().numpy().reshape(X1.shape), interpolation='nearest',
           extent=(-1.5, 1.5, -1.5, 1.5),
           alpha=0.6, cmap='tab10',
           aspect='auto', origin='lower')
plt.savefig("./outputs/02_voronoi_diagrams/voronoi_linear_shift.pdf", bbox_inches='tight')