In [2]:
import math
import sys
from os.path import abspath

if abspath("..") not in sys.path:
    sys.path.append(abspath(".."))

import matplotlib.pyplot as plt
import numpy as np
import torch

from flows4manufacturing.common.flows import (
    AffineCouplingBlock,
    NormalizingFlow,
    SequentialBijector,
)
from flows4manufacturing.image_generation.generation import (
    Autoencoder,
    ScaleTranslateNet,
)
from flows4manufacturing.image_generation.kolektor import KolektorDataset


device = "cuda" if torch.cuda.is_available() else "cpu"


In [2]:
input_shape = (1, 64, 64)
hidden = 16
flow_hidden = 512
checkerboard = torch.arange(hidden) % 2 == 0
autoencoder = Autoencoder(input_shape, hidden)
bij = SequentialBijector(
    AffineCouplingBlock(ScaleTranslateNet(hidden, flow_hidden), checkerboard),
    AffineCouplingBlock(ScaleTranslateNet(hidden, flow_hidden), ~checkerboard),
    AffineCouplingBlock(ScaleTranslateNet(hidden, flow_hidden), checkerboard),
    AffineCouplingBlock(ScaleTranslateNet(hidden, flow_hidden), ~checkerboard),
    AffineCouplingBlock(ScaleTranslateNet(hidden, flow_hidden), checkerboard),
    AffineCouplingBlock(ScaleTranslateNet(hidden, flow_hidden), ~checkerboard),
    AffineCouplingBlock(ScaleTranslateNet(hidden, flow_hidden), checkerboard),
    AffineCouplingBlock(ScaleTranslateNet(hidden, flow_hidden), ~checkerboard),
)
flow = NormalizingFlow(bij, (hidden,))


In [17]:
# Replace with paths to your trained autoencoder and flow checkpoints
autoencoder.load_state_dict(torch.load("./image-autoencoder.pt"))
flow.load_state_dict(torch.load("./image-flow.pt"))
autoencoder.to(device)
flow.to(device)
pass


In [18]:
dataset = KolektorDataset(r"C:\Users\Matthew\Downloads\KolektorSDD2\train")
defects_only = [x for x in dataset if x[1] == 1]


In [32]:
def show_images(out: torch.Tensor, color: bool = True):
    ims = []
    for x in out:
        im = (x[0].detach().cpu().numpy() * 255).astype(np.uint8)
        ims.append(im)

    NUM_PER_ROW = 8
    num_rows = math.ceil(out.shape[0] / NUM_PER_ROW)
    fig, axs = plt.subplots(
        ncols=NUM_PER_ROW,
        nrows=num_rows,
        figsize=(7, num_rows * 7 / NUM_PER_ROW),
        dpi=1500,
        layout="constrained",
    )
    if num_rows > 1:
        axs = [ax for row in axs for ax in row]
    for im, ax in zip(ims, axs):
        ax.imshow(im, vmin=0, vmax=255, cmap="gray" if color else "binary")
        ax.axis("off")
    fig.get_layout_engine().set(wspace=0.05, hspace=0.05, w_pad=0.0, h_pad=0.0)
    return fig


In [None]:
torch.manual_seed(0)
codes = flow.sample(8)
out = autoencoder.decode(codes)
fig = show_images(out, color=False)
fig.savefig("../figures/flow-images-false.jpg")
torch.manual_seed(0)
bad_codes = torch.randn_like(codes)
bad_out = autoencoder.decode(bad_codes)
fig = show_images(bad_out, color=False)
fig.savefig("../figures/ae-images-false.jpg")
samples = torch.stack([im for im, label in defects_only[:8]])
fig = show_images(samples, color=False)
fig.savefig("../figures/real-images-false.jpg")