In [58]:
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import os
import pyheif

In [70]:
image_size = 256

In [73]:
class IITM_Dataset(Dataset):
    def __init__(self, data_dir = 'iitm_data', patch_size = 512):
        self.data_dir = data_dir
        self.patch_size = patch_size
        self.image_filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.heic')]

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, index):
        # get me a 512,512 image
        img_path = self.image_filenames[index]
        img = None

        if img_path.endswith('.heic'):
            heif_file = pyheif.read(img_path)
            img = Image.frombytes(heif_file.mode, 
                heif_file.size, 
                heif_file.data,
                "raw",
                heif_file.mode,
                heif_file.stride,
            )
        else:
            img = Image.open(img_path)

        if img is None:
            print(f"Failed to load {img_path}")
            return None

        width, height = img.size
        left = np.random.randint(0, width - self.patch_size)
        top = np.random.randint(0, height - self.patch_size)
        right = left + self.patch_size
        bottom = top + self.patch_size

        patch = img.crop((left, top, right, bottom))

        # Apply transformations if needed (e.g., normalization)
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            # Add more transformations if needed
        ])
        patch = transform(patch)

        return patch

In [74]:
batch_size = 4

dataset = IITM_Dataset('../iitm_data')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [72]:
from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

In [84]:
import matplotlib.pyplot as plt

In [98]:
next(iter(dataloader)).shape

torch.Size([4, 3, 256, 256])

In [97]:
model(next(iter(dataloader)), timestep=0).sample.shape

torch.Size([4, 3, 256, 256])