In [41]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm import tqdm

#### GPU Usage

In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


#### Dataset class for inpainting

In [43]:
class ImageInpaintingDataset(Dataset):
    def __init__(self, data_dir="../data/neo_usable", mask_percentage=0.7, image_size=(256, 256), transform=None):
        self.image_paths = [os.path.join(data_dir, fname) for fname in os.listdir(data_dir) if fname.endswith('.png')]
        self.mask_percentage = mask_percentage
        self.image_size = image_size
        self.transform = transform
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = cv2.imread(image_path)
        image = cv2.resize(image, self.image_size)
        if self.transform:
            image = self.transform(image)
        
        mask = np.ones(self.image_size, dtype=np.uint8)
        mask_height = int(self.image_size[0] * self.mask_percentage)
        mask[:mask_height, :] = 0
        
        mask = torch.tensor(mask).unsqueeze(0)
        return image, mask

#### UNet architecture

In [44]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

#### Diffusion process for noise addition

In [45]:
def forward_diffusion(image, t, noise_schedule):
    noise = torch.randn_like(image)  # Gaussian noise
    noisy_image = torch.sqrt(1 - noise_schedule[t]) * image + torch.sqrt(noise_schedule[t]) * noise
    return noisy_image

# Loss function for diffusion model
def diffusion_loss(pred_noise, noisy_image):
    return nn.MSELoss()(pred_noise, noisy_image)

#### Training

In [46]:
def train_diffusion_model(model, dataset, num_epochs=10, batch_size=16, timesteps=1000, learning_rate=1e-4):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    noise_schedule = torch.linspace(1e-4, 0.02, timesteps).to(device)

    model = model.to(device)
    
    print("Starting training...")
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (images, masks) in enumerate(tqdm(dataloader)):
            images = images.to(device).permute(0, 3, 1, 2)  # Convert to NCHW format
            masks = masks.to(device)

            t = torch.randint(0, timesteps, (images.size(0),), dtype=torch.long).to(device)
            noisy_images = forward_diffusion(images, t, noise_schedule)

            pred_noise = model(noisy_images)
            loss = diffusion_loss(pred_noise, noisy_images)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Progress tracking
            if i % 10 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}")

        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {running_loss / len(dataloader):.4f}")

#### Save results

In [47]:
def save_reconstructed_images(model, dataset, output_dir="../data/neo_diffusion"):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
    noise_schedule = torch.linspace(1e-4, 0.02, 1000).to(device)

    model = model.to(device)
    model.eval()

    with torch.no_grad():
        for i, (image, mask) in enumerate(tqdm(dataloader)):
            image = image.to(device).permute(0, 3, 1, 2)
            mask = mask.to(device)

            t = torch.randint(0, 1000, (1,), dtype=torch.long).to(device)
            noisy_image = forward_diffusion(image, t, noise_schedule)
            reconstructed_image = model(noisy_image)

            reconstructed_image = reconstructed_image.cpu().permute(0, 2, 3, 1).numpy()[0]
            reconstructed_image = (reconstructed_image * 255).astype(np.uint8)

            output_path = os.path.join(output_dir, f"reconstructed_{i}.png")
            cv2.imwrite(output_path, reconstructed_image)
            print(f"Saved: {output_path}")

#### Load dataset and train the model

In [48]:
if __name__ == "__main__":
    dataset = ImageInpaintingDataset()  # Load usable images
    model = UNet()  # Initialize the U-Net model
    
    # Train the model
    train_diffusion_model(model, dataset, num_epochs=5)

    # Save the reconstructed images
    save_reconstructed_images(model, dataset)

Starting training...


  0%|          | 0/514 [00:00<?, ?it/s]