In [2]:
import pandas as pd
import numpy as np
import os
from PIL import Image

import torch
import torch.nn as nn
from torchvision.io import read_image, ImageReadMode
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
import torchvision

from torch.utils.data import Dataset, DataLoader

In [3]:
patch_path = '/project/GutIntelligenceLab/ss4yd/gtex_data/process_path_level1/'
dirs = os.listdir(patch_path)
dir_patch_dict = {}
for pid in dirs:
    patches = [os.path.join(patch_path, pid, x) for x in os.listdir(os.path.join(patch_path, pid))]
    dir_patch_dict[pid] = patches

In [5]:
patch_paths = [x for xs in dir_patch_dict.values() for x in xs]

df = pd.DataFrame(patch_paths, columns=['patch_paths'])

df['pid'] = df['patch_paths'].apply(lambda x: x.split('/')[-2])

split = pd.read_csv('../../csv/train_test_val_split.csv')

train = split[split['dtype']=='train'].sample(frac=0.2, random_state=1)
val = split[split['dtype']=='val'].sample(frac=0.2, random_state=1)
test = split[split['dtype']=='test'].sample(frac=0.2, random_state=1)

split_small = pd.concat([train, val, test])
print("length smaller split:{}".format(len(split_small)))

df = df.merge(split_small, on='pid')
print("Number of patches: {}".format(len(df)))

length smaller split:1546
Number of patches: 144926


In [6]:
df.dtype.value_counts()

train    108920
val       22344
test      13662
Name: dtype, dtype: int64

In [7]:
df.columns

Index(['patch_paths', 'pid', 'svs_paths', 'dtype'], dtype='object')

In [9]:
df.to_csv('../../csv/working_df.csv', index=False)

In [10]:
class GetRepsDataset(Dataset):
    
    def __init__(self, df, dtype, transform=None):
        self.df = df
        self.dtype = dtype
        self.transform = transform
        self.typ_df = df[df['dtype']==dtype]
        
    def __len__(self):
        return len(self.typ_df)
    
    def __getitem__(self, idx):
        img_path = self.typ_df.patch_paths.iloc[idx]
        
        image = read_image(img_path, mode=ImageReadMode.RGB)
        
        if self.transform:
            image = self.transform(image)
        return image

In [12]:
normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

transform=transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ConvertImageDtype(torch.float),
    normalize,
        ])

train_dataset = GetRepsDataset(df, 'train', transform)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=64, shuffle=True, num_workers=1, pin_memory=True)

train_loader.dataset.__getitem__(1).shape

torch.Size([3, 224, 224])

In [None]:
# test
Image.open(dir_patch_dict['GTEX-13SLW-2526'][0])

In [None]:
class GetRepsDataset_old(Dataset):
    
    def __init__(self, dir_patch_dict, transform=None):
        self.dir_patch_dict = dir_patch_dict
        self.transform = transform
        self.patch_list = [x for xs in self.dir_patch_dict.values() for x in xs]
        
    def __len__(self):
        return len(self.patch_list)
    
    def __getitem__(self, idx):
        img_path = self.patch_list[idx]
        
        image = read_image(img_path, mode=ImageReadMode.RGB)
        
        if self.transform:
            image = self.transform(image)
            
        
        return image

len([x for xs in dir_patch_dict.values() for x in xs])

# test
#https://pytorch.org/vision/stable/auto_examples/plot_scripted_tensor_transforms.html#sphx-glr-auto-examples-plot-scripted-tensor-transforms-py
normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

transform=transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ConvertImageDtype(torch.float),
    normalize,
        ])

dataset = GetRepsDataset_old(dir_patch_dict, transform)
train_loader = torch.utils.data.DataLoader(dataset,batch_size=64, shuffle=True, num_workers=1, pin_memory=True)

train_loader.dataset.__getitem__(1).shape

## Make patch pytorch dataset

