## Evaluate the Equivariant Diffusion Model

### Loading dependencies and datasets

In [None]:
import sys
sys.path.append('../../..')
from probai.src.data.mini_qm9 import MiniQM9Dataset
from torch_geometric.loader import DataLoader 
from probai.src.models.ddpm import DDPM
from probai.src.models.egnn import EGNNScore
from probai.src.training.training_loop import Trainer
from probai.src.evaluation.evaluator import Evaluator
import torch
import yaml


In [None]:

# Loading the validation dataset and creating a DataLoader  
dataset_valid = MiniQM9Dataset(file_path=f"../../raw_data/mini_qm9_valid.pickle")
loader_valid = DataLoader(dataset_valid, batch_size=128, shuffle=False)

### Loading models from previous checkpoint

In [None]:
# Initialize EGNN
with open("../../configs/default_config.yml", 'r') as file:  
    config = yaml.safe_load(file)  
  
egnn_config = config['EGNN']  
hidden_nf = egnn_config['hidden_nf']  
n_layers = egnn_config['n_layers']  
score = EGNNScore(in_node_nf=5 + 1, # 5 for the one hot encoding, 1 for diffusion time
        hidden_nf=hidden_nf,
        n_layers=n_layers,
        out_node_nf=5) # 5 atom types in QM9


In [None]:
# Initialize DDPM and load checkpoint
ddpm_config = config['DDPM']
N = config['DDPM']['N'] # Numbero of noise level, default set to 100
ddpm = DDPM(noise_schedule_type="linear", model=score, N=N)
trainer = Trainer(ddpm)
trainer.load_checkpoint("../../checkpoints/egnn_checkpoint.pth")

### Generate samples

In [None]:
# Generate some samples (same as loader_valid.batch_size)
evaluator = Evaluator(ddpm, loader_valid=loader_valid)
x, h, ptr = evaluator.sample_batch(device=torch.device("cuda:0"))

### Evaluate molecule and atom stability

<small> For a quickly trained model we expect some decent atom stability (over 50%) and 0% molecule stability. The molecule stability is so low because a single not stalbe atom implies the molecule not being stable. Therefore molecule stability is only achievable after long trainings when atom stability becomes 85% ~ 100%. 
However, some generated structures may still look qualitatevely well even if they contain a wrong bond. </small>

In [None]:
# Evaluate atom and molecule stabilities. 
# For a model trained in few epochs we should expect good atom stability and ver low molecule stability.
# Large Molecule stability would require longer trainings

atom_st, mol_st = evaluator.eval_stability(x, h, ptr)
print(f"Atom stability: {atom_st} \t Molecule Stability {mol_st}")


In [None]:
# Print some sample generated with the trained model
evaluator.eval_plot(x, h, ptr, max_num_plots=10)


### Additional task
If you are curious how the smaples would look in the Gaussian domain without mapping them to the correct distribution you can write the following two lines before they get plotted
<code>
h=torch.randn(h.shape)  
x=torch.randn(x.shape)
</code> 

On the other hand, if you want to use a model previously trained for 200 epochs you can use:  
<code>trainer.load_checkpoint("../../checkpoints/egnn_checkpoint_instructors.pth")</code>  
