In [None]:
import torch
from matplotlib import pyplot as plt
from the_well.data import WellDataset
from torch.utils.data import DataLoader
import numpy as np

from dataclasses import dataclass
from datasets import load_dataset

from torchvision import transforms
import torch

from accelerate import Accelerator
from huggingface_hub import create_repo, upload_folder
from tqdm.auto import tqdm
from pathlib import Path
import os
import torch.nn.functional as F
from matplotlib import pyplot as plt
from accelerate import notebook_launcher
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers import UNet2DModel
from diffusers import DDPMScheduler
from diffusers import DDPMPipeline
from diffusers.utils import make_image_grid
import os
from PIL import Image
import glob
import numpy as np
from torchvision.transforms import InterpolationMode

In [None]:
@dataclass
class TrainingConfig:
    image_size = 128  
    train_batch_size = 16
    eval_batch_size = 16  
    num_epochs = 1000
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 10
    mixed_precision = "fp16" 
    output_dir = "ddpm-active-matter-128"  

    push_to_hub = False 
    overwrite_output_dir = True  
    seed = 42
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



config = TrainingConfig()


In [None]:
plt.plot(dataset[1]['X'])
plt.show()


In [None]:
dataset = WellDataset(
    well_base_path="./../the_well/datasets/",
    well_dataset_name="active_matter",
    well_split_name="train"
)

train_dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)
avg_pool = torch.nn.AvgPool2d(kernel_size=2, stride=2)


In [None]:
len(dataset)

In [None]:

random_samples = [dataset[i] for i in torch.randint(len(dataset), (4,))]

fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for i, sample in enumerate(random_samples):
    image = sample['input_fields'].squeeze(0).permute(2, 0, 1)[0, :, :].unsqueeze(0)
    
    image = avg_pool(image)
    
    im = axs[i].imshow(image.squeeze(), cmap='viridis')
    axs[i].set_axis_off()
    fig.colorbar(im, ax=axs[i])


plt.tight_layout()
plt.show()

In [None]:
from diffusers import UNet2DModel
from diffusers import UNet2DConditionModel

model = UNet2DModel(
    sample_size=config.image_size, 
    in_channels=1, 
    out_channels=1,  
    layers_per_block=2,  
    block_out_channels=(128, 128, 256, 256, 512, 512),  
    down_block_types=(
        "DownBlock2D",  
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  
        "AttnUpBlock2D",  
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

noise_scheduler = DDPMScheduler(num_train_timesteps=1000, 
                                clip_sample=False)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),
)



In [None]:
if config.resume_from_checkpoint:
    checkpoint_dir = os.path.join(config.output_dir, f"checkpoint-{epoch_start}")
    accelerator.load_state(checkpoint_dir)

In [None]:
loss_history = []
epoch_start = 0


def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):

    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )
    if accelerator.is_main_process:
        os.makedirs(config.output_dir, exist_ok=True)
        accelerator.init_trackers("train_example")

    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )


    global_step = 0
    #loss_history = []  # List to store periodic loss values

    for epoch in range(epoch_start, epoch_start + config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process, desc=f"Epoch {epoch}")
        epoch_loss = 0.0  
        num_batches = 0   
        
        for step, batch in enumerate(train_dataloader):
            images = avg_pool(batch['input_fields'].squeeze(1).permute(0, 3, 1, 2)[:, 0, :, :].unsqueeze(1))

            noise = torch.randn(images.shape, device=images.device)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (images.shape[0],), device=images.device, dtype=torch.int64)
            noisy_images = noise_scheduler.add_noise(images, noise, timesteps)
            
            with accelerator.accumulate(model):
                noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
                
                loss = F.mse_loss(noise_pred, noise)
                
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

            epoch_loss += loss.item()
            num_batches += 1

            # Save loss value every 10 batches
            if step % 10 == 0:
                loss_history.append(loss.item())

        if accelerator.is_main_process:
            pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)

            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                evaluate(config, epoch, pipeline)
                
                # Plot and save loss graph
                plt.figure(figsize=(10, 5))
                plt.plot(loss_history)
                plt.yscale('log')  # Set y-axis to logarithmic scale
                plt.title('Training Loss')
                plt.xlabel('Steps (x10)')
                plt.ylabel('Loss (log scale)')
                plt.savefig(os.path.join(config.output_dir, "samples", "loss_graph.png"))
                plt.close()
                save_dir = os.path.join(config.output_dir, "samples", "checkpoint-{epoch}")
                accelerator.save_state(save_dir)

            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                pipeline.save_pretrained(config.output_dir)

    accelerator.end_training()

