In [1]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar

import torch
from torch.utils.data import TensorDataset, DataLoader
from src.model_specpred import SpectrumPredictor
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt

torch.set_default_dtype(torch.float64)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = torch.load("data/CrI3/20221110.pt")
X = data['param'][:,:2]
Y = torch.cat((data['omega'], data['inten']), dim=1)

X_train, X_val_test, Y_train, Y_val_test = train_test_split(X, Y, test_size=9/10, random_state=42)
X_val, X_test, Y_val, Y_test = train_test_split(X_val_test, Y_val_test, test_size=1/2, random_state=42)

print("print some values for further reference:")
print("training:\n", X_train[:5])
print("validation:\n", X_val[:5])
print("testing:\n", X_test[:5])

train_dataset = TensorDataset(X_train, Y_train)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

val_dataset = TensorDataset(X_val, Y_val)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print some values for further reference:
training:
 tensor([[-1.0158, -0.0140],
        [-2.4979, -0.1462],
        [-2.8272, -0.4800],
        [-1.4898, -0.6226],
        [-1.1844, -0.0254]])
validation:
 tensor([[-1.4781, -0.1038],
        [-1.9327, -0.3722],
        [-1.3576, -0.7201],
        [-2.7287, -0.8826],
        [-2.1087, -0.1560]])
testing:
 tensor([[-2.7746, -0.3517],
        [-2.9461, -0.6775],
        [-1.6326, -0.5627],
        [-2.0043, -0.0640],
        [-1.9798, -0.2081]])


In [3]:
model_spec = SpectrumPredictor(num_param_in=X.shape[1], num_mode=2)

In [4]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    save_on_train_epoch_end=False, save_last=True, save_top_k=1, monitor="val_loss"
)

trainer = pl.Trainer(
    max_epochs=10000, accelerator="gpu",
    callbacks=[checkpoint_callback],
    log_every_n_steps=2, devices=1, 
    enable_checkpointing=True,
    default_root_dir="training_logs"
    )

trainer.fit(model_spec, train_dataloader, val_dataloader)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type       | Params
--------------------------------------
0 | fc_net | Sequential | 42.2 K
--------------------------------------
42.2 K    Trainable params
0         Non-trainable params
42.2 K    Total params
0.169     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 9999: 100%|██████████| 19/19 [00:00<00:00, 49.50it/s, loss=0.00133, v_num=5] 

`Trainer.fit` stopped: `max_epochs=10000` reached.


Epoch 9999: 100%|██████████| 19/19 [00:00<00:00, 49.29it/s, loss=0.00133, v_num=5]


In [None]:
model_spec.load_from_checkpoint("training_logs/lightning_logs/version_52/checkpoints/epoch=713-step=17850.ckpt")

In [None]:
Y_val_pred = []
with torch.no_grad():
    for x_val in X_val:
        Y_val_pred.append(model_spec(x_val.to(model_spec.device)).detach())
Y_val_pred = torch.vstack(Y_val_pred).cpu()

In [None]:
labels = ['$\omega_1$', '$\omega_2$', '$S_1$', '$S_2$']

fig = plt.figure(figsize=(5,5))
gs = plt.GridSpec(2,2)
for i in range(4):
    ax = fig.add_subplot(gs[i])
    ax.plot([-100,100], [-100,100], 'k', linewidth=0.5)
    ax.scatter(Y_val[:,i], Y_val_pred[:,i], s=5)
    ax.set_aspect('equal')
    ax.set_xlim([-0.1*Y_val[:,i].max(), 1.1*Y_val[:,i].max()])
    ax.set_ylim([-0.1*Y_val[:,i].max(), 1.1*Y_val[:,i].max()])
    ax.set_xlabel(f"True {labels[i]}")
    ax.set_ylabel(f"Predicted {labels[i]}")

fig.tight_layout()

In [None]:
x_val = X_val[0]
y_val_pred = model_spec(x_val)

In [None]:
y_val_pred