In [70]:
import torch
from torch.nn.functional import one_hot
from torch.utils.data import Dataset, ConcatDataset
from torchvision.transforms import Compose, Resize, RandomCrop
import numpy as np
from os import listdir
from os.path import exists
from numpy import zeros_like
from glob import glob
from PIL import Image
from torchvision.transforms.functional import InterpolationMode

class AdaptiveResize():
    def __init__(self, size):
        self.int_resize = Resize(size, InterpolationMode.NEAREST)
        self.float_resize = Resize(size, InterpolationMode.BILINEAR)

    def __call__(self, tensor):
        if tensor.dtype == torch.long:
            return self.int_resize(tensor)
        else:
            return self.float_resize(tensor)

class StatefulRandomCrop():
    def __init__(self, size):
        self.random_crop = RandomCrop(size)
        self.state = None

    def __call__(self, tensor):
        if self.state == None:
            self.state = torch.random.get_rng_state()
        else:
            torch.random.set_rng_state(self.state)
            self.state = None
        return self.random_crop(tensor)

def none_transform(x):
    return x

def prepare_cross_validation_datasets(root_dir, cross_validation_index=0, training_transform=none_transform, validation_transform=none_transform):
    
    surgery_types = listdir(f"{root_dir}/Training")
    surgery_indices_per_type = [sorted(map(int, listdir(f"{root_dir}/Training/{surgery_type}"))) for surgery_type in surgery_types]

    training_surgeries = []
    validation_surgeries = []

    for surgery_type, indicies in zip(surgery_types, surgery_indices_per_type):
        validation_surgeries += [f"{root_dir}/Training/{surgery_type}/{indicies.pop(cross_validation_index)}"]
        training_surgeries += [f"{root_dir}/Training/{surgery_type}/{index}" for index in indicies]

    training_dataset = ConcatDataset([RobustDataset(f"{surgery_path}/*", training_transform) for surgery_path in training_surgeries])
    validation_dataset = ConcatDataset([RobustDataset(f"{surgery_path}/*", validation_transform) for surgery_path in validation_surgeries])

    return training_dataset, validation_dataset

class RobustDataset(Dataset):

    def __init__(self, sample_path_glob, common_transform=none_transform, img_transform=none_transform, seg_transform=none_transform):

        self.sample_paths = glob(sample_path_glob)

        self.img_transform = Compose([common_transform, img_transform])
        self.seg_transform = Compose([common_transform, seg_transform])

    def __len__(self):
        return len(self.sample_paths)

    def __getitem__(self, index):

        sample_path = self.sample_paths[index]

        img_path = f"{sample_path}/raw.png"
        new_string = img_path.split("/",20)
        name=new_string[6]+'/'+new_string[7]+'/'+new_string[8]
        seg_path = f"{sample_path}/instrument_instances.png"

        img = np.array(Image.open(img_path))
        seg = np.array(Image.open(seg_path)) if exists(seg_path) else zeros_like(img[..., 0])
        
        size=img.shape

        img = torch.from_numpy(img).permute(2, 0, 1) / 255
        seg = torch.from_numpy(seg).unsqueeze(0).long()

        img = self.img_transform(img)
        seg = self.seg_transform(seg)
        
        seg[seg != 0] = 1
        seg = seg.squeeze()
        seg = one_hot(seg.squeeze(), num_classes=2).permute(2, 0, 1).float()

        return img, seg, np.array(size),name

In [71]:
from torchvision.transforms import Compose, CenterCrop

training_transform = Compose([AdaptiveResize((270, 480)), StatefulRandomCrop((256, 448))])
validation_transform = Compose([AdaptiveResize((270, 480)), CenterCrop((256, 448))])
training_dataset, validation_dataset = prepare_cross_validation_datasets("/redresearch1/shared/mwei/robustmislite", 1, training_transform, validation_transform)

In [72]:
from torch.utils.data import DataLoader

training_dataloader = DataLoader(training_dataset, batch_size=16, num_workers=16//2, pin_memory=True, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=16, num_workers=16//2, pin_memory=True)

In [None]:
import matplotlib.pyplot as plt
for batch_idx, (img,seg,size,name) in enumerate(training_dataloader):
    plt.figure()
    plt.imshow(seg[1][1].numpy())
    print(np.unique(seg[1][1]))
    plt.show()
    #print(name)

In [5]:
root_dir="/redresearch1/shared/mwei/robustmislite"
surgery_types=listdir(f"{root_dir}/Training")
surgery_indices_per_type = [sorted(map(int, listdir(f"{root_dir}/Training/{surgery_type}"))) for surgery_type in surgery_types]

In [6]:
surgery_indices_per_type

[[1, 2, 3, 4, 5, 8, 9, 10], [1, 2, 3, 6, 7, 8, 9, 10]]

In [8]:
training_surgeries = []
validation_surgeries = []
for surgery_type, indicies in zip(surgery_types, surgery_indices_per_type):
        validation_surgeries += [f"{root_dir}/Training/{surgery_type}/{indicies.pop(1)}"]
        training_surgeries += [f"{root_dir}/Training/{surgery_type}/{index}" for index in indicies]

In [60]:
img_temp.replace(' ', '')

'/Training/Rectalresection/10/99000/raw.png'

In [63]:
img_path=[]
label_path=[]
with open('val.txt', 'w') as f:
    for surgery_path in validation_surgeries:
        sample_all=glob(f"{surgery_path}/*")
        for sample in sample_all:
            img_temp=f"{sample[39:]}/raw.png"
            img_temp=img_temp.replace(' ', '')
            f.write(img_temp)
          #  img_path.append(img_temp)
            label_temp=f"{sample[39:]}/instrument_instances.png"
            label_temp=label_temp.replace(' ', '')
            f.write(' ')
            f.write(label_temp)
          #  label_path.append(label_temp)
            f.write('\n')

In [64]:
import numpy as np
import random
lines = open('train.txt').read().splitlines()
myline =np.random.choice(lines,1000,replace=False)
suline= np.setdiff1d(np.array(lines),np.array(myline))
with open('1000_train_supervised.txt', 'w') as f:
    for line in myline:
        f.write(line)
        f.write('\n')
with open('1000_train_unsupervised.txt', 'w') as f:  
    for line in suline:
        f.write(line)
        f.write('\n')

In [65]:
def get_voc_pallete(num_classes):
    n = num_classes
    pallete = [0]*(n*3)
    for j in range(0,n):
            lab = j
            pallete[j*3+0] = 0
            pallete[j*3+1] = 0
            pallete[j*3+2] = 0
            i = 0
            while (lab > 0):
                    pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i))
                    pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i))
                    pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i))
                    i = i + 1
                    lab >>= 3
    return pallete

In [78]:
torch.rand([1,1,1]).sub(0.5)

SyntaxError: invalid syntax (2804947235.py, line 1)