def evaluate(config, epoch, pipeline):
    pipeline.unet.eval()
    
    # Generate images
    images = pipeline(
        batch_size=config.eval_batch_size,
        generator=torch.Generator(device='cpu').manual_seed(config.seed),
        output_type="tensor"
    ).images
    
    # Create a matplotlib figure with 4x4 grid
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(4, 4, figsize=(12, 12))
    
    for i, img in enumerate(images):
        row = i // 4
        col = i % 4
        
        if isinstance(img, torch.Tensor):
            img = img.cpu().squeeze().numpy()  
        
        axes[row, col].imshow(img) 
        axes[row, col].axis('off')
        plt.colorbar(axes[row, col].images[0], ax=axes[row, col]) 
    
    plt.tight_layout()
    
    test_dir = os.path.join(config.output_dir, "samples")
    os.makedirs(test_dir, exist_ok=True)
    
    plt.savefig(f"{test_dir}/{(epoch):04d}.png", dpi=150, bbox_inches='tight')
    plt.close(fig)  
    return images


In [None]:
device = config.device
args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
loss_history = notebook_launcher(train_loop, args, num_processes=1)


# VALIDATION

In [None]:
from diffusers import DiffusionPipeline
import json
import torch
import time
import importlib
import os
import custom_scheduler_active_matter
importlib.reload(custom_scheduler_active_matter)
import sys 
sys.path.append(os.path.abspath('..'))

import custom_scheduler_base
importlib.reload(custom_scheduler_base)

pipe = DiffusionPipeline.from_pretrained('./ddpm-active-matter-128', use_safetensors=True).to("cuda")

# Load loss history from file 
with open(os.path.join(config.output_dir, 'loss_history.json'), 'r') as f:
    loss_history = json.load(f)


def sample_arbitrary(config, pipeline, num_inference_steps=1000):
    images = torch.randn((16, 1, 128, 128), device=config.device)
    
    pipeline.scheduler.set_timesteps(num_inference_steps)
    progress_bar = tqdm(total=num_inference_steps, desc=f"yoyoyo")

    pipeline.unet.eval()
    pipeline.unet.to(config.device)

    for t in pipeline.scheduler.timesteps:
        timestep = torch.full((1,), t, device=config.device, dtype=torch.long)
        with torch.no_grad():
            noise_pred = pipeline.unet(images, timestep).sample
        
        images = pipeline.scheduler.step(noise_pred, t, images).prev_sample
        progress_bar.update(1)
    pipeline.unet.train()
    return images.detach().cpu()

In [None]:

pipe.scheduler = custom_scheduler_active_matter.CustomScheduler(clip_sample=False,
                                                                sample_max_value = 2.0,
                                                                clip_sample_range = 2.0)
sample_conditional = sample_arbitrary(config, pipe)

pipe.scheduler = custom_scheduler_base.CustomSchedulerBase(clip_sample=False,
                                                            sample_max_value = 2.0,
                                                            clip_sample_range = 2.0)
sample_unconditional = sample_arbitrary(config, pipe)


In [None]:
# Take 4 samples from each type
samples_conditional = sample_conditional[:4].detach().clone()
samples_unconditional = sample_unconditional[:4].detach().clone()
random_samples = [dataset[i] for i in torch.randint(len(dataset), (4,))]

fig, axs = plt.subplots(3, 4, figsize=(20, 12))

