In [1]:
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn.functional import softmax
from torch.distributions import uniform, cauchy, normal, relaxed_bernoulli
from scipy.linalg import toeplitz, circulant
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from matplotlib.collections import PatchCollection
plt.style.use('ggplot')
import os
import ipdb

In [2]:
class coupling_network(torch.nn.Module):
    def __init__(self,num_features=32, img_side=32, kernel_size=5):
        super(coupling_network, self).__init__()
        self.img_side = img_side
        self.num_features = num_features
        self.padding = torch.nn.ZeroPad2d(int((kernel_size - 1) / 2))
        self.conv = torch.nn.Conv2d(1,num_features,kernel_size,stride=1)

    def forward(self,x):
        batch_size = x.shape[0]
        x = self.padding(x)
        x = self.conv(x).reshape(batch_size, self.num_features, -1)
        x = torch.einsum('bci,bcj->bij', x, x)/ self.num_features
        return x

In [3]:
def circular_moments_batch(phases, masks):
    num_groups = masks.shape[1]
    group_size = masks.sum(2)
    group_size = torch.where(group_size == 0, torch.ones_like(group_size), group_size)
    T = phases.shape[0]
    masked_phases = phases.unsqueeze(2) * masks.unsqueeze(0)
    xx = torch.where(masks.bool(), torch.cos(masked_phases), torch.zeros_like(masked_phases))
    yy = torch.where(masks.bool(), torch.sin(masked_phases), torch.zeros_like(masked_phases))
    go = torch.sqrt((xx.sum(-1))**2 + (yy.sum(-1))**2) / group_size
    synch = 1 - go.sum(2)/num_groups
    
    mean_angles = torch.atan2(yy.sum(-1), xx.sum(-1))
    desynch = 0
    
    for m in np.arange(1, int(np.floor(num_groups/2.)) + 1):
#         K_m = 1 if m < int(np.floor(num_groups/2.)) + 1 else -1
        desynch += (1.0 / (2*num_groups * m**2)) * (torch.cos(m*mean_angles).sum(-1)**2 + torch.sin(m*mean_angles).sum(-1)**2)
#     loss = .5*(synch + desynch)
    loss = desynch
    return loss.mean()

In [4]:
# Make data
def make_data(num_samples, num_cells=12, num_textures=4, img_side=32):
    cells_per_texture = int(num_cells / num_textures)
    gray=True
    all_imgs = []
    all_masks = []
    for s in range(num_samples):
        yy = torch.linspace(0,img_side-1, img_side).int()
        xx = torch.linspace(0,img_side-1, img_side).int()
        grid = torch.meshgrid(yy,xx)
        points = torch.randint(0, img_side, size=(num_cells, 2)).float()
        dists = []
        gray_levels = torch.randperm(255)[:num_textures]
        exploded_gray_levels = []
        for i in range(num_textures):
            exploded_gray_levels += int(num_cells / num_textures) * [gray_levels[i]]
        imgs = [exploded_gray_levels[i]*torch.ones((img_side, img_side)) for i in range(num_cells)]
        for p in points:
            dists.append(torch.sqrt((p[0] - grid[0])**2 + (p[1] - grid[1])**2))
        dists = torch.stack(dists)
        masks = [torch.argmin(dists,axis=0) == i for i in range(len(points))]
        composite = torch.stack([mask.unsqueeze(-1)*image for (mask,image) in zip(masks, imgs)]).sum(0)
        stacked_masks = torch.stack(masks)
        eq_masks = torch.stack([stacked_masks[i*cells_per_texture:(i+1)*cells_per_texture].sum(0) for i in range(num_textures)])
        all_imgs.append(composite.mean(-1))
        all_masks.append(eq_masks)
    return torch.stack(all_imgs), torch.stack(all_masks)

In [5]:
def kuramoto_step(phase, coupling, omega, alpha=.01):
    phase_diffs = torch.sin(phase.unsqueeze(-1) - phase.unsqueeze(-2))
    delta = alpha * (omega + (coupling * phase_diffs).mean(1))
    return phase + delta

In [6]:
# Make Data
img_side=32
num_training= 200
num_testing = 50
print('Making data')
training_imgs, training_masks = make_data(num_training, img_side=img_side)
testing_imgs, testing_masks = make_data(num_testing, img_side=img_side)
print('Done')

Making data
Done


In [None]:
batch_size = 20
num_epochs = 10
kuramoto_steps=100
burn_in_steps=95
num_features=25
lr = 1e-4
alpha = 1e-1
sigma = 1.0
num_batches = int(num_training / batch_size)
cn = coupling_network(num_features=num_features, img_side=32)
opt = torch.optim.Adam(cn.parameters(), lr=lr)
lh = []
omega = torch.zeros((batch_size, img_side**2))
for n in range(num_batches):
    opt.zero_grad()
    batch = training_imgs[n*batch_size:(n+1)*batch_size,...] / 255.
    batch = batch.unsqueeze(1)
    masks = training_masks[n*batch_size:(n+1)*batch_size,...].reshape(batch_size, 4, -1)
    coupling = sigma*cn.forward(batch)
    init_phase = torch.normal(np.pi, .1, size=(batch_size, img_side**2))
    phase = init_phase
    flow = []
    for k in range(kuramoto_steps):
        phase = kuramoto_step(phase, coupling, omega, alpha=alpha)
        flow.append(phase)
    flow = torch.stack(flow)
    truncated_flow = flow[burn_in_steps:,...]
    loss = circular_moments_batch(truncated_flow, masks)
    lh.append(loss.detach().cpu().numpy())
    loss.backward()
    opt.step()

plt.plot(lh)
plt.show()

In [None]:
print(np.array(lh).min())
plt.plot(flow[:,0,...].detach().numpy(), color='b')
plt.show()