In [46]:
import torch
import random
import pandas as pd
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.nn.utils.rnn import pad_sequence


In [20]:
def pad_collate(batch):

    inp = [item for item in batch]
    inp = pad_sequence(inp, batch_first=True)


    return inp

In [34]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __getitem__(self, index):
        return self.data[index]
    def __len__(self):
        return len(self.data)

class SequenceLengthSampler(Sampler):

    def __init__(self, data_source, lengths, batch_size, shuffle=True):
        self.data_source = data_source
        self.lengths = lengths
        self.batch_size = batch_size
        self.shuffle = shuffle

        # Group indices by their sequence length
        self.buckets = defaultdict(list)
        for idx, l in enumerate(lengths):
            self.buckets[l].append(idx)

        self.buckets = list(self.buckets.values())  # list of lists (each list = same-length samples)

        if self.shuffle:
            for b in self.buckets:
                random.shuffle(b)
            random.shuffle(self.buckets)

    def __iter__(self):
        batches = []
        for bucket in self.buckets:
            # Split each bucket into batches of given size
            for i in range(0, len(bucket), self.batch_size):
                batch = bucket[i:i+self.batch_size]
                batches.append(batch)

        if self.shuffle:
            random.shuffle(batches)

        for batch in batches:
            yield batch

    def __len__(self):
        return sum((len(bucket) + self.batch_size - 1) // self.batch_size for bucket in self.buckets)

In [31]:
data = [torch.randint(0, 100, (random.randint(3, 10),)) for _ in range(100)]
lengths = [ a.size() for a in data]

In [55]:
df = pd.read_json("../dataset/filtered_binding_finalm_180k_12.52_-6.0/processed_data/train.json")


In [56]:
df

Unnamed: 0,Target_Chain_encoded,Ligand_encoded,logIC50_scaled
0,"[991, 118, 22, 145, 17, 986, 49, 990, 284, 56,...","[135, 5, 203, 201, 10, 202, 15, 200, 87, 5, 20...",0.539957
1,"[131, 49, 986, 22, 18, 97, 385, 97, 51, 6, 76,...","[12, 209, 199, 13, 211, 24, 18, 204, 8, 200, 1...",0.314513
2,"[131, 720, 361, 17, 987, 93, 8, 835, 387, 876,...","[94, 202, 198, 54, 199, 13, 211, 11, 203, 201,...",0.521747
3,"[991, 64, 27, 76, 983, 12, 12, 21, 33, 20, 975...","[23, 5, 203, 201, 205, 209, 199, 13, 211, 136,...",0.520625
4,"[131, 43, 782, 127, 675, 191, 21, 168, 47, 849...","[123, 202, 8, 204, 3, 213, 29, 200, 35, 110, 2...",0.496760
...,...,...,...
145197,"[131, 13, 564, 49, 54, 276, 71, 983, 113, 161,...","[92, 198, 202, 8, 204, 198, 200, 199, 201, 3, ...",0.361715
145198,"[446, 985, 186, 982, 18, 18, 501, 36, 434, 10,...","[40, 202, 198, 200, 203, 201, 198, 5, 203, 201...",0.431965
145199,"[496, 56, 974, 124, 65, 978, 66, 99, 982, 418,...","[23, 200, 199, 201, 199, 209, 199, 208, 211, 2...",0.469800
145200,"[131, 747, 731, 348, 6, 974, 366, 902, 73, 422...","[40, 202, 8, 59, 199, 13, 211, 44, 203, 209, 1...",0.494674


In [51]:
data = df.Target_Chain_encoded.to_list()
data = [ torch.tensor(a) for a in data]
lengths = [ len(a) for a in data]

In [53]:
sampler = SequenceLengthSampler(data, lengths, batch_size=4, shuffle=True)
loader = DataLoader(data, batch_sampler=sampler, collate_fn=pad_collate)

In [54]:
for batch in loader:
    print(batch.size())

torch.Size([4, 263])
torch.Size([4, 206])
torch.Size([4, 252])
torch.Size([4, 151])
torch.Size([4, 211])
torch.Size([4, 163])
torch.Size([4, 252])
torch.Size([4, 220])
torch.Size([4, 311])
torch.Size([4, 172])
torch.Size([1, 348])
torch.Size([4, 289])
torch.Size([4, 273])
torch.Size([4, 172])
torch.Size([4, 160])
torch.Size([4, 135])
torch.Size([4, 183])
torch.Size([4, 545])
torch.Size([4, 553])
torch.Size([4, 208])
torch.Size([4, 132])
torch.Size([4, 131])
torch.Size([4, 123])
torch.Size([4, 474])
torch.Size([4, 240])
torch.Size([4, 238])
torch.Size([4, 449])
torch.Size([4, 106])
torch.Size([4, 530])
torch.Size([4, 574])
torch.Size([4, 435])
torch.Size([4, 381])
torch.Size([4, 125])
torch.Size([4, 406])
torch.Size([4, 565])
torch.Size([4, 295])
torch.Size([4, 140])
torch.Size([4, 95])
torch.Size([4, 224])
torch.Size([4, 259])
torch.Size([4, 219])
torch.Size([4, 126])
torch.Size([4, 291])
torch.Size([4, 213])
torch.Size([4, 303])
torch.Size([4, 307])
torch.Size([4, 209])
torch.Size([4,