<a href="https://colab.research.google.com/github/yasin-arkan/waveform_diff/blob/main/diffusion_flow_matching.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/gdrive')




Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [2]:
%cd ..


/


In [3]:
%cd gdrive/MyDrive/diffusion

/gdrive/MyDrive/diffusion


## Preparing the data for training

In [4]:
import torch
import torchaudio
import torch.nn.functional as F
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader
from geopy import distance

import librosa


N_FFT = 256
HOP_LENGTH = 64
WIN_LENGTH = N_FFT

file_path = "data/timeseries_EW.csv"

def load_data(path, n_fft, hop_length, win_length, batch_size=16):

    data = pd.read_csv(path)

    cond_cols = ['Depth', 'Magnitude']
    cond_vars = data[cond_cols].copy()

    cond_data = []
    norm_dict = {}

    dist = source_site_distance(data)
    cond_vars.loc[:, ('Distance')] = dist
    cond_cols.append('Distance')

    angle = compute_angle(data)
    cond_vars.loc[:, ('Angle')] = angle
    cond_cols.append('Angle')

    for cvar in cond_cols:
        cv = cond_vars[cvar].to_numpy()
        cv = cv.reshape(cv.shape[0], 1)

        cv_mean, cv_std = cv.mean(), cv.std()
        norm_dict[cvar] = [cv_mean, cv_std]

        cv = (cv - cv_mean) / cv_std

        cond_data.append(cv)


    wfs = data.iloc[:, 16:].to_numpy()

    orig_wfs = wfs.copy()

    wfs = librosa.stft(wfs, n_fft=n_fft, hop_length=hop_length, win_length=win_length)

    wfs = np.abs(wfs)

    print(wfs.shape)
    wfs = torch.from_numpy(wfs).float()  # [1183, 128, 110]

    print("Before padding:", wfs.shape)
    current_time_dim = wfs.shape[2] # Should be 110, we will pad it to 128
    padding_needed = 128 - current_time_dim
    time_padding = (0, padding_needed)
    wfs = F.pad(wfs, time_padding, mode='constant', value=0)

    wfs = wfs[:, :64, :]
    print("After padding:" ,wfs.shape) # Now it is [1183, 64, 128]


    # We get the length for now and reshape the wfs to squeeze last 2 dimensions,
    # so we can normalize them
    length, x, y = wfs.shape

    wfs = wfs.reshape((length, -1)) # wfs = [1183, 8192]

    wfs_mean, wfs_std = wfs.mean(), wfs.std()
    wfs = (wfs - wfs_mean) / wfs_std

    wfs = wfs.reshape((length, x, y))

    cond_var = np.concatenate(cond_data, axis=1)
    cond_var = torch.from_numpy(cond_var)


    train_dataset = STFTDataset(wfs[:1120, :, :], cond_var[:1120, :])
    val_dataset = STFTDataset(wfs[1120:, :, :], cond_var[1120:, :])
    all_dataset = STFTDataset(wfs, cond_var)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    all_dataloader = DataLoader(all_dataset,batch_size=batch_size, shuffle=False)

    print("Waveform shape:", train_dataset.wfs.shape)
    print("Cond shape:", train_dataset.cond_var.shape)

    print("Waveform shape (val):", val_dataset.wfs.shape)
    print("Cond shape (val):", val_dataset.cond_var.shape)

    return train_dataset, train_dataloader, val_dataset, val_dataloader, all_dataset, all_dataloader, wfs_mean, wfs_std, norm_dict, orig_wfs, cond_vars




class STFTDataset(Dataset):
    def __init__(self, wfs, cond_var):
        # Ensure data is already scaled and in tensor format
        self.wfs = wfs
        self.cond_var = cond_var

    def __len__(self):
        return self.wfs.shape[0]

    def __getitem__(self, idx):
        return self.wfs[idx], self.cond_var[idx]



