In [None]:
from data_processing import Dataset
from noise import NoiseScheduler
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from torchvision import transforms
import torch


In [None]:
data_dir = "/cephfs/dice/users/ek19824/l1trigger/diffusion/datasets"       ## set to directory where data is stored

dataset = Dataset(1_000, (120, 72), signal_file=f"{data_dir}/CaloImages_signal.root", pile_up_file=f"{data_dir}/CaloImages_bkg.root", save=False)

In [None]:
dataset() #once this is cached, you don't have to re-load

In [None]:
new_dim=(64,64)

In [None]:
dataset.preprocess(16, new_dim)

In [None]:
preprocess = transforms.Compose(
        [
            transforms.ToTensor()
        ]
)

In [None]:
clean_frames = preprocess(dataset.signal).float().permute(1, 2, 0).unsqueeze(1) #pytorch symantics
pile_up = preprocess(dataset.pile_up).float().permute(1, 2, 0).unsqueeze(1)

In [None]:
i = 10
noise_scheduler = NoiseScheduler('pile-up')
timestep = torch.LongTensor([10]) #each unit of timestep represents an addtional 5 pile-up events
random_seed = 42
noisy_image, noise = noise_scheduler.add_noise(clean_frame=clean_frames[i], noise_sample=pile_up, timestep=timestep, random_seed=random_seed, n_events = 1000)


In [None]:

plt.imshow(noisy_image.reshape(new_dim))
plt.colorbar()
plt.title("Noised Image")

In [None]:
from models import Model, UNetLite, UNetLite_hls
modtype = 'UNet_lite'

if modtype == 'UNet2d':
    model = Model('UNet', new_dim)
elif modtype == 'UNet_lite':
    model = UNetLite_hls()

print(model)

In [None]:
if modtype == 'UNet2d':
    model = model.__getitem__()

In [None]:
if modtype == 'UNet2d':
    trained_model_path = '/hdfs/user/ys20884/hackathon/trained_models/trained_diffusor.pt'
elif modtype == 'UNet_lite':
    trained_model_path = 'trained_models_lite/model_epoch_9.pt'

print(trained_model_path)

In [None]:
checkpoint = torch.load(trained_model_path)

In [None]:
model.load_state_dict(checkpoint)

In [None]:
if modtype == 'UNet2d':
    noise_pred = model(noisy_image.unsqueeze(1), timestep, return_dict=False)[0]
elif modtype == 'UNet_lite':
    noise_pred = model(noisy_image.unsqueeze(1), timestep)       #, return_dict=False)[0]

ims = plt.imshow(noise_pred.detach().numpy().reshape(new_dim))
plt.colorbar(ims)
plt.title("Noise Prediction")

In [None]:
de_noised = noisy_image[0] - (noise_pred.detach().numpy().reshape(new_dim))

In [None]:

im = plt.imshow(de_noised.reshape(new_dim))
plt.colorbar(im)
plt.title("Denoised Image")

In [None]:
im2 = plt.imshow(clean_frames[i].squeeze())
plt.colorbar(im2)
plt.title("Clean Image")