In [4]:
import yaml
import torch
import pandas as pd

from torch_geometric.data import Batch
from data.dataset import PolygonDataset
from model.nn import build_model
from utils.plot import draw_graph

In [5]:
with open('cfg/train_glyph.yaml', 'r') as f:
        cfg = yaml.safe_load(f)
        cls = [k for k, _ in cfg['cls'].items()]

In [None]:
glyph_df = pd.read_pickle(cfg['test'])
glyph_df = glyph_df[(glyph_df.name.isin(cls))].reset_index(drop=True)
glyph_set = PolygonDataset(glyph_df, cfg['cls'])

In [7]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

In [8]:
idx = 0
data = Batch.from_data_list([glyph_set[idx]])

In [9]:
cfg['nn'] = 'deepset'
cfg['path'] = f"save/frac0.8/{cfg['nn']}/ckpt/epoch100"

deepset = build_model(cfg=cfg)
deepset.load_state_dict(torch.load(cfg['path'])['params'])
deepset.eval()
deepset.conv2.register_forward_hook(get_activation('conv2'))

deepset_pred = deepset(data).argmax(1).item()
deepset_feat = activation['conv2'].mean(1).softmax(dim=0).cpu().detach().numpy()

In [10]:
cfg['nn'] = 'transformer'
cfg['path'] = f"save/frac0.8/{cfg['nn']}/ckpt/epoch100"

transformer = build_model(cfg=cfg)
transformer.load_state_dict(torch.load(cfg['path'])['params'])
transformer.eval()
transformer.enc.register_forward_hook(get_activation('enc'))

tm_pred = transformer(data).argmax(1).item()
tm_feat = activation['enc'].squeeze(0)[:data.pos.size(0)].mean(1).softmax(dim=0).cpu().detach().numpy()

In [11]:
cfg['nn'] = 'gcn'
cfg['path'] = f"save/frac0.8/{cfg['nn']}/ckpt/epoch100"

gcn = build_model(cfg=cfg)
gcn.load_state_dict(torch.load(cfg['path'])['params'])
gcn.eval()
gcn.conv2.register_forward_hook(get_activation('conv2'))

gcn_pred = gcn(data).argmax(1).item()
gcn_feat = activation['conv2'].mean(1).softmax(dim=0).cpu().detach().numpy()

In [12]:

cfg['nn'] = 'dsc_nmp'
cfg['path'] = f"save/frac0.8/{cfg['nn']}/ckpt/epoch100"

dsc_nmp = build_model(cfg=cfg)
dsc_nmp.load_state_dict(torch.load(cfg['path'])['params'])
dsc_nmp.eval()
dsc_nmp.conv2.register_forward_hook(get_activation('conv2'))

dsc_nmp_pred = dsc_nmp(data).argmax(1).item()
dsc_nmp_feat = activation['conv2'].mean(1).softmax(dim=0).cpu().detach().numpy()

In [13]:
cfg['nn'] = 'polymp'
cfg['path'] = f"save/frac0.8/{cfg['nn']}/ckpt/epoch100"

polymp = build_model(cfg=cfg)
polymp.load_state_dict(torch.load(cfg['path'])['params'])
polymp.eval()
polymp.mp2.register_forward_hook(get_activation('mp2'))

polymp_pred = polymp(data).argmax(1).item()
polymp_feat = activation['mp2'].mean(1).softmax(dim=0).cpu().detach().numpy()

In [14]:
cfg['nn'] = 'dsc_polymp'
cfg['path'] = f"save/frac0.8/{cfg['nn']}/ckpt/epoch100"

dsc_polymp = build_model(cfg=cfg)
dsc_polymp.load_state_dict(torch.load(cfg['path'])['params'])
dsc_polymp.eval()
dsc_polymp.mp2.register_forward_hook(get_activation('mp2'))

dsc_polymp_pred = dsc_polymp(data).argmax(1).item()
dsc_polymp_feat = activation['mp2'].mean(1).softmax(dim=0).cpu().detach().numpy()