def source_site_distance(info):
    '''compute the source-site distance'''
    # reference: https://geopy.readthedocs.io/en/stable/#module-geopy.distance
    # src_loc: (lat, long)

    # source location: (lat, long)
    src_lat = info.EventLat.to_numpy() # [#sample,]
    src_lon = info.EventLon.to_numpy()

    # station location: (lat, long)
    station_lat = info.StationLat.to_numpy()
    station_lon = info.StationLon.to_numpy()

    # calculate the source-site distance
    dist = []
    for i in range(len(src_lat)):
        dist_val = distance.distance((src_lat[i], src_lon[i]), (station_lat[i], station_lon[i])).km
        dist.append(dist_val)

    dist = np.array(dist).reshape(len(dist),1)

    return dist



def compute_angle(info):

    '''
    Calculate the angle between source centers and station locations
    '''

    # source location: (lat, long)
    src_lat = info['EventLat'].to_numpy() # [#samples,]
    src_lon = info['EventLon'].to_numpy()

    # station location: (lat, long)
    station_lat = info['StationLat'].to_numpy()
    station_lon = info['StationLon'].to_numpy()

    # calculate the source-site angle
    angle = []
    for i in range(len(src_lat)):
        src_coord = (src_lat[i], src_lon[i])
        station_coord = (station_lat[i], station_lon[i])

        # Calculate the vector from the source center to the station location
        vector = np.array(station_coord) - np.array(src_coord)

        # Calculate the angle between the vectors and the x-axis (east direction)
        angle_rad = np.arctan2(vector[1], vector[0])  # Angle in radians

        # Convert the angle to degrees
        angle_deg = np.degrees(angle_rad)

        angle.append(angle_deg)

    angle = np.array(angle).reshape(len(angle),1)

    return angle





train_dataset, train_dataloader, val_dataset, val_dataloader, all_dataset, all_dataloader, wfs_mean, wfs_std, norm_dict, orig_wfs, cond_vars = load_data(file_path,
                                                                                                                                                          N_FFT,
                                                                                                                                                          HOP_LENGTH,
                                                                                                                                                          WIN_LENGTH,
                                                                                                                                                          batch_size=16)


def plot_sample_stft(index):

  spectrogram = train_dataset.wfs[index].cpu().numpy()

  mean = wfs_mean.numpy()
  std = wfs_std.numpy()

  spectrogram = (spectrogram * std) + mean

  tf = librosa.griffinlim(spectrogram, n_iter=512)

  fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))


  ax1.set_title(f"Waveform for sample {index}")
  ax1.set_xlabel("Time (seconds)")
  ax1.set_ylabel("Amplitude")
  ax1.plot(tf)


  print(spectrogram.shape)
  ax2.imshow(spectrogram, aspect='auto')
  ax2.set_xlabel('Time Frame')
  ax2.set_ylabel('Frequency Bin')
  ax2.set_title(f"Spectrogram for Sample {index}")


  ax3.set_title("Original waveform")
  orig = orig_wfs[index]
  ax3.plot(orig)

  plt.show()



# plot_sample_stft(333)

(1183, 129, 110)
Before padding: torch.Size([1183, 129, 110])
After padding: torch.Size([1183, 64, 128])
Waveform shape: torch.Size([1120, 64, 128])
Cond shape: torch.Size([1120, 4])
Waveform shape (val): torch.Size([63, 64, 128])
Cond shape (val): torch.Size([63, 4])


In [5]:
%pip install flow_matching



In [6]:
# flow_matching
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.path import AffineProbPath
from flow_matching.solver import Solver, ODESolver
from flow_matching.utils import ModelWrapper

path = AffineProbPath(scheduler=CondOTScheduler())


## Config, model and training loop

In [7]:
from diffusers import UNet2DModel, UNet2DConditionModel
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet2DConditionModel(
    sample_size=(64, 128),
    in_channels=1,
    out_channels=1,
    layers_per_block=2, # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channes for each UNet block
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "CrossAttnDownBlock2D", # Use CrossAttn blocks where you want to inject conditioning
        "CrossAttnDownBlock2D",
    ),
    up_block_types=(
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),

    cross_attention_dim=512,
)


cond_emb = nn.Sequential(
    nn.Linear(4, 64),
    nn.ReLU(),
    nn.Linear(64, 256),
    nn.ReLU(),
    nn.Linear(256, 512)
).to(device)


