In [11]:
import os
import random
from loguru import logger
import json
save_dir = '/Users/keyi/Desktop/ds_11_25_multilabel_8tries'
train_dir_name = "per_sample_1_train"
pos_patches_per_file = 40
total_patches_per_file = 50
pos_samples_per_batch = 2
total_healthy_patches_per_patient = 2
neg_samples_per_batch = 1
split_number = 4
JSON_file = f'/Users/keyi/Desktop/DL_template/_assets/split_json/AutoPET_Cluster/dataset_split_{split_number}.json'
def collect_healthy_indices():
    healthy_indices = []
    with open(JSON_file, 'r') as f:
        data = json.load(f)
        training = data['TRAINING']
        for idx, entry in enumerate(training):
            if entry['DIAGNOSIS'] == "NEGATIVE":  # healthy
                healthy_indices.append(idx)
    logger.info(f"found {len(healthy_indices)} healthy patients")
    return healthy_indices

healthy_indices = collect_healthy_indices()
training_numbers = 708

def _generate_train_combinations():
    used_patches = set() 
    combinations_batch = []
    # file_list = os.listdir(os.path.join(save_dir,  train_dir_name))
    # file_list = sorted([f for f in file_list if f.endswith('train.pt')])
    # print(f"file_list: {file_list}")
    
    healthy_neg_samples = []
    for healthy_idx in  healthy_indices:
        patches_indices = [(healthy_idx, i) for i in range( total_patches_per_file)]
        selected_indices = random.sample(patches_indices, 10)  # 10
        healthy_neg_samples.extend(selected_indices)
        logger.info(f"healthy_neg_samples: {healthy_neg_samples}")
        logger.info(f"Fixed {len(selected_indices)} negative samples for healthy patient {healthy_idx}")
    
    # first step: generate positive and negative indices for each patient
    # for file_idx, _ in enumerate(file_list):
    for file_idx in range(training_numbers):
        if file_idx not in  healthy_indices: 
            pos_indices = [(file_idx, i) for i in range( pos_patches_per_file)] 
            neg_indices = [(file_idx, i) for i in range( pos_patches_per_file,  total_patches_per_file)] 
            
            
            available_pos = [idx for idx in pos_indices if idx not in used_patches]
            # 10 negative samples
            for _ in range(total_patches_per_file - pos_patches_per_file):
                available_pos = [idx for idx in pos_indices if idx not in used_patches]
                available_neg = [idx for idx in neg_indices if idx not in used_patches]
                
                if len(available_pos) >= pos_samples_per_batch and available_neg:
                    pos_samples = random.sample(available_pos, pos_samples_per_batch)
                    neg_sample = random.choice(available_neg)
                    
                    used_patches.update(pos_samples)
                    used_patches.add(neg_sample)
                    
                    combinations_batch.append({
                        'positive': tuple(pos_samples),
                        'negative': (neg_sample)
                    })
    
    while len(healthy_neg_samples) > 0:
        sick_patients = [idx for idx in range(training_numbers) if idx not in healthy_indices]
        
        # Check if we have enough positive samples available for any sick patient
        has_enough_samples = False
        available_sick_patients = []
        for patient in sick_patients:
            pos_indices = [(patient, i) for i in range(pos_patches_per_file)]
            available_pos = [idx for idx in pos_indices if idx not in used_patches]
            if len(available_pos) >= pos_samples_per_batch:
                available_sick_patients.append(patient)
                has_enough_samples = True
        
        # If no sick patient has enough samples, break the loop
        if not has_enough_samples:
            logger.warning(f"Stopping early: No sick patients have enough positive samples. "
                        f"Remaining healthy samples: {len(healthy_neg_samples)}")
            break
        
        # Choose from patients that have enough samples
        random_chosen_sick_patient = random.choice(available_sick_patients)
        random_chosen_healthy_neg_sample = healthy_neg_samples.pop()
        
        pos_indices = [(random_chosen_sick_patient, i) for i in range(pos_patches_per_file)]
        available_pos = [idx for idx in pos_indices if idx not in used_patches]
        
        pos_samples = random.sample(available_pos, pos_samples_per_batch)
        used_patches.update(pos_samples)
        combinations_batch.append({
            'positive': tuple(pos_samples),
            'negative': (random_chosen_healthy_neg_sample)
        })
                    
    print(f"Generated {len(combinations_batch)} combinations")
    with open(f'combinations_batch_split_{split_number}.json', 'w') as f:
        json.dump(combinations_batch, f, indent=4)
    return combinations_batch
