In [87]:
import numpy as np
import pandas as pd
from torch.utils.data.dataset import Dataset
import sys
import os

def save_splits(split_datasets, column_keys, filename, boolean_style=False):
    ##########################################################################################
    splits = [split_datasets[i]['case_id'] for i in range(len(split_datasets))]
    # splits = [split_datasets[i]['case_id']+'/'+split_datasets[i]['slide_id'] for i in range(len(split_datasets))]
    if not boolean_style:
        df = pd.concat(splits, ignore_index=True, axis=1)
        df.columns = column_keys
    else:
        df = pd.concat(splits, ignore_index = True, axis=0)
        index = df.values.tolist()
       
       
        one_hot = np.eye(len(split_datasets)).astype(bool)
        bool_array = np.repeat(one_hot, [len(dset) for dset in split_datasets], axis=0)
        df = pd.DataFrame(bool_array, index=index, columns = ['train', 'val', 'test'])

    display(df)
    # df.to_csv(filename)


class Split_Gene_Clip_Dataset(Dataset):
    def __init__(self,
        wsi_csv_path = '/shared/j.jang/pathai/CLAM/dataset_csv/TCGA-lung-LUAD+LUSC-TMB-pan_cancer-323.csv',
        genomics_csv_path = '/shared/js.yun/data/CLAM_data/genomics_data/TCGA-lung-LUAD+LUSC-selected_2847_zscore.csv',        
        seed = 1,
        n_splits=10,
        val_frac=0, 
        test_frac=0
        ):
        self.seed = seed
        self.n_splits = n_splits
        np.random.seed(self.seed)

        self.slide_data = pd.read_csv(wsi_csv_path)[['case_id', 'slide_id', 'label', 'Subtype', 'Mutation Count', 'TMB (nonsynonymous)']]
        self.genomics_data = pd.read_csv(genomics_csv_path)
        
        self.num_data = self.genomics_data.shape[1]                         # 907명
        self.num_val = np.round(self.num_data * val_frac).astype(int)       # 91
        self.num_test = np.round(self.num_data * test_frac).astype(int)     # 181
        self.num_train = self.num_data-self.num_val-self.num_test           # 635
        
        print(f'number of train set: {self.num_train}')                 # 635
        print(f'number of validation set: {self.num_val}')              # 91
        print(f'number of test set: {self.num_test}')                   # 181
        # df로 저장
        columns = ['train', 'val', 'test']
        num_splits = [self.num_train, self.num_val, self.num_test]
        count_dataset = pd.DataFrame([num_splits], columns=columns)
        display(count_dataset)

    def create_splits_index(self):        
        for i in range(self.n_splits):
            all_indices = np.arange(self.num_data).astype(int)
            val_index = np.random.choice(all_indices, self.num_val, replace = False) 
            remaining_ids = np.setdiff1d(all_indices, val_index)
            test_index = np.random.choice(remaining_ids, self.num_test, replace = False) 
            train_index = np.setdiff1d(remaining_ids, test_index)

            assert len(train_index)+len(val_index)+len(test_index) > 0
            assert len(np.intersect1d(train_index, test_index)) == 0
            assert len(np.intersect1d(train_index, val_index)) == 0
            assert len(np.intersect1d(val_index, test_index)) == 0

            yield train_index, val_index, test_index

    def create_split_file_name(self, train_index, val_index, test_index):
        train_data = self.genomics_data.columns[train_index]
        val_data = self.genomics_data.columns[val_index]
        test_data = self.genomics_data.columns[test_index]

        train_data = pd.DataFrame(train_data, columns=['case_id'])
        val_data = pd.DataFrame(val_data, columns=['case_id'])
        test_data = pd.DataFrame(test_data, columns=['case_id'])

        return train_data, val_data, test_data


split_dataset = Split_Gene_Clip_Dataset()


for i in range(1):
    train_index, val_index, test_index = next(split_dataset.create_splits_index())
    splits = split_dataset.create_split_file_name(train_index, val_index, test_index)
    # print(splits[0])
    save_splits(splits, ['train', 'val', 'test'], 'splits_{}.csv'.format(i))
    save_splits(splits, ['train', 'val', 'test'], 'splits_{}_bool.csv'.format(i), boolean_style=True)
    
# split_dir=(f'{args.split_dir}TCGA-lung-label_col_{args.label_column}_sub_{",".join(args.target_subtype)}'
#                 f'-TMB-high-ratio-{args.tmb_high_ratio:.2f}-splits_{args.k}-seed{args.seed}/{args.task}')