for row, (samples, title) in enumerate([
    (samples_conditional, "Conditional Samples"),
    (samples_unconditional, "Unconditional Samples"),
    (random_samples, "Dataset Samples")
]):
    # Add text to the left of the plots
    fig.text(0.05, 0.85 - row*0.33, title, rotation=90, verticalalignment='center', fontsize=15)
    
    for i, sample in enumerate(samples):
        if row < 2:  # For conditional and unconditional samples
            image_np = sample.detach().cpu().numpy().squeeze()
        else:  # For dataset samples
            image = sample['input_fields'].squeeze(0).permute(2, 0, 1)[0, :, :].unsqueeze(0)
            image_np = image.detach().cpu().numpy().squeeze()
        
        im = axs[row, i].imshow(image_np, cmap='viridis')
        axs[row, i].set_axis_off()
        fig.colorbar(im, ax=axs[row, i], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.subplots_adjust(left=0.07)  # Adjust left margin to make room for text
plt.show()

In [None]:
(10.0, 1.0, -5.0): 400 occurrences
(10.0, 1.0, -4.0): 320 occurrences
(10.0, 1.0, -3.0): 240 occurrences
(10.0, 1.0, -2.0): 320 occurrences
(10.0, 1.0, -1.0): 240 occurrences
(10.0, 3.0, -5.0): 400 occurrences
(10.0, 3.0, -4.0): 320 occurrences
(10.0, 3.0, -3.0): 240 occurrences
(10.0, 3.0, -2.0): 320 occurrences
(10.0, 3.0, -1.0): 400 occurrences
(10.0, 5.0, -5.0): 320 occurrences
(10.0, 5.0, -4.0): 320 occurrences
(10.0, 5.0, -3.0): 240 occurrences
(10.0, 5.0, -2.0): 320 occurrences
(10.0, 5.0, -1.0): 240 occurrences
(10.0, 7.0, -5.0): 320 occurrences
(10.0, 7.0, -4.0): 400 occurrences
(10.0, 7.0, -3.0): 240 occurrences
(10.0, 7.0, -2.0): 240 occurrences
(10.0, 7.0, -1.0): 400 occurrences
(10.0, 9.0, -5.0): 400 occurrences
(10.0, 9.0, -4.0): 320 occurrences
(10.0, 9.0, -3.0): 320 occurrences
(10.0, 9.0, -2.0): 400 occurrences
(10.0, 9.0, -1.0): 240 occurrences
(10.0, 11.0, -5.0): 400 occurrences
(10.0, 11.0, -4.0): 80 occurrences
(10.0, 11.0, -3.0): 240 occurrences
(10.0, 11.0, -2.0): 240 occurrences
(10.0, 11.0, -1.0): 320 occurrences
(10.0, 13.0, -5.0): 320 occurrences
(10.0, 13.0, -4.0): 400 occurrences
(10.0, 13.0, -3.0): 240 occurrences
(10.0, 13.0, -2.0): 320 occurrences
(10.0, 13.0, -1.0): 320 occurrences
(10.0, 15.0, -5.0): 160 occurrences
(10.0, 15.0, -4.0): 320 occurrences
(10.0, 15.0, -3.0): 400 occurrences
(10.0, 15.0, -2.0): 320 occurrences
(10.0, 15.0, -1.0): 320 occurrences
(10.0, 17.0, -5.0): 400 occurrences
(10.0, 17.0, -4.0): 240 occurrences
(10.0, 17.0, -3.0): 240 occurrences
(10.0, 17.0, -2.0): 400 occurrences
(10.0, 17.0, -1.0): 400 occurrences

In [None]:
dataset_test = WellDataset(
    well_base_path="./../the_well/datasets/",
    well_dataset_name="active_matter",
    well_split_name="test"
)


test_dataloader = DataLoader(dataset_test, batch_size=config.train_batch_size, shuffle=True)
avg_pool = torch.nn.AvgPool2d(kernel_size=2, stride=2)

"""
train_dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)
#test_dataloader = DataLoader(dataset_test, batch_size=config.train_batch_size, shuffle=True)
avg_pool = torch.nn.AvgPool2d(kernel_size=2, stride=2)

batch = next(iter(train_dataloader))
images = avg_pool(batch['input_fields'].squeeze(1).permute(0, 3, 1, 2)[:, 0, :, :].unsqueeze(1)) # (16, 1, 128, 128)
target = batch['constant_scalars'][:, 1].to(torch.int64)
"""

# CNN MODEL FOR FID

In [None]:

# Generate samples from both models and save them
import os

# Create temp directory if it doesn't exist
os.makedirs("temp", exist_ok=True)

# Generate and save conditional samples
pipe.scheduler = custom_scheduler_active_matter.CustomScheduler(clip_sample=False,
                                                                sample_max_value = 2.0,
                                                                clip_sample_range = 2.0)
conditional_samples = sample_arbitrary(config, pipe)
torch.save(conditional_samples, "temp/conditional_samples.pt")

# Generate and save unconditional samples 
pipe.scheduler = custom_scheduler_base.CustomSchedulerBase(clip_sample=False,
                                                            sample_max_value = 2.0,
                                                            clip_sample_range = 2.0)
unconditional_samples = sample_arbitrary(config, pipe)
torch.save(unconditional_samples, "temp/unconditional_samples.pt")

In [None]:
import torch
import torchvision.transforms as transforms
from torchmetrics.image.fid import FrechetInceptionDistance
from the_well.data import WellDataset
from torch.utils.data import DataLoader
dataset_test = WellDataset(
    well_base_path="./../the_well/datasets/",
    well_dataset_name="active_matter",
    well_split_name="test"
)


test_dataloader = DataLoader(dataset_test, batch_size=config.train_batch_size, shuffle=True)
avg_pool = torch.nn.AvgPool2d(kernel_size=2, stride=2)


# Load your generated samples
conditional_samples = torch.load("temp/conditional_samples.pt")  # shape: (16, 1, 128, 128)
unconditional_samples = torch.load("temp/unconditional_samples.pt")

# Get real samples from the test dataloader
real_batch = next(iter(test_dataloader))
real_samples = avg_pool(real_batch['input_fields'].squeeze(1).permute(0, 3, 1, 2)[:, 0, :, :].unsqueeze(1))  # shape: (16, 1, 128, 128)

# Define a transform to upsample and convert 1-channel to 3-channel
transform = transforms.Compose([
    transforms.Resize((299, 299)),                     # Required for Inception-v3
    transforms.Lambda(lambda x: x.repeat(3, 1, 1))     # Convert 1-channel to 3-channel
])

# Apply transformation
def prepare_images(batch):
    # batch shape: (B, 1, 128, 128)
    processed = []
    for img in batch:
        # img: (1, 128, 128)
        img_resized = transforms.functional.resize(img, (299, 299))  # Resize
        img_rgb = img_resized.repeat(3, 1, 1)  # Convert to 3-channel
        processed.append(img_rgb)
    return torch.stack(processed)


real_prepared = prepare_images(real_samples)
cond_prepared = prepare_images(conditional_samples)
uncond_prepared = prepare_images(unconditional_samples)

# Initialize FID metric
fid = FrechetInceptionDistance(normalize=True)

# Compute FID for conditional
fid.update(real_prepared, real=True)
fid.update(cond_prepared, real=False)
fid_conditional = fid.compute()
fid.reset()

# Compute FID for unconditional
fid.update(real_prepared, real=True)
fid.update(uncond_prepared, real=False)
fid_unconditional = fid.compute()

print("FID (Conditional):", fid_conditional.item())
print("FID (Unconditional):", fid_unconditional.item())


In [None]:
import torch

def compute_snr(real, generated):
    """
    Computes Signal-to-Noise Ratio (SNR) between real and generated images.
    Both tensors should have shape: (B, C, H, W)
    """
    # Ensure both inputs are float32
    real = real.float()
    generated = generated.float()
    
    # Compute signal power and noise power
    signal_power = torch.mean(real ** 2, dim=[1, 2, 3])
    noise_power = torch.mean((real - generated) ** 2, dim=[1, 2, 3])
    
    # Avoid division by zero
    snr = 10 * torch.log10(signal_power / (noise_power + 1e-8))
    
    # Return average SNR across batch
    return snr.mean().item()

# Assume real_samples and generated_samples are (B, 1, 128, 128)
snr_conditional = compute_snr(real_samples, conditional_samples)
snr_unconditional = compute_snr(real_samples, unconditional_samples)

print("SNR (Conditional):", snr_conditional)
print("SNR (Unconditional):", snr_unconditional)

In [None]:
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure

# Initialize SSIM metric
ssim_metric = StructuralSimilarityIndexMeasure()  # data_range=1.0 assumes images in [0, 1]

# Compute SSIM for conditional and unconditional samples
ssim_conditional = ssim_metric(conditional_samples, real_samples).item()
ssim_unconditional = ssim_metric(unconditional_samples, real_samples).item()

print("SSIM (Conditional):", ssim_conditional)
print("SSIM (Unconditional):", ssim_unconditional)


In [None]:
# CNN Classification Model for the 2D concentration dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Define the model
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=17):
        super(SimpleCNN, self).__init__()
        # Input: (batch_size, 1, 128, 128)
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Fully connected layers
        self.fc1 = nn.Linear(64 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, num_classes)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x

