In [None]:
#source /scratch/phys/sin/sethih1/venv/MolNexTR_env/bin/activate
import sys
sys.path.append('/home/sethih1/MORAFInator/')

In [45]:
import numpy as np
import h5py

import torch
from torch.utils.data import Dataset, DataLoader

COVALENT_RADII = {
    1: 0.32,
    6: 0.77,
    7: 0.75,
    8: 0.73,
    9: 0.71,
}

ELEMENT_TO_INDEX = {
    1: 0,
    6: 1,
    7: 2,
    8: 3,
    9: 4,
}

INDEX_TO_SYMBOL = {
    1: 'H',
    6: 'C',
    7: 'N',
    8: 'O',
    9: 'F',
}



INDEX_TO_SYMBOL = {
    0: 'H',
    1: 'C',
    2: 'N',
    3: 'O',
    4: 'F',
}

In [46]:
class AFMData(Dataset): 
    def __init__(self, data_path, transform, train_size=0.8, split='train'): 
        self.data_path = data_path
        self.transform = transform
        self.split = split

        with h5py.File(self.data_path, 'r') as f:
            total_length = f['x'].shape[0]
            self.train_length = int(train_size*total_length)
            self.val_length = total_length - self.train_length

    def __len__(self):
        if self.split == 'train':
            return self.train_length
        else: 
            return self.val_length

    def __getitem__(self, idx): 
        if self.split == 'train': 
            idx += 0
        else: 
            idx += self.train_length

        with h5py.File(self.data_path, 'r') as f:
            x = f['x'][idx]
            xyz = f['xyz'][idx]

        # Remove padding atoms
        xyz = xyz[xyz[:, -1] > 0]

        # Get edges 
        edges = []
        for i in range(xyz.shape[0]): 
            for j in range(i+1, xyz.shape[0]): 
                dist = np.linalg.norm(xyz[i, :3] - xyz[j, :3])
                if dist < 1.2*(COVALENT_RADII[xyz[i, -1]] + COVALENT_RADII[xyz[j, -1]]):
                    edges.append([i,j])

        # Normalize xyz to [0.25, 0.75]
        xyzmin = np.min(xyz[:, :3])
        xyz_max = np.max(xyz[:, 3])

        xyz[:, :3] = (xyz[:, :3] - xyzmin[:, :3])/(xyz_max - xyzmin)
        xyz[:, :3] = 0.5*xyz[:, :3] + 0.25

        # map atom types to integers (0,1, ..)
        xyz[:, -1] = [ELEMENT_TO_INDEX[atom_type] for atom_type in xyz[:, -1]]

        sample = {'coords': xyz, edges: np.asarray(edges)}

        if self.transform:
            sample = self.transform(sample)

        # Keep all channels from HDF5 and convert to [C,H,W]
        x = torch.from_numpy(x)
        if x.dim() == 4 and x.size(0) == 1: 
            # Input like [1, H, W, C] -> [H, W, C]
            x = x.squeeze(0)
        if x.dim() == 3: 
            # Assume [H, W, C] from HDF5, move channels first
            x = x.permute(2,0,1).continguous()

        return idx, x, sample


def get_datasets(data_path, train_transform = None, val_transform = None, train_size = 0.8): 

    train_dataset = AFMData(data_path, transform = train_transform, train_size=train_size, split='train')
    val_dataset = AFMData(data_path, transform = val_transform, train_size=train_size, split='val')

    return train_dataset, val_dataset


def afm_collate_fn(batch): 

    sample = {'coords': [], 'edges': []}
    ids = [id[0] for id in batch]
    images = torch.stack([item[1] for item in batch])
    for item in batch:
        sample['coords'].append(torch.from_numpy(item[2]['coords']))
        sample['edges'].append(torch.from_numpy(item[2]['edges']))

    return ids, images, samples

In [47]:
afm_data = h5py.File("/scratch/phys/project/sin/hackathon/data/afm.h5", 'r')
for key in afm_data.keys(): 
    print(f"{key}: {afm_data[key]}")

sw: <HDF5 dataset "sw": shape (59392, 1, 2, 3), type "<f4">
x: <HDF5 dataset "x": shape (59392, 1, 128, 128, 10), type "<f4">
xyz: <HDF5 dataset "xyz": shape (59392, 54, 5), type "<f4">


