In [None]:
import gwb as gwb
from gwb import GM as gm

import numpy as np
import ot
import os
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import trange

In [None]:
def img2atomic(img):
    '''
    Creates a discrete measure from an image.
    '''
    assert img.ndim == 2, "img needs to be 2d array"
    x, y = img.shape
    pts = np.stack([grid.flatten() for grid in np.meshgrid(np.arange(x), y-np.arange(y))], axis=1)
    return pts[img.flatten() > 0], img.flatten()[img.flatten() > 0]

In [None]:
ims = []
ims_fill = []
i=0
resize_pix = 10
for filename in sorted(os.listdir("../data/2d_shapes/")):
    if filename.endswith(".gif"):
        #print(filename)
        image = Image.open('../data/2d_shapes/' + filename)
        image = image.convert("L")
        tmp = np.array(image)
        tmp[tmp < 0.01] = 0
        image = Image.fromarray(tmp)
        ims_fill.append(image)
        image = image.resize((resize_pix,resize_pix))
        #image = image.filter(ImageFilter.FIND_EDGES)

        #plt.imshow(image)
        tmp = np.array(image,dtype = float)
        tmp[tmp!= 0] = 1
        ims.append(tmp/np.max(tmp))
        continue
    else:
        continue
n = len(ims)

In [None]:
n_ims_per_class = 10
classes = [0,8,12,50]
n_classes = len(classes)
N = n_classes * n_ims_per_class
ims_pick = np.concatenate([ims[i*n_ims_per_class:(i+1)*n_ims_per_class] for i in classes],axis=0)

In [None]:
Xs = []
for i in range(N):
    points,measure = img2atomic(ims_pick[i])
    points = np.array(points,dtype=float)
    measure /= np.sum(measure)
    X = gm(mode="euclidean",gauge_mode = "euclidean",X=points,xi=measure,normalize_gauge=True)
    Xs.append(X)

In [None]:
pwGW = gwb.pairwise_GW(Xs)

In [None]:
plt.imshow(pwGW)

In [None]:
n_its_tb = 5
i_init_tb = 0
n = np.sum([X.len for X in Xs]) - N + 1
n_sample = 500

In [None]:
#TB iterations and spectral clustering
bary = Xs[i_init_tb]
#bary = gwb.sample_GM(bary,n=n_sample)
LGWs = []
for i in trange(n_its_tb):
    bary_prev = bary
    bary,log = gwb.TB(bary_prev,Xs,ws = ot.unif(N),mode="avg_gauge_only",log=True)
    idxs, meas, Ps = log.values()
    bary = gwb.sample_GM(bary,n=n_sample)
    
    #LGW   
    LGWs.append(gwb.LGW_via_idxs(Xs,idxs,meas))
print("TB iterations and clustering completed!")

In [None]:
fig, ax = plt.subplots(1,n_its_tb,figsize=(10,10*n_its_tb))
for i in range(n_its_tb):
    ax[i].imshow(LGWs[i])
plt.show()

In [None]:
for i in range(len(LGWs)):
    print(np.corrcoef(pwGW.flatten(),LGWs[i].flatten())[1,0])