In [None]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
from PIL import Image
import pandas as pd

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Diffusion process parameters
T = 100  # Number of diffusion steps
beta_min = 0.01  # Minimum value of beta (for noise)

# Create a diffusion process schedule
betas = np.linspace(0, beta_min, T)
betas = torch.tensor(betas, dtype=torch.float32)

# Define a simple generator and discriminator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = nn.Linear(64, 784)  # Example: Mapping from noise vector to image space

    def forward(self, z, t):
        noise = torch.randn_like(z)
        x = (1 - t.view(-1, 1, 1, 1)) * self.fc(z) + t.view(-1, 1, 1, 1) * noise
        return x

# Define a simple discriminator network (example)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(784, 128),  # Example: Mapping from image space to a binary classification
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.fc(x)

# Define the U-Net architecture for segmentation
class UNet(nn.Module):
    # U-Net architecture definition (same as before)
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2)
        )

    def forward(self, x):
        x1 = self.encoder(x)
        x2 = self.decoder(x1)
        return x2

# Initialize the U-Net segmentation model

num_classes = 10  # 예시: 클래스 수에 맞게 설정

in_channels = 3  # Input channels (e.g., for RGB images)
out_channels = num_classes  # Number of segmentation classes
segmentation_model = UNet(in_channels, out_channels)

# Define segmentation loss function (e.g., cross-entropy loss)
segmentation_criterion = nn.CrossEntropyLoss()

# Define optimizer for segmentation model
segmentation_optimizer = torch.optim.Adam(segmentation_model.parameters(), lr=0.0001)

# Define the reverse diffusion loss (MSE loss)
reverse_diffusion_criterion = nn.MSELoss()

# Training loop
# Training parameters
num_epochs = 100  # 적절한 에폭 수로 설정
print_interval = 10  # 손실 출력 간격 설정

# Training loop
for epoch in range(num_epochs):
    total_loss = 0.0  # 손실을 누적할 변수 초기화

    for step, (real_images, labels) in enumerate(data_loader):
        real_images = real_images.view(real_images.size(0), -1)
        batch_size = real_images.size(0)
        z = torch.randn(batch_size, 64)

        for t in range(T):
            current_beta = betas[t]

            # Forward Diffusion process: Add noise to the initial image
            noise = torch.randn_like(real_images) * (current_beta ** 0.5)
            x_t = real_images + noise

            # Segmentation
            segmentation_output = segmentation_model(x_t)

            # Calculate segmentation loss
            segmentation_loss = segmentation_criterion(segmentation_output, labels)

            # Update segmentation model
            segmentation_optimizer.zero_grad()
            segmentation_loss.backward()
            segmentation_optimizer.step()

            # Reverse Diffusion process: Remove noise from the image
            noise = torch.randn_like(x_t) * (current_beta ** 0.5)
            x_t = x_t - noise

            # Calculate Reverse Diffusion loss (MSE loss)
            reverse_diffusion_loss = reverse_diffusion_criterion(x_t, real_images)

            # Update the segmentation model using the reverse diffusion optimizer
            segmentation_optimizer.zero_grad()  # Clear gradients from segmentation model
            reverse_diffusion_loss.backward()
            segmentation_optimizer.step()  # Update segmentation model parameters

        # 손실을 누적
        total_loss += segmentation_loss.item() + reverse_diffusion_loss.item()

    # 평균 손실 계산
    average_loss = total_loss / len(data_loader)

    # 일정 간격으로 손실 출력
    if (epoch + 1) % print_interval == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}] Loss: {average_loss:.4f}")

# 학습 완료 후에도 손실 출력
print(f"Training completed. Final loss: {average_loss:.4f}")

# Inference loop for segmentation
segmentation_results = []
for _ in range(num_samples):
    z = torch.randn(1, 64)
    x_t = torch.zeros(1, 3, 128, 128)  # Adjust dimensions and channels as needed
    for t in range(T):
        current_beta = betas[t]
        noise = torch.randn_like(x_t) * (current_beta ** 0.5)
        x_t = x_t + noise
    # Perform segmentation on x_t using the trained segmentation model
    segmentation_map = segmentation_model(x_t)
    segmentation_results.append(segmentation_map)

# Visualize or analyze segmentation results

# 시각화 함수 추가
def visualize_segmentation(image, segmentation_map, title="Segmentation"):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

    ax1.imshow(image)
    ax1.set_title("Input Image")

    ax2.imshow(segmentation_map, cmap="viridis")  # 적절한 컬러맵 사용
    ax2.set_title(title)

    plt.show()

# 결과 시각화
for i, segmentation_map in enumerate(segmentation_results):
    input_image = data_loader[i][0]  # 입력 이미지
    visualize_segmentation(input_image, segmentation_map, title=f"Segmentation Result {i+1}")