In [1]:
import json
import numpy as np
import pandas as pd
from itertools import chain
from collections import defaultdict, Counter
from IPython.display import display
import json
from load_image import load_image
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from load_image import load_image
import sys
import torchvision
from IPython.utils import io

if ".." not in sys.path:
    sys.path.insert(0, "..")
import cbm

In [2]:
GLOBAL_PATH = "../../physionet.org/files/vindr-mammo/1.0.0/breast-level_annotations.csv"
LOCAL_PATH = "../../physionet.org/files/vindr-mammo/1.0.0/finding_annotations.csv"
birads_LESIONS = {
    "Mass",
    "Suspicious Calcification", "Architectural Distortion", 
    "Focal Asymmetry", "Global Asymmetry", "Asymmetry",
}
NO_BIRADS = {
    "Suspicious Lymph Node", 
    'Skin Thickening',
    'Skin Retraction',
    'Nipple Retraction',
    'No Finding',
}
BIRADS345 = ["BI-RADS 3", "BI-RADS 4", "BI-RADS 5"]
ALL_LESIONS = [
    'Suspicious Lymph Node',
    'Mass',
    'Suspicious Calcification',
    'Asymmetry',
    'Focal Asymmetry',
    'Global Asymmetry',
    'Architectural Distortion',
    'Skin Thickening',
    'Skin Retraction',
    'Nipple Retraction',
    'No Finding',
]


def show_df(df):
    with pd.option_context(
        'display.max_rows', None, 
        'display.max_columns', None,
        'display.max_colwidth', None,
    ):  # more options can be specified also
        display(df)


def count_birads_densities(df):
    """
    count birads density at breast level
    """
    counter = defaultdict(lambda : 0)
    den_counter = defaultdict(lambda : 0)
    for (study_id, side), rows in df.groupby(["study_id","laterality"]):
        birads = rows.breast_birads.values[0]
        counter[birads] += 1
        density = rows.breast_density.values[0]
        den_counter[density] += 1
        
    total = sum(counter.values())
    total2 = sum(den_counter.values())
    assert total == total2
    percent = {k:f"{100.*v/total:.2f}" for k,v in counter.items()}
    counter["Total"] = total
    stats = pd.DataFrame.from_records({"No. breast": counter, "percent": percent})
    stats.index.name = "BI-RADS"
    stats = stats.sort_index()
    
    den_percent = {k:f"{100.*v/total:.2f}" for k,v in den_counter.items()}
    den_counter["Total"] = total
    den_stats = pd.DataFrame({"No. breast": den_counter, "percent": den_percent})
    den_stats.index.name = "DENSITY"
    den_stats = den_stats.sort_index()
    return stats, den_stats


def count_box_birads(df):
    """
    """
    counter = defaultdict(lambda: defaultdict(lambda : 0))
    df.finding_birads = df.finding_birads.fillna("")
    all_birads = sorted(df.finding_birads.unique().tolist())
    for _, row in df.iterrows():
        for clas in row.finding_categories:
            counter[clas]["Total"] += 1
            counter[clas][row.finding_birads] += 1
    for k,v in counter.items():
        v["Lesion"] = k
    df = pd.DataFrame.from_records(list(counter.values()), columns=["Lesion", "Total"] + all_birads)
    lesion = df["Lesion"].values
    df = df.set_index("Lesion")
    df = df.reindex(ALL_LESIONS)

    df = df.fillna(0)
    df.loc["All lesions"] = df.sum()
    df = df.astype('int32')
    return df


def count_box_label(df):
    box_label = list(chain(*df.box_label.tolist()))
    return Counter(box_label)


def df_counts(df):
    print("no. studies", len(df.study_id.unique()))
    print("no. images", len(df.image_id.unique()))


In [3]:
# import os
# os.listdir("../../physionet.org")
local_df = pd.read_csv(LOCAL_PATH)
local_df["finding_categories"] = local_df["finding_categories"].apply(lambda x: json.loads(x.replace("\'", "\"")))
local_df.head()
local_df["breast_birads"].unique()

