In [2]:
import os
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
import torch
from tqdm.notebook import tqdm
import pandas as pd

In [3]:
%run covid_x.py

In [4]:
images_dir = os.path.join(DATASET_DIR, 'train')
# images_names = os.listdir(images_dir)
with open(os.path.join(DATASET_DIR, 'further_train_split.txt')) as f:
    images_names = [l.strip() for l in f.readlines()]
len(images_names)

12504

## Calculate mean and std

In [6]:
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

In [7]:
total = torch.zeros(3, 512, 512)

In [8]:
for image_name in tqdm(images_names):
    filepath = os.path.join(images_dir, image_name)
    image = Image.open(filepath).convert('RGB')
    image = transform(image)
    total += image
    
    break
    
total /= len(images_names)
total.size()

HBox(children=(FloatProgress(value=0.0, max=12504.0), HTML(value='')))




torch.Size([3, 512, 512])

In [93]:
total.mean(dim=-1).mean(dim=-1)

tensor([0.4919, 0.4920, 0.4920])

In [94]:
total.std(dim=-1).std(dim=-1)

tensor([0.0467, 0.0467, 0.0467])

## Load labels

In [34]:
labels_path = os.path.join(DATASET_DIR, 'train_split.txt')
columns = ['patient_id', 'image_name', 'label', 'source']

df = pd.read_csv(labels_path, sep=' ', header=None, names=columns)
df.head()

Unnamed: 0,patient_id,image_name,label,source
0,3,SARS-10.1148rg.242035193-g04mr34g0-Fig8a-day0....,pneumonia,cohen
1,3,SARS-10.1148rg.242035193-g04mr34g0-Fig8b-day5....,pneumonia,cohen
2,3,SARS-10.1148rg.242035193-g04mr34g0-Fig8c-day10...,pneumonia,cohen
3,7,SARS-10.1148rg.242035193-g04mr34g04a-Fig4a-day...,pneumonia,cohen
4,7,SARS-10.1148rg.242035193-g04mr34g04b-Fig4b-day...,pneumonia,cohen


In [37]:
set(df['label'])

{'covid', 'normal', 'pneumonia'}

In [36]:
df.replace('COVID-19', 'covid', inplace=True)
df.head()

Unnamed: 0,patient_id,image_name,label,source
0,3,SARS-10.1148rg.242035193-g04mr34g0-Fig8a-day0....,pneumonia,cohen
1,3,SARS-10.1148rg.242035193-g04mr34g0-Fig8b-day5....,pneumonia,cohen
2,3,SARS-10.1148rg.242035193-g04mr34g0-Fig8c-day10...,pneumonia,cohen
3,7,SARS-10.1148rg.242035193-g04mr34g04a-Fig4a-day...,pneumonia,cohen
4,7,SARS-10.1148rg.242035193-g04mr34g04b-Fig4b-day...,pneumonia,cohen


In [39]:
set(['1', 'a']) == set(['a', '1'])

True

In [43]:
df[df['label'] != 'covid'].index

Int64Index([    0,     1,     2,     3,     4,     5,     6,     7,     8,
                9,
            ...
            13882, 13883, 13884, 13885, 13886, 13887, 13888, 13889, 13890,
            13891],
           dtype='int64', length=13424)

## Split train-val

In [28]:
import random

In [1]:
%run covid_x.py

In [3]:
train_dataset = CovidXDataset('train')
test_dataset = CovidXDataset('test')
len(train_dataset), len(test_dataset)

(13892, 1579)

In [20]:
train_dataset._metadata_df.groupby('label').count()['image_name']

label
covid         468
normal       7966
pneumonia    5458
Name: image_name, dtype: int64

In [19]:
test_dataset._metadata_df.groupby('label').count()['image_name']

label
covid        100
normal       885
pneumonia    594
Name: image_name, dtype: int64

In [3]:
(13892-468) // 468

28

In [60]:
def split_train_val(df, split=0.1):
    """The label distribution is mantained."""
    images_chosen = []
    
    labels = list(set(df['label']))
    for label in labels:
        # Filter only this label
        df_with_label = df[df['label'] == label]
        
        # Group images by patient
        images_by_patient = df_with_label.groupby('patient_id')['image_name'].apply(list)
        
        patients = list(images_by_patient.index)
        
        # Calculate split length
        n_images = len(df_with_label)
        split_len = int(n_images * split)
        
        # Choose images
        n_chosen = 0
        while n_chosen < split_len:
            # Choose one random patient
            patient = random.choice(patients)
            
            # Patient has 1 or more images
            images_from_patient = images_by_patient[patient]
            n_chosen += len(images_from_patient)

            # Add chosen images to main list
            images_chosen.extend(images_from_patient)

            # Remove patient from posible options
            patients.remove(patient)

    return images_chosen