In [48]:
mask = afm_data['xyz'][0, :, -1] > 0
new_data = afm_data['xyz'][0,mask, -1]
atomtok = [INDEX_TO_SYMBOL[value] for value in new_data]
print(atomtok)

KeyError: 9.0

In [None]:
coord_bins = 64

coord_bin_values = np.round(np.linspace(0,1,64),2)
print(coord_bin_values)
coord_tokens = [f'<COORD_{i}>' for i in range(coord_bins)]


In [None]:
new_coord = coord_bin_values[np.argmin(np.abs(coord_bin_values.reshape(1,-1) - coord.reshape(-1,1)), axis = 1)]
print(new_coord)

In [None]:
coords = afm_data['xyz'][0, mask, :3]
coord_tokens = []
coordo_tokens = []
for symbol, coord in zip(atomtok, coords):
    new_coord = coord_bin_values[np.argmin(np.abs(coord_bin_values.reshape(1,-1) - coord.reshape(-1,1)), axis = 1)]
    coord_tokens.append(f'{symbol}: {new_coord[0]}, {new_coord[1]}, {new_coord[2]},')
    coordo_tokens.append(f'{symbol}: {coord[0]}, {coord[1]}, {coord[2]},')
print(coord_tokens[1])
print(coordo_tokens[1])

In [49]:
(coord_bin_values.reshape(1,-1) - coord.reshape(-1,1)).shape

(3, 64)

In [153]:
class AFMDataset(Dataset): 
    def __init__(self, data_path, tokenizer, transform, train_size=0.8, split='train'): 
        self.data_path = data_path
        self.transform = transform
        self.split = split
        self.tokenizer = tokenizer

        with h5py.File(self.data_path, 'r') as f:
            total_length = f['x'].shape[0]
            self.train_length = int(train_size*total_length)
            self.val_length = total_length - self.train_length

    def __len__(self): 
        if self.split == 'train': 
            return self.train_length
        else: 
            return self.val_length

    def __getitem__(self, idx):

        if self.split == 'train': 
            idx += 0
        else: 
            idx += self.train_length

        with h5py.File(self.data_path, 'r') as f:
            x = f['x'][idx]
            xyz = f['xyz'][idx]

        # Remove padding atoms
        xyz = xyz[xyz[:, -1] > 0]

        # Get edges 
        edges = []
        for i in range(xyz.shape[0]): 
            for j in range(i+1, xyz.shape[0]): 
                dist = np.linalg.norm(xyz[i, :3] - xyz[j, :3])
                if dist < 1.2*(COVALENT_RADII[xyz[i, -1]] + COVALENT_RADII[xyz[j, -1]]):
                    edges.append([i,j])

        # Normalize xyz to [0.25, 0.75]
        xyzmin = np.min(xyz[:, :3])
        xyz_max = np.max(xyz[:, 3])

        xyz[:, :3] = (xyz[:, :3] - xyzmin)/(xyz_max - xyzmin)
        xyz[:, :3] = 0.5*xyz[:, :3] + 0.25

        # map atom types to integers (0,1, ..)
        xyz[:, -1] = [ELEMENT_TO_INDEX[atom_type] for atom_type in xyz[:, -1]]

        #sample = {'coords': xyz, edges: np.asarray(edges)}

        #if self.transform:
        #    sample = self.transform(sample)

        # Keep all channels from HDF5 and convert to [C,H,W]
        x = torch.from_numpy(x)
        if x.dim() == 4 and x.size(0) == 1: 
            # Input like [1, H, W, C] -> [H, W, C]
            x = x.squeeze(0)
        if x.dim() == 3: 
            # Assume [H, W, C] from HDF5, move channels first
            x = x.permute(2,0,1).contiguous()

        mask = xyz[:, -1] > 0
        atomtok = " ".join([INDEX_TO_SYMBOL[value] for value in xyz[mask, -1]])

        #atomtok = self.tokenizer['atomtok'].text_to_sequence(atomtok)

        nodes = {'coords': [], 'symbols': ""}
        for symbol, coord in zip(atomtok, xyz):
            new_coord = coord_bin_values[np.argmin(np.abs(coord_bin_values.reshape(1,-1) - coord.reshape(-1,1)), axis = 1)]
            #atomtok_coords.append(f'{symbol}: {new_coord[0]}, {new_coord[1]}, {new_coord[2]},')
            nodes['coords'].append(new_coord[:2])
            nodes['symbols'] = symbol + " "

        atomtok_coords = self.tokenizer['atomtok_coords'].nodes_to_sequence(nodes)

        #ref = {'atomtok': atomtok, 'edges': np.asarray(edges), 'atomtok_coords': atomtok_coords, 'chartok_coords': atomtok_coords}
        ref = {'atomtok_coords': atomtok_coords}

        return idx, x, ref