number of train set: 907
number of validation set: 0
number of test set: 0


Unnamed: 0,train,val,test
0,907,0,0


Unnamed: 0,train,val,test
0,TCGA-05-4244-01,,
1,TCGA-05-4249-01,,
2,TCGA-05-4250-01,,
3,TCGA-05-4382-01,,
4,TCGA-05-4384-01,,
...,...,...,...
902,TCGA-O2-A52S-01,,
903,TCGA-O2-A52V-01,,
904,TCGA-O2-A52W-01,,
905,TCGA-O2-A5IB-01,,


Unnamed: 0,train,val,test
TCGA-05-4244-01,True,False,False
TCGA-05-4249-01,True,False,False
TCGA-05-4250-01,True,False,False
TCGA-05-4382-01,True,False,False
TCGA-05-4384-01,True,False,False
...,...,...,...
TCGA-O2-A52S-01,True,False,False
TCGA-O2-A52V-01,True,False,False
TCGA-O2-A52W-01,True,False,False
TCGA-O2-A5IB-01,True,False,False


In [86]:
import sys
import random
import numpy as np
from collections import Counter
import torch

class Clip_dataset(Dataset):
    def __init__(self,
        split_csv_path = '/shared/js.yun/data/CLAM_data/clip_data/TCGA-lung-splits_5-frac_1_0_0-seed0/splits_0.csv',
        genomics_csv_path = '/shared/js.yun/data/CLAM_data/genomics_data/TCGA-lung-LUAD+LUSC-selected_2847_zscore.csv',        
        wsi_csv_path = '/shared/j.jang/pathai/CLAM/dataset_csv/TCGA-lung-LUAD+LUSC-TMB-pan_cancer-323.csv',
        wsi_feature_dir = '/shared/j.jang/pathai/data/TCGA-lung-x256-features-dino-from-pretrained-vitb-img224/',
        split_key = 'train',
        seed = 1,
        ):
        self.wsi_feature_dir = wsi_feature_dir
        self.seed = seed
        np.random.seed(self.seed)

        # genomics dataset
        self.selected_columns = set(pd.read_csv(split_csv_path)[split_key])                      # training set에 있는 환자 set
        genomics_data = pd.read_csv(genomics_csv_path)
        self.genomics_data = genomics_data.loc[:, genomics_data.columns.isin(self.selected_columns)]
        self.length = len(self.selected_columns)

        # WSI dataset
        slide_data = pd.read_csv(wsi_csv_path)[['case_id', 'slide_id']]
        slide_data['patient'] = slide_data['slide_id'].str.split('-').apply(lambda x: '-'.join(x[:3] + [x[3][:2]]))
        self.slide_data = slide_data[slide_data['patient'].isin(self.selected_columns)]

        # for i in range(1000):
        #     indices = slide_data.index[slide_data['patient'] == list(selected_columns)[i]].tolist()
        #     if len(indices) > 1:
        #         print(indices)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # genomics data 불러옴
        genomics = self.genomics_data.iloc[:,idx].to_numpy()

        indices = self.slide_data.index[self.slide_data['patient'] == list(self.selected_columns)[idx]].tolist()
        if len(indices) > 1:
            index = random.choice(indices)
        else:
            index = indices[0]
        case_id = self.slide_data['case_id'][index]
        slide_id = self.slide_data['slide_id'][index]
        full_path = os.path.join(self.wsi_feature_dir, 'pt_files','{}/{}.pt'.format(case_id, slide_id)).replace('.svs', '') # 종성님 CLAM - 이 방식으로 해야됨
        features = torch.load(full_path)

        return genomics, features



train_dataset = Clip_dataset(split_key='train')

for genomics, features in train_dataset:
    print(genomics.shape, features.size())


print('exit')
sys.exit(0)

