In [2]:
import os
import sys
import torch
import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib import image
import numpy as np
from torchvision import transforms
from PIL import Image
import cv2

sys.path.append(os.path.abspath(os.path.join("..")))

In [3]:
from src.models.autoencoder import AELitModule

data_dir = "../data/FDP/"

epoch = 299

run = "logs/train/runs/2023-08-09_13-45-39"

litmodule = AELitModule.load_from_checkpoint(
        os.path.join(os.getcwd(), "..", run, "checkpoints", f"epoch_{epoch}.ckpt")
    )
litmodule.eval()

AELitModule(
  (model): Autoencoder(
    (encoder): Encoder(
      (0): Conv2d(1, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False)
      (1): GDN()
      (2): Conv2d(128, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False)
      (3): GDN()
      (4): Conv2d(128, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False)
      (5): GDN()
      (6): Conv2d(128, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False)
      (7): GDN()
      (8): Conv2d(128, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False)
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (fc1): Linear(in_features=512, out_features=512, bias=True)
    (fc_bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc_bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc3): Linear(in_features=512, out_features

In [7]:
super_real = np.zeros((4096, 4096))
super_fake = np.zeros((4096, 4096))

# codes = ["196116", "247641", "607911",
#                     "30748", "23419", "21934",
#                     "10035", "182922", "183819",
#                     "244041"]
codes = ["76415", "77271"]

ICSD_code = os.listdir(data_dir)

for i in range(2):
    # 
    # chosen_codes = np.random.choice(ICSD_code, 1)
    chosen_codes = [codes[i]]
    # print(chosen_codes)

    transform = transforms.ToTensor()
    structures = torch.stack(
        [
            transform(
                np.load(
                    os.path.join(data_dir, ICSD_code, ICSD_code + "_structure.npy")
                ).astype(np.float32)
            )
            for ICSD_code in chosen_codes
        ]
    )
    patterns = torch.stack(
        [
            transform(
                np.clip(
                    np.load(os.path.join(data_dir, ICSD_code, ICSD_code + "_+0+0+0.npy")),
                    0,
                    1,
                ).astype(np.float32)
            )
            for ICSD_code in chosen_codes
        ]
    )

    fake = litmodule.model(structures.cuda())
    fake = fake.view(128, 128).detach().cpu().numpy()
    real = patterns.view(128, 128).detach().cpu().numpy()

    resolution = 4096
    k = int(np.log2(resolution // 128))

    real = real - np.amin(real)
    real = real / np.amax(real)
    for _ in range(k):
        real = cv2.pyrUp(real)
    super_real += real
    # im_real = Image.fromarray(255 * real)
    # im_real = im_real.convert("L")
    # im_real.save(f"real{i}.png")

    fake = fake - np.amin(fake)
    fake = fake / np.amax(fake)
    for _ in range(k):
        fake = cv2.pyrUp(fake)
    super_fake += fake
    # im_fake = Image.fromarray(255 * fake)
    # im_fake = im_fake.convert("L")
    # im_fake.save(f"fake{i}.png")

super_real = super_real - np.amin(super_real)
super_real = super_real / np.amax(super_real)

im_super_real = Image.fromarray(255 * super_real)
im_super_real = im_super_real.convert("L")
im_super_real.save(f"super_real.png")

super_fake = super_fake - np.amin(super_fake)
super_fake = super_fake / np.amax(super_fake)

im_super_fake = Image.fromarray(255 * super_fake)
im_super_fake = im_super_fake.convert("L")
im_super_fake.save(f"super_fake.png")