In [8]:
class WrappedModel(ModelWrapper):

    def __init__(self, model, num_train_timesteps):
        super().__init__(model)
        self.num_train_timesteps = num_train_timesteps

    def forward(self, x: torch.Tensor, t: torch.Tensor, h: torch.Tensor):
        model_input_t = (t * (self.num_train_timesteps - 1)).long()
        return self.model(x, model_input_t, h).sample

In [9]:
from diffusers import DDPMPipeline, UNet2DConditionModel

class CustomDDPMPipeline(DDPMPipeline):
    def __init__(self, unet, scheduler):
        super().__init__(unet, scheduler)

    def __call__(self, batch_size=1, generator=None, num_inference_steps=None, output_type="pil", return_dict=True, encoder_hidden_states=None):

        # Sample gaussian noise to begin loop
        image = torch.randn(
            (batch_size, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1]),
            generator=generator,
            )

        image = image.to(self.device)
        wrapped_unet = WrappedModel(self.unet, self.scheduler.config.num_train_timesteps)

        solver = ODESolver(velocity_model=wrapped_unet)

        T = torch.linspace(0, 1, 10)
        T = T.to(self.device)


        solver_args = {'h': encoder_hidden_states}
        intermediate_samples = solver.sample(time_grid=T, x_init=image, method='midpoint', step_size=0.05, return_intermediates=True, **solver_args)
        image = intermediate_samples[-1]

        # image = (image / 2 + 0.5).clamp(0, 1)

        if not return_dict:
           return (image,), intermediate_samples, T

        return (image,), intermediate_samples, T

In [40]:
def evaluate(config, epoch, pipeline, hidden, original):
    # Sample some images from random noise (this is the backward diffusion process).
    # The default pipeline output type is `List[PIL.Image]`

    # Output of the CustomDDPMPipeline above
    images, samples, T = pipeline(
        batch_size = config.eval_batch_size,
        generator=torch.manual_seed(config.seed),
        output_type="numpy",
        return_dict=False,
        encoder_hidden_states=hidden
    )

    # Preparing the folders
    arrs_dir =  os.path.join(config.output_dir, "arrays")
    images_dir = os.path.join(config.output_dir, "images")
    os.makedirs(arrs_dir, exist_ok=True)
    os.makedirs(images_dir, exist_ok=True)
    eval_batch_size = config.eval_batch_size


    img_array = images[0].cpu().squeeze()


    # Saving the output as an array here
    np.save(os.path.join(arrs_dir, f"epoch_{epoch:04d}_generated_sample.npy"), img_array)


    # Plotting the original and the output of the pipeline for comparison
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

    original = original.cpu().numpy().squeeze()

    ax1.set_title(f"Original Sample, Epoch {epoch}")
    ax1.set_xlabel('Time Frame')
    ax1.set_ylabel('Frequency Bin')
    ax1.imshow(original, aspect='auto')

    ax2.set_title(f"Generated Sample, Epoch {epoch}")
    ax2.set_xlabel('Time Frame')
    ax2.set_ylabel('Frequency Bin')
    ax2.imshow(img_array, aspect='auto')

    plt.tight_layout()

    plt.savefig(os.path.join(images_dir, f"epoch_{epoch:04d}_orig_vs_gen.png"))
    plt.close()


    # Plotting the generated image at each timestep (Optional) (Gaussian noise at t=0, final generated image at t=1)
    # fig, axs = plt.subplots(2, 5,figsize=(20,10))

    # samples = samples.cpu()

    # axs = axs.flatten()

    # for i, s in enumerate(samples):
    #   s = s.squeeze()
    #   axs[i].imshow(s, aspect='auto')
    #   axs[i].axis('off')
    #   axs[i].set_title(f't= %.2f' % (T[i]))

    # plt.tight_layout()

    # plt.savefig(os.path.join(images_dir, f"epoch_{epoch:04d}_path_timesteps.png"))
    # plt.close()


    # print(f"Saved {eval_batch_size} sample images to {images_dir}")