In [11]:
df = pd.read_csv('../csv/working_train_reps.csv')
df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,503,504,505,506,507,508,509,510,511,patch_paths
0,0.06503,0.477221,0.097286,0.398688,0.113824,0.871815,0.001729,0.115488,5.24283,0.028892,...,0.469636,0.586778,3.896468,1.705638,0.175322,1.444745,0.349379,0.0,0.116404,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...
1,0.151594,0.351325,0.41649,0.9263,0.196956,0.648142,0.016504,0.076212,5.89944,0.07563,...,0.38138,0.852429,4.49505,1.216366,0.532649,2.238657,0.577874,0.0,0.111423,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...
2,0.140285,0.327091,0.314161,0.668894,0.578952,1.346328,0.004096,0.349218,5.604905,0.055081,...,0.730045,1.527315,3.705997,1.914241,0.567546,1.481182,0.699235,0.0,0.035434,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...
3,0.183119,0.143039,0.528455,0.441182,0.427874,1.163569,0.037215,0.474609,5.949831,1.086014,...,0.283806,1.118277,3.030128,0.862544,0.454116,2.124792,1.304261,0.070249,0.669173,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...
4,1.069007,1.11385,0.251766,0.457149,0.001967,1.410211,0.103539,0.066857,6.737306,0.317485,...,0.683167,0.872111,3.109493,1.700155,0.111501,2.354243,1.007943,0.029701,0.334032,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...


In [13]:
import faiss

In [12]:
X = df[[str(x) for x in list(range(512))]]
X = np.ascontiguousarray(X)
X = X.astype('float32')

In [14]:
ncentroids = 8
niter = 300
verbose = False
d = X.shape[1]
kmeans = faiss.Kmeans(d, ncentroids, niter=niter, verbose=verbose, nredo=20)
kmeans.train(X)

254352.203125

In [15]:
D, I = kmeans.index.search(X, 1)

In [16]:
df['cluster_assignment']=I

In [17]:
df['pid'] = df['patch_paths'].apply(lambda x: x.split('/')[-2])
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,505,506,507,508,509,510,511,patch_paths,cluster_assignment,pid
0,0.065030,0.477221,0.097286,0.398688,0.113824,0.871815,0.001729,0.115488,5.242830,0.028892,...,3.896468,1.705638,0.175322,1.444745,0.349379,0.000000,0.116404,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,0,GTEX-R55E-1726
1,0.151594,0.351325,0.416490,0.926300,0.196956,0.648142,0.016504,0.076212,5.899440,0.075630,...,4.495050,1.216366,0.532649,2.238657,0.577874,0.000000,0.111423,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,0,GTEX-R55E-1726
2,0.140285,0.327091,0.314161,0.668894,0.578952,1.346328,0.004096,0.349218,5.604905,0.055081,...,3.705997,1.914241,0.567546,1.481182,0.699235,0.000000,0.035434,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,1,GTEX-R55E-1726
3,0.183119,0.143039,0.528455,0.441182,0.427874,1.163569,0.037215,0.474609,5.949831,1.086014,...,3.030128,0.862544,0.454116,2.124792,1.304261,0.070249,0.669173,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,1,GTEX-R55E-1726
4,1.069007,1.113850,0.251766,0.457149,0.001967,1.410211,0.103539,0.066857,6.737306,0.317485,...,3.109493,1.700155,0.111501,2.354243,1.007943,0.029701,0.334032,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,1,GTEX-R55E-1726
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
108915,3.720138,0.284489,1.545707,0.724670,0.319680,0.341031,0.015502,0.452460,6.731671,1.219829,...,2.004468,0.626353,0.000000,0.791891,1.904577,0.246829,0.122483,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,2,GTEX-14753-1826
108916,3.661222,0.321211,0.231825,0.906495,0.686602,0.667976,0.030586,0.545029,4.750716,1.579729,...,1.359481,1.435388,0.000000,0.766770,1.292814,0.247202,0.055415,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,2,GTEX-14753-1826
108917,3.959967,0.375449,0.883864,0.232680,0.295611,0.926879,0.000000,0.586061,6.204812,2.058973,...,1.312047,0.298206,0.000000,0.319243,1.694219,0.288718,0.278142,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,2,GTEX-14753-1826
108918,3.730825,0.342591,0.730703,0.729921,0.633500,0.546797,0.007914,1.137852,5.849081,1.690261,...,1.648091,0.708258,0.000000,0.418720,1.424751,0.324934,0.143382,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,2,GTEX-14753-1826