array(['BI-RADS 4', 'BI-RADS 3', 'BI-RADS 5', 'BI-RADS 2', 'BI-RADS 1'],
      dtype=object)

In [4]:
global_df = pd.read_csv(GLOBAL_PATH)
global_df.head()
global_df["breast_birads"].unique()

array(['BI-RADS 2', 'BI-RADS 1', 'BI-RADS 3', 'BI-RADS 4', 'BI-RADS 5'],
      dtype=object)

In [5]:
# create attributes list for each study to stratify
split_col = [f"BI-RADS {i}" for i in range(1,6)]
split_col = split_col + [f"DENSITY {x}" for x in "ABCD"]
split_col.extend(list(NO_BIRADS))
split_col = split_col + [f"{box_name}_{box_birads}" for box_name in birads_LESIONS for box_birads in BIRADS345]
# split_col

In [6]:
# count number of instances for each attribute of the study
# e.g for breast-level annotations it is number of images in the study
# for finding annotations it is number of bounding box in the study
study_ids = sorted(global_df.study_id.unique().tolist())
labels_ar = np.zeros((len(study_ids), len(split_col)), dtype=np.int32)
for (study_id, lat), rows in global_df.groupby(["study_id", "laterality"]):
    birads = rows.breast_birads.values[0]
    density = rows.breast_density.values[0]
    labels_ar[study_ids.index(study_id),split_col.index(birads)] += 1
    labels_ar[study_ids.index(study_id),split_col.index(density)] += 1
for _, x in local_df.iterrows():
    birads = x["finding_birads"]
    for label in x["finding_categories"]:
        if label in birads_LESIONS:
            labels_ar[
                study_ids.index(x["study_id"]),
                split_col.index(f"{label}_{birads}"),
            ] += 1
        else:
            labels_ar[
                study_ids.index(x["study_id"]),
                split_col.index(label),
            ] += 1
total = labels_ar.sum(axis=0)
# for name,v in zip(split_col, total):
#     print(name,v)

In [7]:
from stratification import IterativeStratification
SEED = 1999
SPLITS = np.array([0.8, 0.2])
stratifier = IterativeStratification(SEED)
fold_ids = stratifier.stratify(labels_ar, SPLITS)

In [8]:
global_df['fold'] = ""
local_df['fold'] = ""
fold_name = ["training", "test"]
for k in range(2):
    fold_idx = np.where(fold_ids==k)[0]
#     print(fold_idx)
    study_uids = [study_ids[i] for i in fold_idx]
#     print(study_uids[:5])
    global_df.loc[global_df.study_id.isin(study_uids), 'fold'] = fold_name[k]
    local_df.loc[local_df.study_id.isin(study_uids), 'fold'] = fold_name[k]
#     print(global_df[global_df.study_id.isin(study_uids)].shape)
#     print(local_df[local_df.study_id.isin(study_uids)].shape)

In [9]:
# show_df(count_box_birads(local_df[local_df.fold == "training"]))
# show_df(count_box_birads(local_df[local_df.fold == "test"]))

In [10]:
# print("Whole dataset:")
# bi, den = count_birads_densities(global_df)
# show_df(bi)
# show_df(den)

# print("Training split:")
# bi, den = count_birads_densities(global_df[global_df.fold == "training"])
# show_df(bi)
# show_df(den)

# print("Test split:")
# bi, den = count_birads_densities(global_df[global_df.fold == "test"])
# show_df(bi)
# show_df(den)

In [11]:
cols_to_use = [*global_df.columns.difference(local_df.columns), "series_id", "image_id", "study_id"]

overall_df = local_df.merge(global_df[cols_to_use], how="inner", on=["series_id", "image_id", "study_id"])
overall_df