In [41]:
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from huggingface_hub import create_repo, upload_folder
from diffusers import DDPMPipeline

from tqdm.auto import tqdm
from pathlib import Path
import os
import random

def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, val_dataset, lr_scheduler):
    # Initialize accelerator and tensorboard logging
    logging_dir = os.path.join(config.output_dir, "logs")
    accelerator_project_config = ProjectConfiguration(project_dir=config.output_dir, logging_dir=logging_dir)
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_config=accelerator_project_config,
    )
    if accelerator.is_main_process:
        if config.push_to_hub:
            repo_id = create_repo(
                repo_id=Path(config.output_dir).name, exist_ok=True
            ).repo_id
        elif config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        accelerator.init_trackers("train_example")

    # Prepare everything
    # There is no specific order to remember, you just need to unpack the
    # objects in the same order you gave them to the prepare method.
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    global_step = 0

    id = random.randint(0, len(val_dataset) - 1)

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):

            magnitudes, conds = batch

            x_1 = magnitudes.to(accelerator.device).unsqueeze(1)
            conds = conds.to(accelerator.device).float()
            conds = cond_emb(conds)
            conds = conds.unsqueeze(1)

            # Sample noise to add to the images
            x_0 = torch.randn_like(x_1).to(x_1.device)
            bs = x_1.shape[0]

            # Sample a random timestep for each image
            t = torch.rand(x_1.shape[0]).to(x_1.device)
            path_sample = path.sample(t=t, x_0=x_0, x_1=x_1)

            # x_1 and x_0 : [batch_size, channels, width, height] - [16, 1, 64, 128]
            # t : [batch_size] - [16]

            with accelerator.accumulate(model):
                # Predict the noise residual
                # print("T:",path_sample.t)
                # print("T_scaled", path_sample.t * noise_scheduler.config.num_train_timesteps - 1)

                # Scaling the timesteps for UNet2DConditionModel
                model_input_t = (path_sample.t * (noise_scheduler.config.num_train_timesteps - 1)).long()

                predicted_velocity = model(path_sample.x_t, model_input_t, conds).sample
                loss = torch.pow(predicted_velocity - path_sample.dx_t, 2).mean()

                accelerator.backward(loss)
                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        # After each epoch you optionally sample some demo images with evaluate() and save the model
        if accelerator.is_main_process:
            with torch.no_grad():
              pipeline = CustomDDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)

              wfs_val = val_dataset.wfs
              cond_val = val_dataset.cond_var

              sample_wfs = wfs_val[id].to(accelerator.device)

              sample_cond = cond_val[id].to(accelerator.device).float()
              sample_cond = cond_emb(sample_cond)
              sample_cond = sample_cond.unsqueeze(0)
              sample_cond = sample_cond.unsqueeze(0)


              if (epoch) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1 or epoch == 0:
                  evaluate(config, epoch, pipeline, sample_cond, sample_wfs)

              if (epoch) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                  pipeline.save_pretrained(config.output_dir)

In [21]:
from diffusers.optimization import get_cosine_schedule_with_warmup
from dataclasses import dataclass
from diffusers import FlowMatchEulerDiscreteScheduler

@dataclass
class TrainingConfig:
    image_size = (64, 128)  # the generated image resolution
    train_batch_size = 16
    eval_batch_size = 1 # how many images to sample during evaluation
    num_epochs = 120
    gradient_accumulation_steps = 1
    learning_rate = 1e-5
    lr_warmup_steps = 500
    save_image_epochs = 5
    save_model_epochs = 15
    mixed_precision = 'no'  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = 'checkpoints_EW_flow_test'  # the model namy locally and on the HF Hub
    data_path = 'data/timeseries_EW.csv'

    push_to_hub = False  # whether to upload the saved model to the HF Hub
    hub_private_repo = False
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0

config = TrainingConfig()


noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),
)

In [13]:
file_path = config.data_path

