In [1]:
# !git clone https://github.com/simulamet-host/conditional-polyp-diffusion.git
!pip install taming-transformers-rom1504
!pip install pytorch_lightning

Collecting taming-transformers-rom1504
  Using cached taming_transformers_rom1504-0.0.6-py3-none-any.whl.metadata (406 bytes)
Collecting torch (from taming-transformers-rom1504)
  Using cached torch-2.8.0-cp39-cp39-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting torchvision (from taming-transformers-rom1504)
  Using cached torchvision-0.23.0-cp39-cp39-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting omegaconf>=2.0.0 (from taming-transformers-rom1504)
  Using cached omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)
Collecting pytorch-lightning>=1.0.8 (from taming-transformers-rom1504)
  Using cached pytorch_lightning-2.5.6-py3-none-any.whl.metadata (20 kB)
Collecting antlr4-python3-runtime==4.9.* (from omegaconf>=2.0.0->taming-transformers-rom1504)
  Using cached antlr4_python3_runtime-4.9.3-py3-none-any.whl
Collecting torchmetrics>0.7.0 (from pytorch-lightning>=1.0.8->taming-transformers-rom1504)
  Using cached torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collectin

In [5]:
from ldm.models.autoencoder import VQModel
import torch
from torch.utils.data import DataLoader, Dataset
import os, cv2, random
import numpy as np
from PIL import Image
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torchvision import transforms

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "0" 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
# Setting params for VQVAE
ddconfig ={
        "double_z": False,
        "z_channels": 3,
        "resolution": 256,
        "in_channels": 3,
        "out_ch":3,
        "ch":128,
        "ch_mult":[1, 2, 4],
        "num_res_blocks":2,
        "attn_resolutions":[],
        "dropout":0.0
}
lossconfig = {
    "target": "taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator",
    "params": {
        "disc_conditional": False,
        "disc_in_channels": 3,
        "disc_start": 0,
        "disc_weight": 0.75,
        "codebook_weight": 1.0
    }
}
dataconfig = {
      "batch_size": 1,
      "num_workers": 4,
      "path": "../../../01_data/02_preproc/02_abnormal/P2/images", # image path
      "size": 512,
}

n_embed = 8192
embed_dim = 3

# pretrained model path
ckpt_path = "../../../03_model/model.ckpt"
# saving ckpt path
saving_ckpt_path = "../../../03_model/abnormal/P2/vqvae/"

In [7]:
class InpaintingTrain_autoencoder(Dataset):
    def __init__(self, size, data_root, config=None):
        self.size = size
        self.config = config
        self.data_root=data_root
        self.images = [img for img in os.listdir(data_root) if img.endswith(('.png', '.jpg', '.jpeg', '.bmp'))]

    def generate_stroke_mask(self, im_size, parts=15, maxVertex=25, maxLength=80, maxBrushWidth=60, maxAngle=360):

        mask = np.zeros((im_size[0], im_size[1], 1), dtype=np.float32)
        for i in range(parts):
            mask = mask + self.np_free_form_mask(maxVertex, maxLength, maxBrushWidth, maxAngle, im_size[0], im_size[1])
        mask = np.minimum(mask, 1.0)

        return mask

    def np_free_form_mask(self, maxVertex, maxLength, maxBrushWidth, maxAngle, h, w):

        mask = np.zeros((h, w, 1), np.float32)
        numVertex = np.random.randint(maxVertex + 1)
        startY = np.random.randint(h)
        startX = np.random.randint(w)
        brushWidth = 0
        for i in range(numVertex):
            angle = np.random.randint(maxAngle + 1)
            angle = angle / 360.0 * 2 * np.pi
            if i % 2 == 0:
                angle = 2 * np.pi - angle
            length = np.random.randint(maxLength + 1)
            brushWidth = np.random.randint(10, maxBrushWidth + 1) // 2 * 2
            nextY = startY + length * np.cos(angle)
            nextX = startX + length * np.sin(angle)
            nextY = np.maximum(np.minimum(nextY, h - 1), 0).astype(int)
            nextX = np.maximum(np.minimum(nextX, w - 1), 0).astype(int)
            cv2.line(mask, (startY, startX), (nextY, nextX), 1, brushWidth)
            cv2.circle(mask, (startY, startX), brushWidth // 2, 2)
            startY, startX = nextY, nextX
        cv2.circle(mask, (startY, startX), brushWidth // 2, 2)

        return mask

    def __len__(self):
        return len(self.images)

    def __getitem__(self, i):

        image = np.array(Image.open(self.data_root+"/"+self.images[i]).convert("RGB").resize((self.size,self.size)))
        image = image.astype(np.float32) / 255.0#
        # image = image[None].transpose(0,3,1,2)
        image = torch.from_numpy(image)
        mask = self.generate_stroke_mask([self.size, self.size])
        mask[mask < 0.5] = 0
        mask[mask >= 0.5] = 1
        # mask = mask[None].transpose(0,3,1,2)

        mask = torch.from_numpy(mask)
        masked_image = (1 - mask) * image

        ##50% chance to return a masked_image instead of the original image.
        if random.uniform(0, 1)<0.5:
            batch = {"image": np.squeeze(image,0), "masked_image": np.squeeze(masked_image,0)}
        else:
            batch = {"masked_image": np.squeeze(image,0), "image": np.squeeze(masked_image,0)}

        batch = {"image": np.squeeze(image,0), "masked_image": np.squeeze(masked_image,0)}
        for k in batch:
            batch[k] = batch[k] * 2.0 - 1.0

        return batch

In [8]:
model = VQModel(ddconfig, lossconfig, n_embed=n_embed, embed_dim=embed_dim, ckpt_path=ckpt_path)
# Create a ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath=saving_ckpt_path,  # Directory to save checkpoints
    filename='VQVAE-{epoch:02d}',  # Filename format
    monitor='train/total_loss',  # Metric to monitor
    mode='min',          # Mode for monitoring ('min' or 'max')
    save_top_k=3         # Number of top models to save
)

trainer = Trainer(
    max_epochs=1000,
    callbacks=[checkpoint_callback]
)

making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 3, 64, 64) = 12288 dimensions.
making attention of type 'vanilla' with 512 in_channels
loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


Restored from ../../../03_model/model.ckpt with 0 missing and 0 unexpected keys


In [9]:
data = InpaintingTrain_autoencoder(dataconfig["size"], dataconfig["path"])
data = DataLoader(data,
                batch_size=dataconfig["batch_size"],
                shuffle=True,
                num_workers=dataconfig["num_workers"],
        )

In [10]:
model.learning_rate = 4.5e-06
trainer.fit(model, data)

/home/gil/anaconda3/envs/YS_pt/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                     | Params | Mode 
---------------------------------------------------------------------
0 | encoder         | Encoder                  | 22.3 M | train
1 | decoder         | Decoder                  | 33.0 M | train
2 | loss            | VQLPIPSWithDiscriminator | 17.5 M | train
3 | quantize        | VectorQuantizer2         | 24.6 K | train
4 | quant_con

lr_d 4.5e-06
lr_g 4.5e-06




Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1000` reached.
