In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import conv_autoenc
import torch.nn as nn
import h5py
import umap
import scipy.spatial as spat

In [None]:
dset_path="./data/data_temp/"
batch_size=20
n_z=20
n_channels=2
visual=True

In [None]:
h5=h5py.File(dset_path,"r+")
T,W,H=h5.attrs["T"],h5.attrs["W"],h5.attrs["H"]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def to_np(ten):
    return ten.cpu().detach().numpy()
class DS(torch.utils.data.Dataset):
    def __init__(self):
        super(DS,self).__init__()
    def __getitem__(self,i):
        return torch.tensor(np.max(np.array(h5[str(i)+"/frame"]),axis=3)/255)
    def __len__(self):
        return T
ds=DS()
dl=torch.utils.data.DataLoader(ds,batch_size=batch_size,shuffle=True,pin_memory=True)
net=conv_autoenc.Net(n_channels=n_channels,n_z=n_z)
net.to(device=device)
None

In [None]:
%matplotlib notebook
if visual:
    fig=plt.figure(figsize=(8,5))
    ax1=fig.add_subplot(2,1,1)
    lplot=ax1.plot([],[],label="Loss")[0]
    ax1.legend()
    ax1.set_yscale("log")
    ax2=fig.add_subplot(2,2,3)
    im=ax2.imshow(np.zeros((W,H)).T,vmin=0,vmax=0.8)
    ax3=fig.add_subplot(2,2,4)
    imans=ax3.imshow(np.zeros((W,H)).T,vmin=0,vmax=0.8)
    def update():
        if len(losses)<2:
            return
        ax1.set_ylim(np.min(losses),np.max(losses))
        ax1.set_xlim(1,len(losses))
        ts=np.arange(1,len(losses)+1)
        lplot.set_data(np.stack([ts,np.array(losses)]))
        im.set_array(to_np(res[0,0]).T)
        imans.set_array(to_np(ims[0,0]).T)
        fig.canvas.draw()
num_epochs=30
opt=torch.optim.Adam(net.parameters())
losses=[]

In [None]:
for epoch in range(num_epochs):
    print("\r Epoch "+str(epoch+1)+"/"+str(num_epochs),end="")
    for i,ims in enumerate(dl):
        ims=ims.to(device=device,dtype=torch.float32)
        res,latent=net(ims)
        loss=nn.functional.mse_loss(res,ims)
        opt.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.item())
        if visual:
            update()

In [None]:
net.eval()
vecs=[]
with torch.no_grad():
    for i in range(T):
        if (i+1)%100==0:
            print("\r"+str(i)+"/"+str(T),end="")
        _,latent=net(ds[i].unsqueeze(0).to(device=device,dtype=torch.float32))
        vecs.append(to_np(latent[0]))
vecs=np.array(vecs)

In [None]:
key="vecs"
if key in h5.keys():
    del h5[key]
ds=h5.create_dataset(key,shape=(T,n_z),dtype="f4")
ds[...]=vecs.astype(np.float32)

In [None]:
def standardize(vecs):
    m=np.mean(vecs,axis=0)
    s=np.std(vecs,axis=0)
    return (vecs-m)/(s+1e-8)
vecs=standardize(vecs)
u_map=umap.UMAP(n_components=2)
res=u_map.fit_transform(vecs)
distmat=spat.distance_matrix(res,res)

In [None]:
plt.subplot(121)
plt.scatter(res[:,0],res[:,1],s=1)
plt.subplot(122)
plt.imshow(distmat)

In [None]:
key="distmat"
if key in h5.keys():
    del h5[key]
ds=h5.create_dataset(key,shape=(T,T),dtype="f4")
ds[...]=distmat.astype(np.float32)

In [None]:
#check
plt.figure(figsize=(8,4))
exs=np.random.choice(T,4,replace=False)
for i,ex in enumerate(exs):
    plt.subplot(4,5,i*5+1)
    plt.imshow(np.max(np.array(h5[str(ex)+"/frame"][0]),axis=2).T)
    close=np.argsort(distmat[ex])[1:]
    for b in range(4):
        plt.subplot(4,5,i*5+b+2)
        plt.imshow(np.max(np.array(h5[str(close[b])+"/frame"][0]),axis=2).T)

In [None]:
h5.close()