In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from tqdm import tqdm

from sampler import Sampler
from dataset import SQGDataset

from diffusion_networks import SongUNet
from pathlib import Path
import torch.optim as optim
import csv

In [None]:
bs = 10
data_std = 2660
dataset = SQGDataset("data/SQG", mean=0, std=data_std)

loader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
etaz_loss = lambda z0, z1: z0
eta1_loss = lambda z0, z1: z1
b_loss = lambda z0, z1: z1 - z0 # All the samplers are for b.

target_fn = b_loss  # Choose the target function for the loss

model = SongUNet(img_resolution=64, in_channels=2, out_channels=2,
                embedding_type='fourier', encoder_type='residual', decoder_type='standard',
                channel_mult_noise=2, resample_filter=[1, 3, 3, 1], model_channels=32, channel_mult=[2, 2, 2],
                attn_resolutions=[32,]
                              )

def loss_fn(model, batch, target_fn):
    z0 = torch.randn_like(batch)
    z1 = batch
    
    t = torch.rand(batch.shape[0], device=batch.device)
    zt = (1 - t[:,None, None, None]) * z0 + t[:,None, None, None] * z1
    
    pred = model(zt, t)
    target = target_fn(z0, z1)
    loss =  torch.mean(0.5 * pred ** 2 - target * pred) #torch.mean(0.5*(pred - target) ** 2)
    
    return loss

def train(target_fn, device, loader, val_loader):

    print("Num params: ", sum(p.numel() for p in model.parameters()), flush=True)
    result_path = Path('results')
    result_path.mkdir(parents=True, exist_ok=True)

    learning_rate = 1e-3
    weight_decay = 1e-4
    num_epochs = 50

    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.001, end_factor=1.0, total_iters=1000)

    loss_values = []
    val_loss_values = []
    best_val_loss = float('inf')

    # Setup for logging
    log_file_path = result_path / f'training_log.csv'
    with open(log_file_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Epoch', 'Average Training Loss', 'Validation Loss'])

    # Training loop
    for epoch in range(num_epochs):
        
        # Training phase``
        model.train()
        total_train_loss = 0
        for image in tqdm(loader):
            image = image.to(device)
            
            optimizer.zero_grad()   
            loss = loss_fn(model, image, target_fn)
            total_train_loss += loss.item()

            loss.backward()
            optimizer.step()        
            warmup_scheduler.step()

        avg_train_loss = total_train_loss / len(loader)
        
        # Validation phase
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for image in tqdm(loader):
                image = image.to(device)
                
                loss = loss_fn(model, image, target_fn)
                total_val_loss += loss.item()
                    
            avg_val_loss = total_val_loss / len(val_loader)

        # Checkpointing
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), result_path/f'best_model.pth')
            
        scheduler.step()
        
        loss_values.append([avg_train_loss])
        val_loss_values.append(avg_val_loss)
        
        with open(log_file_path, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch+1, avg_train_loss, avg_val_loss])
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}', flush=True)

        torch.save(model.state_dict(), result_path/f'final_model.pth')

#train(target_fn, device, loader, loader)

In [None]:
model_path = "best_model.pth" # or your trained path
image_shape = (2, 64, 64)

members = 5

eps = lambda t: 0.1 * (1 - t)  # Noise when sampling
invert_eps = lambda t: 0. * (1 - t)  # Noise when inverting

steps = 100
invert_steps = 100

debug = True

sampler = Sampler(device, members, eps, steps, invert_eps, invert_steps, model_path, debug)

### Example on how to sample a physical state

In [None]:
## Random noise
z0 = torch.randn((members, *image_shape), device=device)

## All the same noise for testing epsilon
#z0 = torch.randn((1, *image_shape), device=device).repeat(members, 1, 1, 1)

z1, _ = sampler.sample(z0)

### Plotting

In [None]:
fig, ax = plt.subplots(2,3, figsize=(9,6))
cmap  =  plt.get_cmap('jet') #'jet' #
level = 0 # 1

def set_cbar(im):
    ax = im.axes
    # Create an inset axes for the colorbar above the plot
    cax = inset_axes(ax,
                        width="100%",   # relative to ax width
                        height="5%",   # relative to ax height
                        loc='upper center',
                        bbox_to_anchor=(0, 0.18, 1, 1),  # place above
                        bbox_transform=ax.transAxes,
                        borderpad=0)
    cbar = plt.colorbar(im, cax=cax, orientation='horizontal')
    cbar.ax.xaxis.set_ticks_position('top')
    cbar.ax.xaxis.set_label_position('top')

for i, ax in enumerate(ax.flat):
    ax.set_aspect('equal')
    ax.imshow((z1.cpu())[i, level], cmap=cmap)
    
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
plt.tight_layout()