In [1]:
from utils.config import CONFIG
CONFIG.PSP_USE_MEAN = True
import torch
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
from psp.pSp_pure import pSp
from utils.dataset import MulticlassImageDataset
from pathlib import Path
from utils.utils import repeat
import torch.nn.functional as F

output_dir = Path("output/paper/mixed")
output_dir.mkdir(parents=True, exist_ok=True)
dataset = MulticlassImageDataset(
    ["input/data/covid_ct_pneumonia/train"],
    ["lung", "localizer", "bones", "drr", "soft"],
)
dataloader = repeat(DataLoader(
    dataset,
    batch_size=2,
    num_workers=2,
    prefetch_factor=2,
    shuffle=True,
))

Loading images from each class: 100%|██████████| 5/5 [00:04<00:00,  1.25it/s]


In [2]:
ckpt = torch.load("output/psp_pure_pneumonia_chosen_softer/checkpoint/160000.pt", map_location="cuda:0")
net = pSp(ckpt).to("cuda:0")
net.latent_avg = net.latent_avg.to("cuda:0") if net.latent_avg is not None else None
net.eval();

In [4]:
for idx in tqdm(list(range(2000))):
    with torch.no_grad():
        batch = next(dataloader)
        img_in = batch["lung"].to("cuda:0")
        img_in1 = img_in[0].unsqueeze(0)
        img_in2 = img_in[1].unsqueeze(0)

        img_out1, codes1 = net(img_in1)
        img_out2, codes2 = net(img_in2)

        code_7 = torch.concat([codes1[:,:7,:], codes2[:,7:,:]], dim=1)
        code_mean = (codes1 + codes2) / 2

        img_7 = net.decoder([code_7], return_latents=False, input_type="w_plus", noises=None)
        img_7 = F.interpolate(img_7, size=(512, 512), mode="bilinear", align_corners=True)

        img_mean = net.decoder([code_mean], return_latents=False, input_type="w_plus", noises=None)
        img_mean = F.interpolate(img_mean, size=(512, 512), mode="bilinear", align_corners=True)

        img1 = torch.concat([img_in1, img_out1], dim=3)
        img2 = torch.concat([img_in2, img_out2], dim=3)
        img_original = torch.concat([img1, img2], dim=2)
        img = torch.concat([img_original, img_mean, img_7], dim=3)

    save_image(
        img,
        str(output_dir / f"{str(idx).zfill(6)}.png"),
        nrow=1,
        normalize=True,
        value_range=(-1, 1),
    )

100%|██████████| 2000/2000 [08:55<00:00,  3.74it/s]
