In [7]:
import numpy as np
import h5py

import torch
from torch.utils.data import Dataset

COVALENT_RADII = {
    1: 0.32,
    6: 0.77,
    7: 0.75,
    8: 0.73,
    9: 0.71,
}

ELEMENT_TO_INDEX = {
    1: 0,
    6: 1,
    7: 2,
    8: 3,
    9: 4,
}

INDEX_TO_SYMBOL = {
    1: 'H',
    6: 'C',
    7: 'N',
    8: 'O',
    9: 'F',
}

In [8]:
class AFMData(Dataset): 
    def __init__(self, data_path, transform, train_size=0.8, split='train'): 
        self.data_path = data_path
        self.transform = transform
        self.split = split

        with h5py.File(self.data_path, 'r') as f:
            total_length = f['x'].shape[0]
            self.train_length = int(train_size*total_length)
            self.val_length = total_length - self.train_length

    def __len__(self):
        if self.split == 'train':
            return self.train_length
        else: 
            return self.val_length

    def __getitem__(self, idx): 
        if self.split == 'train': 
            idx += 0
        else: 
            idx += self.train_length

        with h5py.File(self.data_path, 'r') as f:
            x = f['x'][idx]
            xyz = f['xyz'][idx]

        # Remove padding atoms
        xyz = xyz[xyz[:, -1] > 0]

        # Get edges 
        edges = []
        for i in range(xyz.shape[0]): 
            for j in range(i+1, xyz.shape[0]): 
                dist = np.linalg.norm(xyz[i, :3] - xyz[j, :3])
                if dist < 1.2*(COVALENT_RADII[xyz[i, -1]] + COVALENT_RADII[xyz[j, -1]]):
                    edges.append([i,j])

        # Normalize xyz to [0.25, 0.75]
        xyzmin = np.min(xyz[:, :3])
        xyz_max = np.max(xyz[:, 3])

        xyz[:, :3] = (xyz[:, :3] - xyzmin[:, :3])/(xyz_max - xyzmin)
        xyz[:, :3] = 0.5*xyz[:, :3] + 0.25

        # map atom types to integers (0,1, ..)
        xyz[:, -1] = [ELEMENT_TO_INDEX[atom_type] for atom_type in xyz[:, -1]]

        sample = {'coords': xyz, edges: np.asarray(edges)}

        if self.transform:
            sample = self.transform(sample)

        # Keep all channels from HDF5 and convert to [C,H,W]
        x = torch.from_numpy(x)
        if x.dim() == 4 and x.size(0) == 1: 
            # Input like [1, H, W, C] -> [H, W, C]
            x = x.squeeze(0)
        if x.dim() == 3: 
            # Assume [H, W, C] from HDF5, move channels first
            x = x.permute(2,0,1).continguous()

        return idx, x, sample


def get_datasets(data_path, train_transform = None, val_transform = None, train_size = 0.8): 

    train_dataset = AFMData(data_path, transform = train_transform, train_size=train_size, split='train')
    val_dataset = AFMData(data_path, transform = val_transform, train_size=train_size, split='val')

    return train_dataset, val_dataset


def afm_collate_fn(batch): 

    sample = {'coords': [], 'edges': []}
    ids = [id[0] for id in batch]
    images = torch.stack([item[1] for item in batch])
    for item in batch:
        sample['coords'].append(torch.from_numpy(item[2]['coords']))
        sample['edges'].append(torch.from_numpy(item[2]['edges']))

    return ids, images, samples

In [9]:
afm_data = h5py.File("/scratch/phys/project/sin/hackathon/data/afm.h5", 'r')
for key in afm_data.keys(): 
    print(f"{key}: {afm_data[key]}")

sw: <HDF5 dataset "sw": shape (59392, 1, 2, 3), type "<f4">
x: <HDF5 dataset "x": shape (59392, 1, 128, 128, 10), type "<f4">
xyz: <HDF5 dataset "xyz": shape (59392, 54, 5), type "<f4">


In [10]:
mask = afm_data['xyz'][0, :, -1] > 0
new_data = afm_data['xyz'][0,mask, -1]
atomtok = [INDEX_TO_SYMBOL[value] for value in new_data]
print(atomtok)

