# a separate notebook to load in the data properly.

In [1]:
from typing import Any
import torch
from torch_geometric.datasets import QM9
import torch_geometric.transforms as T

In [32]:
class NormalizedDistance(T.BaseTransform):
    def __call__(self, data: Any) -> Any:
        edge_attr_list = []
        edge_index_list = data.edge_index.tolist()
        for i,j in zip(edge_index_list[0], edge_index_list[1]):
            edge_attr_list.append(torch.dist(data.pos[i], data.pos[j]).item() / 1.8100) 
            # 1.8100 is the max distance between two connected nodes in the dataset
            # gdb_106558 edge index 10
        data.edge_attr = torch.Tensor(edge_attr_list)
        return data

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}()'

In [39]:
class NormalizedInternalEnergy(T.BaseTransform):
    def __call__(self, data: Any) -> Any:
        data.y = torch.Tensor([data.y[0,7] / -13388.7246])
        # -13388.7246 is the minimum U_0 in the dataset
        return data

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}()'

In [47]:
class LeaveAtomicNumber(T.BaseTransform):
    def __call__(self, data: Any) -> Any:
        data.x = data.z
        return data
    
    def __repr__(self) -> str:
        return f'{self.__class__.__name__}()'

In [48]:
transform = T.Compose([NormalizedDistance(), NormalizedInternalEnergy(), LeaveAtomicNumber()])

In [49]:
dataset = QM9(root='QM9/', pre_transform=transform)

Downloading https://data.pyg.org/datasets/qm9_v3.zip
Extracting QM9/raw/qm9_v3.zip
Processing...
Using a pre-processed version of the dataset. Please install 'rdkit' to alternatively process the raw data.
Done!