c = _generate_train_combinations()
for i, combo in enumerate(c):

     print(f"{i}: {combo}")
     
     
# 356 healthy samples
# 708 - 356 = 352 sick samples
  

[32m2024-11-28 22:52:47.175[0m | [1mINFO    [0m | [36m__main__[0m:[36mcollect_healthy_indices[0m:[36m22[0m - [1mfound 349 healthy patients[0m
[32m2024-11-28 22:52:47.176[0m | [1mINFO    [0m | [36m__main__[0m:[36m_generate_train_combinations[0m:[36m40[0m - [1mhealthy_neg_samples: [(1, 48), (1, 20), (1, 44), (1, 22), (1, 47), (1, 35), (1, 4), (1, 40), (1, 33), (1, 46)][0m
[32m2024-11-28 22:52:47.176[0m | [1mINFO    [0m | [36m__main__[0m:[36m_generate_train_combinations[0m:[36m41[0m - [1mFixed 10 negative samples for healthy patient 1[0m
[32m2024-11-28 22:52:47.177[0m | [1mINFO    [0m | [36m__main__[0m:[36m_generate_train_combinations[0m:[36m40[0m - [1mhealthy_neg_samples: [(1, 48), (1, 20), (1, 44), (1, 22), (1, 47), (1, 35), (1, 4), (1, 40), (1, 33), (1, 46), (2, 37), (2, 27), (2, 46), (2, 13), (2, 0), (2, 10), (2, 3), (2, 15), (2, 48), (2, 26)][0m
[32m2024-11-28 22:52:47.177[0m | [1mINFO    [0m | [36m__main__[0m:[36m_generate_train

Generated 7080 combinations
0: {'positive': ((0, 2), (0, 29)), 'negative': (0, 43)}
1: {'positive': ((0, 8), (0, 26)), 'negative': (0, 41)}
2: {'positive': ((0, 23), (0, 22)), 'negative': (0, 49)}
3: {'positive': ((0, 16), (0, 32)), 'negative': (0, 46)}
4: {'positive': ((0, 30), (0, 10)), 'negative': (0, 40)}
5: {'positive': ((0, 21), (0, 15)), 'negative': (0, 48)}
6: {'positive': ((0, 34), (0, 25)), 'negative': (0, 42)}
7: {'positive': ((0, 13), (0, 20)), 'negative': (0, 47)}
8: {'positive': ((0, 37), (0, 39)), 'negative': (0, 45)}
9: {'positive': ((0, 1), (0, 27)), 'negative': (0, 44)}
10: {'positive': ((3, 32), (3, 15)), 'negative': (3, 41)}
11: {'positive': ((3, 27), (3, 10)), 'negative': (3, 47)}
12: {'positive': ((3, 34), (3, 23)), 'negative': (3, 44)}
13: {'positive': ((3, 24), (3, 20)), 'negative': (3, 49)}
14: {'positive': ((3, 11), (3, 29)), 'negative': (3, 45)}
15: {'positive': ((3, 39), (3, 28)), 'negative': (3, 43)}
16: {'positive': ((3, 33), (3, 13)), 'negative': (3, 40)}

In [18]:
import json
with open("combinations_batch_split_0.json", "r") as f:
    data = json.load(f)
    print(len(data))
    pos_1 = data[0]["positive"][0]
    pos_2 = data[0]["positive"][1]
    neg = data[0]["negative"]
    print(pos_1, pos_2, neg)


7040
[0, 22] [0, 24] [0, 41]
