In [1]:
import sys; sys.version_info

sys.version_info(major=3, minor=6, micro=9, releaselevel='final', serial=0)

In [46]:
import os
import pandas as pd
import torch
from torchvision import transforms
from tqdm.notebook import tqdm
from PIL import Image
import matplotlib.pyplot as plt

In [30]:
%run ../covid_uc.py

### Calculate mean and std

In [31]:
labels_fpath = os.path.join(DATASET_DIR, 'metadata.csv')
metadata = pd.read_csv(labels_fpath, index_col=0)
metadata.head()

Unnamed: 0,ID,Fecha consulta SU,Resultado consenso BSTI,date,image_name,view
6,15,2020-03-01,Non-COVID,2020-03-01,15-0-IM-0001-0001-0001.png,AP
7,17,2020-03-02,Non-COVID,2020-03-02,17-0-IM-0001-0001-0001.png,PA
8,17,2020-03-02,Non-COVID,2020-03-02,17-1-IM-0001-0001-0001.png,PA
9,18,2020-03-02,Non-COVID,2020-03-02,18-0-IM-0001-0001-0001.png,AP
10,19,2020-03-03,Non-COVID,2020-03-03,19-0-IM-0001-0001-0001.png,PA


In [72]:
def get_image_names(frontal_only=False):
    if frontal_only:
        df = metadata.loc[metadata['view'].str.contains('P')]
    else:
        df = metadata
    return list(df['image_name'])

In [75]:
image_names_frontal = get_image_names(True)
image_names_all = get_image_names(False)
len(image_names_frontal), len(image_names_all)

(427, 673)

In [68]:
def calculate_stats(image_names, image_size=(1024, 1024), max_samples=None):
    tf = transforms.Compose([
        transforms.Resize(image_size), # lose as less as possible
        transforms.ToTensor(),
    ])

    total = torch.zeros(3, *image_size)

    iterator = image_names
    if max_samples:
        iterator = iterator[:max_samples]

    for image_name in tqdm(iterator):
        fpath = os.path.join(DATASET_DIR, 'images', image_name)
        image = tf(Image.open(fpath).convert('RGB'))

        total += image

    total /= len(iterator)

    mean = total.mean(dim=-1).mean(dim=-1)
    std = total.std(dim=-1).std(dim=-1)

    return mean, std

In [69]:
mean, std = calculate_stats(image_names_all)
mean, std

HBox(children=(FloatProgress(value=0.0, max=673.0), HTML(value='')))




(tensor([0.3296, 0.3296, 0.3296]), tensor([0.0219, 0.0219, 0.0219]))

In [76]:
mean, std = calculate_stats(image_names_frontal)
mean, std

HBox(children=(FloatProgress(value=0.0, max=427.0), HTML(value='')))




(tensor([0.3836, 0.3836, 0.3836]), tensor([0.0143, 0.0143, 0.0143]))

### Test `CovidUCDataset` class

In [152]:
%run ../covid_uc.py

In [156]:
dataset = CovidUCDataset('test', image_size=(256, 256), frontal_only=True)
len(dataset)

64

In [81]:
dataset._metadata_df

Unnamed: 0,ID,Fecha consulta SU,Resultado consenso BSTI,date,image_name,view
0,15,2020-03-01,Non-COVID,2020-03-01,15-0-IM-0001-0001-0001.png,AP
1,17,2020-03-02,Non-COVID,2020-03-02,17-0-IM-0001-0001-0001.png,PA
2,17,2020-03-02,Non-COVID,2020-03-02,17-1-IM-0001-0001-0001.png,PA
3,18,2020-03-02,Non-COVID,2020-03-02,18-0-IM-0001-0001-0001.png,AP
4,19,2020-03-03,Non-COVID,2020-03-03,19-0-IM-0001-0001-0001.png,PA
...,...,...,...,...,...,...
422,636,2020-03-16,normal,2020-03-16,636-0-IM-0001-0001-0001.png,PA
423,637,2020-03-24,normal,2020-03-24,637-0-IM-0001-0001-0001.png,AP
424,638,2020-03-25,Non-COVID,2020-03-25,638-0-IM-0001-0001-0001.png,PA
425,639,2020-04-06,normal,2020-04-06,639-0-IM-0001-0001-0001.png,PA


In [86]:
for image, label in dataset:
    break

In [87]:
label

1

In [88]:
image.size()

torch.Size([3, 256, 256])

In [93]:
dataset.labels

['covid', 'Non-COVID', 'normal']

In [94]:
dataset.get_labels_presence_for('Non-COVID')