Unnamed: 0,study_id,series_id,image_id,laterality,view_position,height,width,breast_birads,breast_density,finding_categories,finding_birads,xmin,ymin,xmax,ymax,split,fold
0,48575a27b7c992427041a82fa750d3fa,26de4993fa6b8ae50a91c8baf49b92b0,4e3a578fe535ea4f5258d3f7f4419db8,R,CC,3518,2800,BI-RADS 4,DENSITY C,[Mass],BI-RADS 4,2355.139893,1731.640015,2482.979980,1852.750000,training,training
1,48575a27b7c992427041a82fa750d3fa,26de4993fa6b8ae50a91c8baf49b92b0,dac39351b0f3a8c670b7f8dc88029364,R,MLO,3518,2800,BI-RADS 4,DENSITY C,[Mass],BI-RADS 4,2386.679932,1240.609985,2501.800049,1354.040039,training,training
2,75e8e48933289d70b407379a564f8594,853b70e7e6f39133497909d9ca4c756d,c83f780904f25eacb44e9030f32c66e1,R,CC,3518,2800,BI-RADS 3,DENSITY C,[Global Asymmetry],BI-RADS 3,2279.179932,1166.510010,2704.439941,2184.260010,training,training
3,75e8e48933289d70b407379a564f8594,853b70e7e6f39133497909d9ca4c756d,893528bc38a0362928a89364f1b692fd,R,MLO,3518,2800,BI-RADS 3,DENSITY C,[Global Asymmetry],BI-RADS 3,1954.270020,1443.640015,2589.760010,2193.810059,training,training
4,c3487424fee1bdd4515b72dc3fd69813,77619c914263eae44e9099f1ce07192c,318264c881bf12f2c1efe5f93920cc37,R,CC,3518,2800,BI-RADS 4,DENSITY C,[Architectural Distortion],BI-RADS 4,2172.300049,1967.410034,2388.699951,2147.159912,training,training
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20481,f2093a752e6b44df5990f5fd38c99dd2,2b1b2b8f48abab9819c0b3d091e152ee,ea732154d149f619b20070b78060ae65,R,CC,2812,2012,BI-RADS 2,DENSITY C,[No Finding],,,,,,training,test
20482,b3c8969cd2accfa4dbb2aece1f7158ab,69d7f07ea04572dad5e5aa62fbcfc4b7,4689616c3d0b46fcba7a771107730791,R,CC,3580,2702,BI-RADS 2,DENSITY C,[No Finding],,,,,,training,training
20483,b3c8969cd2accfa4dbb2aece1f7158ab,69d7f07ea04572dad5e5aa62fbcfc4b7,3c22491bcf1d0b004715c28d80981cdd,L,CC,3580,2702,BI-RADS 2,DENSITY C,[No Finding],,,,,,training,training
20484,b3c8969cd2accfa4dbb2aece1f7158ab,69d7f07ea04572dad5e5aa62fbcfc4b7,d443b9725e331b8b27589aa725597801,R,MLO,3580,2686,BI-RADS 2,DENSITY C,[No Finding],,,,,,training,training


In [12]:
img_loc = "../../physionet.org/files/vindr-mammo/1.0.0/images"

In [13]:
os.listdir(img_loc)
img = load_image(img_loc + "/" + "0025a5dc99fd5c742026f0b2b030d3e9" + "/" + "2ddfad7286c2b016931ceccd1e2c7bbc" + ".dicom")


In [14]:
overall_df[overall_df["image_id"] == "2ddfad7286c2b016931ceccd1e2c7bbc"]


Unnamed: 0,study_id,series_id,image_id,laterality,view_position,height,width,breast_birads,breast_density,finding_categories,finding_birads,xmin,ymin,xmax,ymax,split,fold
6679,0025a5dc99fd5c742026f0b2b030d3e9,47d59b788d64eecab165d97471c4131a,2ddfad7286c2b016931ceccd1e2c7bbc,L,MLO,3518,2800,BI-RADS 1,DENSITY C,[No Finding],,,,,,test,training


In [15]:
overall_df[overall_df["study_id"] == "0025a5dc99fd5c742026f0b2b030d3e9"]

os.path.isfile(img_loc + "/" + "0025a5dc99fd5c742026f0b2b030d3e9" + "/" + "2ddfad7286c2b016931ceccd1e2c7bbc" + ".dicom")

True

In [16]:
# !pip install pydicom==2.1.2
# global_df.info()
# global_df.loc[0, "series_id"]

