In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch

torch.set_default_dtype(torch.float32)

In [None]:
data_dict = torch.load('data/summarized_neutron_data_w_bkg_260meV_ML.pt')
print(data_dict.keys())

In [None]:
from src.sqw import SpecNeuralRepr
model_sqw = SpecNeuralRepr.load_from_checkpoint(
    'version_14896845/checkpoints/epoch=7160-step=343728.ckpt', 
    map_location=torch.device('cpu')
)

In [None]:
from src.dataset import NeighborDataset
from torch.utils.data import DataLoader

In [None]:
hklw_grid = torch.vstack([_.unsqueeze(0) for _ in torch.meshgrid(*[v for k, v in data_dict['grid'].items()], indexing='ij')]).permute(1, 2, 3, 4, 0)

In [None]:
# import torch
# import torch.nn as nn
# import lightning
# from src.kernel import KernelNet
# from src.siren import SirenNet

# class L_Kernel(lightning.LightningModule):
#     def __init__(self, forward_model, dim=3, neighbor_range=1, exclude_corner=True):
#         super().__init__()
#         # self.save_hyperparameters()
        
#         self.dim = dim
#         self.neighbor_range = neighbor_range
#         self.exclude_corner = exclude_corner
        
#         self.kernel_net = KernelNet(
#             dim=dim, neighbor_range=neighbor_range, 
#             exclude_corner=exclude_corner)
#         self.bkgd_net = SirenNet(
#                 dim_in = dim,
#                 dim_hidden = self.kernel_net.hidden_dim,
#                 dim_out = 1,
#                 num_layers = self.kernel_net.num_layers,
#                 w0_initial = 30.,
#                 final_activation = torch.nn.ReLU()
#         )
#         self.forward_model = forward_model
        
#     def forward(self, x):
#         return self.kernel_net(x)
    
#     def compute_metrics_on_batch(self, batch):
#         kappa = self.forward(
#             batch['center_pts'].to(self.dtype).to(self.device))
#         neighb_data = self.forward_model.forward_qw(
#             batch['neighb_pts'].to(self.dtype).to(self.device))
#         s_sig = torch.einsum(
#             'ij, ij -> i', 
#             kappa, neighb_data[:,self.kernel_net.kernel_mask_flat]
#         ).unsqueeze(-1)
#         # s_pred = s_sig
#         s_bkg = self.bkgd_net(batch['center_pts'].to(self.dtype).to(self.device))
#         s_pred = s_sig + s_bkg
#         s_target = batch['center_data']
#         loss_reconst = torch.nn.functional.mse_loss(s_pred.cpu(), s_target.cpu())
#         loss_bkg_mag = 1e-2 * s_bkg.pow(2).mean()
#         # loss = loss_reconst + loss_bkg_mag
#         return loss_reconst, loss_bkg_mag
    
#     def training_step(self, batch, batch_idx):
#         loss_reconst, loss_bkg_mag = self.compute_metrics_on_batch(batch)
#         loss = loss_reconst + loss_bkg_mag
#         self.log('train_reconst', loss_reconst.item(), prog_bar=True)
#         self.log('train_bkg_mag', loss_bkg_mag.item(), prog_bar=True)
#         self.log('train_loss', loss.item(), prog_bar=True)
#         return loss
    
#     def validation_step(self, batch, batch_idx):
#         loss_reconst, loss_bkg_mag = self.compute_metrics_on_batch(batch)
#         loss = loss_reconst + loss_bkg_mag
#         self.log('val_reconst', loss_reconst.item(), prog_bar=True)
#         self.log('val_bkg_mag', loss_bkg_mag.item(), prog_bar=True)
#         self.log('val_loss', loss.item(), prog_bar=True)
    
#     def configure_optimizers(self):
#         optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
#         return optimizer

In [None]:
from src.dataset import NeighborDataset
from torch.utils.data import DataLoader

In [None]:
dataset = NeighborDataset(hklw_grid, data_dict['S'], neighbor_range=3)
dataloader = DataLoader(dataset, batch_size=5000, shuffle=True, num_workers=32)

In [None]:
from src.kernel import L_Kernel

loss_bkg_mag_weight = 5e-2
model_config = {
    'dim': 4,
    'neighbor_range': 3,
    'exclude_corner': True,
    'hidden_dim': 256, 
    'num_layers': 3, 
    'scale_factor_initial': 'none'
}

model_sqw.params = torch.tensor([29.0, 1.68])
L_model = L_Kernel(forward_model=model_sqw, model_config=model_config, loss_bkg_mag_weight=loss_bkg_mag_weight)

In [None]:
import lightning
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.loggers import TensorBoardLogger

# logger = TensorBoardLogger("./lightning_logs", name=f"sf_{model_config['scale_factor_initial']:.1f}")
logger = TensorBoardLogger("./lightning_logs", name=f"sf_net")

checkpoint_callback = ModelCheckpoint(
    every_n_train_steps=10, save_last=True, save_top_k=1, monitor="train_loss",
    filename=f"sf_{model_config['scale_factor_initial']}-{{epoch}}-{{step}}"
)

torch.set_float32_matmul_precision('high')

# ddp = DDPStrategy(process_group_backend="nccl")
trainer = lightning.Trainer(
    max_epochs=15, accelerator='gpu', logger=logger,
    callbacks=[checkpoint_callback, TQDMProgressBar(refresh_rate=10)],
    log_every_n_steps=1, devices=1, sync_batchnorm = True,
    enable_checkpointing=True, default_root_dir='./')


trainer.fit(L_model, dataloader)