In [None]:
import openslide
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms

class WSIDataset(Dataset):
    def __init__(self, wsi_paths, labels, transform=None):
        self.wsi_paths = wsi_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.wsi_paths)

    def __getitem__(self, index):
        wsi_path = self.wsi_paths[index]
        label = self.labels[index]
        
        # Open WSI image using openslide
        wsi = openslide.open_slide(wsi_path)
        
        # Read the whole slide image into PIL image format
        wsi_pil = Image.fromarray(wsi.read_region((0, 0), 0, wsi.level_dimensions[0]).convert('RGB'))
        
        # Apply transformations if specified
        if self.transform is not None:
            wsi_pil = self.transform(wsi_pil)
        
        return wsi_pil, label, wsi_path

# Example usage
wsi_paths = ['path/to/wsi1', 'path/to/wsi2', 'path/to/wsi3']
labels = [0, 1, 0]  # Example labels corresponding to the WSI paths
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])

dataset = WSIDataset(wsi_paths, labels, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

In [7]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('-t1', action='store_true', default='True')
parser.add_argument('-t2', action='store_true', default='False')
parser.add_argument('-f1', action='store_false', default='False')
parser.add_argument('-f2', action='store_false', default='True')
args, unknown = parser.parse_known_args()

print(args.t1)
print(args.t2)
print(args.f1)
print(args.f2)

True
False
False
True


In [6]:
import torch

pt_file = torch.load("/shared/js.yun/HIPT/HIPT_original/3-Self-Supervised-Eval/embeddings_slide_lib/embeddings_slide_lib/vit256mean_tcga_slide_embeddings/TCGA-WE-A8ZR-06Z-00-DX1.52E1652B-F713-4DEE-B699-FE9D0719A43C.pt")
print(pt_file.size())




torch.Size([67, 192])


In [11]:
import pickle

vit256mean = "/shared/js.yun/HIPT/HIPT_original/3-Self-Supervised-Eval/embeddings_slide_lib/embeddings_slide_lib/knn-subtyping/vit256mean/tcga_lung_vit256mean_class_split_train_0.pkl"
with open(vit256mean, 'rb') as f:
        data = pickle.load(f)
        print(data['embeddings'].shape)

(785, 192)