In [None]:
# We are using a smaller subset of the full dataset
for idx, row in overall_df.iterrows():
    the_path =  img_loc + "/" + overall_df.loc[idx,"study_id"] + "/" + overall_df.loc[idx,"image_id"] + ".dicom"
    if os.path.isfile(the_path):
        continue
    else:
        overall_df = overall_df.drop(index=idx)
overall_df.info()

In [None]:
train_df = overall_df[overall_df.fold == "training"]
test_df = overall_df[overall_df.fold == "test"]


train_file_list = (img_loc + "/" + train_df["study_id"] + "/" + train_df["image_id"] + ".dicom").to_list()
test_file_list = (img_loc + "/" + test_df["study_id"] + "/" + test_df["image_id"] + ".dicom").to_list()
print(type(train_file_list))

train_labels = pd.to_numeric(train_df["breast_birads"].str[-1]).to_numpy()
test_labels = pd.to_numeric(test_df["breast_birads"].str[-1]).to_numpy()

test_concepts = pd.get_dummies(test_df.finding_categories.explode()).groupby(level=0).sum().reindex()
train_concepts = pd.get_dummies(train_df.finding_categories.explode()).groupby(level=0).sum().reindex()
concept_labels = test_concepts.columns

train_df.info()
test_df.info()

In [None]:
class MammoDataset(Dataset):
    def __init__(self, file_list, labels, concepts, transform=None):
        self.file_list = file_list
        self.labels = labels
        self.transform = transform
        self.concepts = concepts

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        image = load_image(img_path)
        label = self.labels[idx]
        concepts = self.concepts.iloc[idx].to_list()
        if self.transform:
            image = self.transform(image)
        return image, label, concepts

    
transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((256//2, 256//2)),
        torchvision.transforms.ToTensor()])

train_dataset = MammoDataset(train_file_list, train_labels, train_concepts, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = MammoDataset(test_file_list, test_labels, test_concepts, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [None]:
model = cbm.CBM(n_concepts=len(concept_labels), latent_dims=16, concept_names=concept_labels,
               channels_in=1, n_out=5)

In [None]:
config = {
    "n_epochs": 10
}


In [None]:
with io.capture_output() as captured:
    # Standard PyTorch learning cycle
    model.train()
    for epoch in range(config["n_epochs"]):
        for batch_idx, (x, y, c) in enumerate(train_dataloader):
            # Encode input, then predict concept and downstream tasks activations
            out = model(x)
            # print(out)
            emb = out["enc"]
            c_pred = out["lay"]
            y_pred = out["pred"]

            # Double loss on concepts and tasks
            loss = loss_fn(y_pred, y) + config["concept_loss_weight"] * loss_fn(c_pred, c)

            # Perform the update
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            task_acc = torch.mean(((y_pred > 0.5) == y).type(torch.float))
            task_acc = task_acc.detach().cpu().numpy()
            if ((epoch + 1) % 5 == 0) and (batch_idx == 0):
                epochs = config["n_epochs"]
                print(f"Epoch [{epoch+1}/{epochs}], "
                    f"Step [{batch_idx+1}/{len(data_loader)}], "
                    f"Loss: {loss.item():.4f}, "
                    f"Task Accuracy: {task_acc * 100:.2f}%, "
                )

In [None]:
import torch.nn as nn

class Encoder(nn.Module):
    """ Encoder for CBM. """
    def __init__(self, latent_dims, channels_in=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(channels_in, 4, (3, 3), padding='same'),
            nn.LeakyReLU(),

            nn.Conv2d(4, 4, (3, 3), padding='same'),
            nn.LeakyReLU(),
            nn.BatchNorm2d(4),

            nn.Conv2d(4, 4, (3, 3), padding='same'),
            nn.LeakyReLU(),

            nn.Conv2d(4, 4, (3, 3), padding='same'),
            nn.LeakyReLU(),
            nn.BatchNorm2d(4),

            nn.MaxPool2d((5, 5)),

            nn.Flatten(start_dim=1, end_dim=-1),
            nn.Linear(576, latent_dims),  
            nn.LeakyReLU(),
        )

    def forward(self, x):
        return self.net(x)
enc = Encoder(32, 1)
print(enc)