['F', 'O', 'N', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H']


In [27]:
coord_bins = 64

coord_bin_values = np.round(np.linspace(0,1,64),2)
print(coord_bin_values)
coord_tokens = [f'<COORD_{i}>' for i in range(coord_bins)]


[0.   0.02 0.03 0.05 0.06 0.08 0.1  0.11 0.13 0.14 0.16 0.17 0.19 0.21
 0.22 0.24 0.25 0.27 0.29 0.3  0.32 0.33 0.35 0.37 0.38 0.4  0.41 0.43
 0.44 0.46 0.48 0.49 0.51 0.52 0.54 0.56 0.57 0.59 0.6  0.62 0.63 0.65
 0.67 0.68 0.7  0.71 0.73 0.75 0.76 0.78 0.79 0.81 0.83 0.84 0.86 0.87
 0.89 0.9  0.92 0.94 0.95 0.97 0.98 1.  ]


In [42]:
new_coord = coord_bin_values[np.argmin(np.abs(coord_bin_values.reshape(1,-1) - coord.reshape(-1,1)), axis = 1)]
print(new_coord)

[1.   0.   0.02]


In [46]:
coords = afm_data['xyz'][0, mask, :3]
coord_tokens = []
coordo_tokens = []
for symbol, coord in zip(atomtok, coords):
    new_coord = coord_bin_values[np.argmin(np.abs(coord_bin_values.reshape(1,-1) - coord.reshape(-1,1)), axis = 1)]
    coord_tokens.append(f'{symbol}: {new_coord[0]}, {new_coord[1]}, {new_coord[2]},')
    coordo_tokens.append(f'{symbol}: {coord[0]}, {coord[1]}, {coord[2]},')
print(coord_tokens[1])
print(coordo_tokens[1])

O: 0.0, 0.25, 0.0,
O: -0.3661240339279175, 0.25187012553215027, -0.0004621819534804672,


In [36]:
(coord_bin_values.reshape(1,-1) - coord.reshape(-1,1)).shape

(3, 64)

array([0., 0., 0.])

In [48]:
class AFMDataset(Dataset): 
    def __init__(self, data_path, transform, train_size=0.8, split='train'): 
        self.data_path = data_path
        self.transform = transform
        self.split = split

        with h5py.File(self.data_path, 'r') as f:
            total_length = f['x'].shape[0]
            self.train_length = int(train_size*total_length)
            self.val_length = total_length - self.train_length

    def __len__(self): 
        if self.split == 'train': 
            return self.train_length
        else: 
            return self.val_length

    def __getitem__(self, idx):

        if self.split == 'train': 
            idx += 0
        else: 
            idx += self.train_length

        with h5py.File(self.data_path, 'r') as f:
            x = f['x'][idx]
            xyz = f['xyz'][idx]

        # Remove padding atoms
        xyz = xyz[xyz[:, -1] > 0]

        # Get edges 
        edges = []
        for i in range(xyz.shape[0]): 
            for j in range(i+1, xyz.shape[0]): 
                dist = np.linalg.norm(xyz[i, :3] - xyz[j, :3])
                if dist < 1.2*(COVALENT_RADII[xyz[i, -1]] + COVALENT_RADII[xyz[j, -1]]):
                    edges.append([i,j])

        # Normalize xyz to [0.25, 0.75]
        xyzmin = np.min(xyz[:, :3])
        xyz_max = np.max(xyz[:, 3])

        xyz[:, :3] = (xyz[:, :3] - xyzmin[:, :3])/(xyz_max - xyzmin)
        xyz[:, :3] = 0.5*xyz[:, :3] + 0.25

        # map atom types to integers (0,1, ..)
        xyz[:, -1] = [ELEMENT_TO_INDEX[atom_type] for atom_type in xyz[:, -1]]

        sample = {'coords': xyz, edges: np.asarray(edges)}

        if self.transform:
            sample = self.transform(sample)

        # Keep all channels from HDF5 and convert to [C,H,W]
        x = torch.from_numpy(x)
        if x.dim() == 4 and x.size(0) == 1: 
            # Input like [1, H, W, C] -> [H, W, C]
            x = x.squeeze(0)
        if x.dim() == 3: 
            # Assume [H, W, C] from HDF5, move channels first
            x = x.permute(2,0,1).continguous()

        mask = xyz[:, -1] > 0
        atomtok = [INDEX_TO_SYMBOL[value] for value in xyz[mask, -1]]

        atomtok_coords = []
        for symbol, coord in zip(atomtok, xyz):
            new_coord = coord_bin_values[np.argmin(np.abs(coord_bin_values.reshape(1,-1) - coord.reshape(-1,1)), axis = 1)]
            atomtok_coords.append(f'{symbol}: {new_coord[0]}, {new_coord[1]}, {new_coord[2]},')

        ref = {'atomtok': np.asarray(atomtok), 'edges': np.asarray(edges), 'atomtok_coords': np.asarray(atomtok_coords), 'chartok_coords': np.asarray(atomtok_coords)}

        return idx, image, ref