In [None]:
from data_processing import Dataset
from noise import NoiseScheduler
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
import torch
from torch.utils.data import TensorDataset, DataLoader
from diffusers.optimization import get_cosine_schedule_with_warmup
import torch.nn.functional as F
from tqdm import tqdm
import torch
from torch.optim import Adam
from pathlib import Path
import os
import numpy as np

if torch.cuda.is_available():
    print("CUDA is available!")
    print("Number of available GPUs:", torch.cuda.device_count())
    print("Current GPU:", torch.cuda.current_device())
else:
    print("CUDA is not available. Running on CPU.")

In [None]:
data_dir = "/cephfs/dice/users/ek19824/l1trigger/diffusion/datasets"       # set to directory where data is stored

dataset = Dataset(10000, (120, 72), signal_file=f"{data_dir}/CaloImages_signal.root", pile_up_file=f"{data_dir}/CaloImages_bkg.root", save=False)


In [None]:
dataset() # once this is cached, you don't have to re-load

In [None]:
new_dim=(64,64)

In [None]:
dataset.preprocess(16, new_dim)

In [None]:
preprocess = transforms.Compose(
        [   
            transforms.ToTensor()
        ]
)

In [None]:
clean_frames = preprocess(dataset.signal).float().permute(1, 2, 0) #pytorch semantics
pile_up = preprocess(dataset.pile_up).float().permute(1, 2, 0)

In [None]:
batch_size = 16  # Adjust as needed",}
dataloader = DataLoader(clean_frames.unsqueeze(1), batch_size=batch_size, shuffle=False)

In [None]:
#check tensor shape

for batch in dataloader:
    for tensor in batch:
        print(tensor.shape)
        break
    break

In [None]:
from models import Model, TrainingConfig

model = Model('UNet-lite', new_dim)
model = model.__getitem__()

config = TrainingConfig(output_dir='trained_models_lite')

print(sum(p.numel() for p in model.parameters() if p.requires_grad)) #number of learnable params


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=len(dataloader) * config.num_epochs
)

In [None]:

def train_loop(config, model, noise_sample, optimizer, train_dataloader, lr_scheduler,noise_scheduler, n_events):
    

    global_step = torch.tensor(0)
    # Now you train the model
    for epoch in range(10):
        progress_bar = tqdm(total=len(train_dataloader))
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):

            clean_images = batch
            # Sample noise to add to the images
            
            bs = clean_images[0].shape[0]
            timesteps = torch.randint(
                0, config.num_train_timesteps, (bs,), device=clean_images.device
            ).long()

            random_seed = np.random.randint(0, n_events)

            noisy_images, noise_added = noise_scheduler.add_noise(clean_frame=clean_images, noise_sample=noise_sample, timestep=timesteps, random_seed=random_seed, n_events = n_events)

            # Predict the noise residual
            noise_pred = model(noisy_images, timesteps)[0]
            loss = F.mse_loss(noise_pred, noise_added.float())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            global_step += 1

            torch.save(model.state_dict(), os.path.join(config.output_dir, f"model_epoch_{epoch}.pt"))


In [None]:
from accelerate import notebook_launcher

args = (config, model, pile_up, optimizer, dataloader, lr_scheduler, NoiseScheduler('pile-up'), torch.tensor(10000))

notebook_launcher(train_loop, args, num_processes=1) #will port to GPU if availible (can't train on mutli-GPU at Bristol) 