model = SimpleCNN(num_classes=5)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)
def remap_labels(original_labels):
    # Convert 1,3,5,...,17 to 0,1,2,...,8
    return (original_labels - 1) // 2
def remap_labels2(original_labels):
    return -(original_labels + 1) 
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch in dataloader:
        # Get the inputs and targets
        inputs = batch['input_fields'].squeeze(1).permute(0, 3, 1, 2)[:, 0, :, :].unsqueeze(1)
        inputs = torch.nn.functional.avg_pool2d(inputs, kernel_size=2, stride=2)
        inputs = inputs.to(device)
        
        targets = remap_labels2(batch['constant_scalars'][:, 2].to(torch.int64))
        targets = targets.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)

        correct += (predicted == targets).sum().item()
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

num_epochs = 20
train_losses = []
train_accuracies = []

for epoch in range(num_epochs):
    loss, acc = train_epoch(model, train_dataloader, criterion, optimizer, device)
    train_losses.append(loss)
    train_accuracies.append(acc)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}, Accuracy: {acc:.2f}%')

torch.save(model.state_dict(), "cnn_classifier.pth")
print("Model saved to cnn_classifier.pth") 

# Plot training progress
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.subplot(1, 2, 2)
plt.plot(train_accuracies)
plt.title('Training Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.tight_layout()
plt.show()