In [18]:
df.pid.value_counts()

GTEX-15UF7-0426    747
GTEX-1IDJV-0826    747
GTEX-1CB4J-0726    669
GTEX-1N2DW-2626    577
GTEX-14PKV-0826    574
                  ... 
GTEX-1LSNL-2826      1
GTEX-14PJ6-1226      1
GTEX-1JJ6O-0226      1
GTEX-X15G-0326       1
GTEX-ZG7Y-2426       1
Name: pid, Length: 801, dtype: int64

In [19]:
df.cluster_assignment.value_counts()

6    24706
0    17891
2    14198
3    13094
5    11356
1    11342
4     9553
7     6780
Name: cluster_assignment, dtype: int64

In [20]:
df['reps']=df[[str(x) for x in range(512)]].values.tolist()

In [21]:
df=df.drop(columns=[str(x) for x in range(512)])

In [22]:
df

Unnamed: 0,patch_paths,cluster_assignment,pid,reps
0,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,0,GTEX-R55E-1726,"[0.06503012, 0.4772212, 0.097285874, 0.3986875..."
1,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,0,GTEX-R55E-1726,"[0.15159419, 0.3513252, 0.41649047, 0.9262995,..."
2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,1,GTEX-R55E-1726,"[0.14028509, 0.32709098, 0.31416133, 0.6688937..."
3,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,1,GTEX-R55E-1726,"[0.18311931, 0.14303893, 0.52845514, 0.4411823..."
4,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,1,GTEX-R55E-1726,"[1.069007, 1.1138498, 0.2517664, 0.45714885, 0..."
...,...,...,...,...
108915,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,2,GTEX-14753-1826,"[3.7201376, 0.28448898, 1.5457067, 0.7246698, ..."
108916,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,2,GTEX-14753-1826,"[3.661222, 0.32121077, 0.23182462, 0.9064946, ..."
108917,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,2,GTEX-14753-1826,"[3.959967, 0.3754492, 0.8838644, 0.23268037, 0..."
108918,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,2,GTEX-14753-1826,"[3.730825, 0.34259093, 0.7307033, 0.72992116, ..."


In [23]:
df = df.set_index(['pid', 'cluster_assignment'])

In [24]:
df

