In [1]:
import numpy as np
import torch
import HospitalGenerators as hg

In [2]:
(2, 0, 3, 4)[:2]

(2, 0)

In [12]:
class RealisticHospitalNoTissue:
    def __init__(self, k, with_tissue=False):
        # [O, A, B, AB] Probabilities
        self.bloodtypes = ['', 'A', 'B', 'AB']
        self.bloodtype_probs = [0.04, 0.34, 0.14, 0.48]
        self.bloodtype_cutoffs = [0.48, 0.82, 0.96, 1.0]
        
        self.patients = k
        self.tissue = with_tissue
        
    def generate(self, batch_size):
        num_types = len(self.bloodtype_probs)
        
        # patient type, donor_type, patient tissue_type, donor tissue type
        bids = np.zeros((batch_size, num_types, num_types))
        
        for i in range(batch_size):
            for j in range(self.patients):
                bids[(i, ) + self.generate_pair()] += 1.0
            
        return bids.reshape((batch_size, -1))
    
    def generate_pair(self):
        blood_idx = np.random.choice(len(self.bloodtype_probs), p=self.bloodtype_probs, size=2)
        p_blood_idx = blood_idx[0]
        d_blood_idx = blood_idx[1]

        ret_tuple = (p_blood_idx, d_blood_idx)
                
        return ret_tuple
    
    def is_blood_compat(self, p_idx, d_idx):
        p_type = self.bloodtypes[p_idx]
        d_type = self.bloodtypes[d_idx]
        for ch in d_type:
            if ch not in p_type:
                return False
        return True

In [36]:
class RealisticHospital:
    def __init__(self, k, with_tissue=False):
        # [O, A, B, AB] Probabilities
        self.bloodtypes = ['', 'A', 'B', 'AB']
        self.bloodtype_probs = [0.48, 0.34, 0.14, 0.04]
        self.bloodtype_cutoffs = [0.48, 0.82, 0.96, 1.0]

        # Tissue incompatibilitiy probabilities
        self.tissue_probs = [.7, .2, .1]
        self.tissue_cutoffs = [.7, .9, 1.0]

        # Tissue incompat values
        self.tissue_vals = [0.05, 0.45, 0.9]
        self.patients = k
        self.tissue = with_tissue

    def generate(self, batch_size):
        num_types = len(self.bloodtype_probs)
        num_tissue = len(self.tissue_probs)
        
        if self.tissue:
            # patient type, donor_type, patient tissue_type, donor tissue type
            bids = np.zeros((batch_size, num_types, num_types, num_tissue, num_tissue))
            for i in range(batch_size):
                for j in range(self.patients):
                    bids[(i, ) + self.generate_pair()] += 1.0
        else:
            # patient type, donor_type,
            bids = np.zeros((batch_size, num_types, num_types))
            for i in range(batch_size):
                for j in range(self.patients):
                    bids[(i, ) + self.generate_pair()[:2]] += 1.0
            
        return bids.reshape((batch_size, -1))
    
    def generate_pair(self):
        incompat = False
        ret_tuple = None
        while (not incompat):
            tissue_idx = np.random.choice(len(self.tissue_probs), p=self.tissue_probs, size=2)
            p_tissue_idx = tissue_idx[0]
            d_tissue_idx = tissue_idx[1]

            blood_idx = np.random.choice(len(self.bloodtype_probs), p=self.bloodtype_probs, size=2)
            p_blood_idx = blood_idx[0]
            d_blood_idx = blood_idx[1]
            
            if (self.tissue_vals[d_tissue_idx] < self.tissue_vals[p_tissue_idx] 
                or not self.is_blood_compat(p_blood_idx, d_blood_idx)):
                incompat = True
                ret_tuple = (p_blood_idx, d_blood_idx, p_tissue_idx, d_tissue_idx)
        return ret_tuple
    
    def is_blood_compat(self, p_idx, d_idx):
        p_type = self.bloodtypes[p_idx]
        d_type = self.bloodtypes[d_idx]
        for ch in d_type:
            if ch not in p_type:
                return False
        return True
            
        
        

In [43]:
np.random.seed(0)
test = RealisticHospital(100)

In [47]:
test.generate(10)[0, :].reshape((4, 4))

array([[11., 26., 16.,  2.],
       [12.,  8., 10.,  0.],
       [ 1.,  9.,  1.,  2.],
       [ 1.,  1.,  0.,  0.]])

In [13]:
hos_list = [RealisticHospitalNoTissue(200), RealisticHospitalNoTissue(100), RealisticHospitalNoTissue(10)]
gen = hg.ReportGenerator(hos_list, (3, 16))

In [14]:
next(gen.generate_report(10))

array([[[ 0.,  0.,  1.,  3.,  2., 18., 13., 29.,  0., 14.,  4., 14.,
          4., 29., 14., 55.],
        [ 0.,  1.,  0.,  2.,  0., 18.,  5., 12.,  0.,  9.,  2.,  7.,
          2., 17.,  6., 19.],
        [ 0.,  0.,  0.,  0.,  0.,  1.,  2.,  1.,  0.,  0.,  1.,  0.,
          1.,  2.,  1.,  1.]],

       [[ 2.,  4.,  1.,  2.,  4., 28.,  7., 36.,  1., 13.,  3.,  9.,
          1., 36., 13., 40.],
        [ 0.,  1.,  1.,  0.,  2.,  9.,  9., 11.,  1.,  8.,  1.,  5.,
          2., 19.,  4., 27.],
        [ 0.,  0.,  0.,  0.,  0.,  2.,  2.,  1.,  0.,  0.,  1.,  1.,
          0.,  1.,  1.,  1.]],

       [[ 0.,  1.,  0.,  2.,  2., 28.,  9., 28.,  1., 11.,  4.,  7.,
         10., 31., 14., 52.],
        [ 1.,  1.,  0.,  5.,  0., 15.,  2., 15.,  0.,  5.,  0.,  9.,
          4., 12.,  6., 25.],
        [ 0.,  0.,  0.,  3.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,
          0.,  1.,  0.,  4.]],

       [[ 0.,  1.,  0.,  3.,  3., 18.,  7., 31.,  0.,  8.,  5., 15.,
          7., 38., 21., 43.],
    