train_dataset, train_dataloader, val_dataset, val_dataloader, all_dataset, all_dataloader, wfs_mean, wfs_std, norm_dict, orig_wfs, cond_vars = load_data(file_path,
                                                                        N_FFT,
                                                                        HOP_LENGTH,
                                                                        WIN_LENGTH,
                                                                        batch_size=config.train_batch_size)

(1183, 129, 110)
Before padding: torch.Size([1183, 129, 110])
After padding: torch.Size([1183, 64, 128])
Waveform shape: torch.Size([1120, 64, 128])
Cond shape: torch.Size([1120, 4])
Waveform shape (val): torch.Size([63, 64, 128])
Cond shape (val): torch.Size([63, 4])


In [14]:
print(train_dataset.wfs.shape)
print(train_dataset.cond_var.shape)

torch.Size([1120, 64, 128])
torch.Size([1120, 4])


In [39]:
from accelerate import notebook_launcher


print("Saving to:", config.output_dir)
print("Batch size:", config.train_batch_size)

args = (config, model, noise_scheduler, optimizer, train_dataloader, val_dataset, lr_scheduler)


notebook_launcher(train_loop, args, num_processes=1, mixed_precision=config.mixed_precision)

Saving to: checkpoints_EW_flow_test
Batch size: 16
Launching training on one GPU.


  0%|          | 0/70 [00:00<?, ?it/s]

Saved 1 sample images to checkpoints_EW_flow_test/images


  0%|          | 0/70 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
%run train.py

## MLP for max amplitude