Unnamed: 0_level_0,Unnamed: 1_level_0,patch_paths,reps
pid,cluster_assignment,Unnamed: 2_level_1,Unnamed: 3_level_1
GTEX-R55E-1726,0,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[0.06503012, 0.4772212, 0.097285874, 0.3986875..."
GTEX-R55E-1726,0,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[0.15159419, 0.3513252, 0.41649047, 0.9262995,..."
GTEX-R55E-1726,1,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[0.14028509, 0.32709098, 0.31416133, 0.6688937..."
GTEX-R55E-1726,1,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[0.18311931, 0.14303893, 0.52845514, 0.4411823..."
GTEX-R55E-1726,1,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[1.069007, 1.1138498, 0.2517664, 0.45714885, 0..."
...,...,...,...
GTEX-14753-1826,2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[3.7201376, 0.28448898, 1.5457067, 0.7246698, ..."
GTEX-14753-1826,2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[3.661222, 0.32121077, 0.23182462, 0.9064946, ..."
GTEX-14753-1826,2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[3.959967, 0.3754492, 0.8838644, 0.23268037, 0..."
GTEX-14753-1826,2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[3.730825, 0.34259093, 0.7307033, 0.72992116, ..."


In [25]:
df.loc['GTEX-14753-1826'].loc[2]

Unnamed: 0_level_0,patch_paths,reps
cluster_assignment,Unnamed: 1_level_1,Unnamed: 2_level_1
2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[3.2250712, 0.8016, 0.836539, 0.78275037, 0.11..."
2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[2.3976767, 0.18086497, 0.15328069, 0.26081064..."
2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[2.9110312, 0.1249348, 0.3478021, 0.40884006, ..."
2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[4.0418515, 0.10833909, 0.7948832, 0.7046966, ..."
2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[3.7213733, 0.41220257, 0.35996193, 2.1218994,..."
...,...,...
2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[3.7201376, 0.28448898, 1.5457067, 0.7246698, ..."
2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[3.661222, 0.32121077, 0.23182462, 0.9064946, ..."
2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[3.959967, 0.3754492, 0.8838644, 0.23268037, 0..."
2,/project/GutIntelligenceLab/ss4yd/gtex_data/pr...,"[3.730825, 0.34259093, 0.7307033, 0.72992116, ..."


In [48]:
pid_batch_size=8
img_per_pid=8
num_cluster=8

In [35]:
unique_pids=pd.unique(df.index.levels[0])
batches = [unique_pids[i:i+pid_batch_size] for i in pid_batch_size*np.arange(0, len(unique_pids)//pid_batch_size+1, 1)]

In [102]:
img_list=[]
for pids in batches:
    for pid in pids:
        if len(df.loc[pid]) < img_per_pid:
            img_list+=list(df.loc[pid]['patch_paths'])
        else:
            for c in range(num_cluster):
                if c in pd.unique(df.loc[pid].index):
                    img_list+=list(df.loc[(df.index.get_level_values('pid')==pid) & (df.index.get_level_values('cluster_assignment')==c),'patch_paths'].sample(1))

In [138]:
def return_batch_img_list(df, pids, num_cluster=8):
    img_list=[]
    for pid in pids:
        if len(df.loc[pid]) < img_per_pid:
            img_list+=list(df.loc[pid]['patch_paths'])
        elif len(pd.unique(df.loc[pid].index))==num_cluster:
            for c in range(num_cluster):
                if c in pd.unique(df.loc[pid].index):
                    img_list+=list(df.loc[(df.index.get_level_values('pid')==pid) & \
                                          (df.index.get_level_values('cluster_assignment')==c),\
                                          'patch_paths'].sample(1))
        else:
            img_list+=list(df.loc[(df.index.get_level_values('pid')==pid),\
                                          'patch_paths'].sample(num_cluster))
            
    
    return img_list

In [139]:
return_batch_img_list(df, batches[0])

['/project/GutIntelligenceLab/ss4yd/gtex_data/process_path_level1/GTEX-1117F-0526/tile_1_level1_4324-16409-5348-17433.png',
 '/project/GutIntelligenceLab/ss4yd/gtex_data/process_path_level1/GTEX-1117F-0526/tile_3_level1_5348-15385-6372-16409.png',
 '/project/GutIntelligenceLab/ss4yd/gtex_data/process_path_level1/GTEX-1117F-0526/tile_2_level1_5348-14361-6372-15385.png',
 '/project/GutIntelligenceLab/ss4yd/gtex_data/process_path_level1/GTEX-1117F-0526/tile_0_level1_4324-14361-5348-15385.png',
 '/project/GutIntelligenceLab/ss4yd/gtex_data/process_path_level1/GTEX-1117F-1226/tile_117_level1_16197-26981-17221-28005.png',
 '/project/GutIntelligenceLab/ss4yd/gtex_data/process_path_level1/GTEX-1117F-1226/tile_48_level1_11076-19813-12100-20837.png',
 '/project/GutIntelligenceLab/ss4yd/gtex_data/process_path_level1/GTEX-1117F-1226/tile_58_level1_12100-18789-13124-19813.png',
 '/project/GutIntelligenceLab/ss4yd/gtex_data/process_path_level1/GTEX-1117F-1226/tile_45_level1_11076-16741-12100-17765.p

In [136]:
len(return_batch_img_list(batches[-1]))

8

In [146]:
df.loc[batches[-1][0]].index.value_counts()

3    95
5    65
1     7
6     3
Name: cluster_assignment, dtype: int64