# to build a representation of QM9 that contains only the stuff I need and will not crash my kernel.

In [2]:
import math
from itertools import product
from typing import List, Any, Set
import random
import pandas as pd
import torch
from torch import Tensor, LongTensor
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [5]:
class MoleculesDataset(Dataset):
    def __init__(self, data: List[Data]) -> None:
        super().__init__()
        self.data = data
        
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Data:
        return self.data[idx]

In [6]:
class Data():
    def __init__(self, x: Tensor, pos: Tensor, y: Tensor) -> None:
        self.x = x
        self.pos = pos
        self.y = y
        
    def __len__(self) -> int:
        return x.size(0)
    
    def __str__(self):
        return f"x: {x.size()} | pos: {self.pos.size()} | e: {self.e.size()} | y = {self.y}"

In [7]:
path_head = 'QM9/'
# path_head = 'baby_QM9/'

node_features_df = pd.read_csv(path_head + 'node_attributes.txt', header=None, nrows=157492)
# node_features_df = pd.read_csv(path_head + 'node_attributes.txt', header=None)
node_features = torch.tensor(node_features_df.values)

graph_features_df = pd.read_csv(path_head + 'Y.txt', header=None, nrows=10000)
# graph_features_df = pd.read_csv(path_head + 'Y.txt', header=None)
graph_features = torch.tensor(graph_features_df.values)

atomic_numbers = node_features[:,5].long()

positions = node_features[:,-3:]

internal_energies = graph_features[:,7]
internal_energies_normalized = torch.nn.functional.normalize(internal_energies, dim=0)
internal_energies_normalized_list = internal_energies_normalized.tolist()

new_graph_node_indices = [0]
node_indicators_df = pd.read_csv(path_head + 'graph_indicator.txt', header=None,nrows=157492)
# node_indicators_df = pd.read_csv(path_head + 'graph_indicator.txt', header=None)
node_indicators = node_indicators_df.values.tolist()
for ix in range(len((node_indicators)))[1:]:
    if node_indicators[ix] != node_indicators[ix-1]:
        new_graph_node_indices.append(ix)
new_graph_node_indices.append(len(node_indicators))

molecules_list = []
for start, end in zip(new_graph_node_indices, new_graph_node_indices[1:]):
    x = atomic_numbers.clone()[start:end].view(-1, 1)
    y = torch.Tensor([internal_energies_normalized_list.pop(0)])
    pos = positions.clone()[start:end]
    molecule = Data(x=x, pos=pos, y=y)
    molecules_list.append(molecule)

In [9]:
molecules_list[0].x

tensor([[6],
        [1],
        [1],
        [1],
        [1]])

In [28]:
with open('Zebra/QM9.csv', 'w') as file:
    for m in molecules_list:
        x = m.x
        pos = m.pos
        y = m.y
        string_builder = '['
        for a in x:
            string_builder+=str(a.item())+','
        string_builder = f'{string_builder[:-1]}]['
        for p in pos:
            string_builder += f'[{str(p[0].item())},{str(p[1].item())},{str(p[2].item())}],'
        string_builder = f'{string_builder[:-1]}]'
        string_builder += f'[{y.item()}]\n'
        print(string_builder)
        file.write(string_builder)
    file.close()

[6,1,1,1,1][[-0.012699999846518,1.085800051689148,0.0080000003799796],[0.002199999988079,-0.006000000052154,0.0020000000949949],[1.0117000341415403,1.4637999534606934,0.0003000000142492],[-0.5407999753952026,1.4474999904632568,-0.8766000270843506],[-0.5238000154495239,1.4378999471664429,0.9064000248908995]][-0.019301854074001312]

[7,1,1,1][[-0.0403999984264373,1.0240999460220337,0.0626000016927719],[0.0173000004142522,0.0125000001862645,-0.0274000000208616],[0.9157999753952026,1.3587000370025637,-0.0287999995052814],[-0.5202999711036682,1.343500018119812,-0.7754999995231628]][-0.02695363573729992]

[8,1,1][[-0.034400001168251,0.977500021457672,0.0076000001281499],[0.0648000016808509,0.0206000003963708,0.0015000000130385],[0.8718000054359436,1.3007999658584597,0.0006999999750405]][-0.03643259033560753]

[6,6,1,1][[0.5995000004768372,0.0,1.0],[-0.5995000004768372,0.0,1.0],[-1.6615999937057495,0.0,1.0],[1.6615999937057495,0.0,1.0]][-0.036863524466753006]

[6,7,1][[-0.0132999997586011,1.1

tensor([1])