class Multi_Task_Dataset(Dataset):
    '''
    230926
        output으로 subtype이랑 tmb label이랑 같이 출력
        dataset에서 split까지 나누도록
        regression도 가능하도록
        LUSC, LUAD 중 하나만도 가능하도록

    Args:
        one_subtype: Setting 'LUSC' or 'LUAD' will create a dataset with a single subtype.
        label_column: The label column to use. [label, TMB (nonsynonymous)]
                      Default: label (TMB_low, TMB_high) -> TMB (nonsynonymous)로 변경하여 돌려볼 것
    '''
    def __init__(
            self, 
            csv_path, 
            split_path, 
            data_dir,
            shuffle,        # 여기서 안하고 loader하면 됨
            seed, 
            print_info,
            label_dict, 
            label_dict2, 
            use_h5=True, 
            target_subtype=['LUSC', 'LUAD'],
            label_column='TMB (nonsynonymous)',
            tmb_threshold = [0.5,0.5],
            regression = False,
            split_key = 'train',
            balance = [1,0,1],
            genomics = None
    ):
        
        self.slide_data = pd.read_csv(csv_path)[['case_id', 'slide_id', 'label', 'Subtype', 'Mutation Count', 'TMB (nonsynonymous)']]
        self.slide_data.reset_index(drop=True, inplace=True)
        # target_subtype 리스트에 있는 데이터만 필터링
        self.slide_data = self.slide_data[self.slide_data['Subtype'].isin(target_subtype)]
        self.split_path = split_path        # target_subtype split 데이터만
        self.data_dir = data_dir
        self.label_dict = label_dict        # tmb label
        self.label_dict2 = label_dict2      # subtype label
        self.use_h5 = use_h5
        self.num_classes = len(label_dict)
        self.seed = seed
        self.label_column = label_column
        self.regression = regression
        self.split_key = split_key
        self.balance = balance

        if genomics:
            self.genomics_data = pd.read_csv(genomics)

        # self.slide_data 중에서 split_key(train, val, test 중 하나)만 고름
        all_splits = pd.read_csv(split_path, dtype=self.slide_data['slide_id'].dtype)  # Without "dtype=self.slide_data['slide_id'].dtype", read_csv() will convert all-number columns to a numerical type. Even if we convert numerical columns back to objects later, we may lose zero-padding in the process; the columns must be correctly read in from the get-go. When we compare the individual train/val/test columns to self.slide_data['slide_id'] in the get_split_from_df() method, we cannot compare objects (strings) to numbers or even to incorrectly zero-padded objects/strings. An example of this breaking is shown in https://github.com/andrew-weisman/clam_analysis/tree/main/datatype_comparison_bug-2021-12-01.
        split = all_splits[split_key]
        split = split.dropna().reset_index(drop=True)
        mask = (self.slide_data['case_id']+'/'+self.slide_data['slide_id']).isin(split.tolist())      # clam  
        self.slide_data = self.slide_data[mask].reset_index(drop=True)
        self.length = len(self.slide_data)
        # exit()

        # regression이 아니라면 slide_data[label_column]을 0,1로 변경
        if not regression and tmb_threshold is not None:
            self.tmb_threshold = tmb_threshold
            # 각 target_subtype에 대해 label_col 찾아서 label_col 값을 0,1 label로 바꿈. tmb_low: 0, tmb_high: 1 
            for subtype, thr in zip(target_subtype, self.tmb_threshold):
                mask = self.slide_data['Subtype'] == subtype
                self.slide_data.loc[mask, label_column] = self.slide_data.loc[mask, label_column].apply(lambda x: label_dict['TMB_high'] if x >= thr else label_dict['TMB_low'])
            self.slide_data[label_column] = self.slide_data[label_column].astype(int) 
        # Subtype column을 string에서 int로 변경
        self.slide_data['Subtype'] = self.slide_data['Subtype'].map(self.label_dict2)

        # label 변환을 __get_item__에서 할거라 사용 불가
        self.cls_ids_prep()
        if print_info:
            self.summarize()

    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        case_id = self.slide_data['case_id'][idx]
        slide_id = self.slide_data['slide_id'][idx]
        label = self.slide_data[self.label_column][idx]
        label2 = self.slide_data['Subtype'][idx]
        if type(self.data_dir) == dict:
            source = self.slide_data['source'][idx]
            data_dir = self.data_dir[source]
        else:
            data_dir = self.data_dir

        if not self.use_h5:
            if self.data_dir:       # 이게 없을 수가 있나??
                # full_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id))	# 이게 원래 CLAM 인듯?
                full_path = os.path.join(data_dir, 'pt_files','{}/{}.pt'.format(case_id, slide_id)).replace('.svs', '') # 종성님 CLAM - 이 방식으로 해야됨
                features = torch.load(full_path)
                # print(label, label2)
                return features, label, label2
            
            else:
                return slide_id, label

        else:	# patch coordinates까지 포함되어 있음. coordinate 정보까지 이용하면 도움될 것 같은데 사용 안하고 있음
            full_path = os.path.join(data_dir,'h5_files','{}.h5'.format(slide_id))
            with h5py.File(full_path,'r') as hdf5_file:
                features = hdf5_file['features'][:]
                coords = hdf5_file['coords'][:]

            features = torch.from_numpy(features)
            return features, label, coords

    def cls_ids_prep(self):
        # store ids corresponding each class at the slide level
        self.slide_subtype_cls_ids = [[] for i in range(len(self.label_dict2))]
        for i in range(len(self.label_dict2)):
            self.slide_subtype_cls_ids[i] = np.where(self.slide_data['Subtype'] == i)[0]
        # If not regression, store ids corresponding each tmb class at the slide level
        if not self.regression and self.label_column in ['TMB (nonsynonymous)', 'Mutation Count']:
            self.slide_cls_ids = [[] for i in range(len(self.label_dict))]
            for i in range(len(self.label_dict)):
                # self.slide_cls_ids[i] = np.where(self.slide_data[self.label_column] == list(self.label_dict.keys())[i])[0]
                self.slide_cls_ids[i] = np.where(self.slide_data[self.label_column] == i)[0]
        else:
            # weighted loss 구할 때 self.slide_cls_ids 필요
            # regression이더라도 subtype별로 weighted loss하고자 하면 필요할까봐 일단 변수 만들어 놓음
            self.slide_cls_ids = self.slide_subtype_cls_ids
            # print(self.slide_cls_ids)
            

    def summarize(self):
        if self.balance[1]:
            for i in range(len(self.label_dict2)):
                print(f'{self.split_key} Slide-LVL; Number of samples registered in subtype class {i}: {self.slide_subtype_cls_ids[i].shape[0]}')

        if self.balance[0] and not self.regression and self.label_column in ['TMB (nonsynonymous)', 'Mutation Count']:
            for i in range(len(self.label_dict)): 
                print(f'{self.split_key} Slide-LVL; Number of samples registered in tmb class {i}: {self.slide_cls_ids[i].shape[0]}')

    def getlabel(self, ids):
        '''
        이렇게 하면 label_column에서 가져오는데 regression에서는 int로 안바꿨으므로 문제 생김
        regression에서는 굳이 이 함수를 부르게 되면 tmb에 대한 label이 아니라 
        '''
        return self.slide_data[self.label_column][ids]

