<a href="https://colab.research.google.com/github/zrghassabi/Diffusion-Models/blob/main/DiffuionModelimageenhancementObectDetection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim

class SyntheticDataset(Dataset):
    def __init__(self, size, img_size, num_classes):
        self.size = size
        self.img_size = img_size
        self.num_classes = num_classes  # Add this if it's necessary for your dataset

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return torch.randn(3, self.img_size, self.img_size)


# Define the synthetic dataset and dataloader
dataset = SyntheticDataset(size=1000, img_size=64, num_classes=10)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


In [None]:
import torch
import torch.nn as nn

# Step 1: Define the U-Net Model
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Encoder
        for feature in features:
            self.encoder.append(self._block(in_channels, feature))
            in_channels = feature

        # Bottleneck
        self.bottleneck = self._block(features[-1], features[-1] * 2)

        # Decoder
        for feature in reversed(features):
            self.decoder.append(
                nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
            )
            self.decoder.append(self._block(feature * 2, feature))

        # Final layer
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for enc in self.encoder:
            x = enc(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.decoder), 2):
            x = self.decoder[idx](x)
            skip_connection = skip_connections[idx // 2]
            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:])
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.decoder[idx + 1](concat_skip)

        return self.final_conv(x)

    def _block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )


In [None]:
# Step 3: Define the Diffusion Process
class DiffusionModel(nn.Module):
    def __init__(self, unet):
        super(DiffusionModel, self).__init__()
        self.unet = unet

    def forward(self, x, t):
        t = t.unsqueeze(1).unsqueeze(2).unsqueeze(3).repeat(1, 1, x.size(2), x.size(3))  # Match t to x dimensions
        return self.unet(torch.cat([x, t], dim=1))

def forward_diffusion(x_0, t, noise_schedule):
    noise = torch.randn_like(x_0)
    t = t.long()  # Convert to long type
    alpha_t = noise_schedule[t].to(x_0.device).unsqueeze(1).unsqueeze(2).unsqueeze(3)
    return torch.sqrt(alpha_t) * x_0 + torch.sqrt(1 - alpha_t) * noise

def reverse_diffusion(x_t, t, model, noise_schedule):
    t = t.long()  # Ensure t is long
    beta_t = 1 - noise_schedule[t].to(x_t.device).unsqueeze(1).unsqueeze(2).unsqueeze(3)
    predicted_noise = model(x_t, t)
    return (x_t - beta_t * predicted_noise) / torch.sqrt(noise_schedule[t].to(x_t.device).unsqueeze(1).unsqueeze(2).unsqueeze(3))



In [None]:
# Function to visualize images
def visualize_images(inputs, targets, predictions, num_images=5):
    fig, axs = plt.subplots(num_images, 3, figsize=(12, 4 * num_images))
    for i in range(num_images):
        axs[i, 0].imshow(inputs[i].permute(1, 2, 0).cpu().numpy())
        axs[i, 0].set_title("Input Image")
        axs[i, 0].axis('off')

        axs[i, 1].imshow(targets[i].permute(1, 2, 0).cpu().numpy())
        axs[i, 1].set_title("Target Image")
        axs[i, 1].axis('off')

        axs[i, 2].imshow(predictions[i].permute(1, 2, 0).cpu().numpy())
        axs[i, 2].set_title("Predicted Image")
        axs[i, 2].axis('off')

    plt.show()

# Updated training loop with visualization
def train_model(model, dataloader, optimizer, num_epochs, noise_schedule):
    model.train()
    criterion = nn.MSELoss()

    for epoch in range(num_epochs):
        for i, x_0 in enumerate(dataloader):
            x_0 = x_0.to(device)
            t = torch.randint(0, len(noise_schedule), (x_0.size(0),)).to(device)
            x_t = forward_diffusion(x_0, t, noise_schedule)
            predicted_noise = model(x_t, t)

            # Compute target noise
            noise = (x_t - torch.sqrt(noise_schedule[t.long()]).to(device).unsqueeze(1).unsqueeze(2).unsqueeze(3) * x_0) / torch.sqrt(1 - noise_schedule[t.long()]).to(device).unsqueeze(1).unsqueeze(2).unsqueeze(3)

            # Calculate loss
            loss = criterion(predicted_noise, noise)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % 10 == 0:  # Visualize every 10 batches
                with torch.no_grad():
                    # Take a batch of images for visualization
                    sample_inputs = x_0[:5].cpu()
                    sample_targets = noise[:5].cpu()
                    sample_predictions = predicted_noise[:5].cpu()

                    # Visualize images
                    visualize_images(sample_inputs, sample_targets, sample_predictions)

            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}')

# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet = UNet(in_channels=4, out_channels=3).to(device)  # in_channels is 4 because we concatenate t
diffusion_model = DiffusionModel(unet).to(device)
optimizer = optim.Adam(diffusion_model.parameters(), lr=1e-4)
noise_schedule = torch.linspace(0.0001, 0.02, 1000).to(device)  # Example noise schedule

# Define the synthetic dataset and dataloader
dataset = SyntheticDataset(size=1000, img_size=64, num_classes=10)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Train the model
train_model(diffusion_model, dataloader, optimizer, num_epochs=10, noise_schedule=noise_schedule)

In [None]:
# Save the model state dict
torch.save(diffusion_model.state_dict(), 'diffusion_model.pth')

# Load the model state dict
loaded_unet = UNet(in_channels=4, out_channels=3).to(device)
loaded_diffusion_model = DiffusionModel(loaded_unet).to(device)
loaded_diffusion_model.load_state_dict(torch.load('diffusion_model.pth'))

In [None]:
# Check the current working directory
!pwd

In [None]:
# List the files in the current directory to verify the model is saved
!ls -l

In [None]:
import os
print(os.listdir('.'))  # List files in the current directory

In [None]:
# Instantiate the UNet and DiffusionModel
unet = UNet(in_channels=4, out_channels=3)
diffusion_model = DiffusionModel(unet)

# Load the model state dict
diffusion_model.load_state_dict(torch.load('diffusion_model.pth'))

# Set the model to evaluation mode
diffusion_model.eval()

In [None]:
# Save only the UNet state dict
torch.save(unet.state_dict(), 'unet_model.pth')

In [None]:
# Instantiate the UNet
unet = UNet(in_channels=4, out_channels=3)

# Load the UNet state dict
unet.load_state_dict(torch.load('unet_model.pth'))

# Set the model to evaluation mode
unet.eval()

In [None]:
import torchvision.transforms as transforms
from torchvision import models

# Load a pre-trained ResNet model
detection_model = models.resnet18(pretrained=True)
detection_model.fc = nn.Linear(detection_model.fc.in_features, 1000)  # Modify the final layer for your needs
detection_model.to(device)
detection_model.eval()

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [None]:
from google.colab import files
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as transforms

# Upload the image file
uploaded = files.upload()
image_path = list(uploaded.keys())[0]

# Define the transformation
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Adjust based on your model's input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the pre-trained models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Assume UNet and DiffusionModel classes are defined
# and the models are instantiated and loaded as per previous instructions

model = UNet(in_channels=4, out_channels=3).to(device)
diffusion_model = DiffusionModel(model).to(device)
diffusion_model.load_state_dict(torch.load('diffusion_model.pth'))
diffusion_model.eval()

# Assume detection_model is defined and loaded
# detection_model = ... (your object detection model)

def enhance_image(image_path, t=0):
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    # Enhance the image using the diffusion model
    with torch.no_grad():
        t_tensor = torch.tensor([t], dtype=torch.float32).to(device)
        enhanced_image = diffusion_model(image, t_tensor)

    # Convert the enhanced image back to PIL format
    enhanced_image = enhanced_image.squeeze().cpu().numpy().transpose(1, 2, 0)
    enhanced_image = (enhanced_image * 255).astype(np.uint8)
    return Image.fromarray(enhanced_image)

def detect_objects(image):
    # Apply the same transformations as before
    image = transform(image).unsqueeze(0).to(device)

    # Perform object detection
    with torch.no_grad():
        outputs = detection_model(image)

    # Process the outputs as needed
    return outputs

# Enhance the image
enhanced_image = enhance_image(image_path)
enhanced_image.show()  # Display the enhanced image

# Perform object detection
outputs = detect_objects(enhanced_image)

# Process the outputs (e.g., print the top 5 predictions)
_, preds = torch.max(outputs, 1)
print('Predicted class:', preds.item())


In [None]:
# Enhance the image
enhanced_image = enhance_image(image_path)

# Save the enhanced image
enhanced_image_path = "enhanced_image.jpg"
enhanced_image.save(enhanced_image_path)

# Display the enhanced image
enhanced_image.show()


In [None]:
from google.colab import files
files.download(enhanced_image_path)

In [None]:
import os
print(os.listdir())  # List all files in the current directory