def get_datasets(data_path, tokenizer, train_transform = None, val_transform = None, train_size=0.8):

    train_dataset = AFMDataset(data_path, tokenizer = tokenizer, transform=train_transform, train_size=train_size, split='train')
    val_dataset = AFMDataset(data_path, tokenizer = tokenizer, transform=val_transform, train_size=train_size, split='val')

    return train_dataset, val_dataset


def afm_collate_fn(batch):

    #sample = {'coords':[], 'edges':[]}
    #ref = {'atomtok': [], 'edges': [], 'atomtok_coords': [], 'chartok_coords': []}
    ref = {'atomtok_coords': []}

    PAD_ID = 0
    length = 128

    ids = [id[0] for id in batch]
    images = torch.stack([item[1] for item in batch])
    for item in batch:
        #sample['coords'].append(torch.from_numpy(item[2]['coords']))
        #sample['edges'].append(torch.from_numpy(item[2]['edges']))
        #ref['atomtok'].append(torch.from_numpy(item[2]['atomtok']))
        #ref['edges'].append(torch.from_numpy(item[2]['edges']))

        tok = item[2]['atomtok_coords']
        
        pad = [PAD_ID]*(length - len(tok))
        tok.extend(pad)
        
        ref['atomtok_coords'] = tok
        #ref['atomtok_coords'].append(torch.from_numpy(item[2]['atomtok_coords']))
        #ref['chartok_coords'].append(torch.from_numpy(item[2]['chartok_coords']))

    return ids, images, ref
    #return ids, images, sample

In [154]:
from src.tokenization import get_tokenizer
from types import SimpleNamespace

args_dict = {
    'batch_size': 32,
    'learning_rate': 0.001,

    # related to model encoder
    'encoder': 'swin_base', 
    'use_checkpoint':False, 
    'encoder_dim': 64, 
    'in_chans': 10,

    # related to model decoder
    'dec_hidden_size': 16,
    'enc_pos_emb': True, 
    'dec_num_layers': 3, 
    'dec_attn_heads': 4,
    'hidden_dropout': 0.2,
    'attn_dropout': 0.2,
    'max_relative_positions': 10,
    'compute_confidence': True,

    # related to tokenizer
    'formats':['atomtok_coords'],
    'vocab_file': '/home/sethih1/MORAFInator/src/vocab/vocab_chars.json', 
    'coord_bins': 64, 
    'sep_xy': False, 
    'continuous_coords': True, 
    'input_resolution': 256
    
}

args = SimpleNamespace(**args_dict)
print(args)

tokenizer = get_tokenizer(args)



namespace(batch_size=32, learning_rate=0.001, encoder='swin_base', use_checkpoint=False, encoder_dim=64, in_chans=10, dec_hidden_size=16, enc_pos_emb=True, dec_num_layers=3, dec_attn_heads=4, hidden_dropout=0.2, attn_dropout=0.2, max_relative_positions=10, compute_confidence=True, formats=['atomtok_coords'], vocab_file='/home/sethih1/MORAFInator/src/vocab/vocab_chars.json', coord_bins=64, sep_xy=False, continuous_coords=True, input_resolution=256)


In [155]:
h5_path =  "/scratch/phys/project/sin/hackathon/data/afm.h5"
train_dataset, val_dataset = get_datasets(data_path = h5_path, tokenizer=tokenizer, train_transform = None, val_transform = None, train_size=0.8)

In [156]:
train_loader = DataLoader(train_dataset, batch_size=1, num_workers=1, drop_last = False, collate_fn=afm_collate_fn)

In [157]:
tok = [1, 125, 143, 3, 145, 147, 3, 2]
pad = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

print(tok.extend(pad))

None


In [158]:
ids, imgs, refs = next(iter(train_loader))

In [159]:
ids

[0]

In [160]:
imgs.shape

torch.Size([1, 10, 128, 128])

In [161]:
refs

{'atomtok_coords': [1,
  125,
  143,
  3,
  145,
  147,
  3,
  2,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0]}