(2847,) torch.Size([44, 256, 768])
(2847,) torch.Size([89, 256, 768])
(2847,) torch.Size([36, 256, 768])
(2847,) torch.Size([130, 256, 768])
(2847,) torch.Size([155, 256, 768])
(2847,) torch.Size([310, 256, 768])
(2847,) torch.Size([168, 256, 768])
(2847,) torch.Size([59, 256, 768])
(2847,) torch.Size([28, 256, 768])
(2847,) torch.Size([38, 256, 768])
(2847,) torch.Size([84, 256, 768])
(2847,) torch.Size([55, 256, 768])
(2847,) torch.Size([335, 256, 768])
(2847,) torch.Size([303, 256, 768])
(2847,) torch.Size([243, 256, 768])
(2847,) torch.Size([76, 256, 768])
(2847,) torch.Size([64, 256, 768])
(2847,) torch.Size([62, 256, 768])


KeyboardInterrupt: 

In [1]:
import torch
import torch.nn as nn

a = torch.randn((5,10))
b = torch.arange(len(a))
print(a)
print(b)
print(a.type())
print(b.type())
loss = nn.CrossEntropyLoss()

print(loss(a,b))

tensor([[-0.1730, -0.5198, -0.7209, -0.5456,  0.5410, -1.3605,  1.4543,  0.6625,
          0.1969, -0.6286],
        [ 1.0279,  0.7707,  0.4302, -0.4608,  0.5533, -1.4768,  1.3008, -0.6745,
          0.5119,  0.2290],
        [-1.8112, -0.7118,  1.2214, -0.2534, -0.2556,  0.9658, -0.5006, -0.2799,
          0.4337, -0.0267],
        [ 0.8516, -0.1119,  0.7187,  0.2937,  1.9620, -1.1525, -0.2221, -0.0307,
          1.7035,  0.6476],
        [ 1.5664,  0.8207, -2.1738, -0.3423, -0.1067, -0.6570,  0.1420,  0.1107,
          0.6497, -0.3075]])
tensor([0, 1, 2, 3, 4])
torch.FloatTensor
torch.LongTensor
tensor(2.3189)


  from .autonotebook import tqdm as notebook_tqdm


In [10]:
for i in range(-4):
    print(i)