In [1]:
import torch
from torchvision.transforms import v2 as transforms
from bioscan_dataset import BIOSCAN5M
from bioscan_dataset.bioscan5m import RGB_MEAN, RGB_STDEV

image_transform = transforms.Compose(
    [
        transforms.CenterCrop(256),
        transforms.ToImage(),
        transforms.ToDtype(torch.float32, scale=True),
        transforms.Normalize(mean=RGB_MEAN, std=RGB_STDEV),
    ]
)
# Create a DNA transform, mapping from characters to integers and padding to a fixed length
charmap = {"P": 0, "A": 1, "C": 2, "G": 3, "T": 4, "N": 5}
dna_transform = lambda seq: torch.tensor(
    [charmap[char] for char in seq] + [0] * (660 - len(seq)), dtype=torch.long
)

In [4]:
ds_train = BIOSCAN5M(
    root="~/Desktop/Bridged Clustering Project/Code/bioscan_data/bioscan-5m",
    split="val",
    transform=image_transform,
    dna_transform=dna_transform,
)

In [9]:
ds_train[0]

(Image([[[0.2617, 0.2617, 0.2617,  ..., 0.2617, 0.2617, 0.2617],
         [0.2617, 0.2617, 0.2617,  ..., 0.2617, 0.2617, 0.2617],
         [0.2617, 0.2617, 0.2617,  ..., 0.2617, 0.2617, 0.2617],
         ...,
         [1.1727, 1.2556, 1.2832,  ..., 1.6144, 1.6697, 1.6697],
         [1.3384, 1.4212, 1.4212,  ..., 1.6421, 1.6697, 1.6697],
         [1.4764, 1.5040, 1.5040,  ..., 1.6421, 1.6697, 1.6697]],
 
        [[0.5745, 0.5745, 0.5745,  ..., 0.5745, 0.5745, 0.5745],
         [0.5745, 0.5745, 0.5745,  ..., 0.5745, 0.5745, 0.5745],
         [0.5745, 0.5745, 0.5745,  ..., 0.5745, 0.5745, 0.5745],
         ...,
         [1.2263, 1.3015, 1.3517,  ..., 1.4519, 1.4519, 1.4519],
         [1.3767, 1.4770, 1.4770,  ..., 1.4269, 1.4519, 1.4519],
         [1.5021, 1.5021, 1.5021,  ..., 1.4269, 1.4519, 1.4519]],
 
        [[1.1158, 1.1158, 1.1158,  ..., 1.1158, 1.1158, 1.1158],
         [1.1158, 1.1158, 1.1158,  ..., 1.1158, 1.1158, 1.1158],
         [1.1158, 1.1158, 1.1158,  ..., 1.1158, 1.1158, 

In [23]:
for image, dna_barcode, label in ds_train:
    print(image.shape, dna_barcode.shape, label)

torch.Size([3, 256, 256]) torch.Size([660]) 240
torch.Size([3, 256, 256]) torch.Size([660]) 213
torch.Size([3, 256, 256]) torch.Size([660]) 47
torch.Size([3, 256, 256]) torch.Size([660]) 459
torch.Size([3, 256, 256]) torch.Size([660]) 498
torch.Size([3, 256, 256]) torch.Size([660]) 425
torch.Size([3, 256, 256]) torch.Size([660]) 514
torch.Size([3, 256, 256]) torch.Size([660]) 199
torch.Size([3, 256, 256]) torch.Size([660]) 135
torch.Size([3, 256, 256]) torch.Size([660]) 452
torch.Size([3, 256, 256]) torch.Size([660]) 333
torch.Size([3, 256, 256]) torch.Size([660]) 243
torch.Size([3, 256, 256]) torch.Size([660]) 97
torch.Size([3, 256, 256]) torch.Size([660]) 426
torch.Size([3, 256, 256]) torch.Size([660]) 198
torch.Size([3, 256, 256]) torch.Size([660]) 198
torch.Size([3, 256, 256]) torch.Size([660]) 260
torch.Size([3, 256, 256]) torch.Size([660]) 340
torch.Size([3, 256, 256]) torch.Size([660]) 398
torch.Size([3, 256, 256]) torch.Size([660]) 413
torch.Size([3, 256, 256]) torch.Size([660]

KeyboardInterrupt: 

In [None]:
#only keep the datapoints with "genus" information
#randomly select 5 datapoints from the dataset as our roots, then only keep the datapoints that are of the same (class,order,family) as the roots