In [None]:
import matplotlib.pyplot as plt
plt.rcParams['font.size'] = 20
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams['savefig.dpi'] = 250
plt.rcParams['figure.dpi'] = 250

fig, axes = plt.subplots(nrows=1, ncols=6,  figsize=(18, 3))

axes[0].set_title(f'DeepSet \n Pred.: {cls[deepset_pred]}')
draw_graph(data.pos.numpy(), data.edge_index.t().numpy(), ax=axes[0], mask=deepset_feat)

axes[1].set_title(f'SetTransformer \n Pred.: {cls[tm_pred]}')
draw_graph(data.pos.numpy(), data.edge_index.t().numpy(), ax=axes[1], mask=tm_feat)

axes[2].set_title(f'GCAE \n Pred.: {cls[gcn_pred]}')
draw_graph(data.pos.numpy(), data.edge_index.t().numpy(), ax=axes[2], mask=gcn_feat)

axes[3].set_title(f'DSC-NMP \n Pred.: {cls[dsc_nmp_pred]}')
draw_graph(data.pos.numpy(), data.edge_index.t().numpy(), ax=axes[3], mask=dsc_nmp_feat)

axes[4].set_title(f'PolyMP \n Pred.: {cls[polymp_pred]}')
draw_graph(data.pos.numpy(), data.edge_index.t().numpy(), ax=axes[4], mask=polymp_feat)

axes[5].set_title(f'PolyMP-DSC \n Pred.: {cls[dsc_polymp_pred]}')
draw_graph(data.pos.numpy(), data.edge_index.t().numpy(), ax=axes[5], mask=dsc_polymp_feat)

## Limitations on Glyph

In [None]:
cfg['nn'] = 'polymp'
cfg['path'] = f"save/frac0.8/{cfg['nn']}/ckpt/epoch100"

polymp = build_model(cfg=cfg)
polymp.load_state_dict(torch.load(cfg['path'])['params'])
polymp.eval()
polymp.mp2.register_forward_hook(get_activation('mp2'))

In [21]:
# idx = [14, 77, 237, 57]
idx = [127, 170,216, 100]
data_batch = [Batch.from_data_list([glyph_set[i]]) for i in idx]

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=len(data_batch), figsize=(3 * len(data_batch), 3))

for data, axis in zip(data_batch, axes):
    polymp_pred = polymp(data).argmax(1).item()
    polymp_feat = activation['mp2'].mean(1).softmax(dim=0).cpu().detach().numpy()

    axis.set_title(f'Label: {cls[data.y]} \n Pred.: {cls[polymp_pred]}')
    draw_graph(data.pos.numpy(), data.edge_index.t().numpy(), ax=axis, mask=polymp_feat)

## Limitations on OSM

In [None]:
osm_df = pd.read_pickle(cfg['osm_test'])
osm_df = osm_df[(osm_df.name.isin(cls))].reset_index(drop=True)
osm_set = PolygonDataset(osm_df, cfg['cls'])

In [25]:
cfg['nn'] = 'polymp'
cfg['path'] = f"save/finetune/normalised/{cfg['nn']}/ckpt/epoch100"

polymp = build_model(cfg=cfg)
polymp.load_state_dict(torch.load(cfg['path'])['params'])
polymp.eval()

polymp.mp2.register_forward_hook(get_activation('mp2'))
polymp_pred = polymp(data).argmax(1).item()
polymp_feat = activation['mp2'].mean(1).softmax(dim=0).cpu().detach().numpy()

In [29]:
# idx = [173, 83, 214, 213]
idx = [66, 86, 124, 215]
data_batch = [Batch.from_data_list([osm_set[i]]) for i in idx]

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=len(data_batch), figsize=(3 * len(data_batch), 3))

for data, axis in zip(data_batch, axes):
    polymp_pred = polymp(data).argmax(1).item()
    polymp_feat = activation['mp2'].mean(1).softmax(dim=0).cpu().detach().numpy()

    axis.set_title(f'Label: {cls[data.y]} \n Pred.: {cls[polymp_pred]}')
    draw_graph(data.pos.numpy(), data.edge_index.t().numpy(), ax=axis, mask=polymp_feat)