In [None]:
class MaxAmplitudePredictor(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 32)
        self.fc5 = nn.Linear(32, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.fc4(x)
        x = self.relu(x)
        x = self.fc5(x)
        return x

In [None]:
wfs = train_dataset.wfs
wfs_val = val_dataset.wfs

wfs_flat = wfs.view(wfs.shape[0], -1)
wfs_val_flat = wfs_val.view(wfs_val.shape[0], -1)

train_max_amplitudes = torch.max(wfs_flat, dim=1, keepdims=True)[0]
val_max_amplitudes = torch.max(wfs_val_flat, dim=1, keepdims=True)[0]

print(train_max_amplitudes.shape)


class MLPDataset(Dataset):
    def __init__(self, cond_var, max_amps):
        self.cond_var = cond_var
        self.max_amps = max_amps

    def __len__(self):
        return self.cond_var.shape[0]

    def __getitem__(self, idx):
        return self.cond_var[idx], self.max_amps[idx]

mlp_train_dataset = MLPDataset(train_dataset.cond_var, train_max_amplitudes)
mlp_val_dataset = MLPDataset(val_dataset.cond_var, val_max_amplitudes)

mlp_train_dataloader = DataLoader(mlp_train_dataset, batch_size=config.train_batch_size, shuffle=True)
mlp_val_dataloader = DataLoader(mlp_val_dataset, batch_size=config.eval_batch_size, shuffle=False)

In [None]:
input_dim = mlp_train_dataset.cond_var.shape[1]
mlp_model = MaxAmplitudePredictor(input_dim).to(device)

criterion = nn.MSELoss()
mlp_optimizer = torch.optim.Adam(mlp_model.parameters(), lr=0.001)

num_epochs_mlp = 1000

for epoch in range(num_epochs_mlp):
    mlp_model.train()
    running_loss = 0.0
    for conds, max_amps in mlp_train_dataloader:
        conds = conds.to(device).float()
        max_amps = max_amps.to(device)

        mlp_optimizer.zero_grad()
        outputs = mlp_model(conds)
        loss = criterion(outputs, max_amps)
        loss.backward()
        mlp_optimizer.step()

        running_loss += loss.item() * conds.size(0)

    epoch_loss = running_loss / len(mlp_train_dataset)
    if (epoch + 1) % 100 == 0:
      print(f"MLP Epoch [{epoch+1}/{num_epochs_mlp}], Loss: {epoch_loss:.4f}")

    # Evaluate the MLP on the validation set
    mlp_model.eval()
    val_running_loss = 0.0
    with torch.no_grad():
        for conds, max_amps in mlp_val_dataloader:
            conds = conds.to(device).float()
            max_amps = max_amps.to(device)
            outputs = mlp_model(conds)
            loss = criterion(outputs, max_amps)
            val_running_loss += loss.item() * conds.size(0)
    val_epoch_loss = val_running_loss / len(mlp_val_dataset)
    if (epoch + 1) % 100 == 0:
      print(f"MLP Validation Loss: {val_epoch_loss:.4f}")

print("MLP training finished.")

In [None]:
example = val_dataset.cond_var[50]
example_wf = val_dataset.wfs[50]

max = example_wf.max()

example = example.to(device).float()

print(example)
print(example.shape)
preds = mlp_model(example)
print(preds)
print(max)

In [None]:
model_save_path = 'max_amplitude.pth'

torch.save(mlp_model.state_dict(), model_save_path)

print(f"MLP model saved to {model_save_path}")

In [None]:
loaded_model = MaxAmplitudePredictor(input_dim)

model_save_path = 'max_amplitude.pth'
loaded_model.load_state_dict(torch.load(model_save_path))

loaded_model.to(device)

loaded_model.eval()

p = loaded_model(example)
print(p)
print(max)

## Results of the model

In [None]:
import torch

torch.cuda.empty_cache()

In [None]:
import gc
gc.collect()

In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

i = random.randint(0, len(val_dataset) - 1)
print(i)

h = val_dataset.cond_var[50].to(torch.device('cuda'))
h = h.float()

max_amplitude = mlp_model(h)
max_amplitude = max_amplitude.detach().cpu().numpy()

h = cond_emb(h)
h = h.unsqueeze(0)
h = h.unsqueeze(0)



with torch.no_grad():
  pipeline = CustomDDPMPipeline.from_pretrained("checkpoints_EW_flow_1e-4/").to("cuda")

  images, samples, T = pipeline(batch_size = 1,
                    generator = torch.manual_seed(config.seed),
                    output_type = "numpy",
                    return_dict = False,
                    encoder_hidden_states=h)





In [None]:
fig, axs = plt.subplots(2, 5,figsize=(20,10))

samples = samples.cpu()

axs = axs.flatten()

for i, s in enumerate(samples):
  s = s.squeeze()
  axs[i].imshow(s, aspect='auto')
  axs[i].axis('off')
  axs[i].set_title(f't= %.2f' % (T[i]))

plt.tight_layout()

plt.show()

In [None]:
image = images[0]


original = val_dataset.wfs[50].cpu().numpy()

orig_max = original.max()

im_1 = image[0].squeeze().cpu().numpy()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

mean = wfs_mean.cpu().numpy()
std = wfs_std.cpu().numpy()

im_1 = (im_1 * std) + mean
original = (original * std) + mean

im_1 *= max_amplitude

print("Mean:", mean)
print("STD:", std)
print("Max amplitude:", max_amplitude)
print("Max amplitude (Original):", orig_max)

ax1.set_title(f"Original")
ax1.set_xlabel('Time Frame')
ax1.set_ylabel('Frequency Bin')
orig = ax1.imshow(original, aspect='auto')
fig.colorbar(orig, ax=ax1, label='Amplitude')

ax2.set_title(f"Generated")
ax2.set_xlabel('Time Frame')
ax2.set_ylabel('Frequency Bin')
gen = ax2.imshow(im_1, aspect='auto')
fig.colorbar(gen, ax=ax2, label='Amplitude')

plt.show()

plt.close()


In [None]:
import librosa

wf_orig = librosa.griffinlim(original, n_iter = 512)
wf_gen = librosa.griffinlim(im_1, n_iter = 512)

plt.figure(figsize=(10, 6))
plt.title(f"Generated waveform")
plt.xlabel("Time (miliseconds)")
plt.ylabel("Amplitude")
plt.ylim(wf_gen.min(), wf_gen.max())
plt.plot(wf_gen)

plt.figure(figsize=(10, 6))
plt.title(f"Original waveform")
plt.xlabel("Time (miliseconds)")
plt.ylabel("Amplitude")
plt.ylim(wf_orig.min(), wf_orig.max())
plt.plot(wf_orig)
plt.show()