[(0, 1),
 (1, 1),
 (2, 1),
 (3, 1),
 (4, 1),
 (5, 1),
 (6, 1),
 (7, 1),
 (8, 1),
 (9, 1),
 (10, 1),
 (11, 1),
 (12, 1),
 (13, 1),
 (14, 1),
 (15, 1),
 (16, 1),
 (17, 1),
 (18, 1),
 (19, 1),
 (20, 1),
 (21, 1),
 (22, 0),
 (23, 0),
 (24, 0),
 (25, 0),
 (26, 0),
 (27, 0),
 (28, 0),
 (29, 0),
 (30, 0),
 (31, 0),
 (32, 0),
 (33, 0),
 (34, 0),
 (35, 0),
 (36, 0),
 (37, 0),
 (38, 0),
 (39, 0),
 (40, 0),
 (41, 0),
 (42, 0),
 (43, 0),
 (44, 0),
 (45, 0),
 (46, 0),
 (47, 0),
 (48, 0),
 (49, 0),
 (50, 0),
 (51, 0),
 (52, 0),
 (53, 0),
 (54, 0),
 (55, 0),
 (56, 0),
 (57, 0),
 (58, 0),
 (59, 0),
 (60, 0),
 (61, 0),
 (62, 0),
 (63, 0),
 (64, 0),
 (65, 0),
 (66, 0),
 (67, 0),
 (68, 0),
 (69, 0),
 (70, 0),
 (71, 0),
 (72, 0),
 (73, 0),
 (74, 0),
 (75, 0),
 (76, 0),
 (77, 0),
 (78, 0),
 (79, 0),
 (80, 0),
 (81, 0),
 (82, 0),
 (83, 0),
 (84, 0),
 (85, 0),
 (86, 0),
 (87, 0),
 (88, 0),
 (89, 0),
 (90, 0),
 (91, 0),
 (92, 0),
 (93, 0),
 (94, 0),
 (95, 0),
 (96, 0),
 (97, 0),
 (98, 0),
 (99, 0),
 (100, 0),

### Split

In [95]:
%run ../covid_uc.py

In [135]:
import random
import os
from collections import Counter

In [134]:
LABEL_COL = 'Resultado consenso BSTI'

In [97]:
def save_list(items, name):
    filepath = os.path.join(DATASET_DIR, f'{name}.txt')
    with open(filepath, 'w') as f:
        for item in items:
            f.write(f'{item}\n')

    print(f'List saved to: {filepath}')

In [98]:
labels_fpath = os.path.join(DATASET_DIR, 'metadata.csv')
metadata = pd.read_csv(labels_fpath, index_col=0)
metadata.head()

Unnamed: 0,ID,Fecha consulta SU,Resultado consenso BSTI,date,image_name,view
16,15,2020-03-01,Non-COVID,2020-03-01,15-0-IM-0001-0001-0001.png,AP
17,17,2020-03-02,Non-COVID,2020-03-02,17-0-IM-0001-0001-0001.png,PA
18,17,2020-03-02,Non-COVID,2020-03-02,17-1-IM-0001-0001-0001.png,PA
19,18,2020-03-02,Non-COVID,2020-03-02,18-0-IM-0001-0001-0001.png,AP
20,19,2020-03-03,Non-COVID,2020-03-03,19-0-IM-0001-0001-0001.png,PA


In [100]:
images_by_patient = metadata.groupby('ID')['image_name'].apply(list)
images_by_patient

ID
100    [100-0-IM-0001-0001-0001.png, 100-1-IM-0001-00...
101                        [101-0-IM-0001-0001-0001.png]
102    [102-0-IM-0001-0001-0001.png, 102-1-IM-0001-00...
103    [103-0-IM-0001-0001-0001.png, 103-1-IM-0001-00...
104    [104-0-IM-0001-0001-0001.png, 104-1-IM-0001-00...
                             ...                        
94     [94-0-IM-0001-0001-0001.png, 94-1-IM-0001-0002...
95     [95-0-IM-0001-0001-0001.png, 95-1-IM-0001-0002...
97     [97-0-IM-0001-0001-0001.png, 97-1-IM-0001-0002...
98     [98-0-IM-0001-0001-0001.png, 98-1-IM-0001-0001...
99     [99-0-IM-0001-0001-0001.png, 99-1-IM-0001-0002...
Name: image_name, Length: 573, dtype: object

In [107]:
patients = list(set(metadata['ID']))
len(patients)

573

In [110]:
VAL_SPLIT = 0
TEST_SPLIT = 0.1

In [111]:
n_val = int(VAL_SPLIT * len(patients))
n_test = int(TEST_SPLIT * len(patients))
n_val, n_test

(0, 57)

In [114]:
val_test_patients = random.sample(patients, n_val + n_test)
len(val_test_patients)

57

In [116]:
train_patients = [pat for pat in patients if pat not in val_test_patients]
len(train_patients)

516

In [127]:
combine_images = lambda pats: sum((images_by_patient[pat] for pat in pats), [])
count_images = lambda pats: sum(len(images_by_patient[pat]) for pat in pats)

In [122]:
count_images(train_patients), count_images(val_test_patients)

(795, 88)

In [129]:
train_images = combine_images(train_patients)
val_test_images = combine_images(val_test_patients)

In [130]:
filter_meta = lambda images: metadata.loc[metadata['image_name'].isin(images)]

In [137]:
train_df = filter_meta(train_images)
val_test_df = filter_meta(val_test_images)
# train_df

In [139]:
Counter(train_df[LABEL_COL])

Counter({'Non-COVID': 94, 'normal': 678, 'covid': 23})

In [145]:
total = (94 + 678 + 23)
mult = 100/total
94 * mult, 678 * mult, 23 * mult

(11.823899371069182, 85.28301886792453, 2.893081761006289)

In [141]:
Counter(val_test_df[LABEL_COL])

Counter({'Non-COVID': 4, 'normal': 79, 'covid': 5})

In [143]:
4 / 88 * 100, 79 / 88 * 100, 5 / 88 * 100

(4.545454545454546, 89.77272727272727, 5.681818181818182)

In [146]:
len(train_images), len(val_test_images)

(795, 88)

In [147]:
val_images = val_test_images[:n_val]
test_images = val_test_images[n_val:]
len(val_images), len(test_images)

(0, 88)

In [None]:
save_list(train_images, 'train')
save_list(val_images, 'val')
save_list(test_images, 'test')