In [58]:
labels_fpath = os.path.join(DATASET_DIR, 'train_split.txt')
columns = ['patient_id', 'image_name', 'label', 'source']
df = pd.read_csv(labels_fpath, sep=' ', header=None, names=columns)
df.head()

Unnamed: 0,patient_id,image_name,label,source
0,3,SARS-10.1148rg.242035193-g04mr34g0-Fig8a-day0....,pneumonia,cohen
1,3,SARS-10.1148rg.242035193-g04mr34g0-Fig8b-day5....,pneumonia,cohen
2,3,SARS-10.1148rg.242035193-g04mr34g0-Fig8c-day10...,pneumonia,cohen
3,7,SARS-10.1148rg.242035193-g04mr34g04a-Fig4a-day...,pneumonia,cohen
4,7,SARS-10.1148rg.242035193-g04mr34g04b-Fig4b-day...,pneumonia,cohen


In [61]:
val_images = split_train_val(df, split=0.1)
train_images = [i for i in train_dataset._metadata_df['image_name'] if i not in val_images]

assert len(df) == len(train_images) + len(val_images)

len(train_images), len(val_images)

(12505, 1387)

### Save split to files

In [53]:
def write_to_txt(arr, fname, sep='\n'):
    with open(fname, 'w') as f:
        for line in arr:
            f.write(line + sep)

In [54]:
write_to_txt(train_images, os.path.join(DATASET_DIR, 'further_train_split.txt'))

In [55]:
write_to_txt(val_images, os.path.join(DATASET_DIR, 'further_val_split.txt'))

## Debug `CovidXDataset` class

In [1]:
%run covid_x.py

In [2]:
train_dataset = CovidXDataset('train')
val_dataset = CovidXDataset('val')
test_dataset = CovidXDataset('test')
len(train_dataset), len(val_dataset), len(test_dataset)

(12504, 1388, 1579)

In [79]:
train_patients = set(train_dataset._metadata_df['patient_id'])
val_patients = set(val_dataset._metadata_df['patient_id'])
test_patients = set(test_dataset._metadata_df['patient_id'])
len(train_patients), len(val_patients), len(test_patients)

(12358, 1374, 1550)

In [83]:
train_patients.intersection(test_patients), \
train_patients.intersection(val_patients), \
val_patients.intersection(test_patients)

(set(), set(), set())

In [26]:
def get_dataset_distribution(dataset):
    temp_df = dataset._metadata_df.groupby('label').count()
    temp_df.rename(columns={'patient_id': 'counts'}, inplace=True)
    
    counts = temp_df['counts'].to_numpy()
    total = sum(counts)
    
    temp_df['percentage'] = counts / total * 100
    
    return temp_df[['counts', 'percentage']]

In [31]:
get_dataset_distribution(train_dataset)

Unnamed: 0_level_0,counts,percentage
label,Unnamed: 1_level_1,Unnamed: 2_level_1
covid,421,3.366923
normal,7170,57.341651
pneumonia,4913,39.291427


In [28]:
get_dataset_distribution(val_dataset)

Unnamed: 0_level_0,counts,percentage
label,Unnamed: 1_level_1,Unnamed: 2_level_1
covid,47,3.386167
normal,796,57.348703
pneumonia,545,39.26513


In [29]:
get_dataset_distribution(test_dataset)

Unnamed: 0_level_0,counts,percentage
label,Unnamed: 1_level_1,Unnamed: 2_level_1
covid,100,6.333122
normal,885,56.048132
pneumonia,594,37.618746


In [1]:
100 + 885 + 594

1579

In [3]:
covid = 421 + 47 + 100
pneum = 4913 + 545 + 594
normal = 7170 + 796 + 594
total = covid + pneum + normal
total, covid, pneum, normal

(15180, 568, 6052, 8560)

In [4]:
covid / total * 100, pneum / total * 100, normal /total * 100 

(3.741765480895916, 39.86824769433465, 56.38998682476944)

In [3]:
4913 / 421, 7170 / 421

(11.669833729216151, 17.03087885985748)

In [32]:
100 + 885 + 594

1579