# **Causal De-biasing Fusion Pipeline**

In [3]:
!pip install monai torchinfo pytorch-metric-learning



In [60]:
import os
import cv2
import pandas as pd
import numpy as np
from numpy.linalg import norm
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.io import read_image
import torchvision.transforms.v2 as transforms
from torchvision.transforms import Compose
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet50, ResNet50_Weights
from pytorch_metric_learning import losses

import nltk
from nltk import word_tokenize
from gensim import models

from transformers import AutoTokenizer, BertModel

import monai.transforms as mt

from torchinfo import summary

from tqdm.notebook import tqdm

from google.colab import drive
drive.mount('/content/drive')

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
nltk.download('punkt')
device = "cuda" if torch.cuda.is_available() else "cpu"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## **Data Setup**

In [4]:
df = pd.read_csv('/content/drive/MyDrive/harvard/6.8610/project/data/6.8610proj/dataset_ready.csv')

In [5]:
df['img_path'] = df['img_path'].str.replace('/MyDrive', '/MyDrive/harvard/6.8610/project/data')
df.head()

Unnamed: 0,dicom_id,subject_id,study_id,ViewPosition,Rows,Columns,label,img_path,reports,gender,anchor_age,marital_status,race,general_race_group
0,86a4065d-a51890c6-094034a8-c549f6b9-6660ff93,10406570,57207207,PA,3056,2500,0,/content/drive/MyDrive/harvard/6.8610/project/...,FINAL REPORT\...,F,69,WIDOWED,BLACK/AFRICAN AMERICAN,Black/African Descent
1,a1072339-d1fe3a01-149f489b-9a08c49b-f2fee79e,10965697,51095949,AP,3056,2544,0,/content/drive/MyDrive/harvard/6.8610/project/...,FINAL REPORT\...,F,67,WIDOWED,ASIAN - CHINESE,Asian
2,becf8ed6-c5f60c71-89040c32-9d94b7c6-eb956bc1,10104732,59794138,AP,3056,2544,0,/content/drive/MyDrive/harvard/6.8610/project/...,WET READ: ___ ___ ___ 8:27 PM\n No change fr...,M,49,SINGLE,WHITE,White/Caucasian
3,db41181b-0c240a54-1a370211-3723e7a9-b6cdb316,10898945,50988324,AP,2606,2544,0,/content/drive/MyDrive/harvard/6.8610/project/...,FINAL REPORT\...,M,78,MARRIED,WHITE,White/Caucasian
4,1739a403-be126d84-266aab85-442b24fc-a4ebe43c,10867055,59657889,AP,3050,2539,1,/content/drive/MyDrive/harvard/6.8610/project/...,FINAL REPORT\...,M,47,SINGLE,WHITE,White/Caucasian


In [6]:
df['reports'] = df['reports'].str.lower().str.replace("\n", " ").str.strip()

In [7]:
df.shape

(2000, 14)

In [8]:
IMG_CLASSES = ['Normal','PleuralEffusion']
NUM_CLASSES = len(IMG_CLASSES)
IMG_SIZE = 224
class MultiModalDataset(Dataset):
    def __init__(self, indices, image_dir, texts, labels, transform = None):
        # load dataset
        self.images = []
        for image_file in tqdm(image_dir, position = 0, leave = True):
            image = transforms.Resize([IMG_SIZE, IMG_SIZE])(read_image(image_file))
            self.images.append(image)


        self.indices = indices
        self.texts = texts
        self.labels = labels

        # transforms
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.transform(self.images[idx]) if self.transform is not None else self.images[idx]
        image = image.repeat(3, 1, 1)
        # image = (image - image.min())/(image.max() - image.min()) # map to [0, 1]
        index = self.indices[idx]
        label = self.labels[idx]
        text = self.texts[idx]
        sample = {'index': index, 'image': image, 'text': text, 'label': label}#transforms.ToDtype(torch.uint8, scale=True)(mask)}
        return sample

In [9]:
# splits
TRAIN_SIZE = 0.7
VAL_SIZE = 0.15
TEST_SIZE = 0.15
samp_df = df.sample(frac=1, random_state=SEED)
train_df, val_df, test_df = samp_df[0:int(TRAIN_SIZE*len(samp_df))], samp_df[int(TRAIN_SIZE*len(samp_df)):int((TRAIN_SIZE+VAL_SIZE)*len(samp_df))], samp_df[int((TRAIN_SIZE+VAL_SIZE)*len(samp_df)):]
train_df.shape, val_df.shape, test_df.shape

((1400, 14), (300, 14), (300, 14))

In [10]:
train_transforms = transforms.Compose([
    mt.NormalizeIntensity(),
])

val_transforms = transforms.Compose([
    mt.NormalizeIntensity(),
])

train_ds = MultiModalDataset(list(train_df.index),
                             train_df['img_path'].values,
                             train_df['reports'].values,
                             train_df['label'].values,
                             transform = train_transforms)
val_ds = MultiModalDataset(list(val_df.index),
                           val_df['img_path'].values,
                           val_df['reports'].values,
                           val_df['label'].values,
                           transform = val_transforms)
test_ds = MultiModalDataset(list(test_df.index),
                            test_df['img_path'].values,
                            test_df['reports'].values,
                            test_df['label'].values,
                            transform = val_transforms)

  0%|          | 0/1400 [00:00<?, ?it/s]



  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

In [11]:
BATCH_SIZE = 4
train_kwargs = {'batch_size': BATCH_SIZE, 'shuffle': True}
val_kwargs = {'batch_size': BATCH_SIZE, 'shuffle': False}

train_loader = DataLoader(train_ds, **train_kwargs)
val_loader = DataLoader(val_ds, **val_kwargs)
test_loader = DataLoader(test_ds, **val_kwargs)

In [12]:
for batch in test_loader:
    print(batch['index'].shape)
    print(batch['image'].shape)
    print(len(batch['text'])) # will tokenize in training loop...
    print(batch['label'].shape)
    break

torch.Size([4])
torch.Size([4, 3, 224, 224])
4
torch.Size([4])


## **Model**

In [13]:
bert = BertModel.from_pretrained("distilbert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

You are using a model of type distilbert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of BertModel were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.layer.5.intermediate.dense.weight', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.9.attention.self.quer

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [14]:
summary(bert)

Layer (type:depth-idx)                             Param #
BertModel                                          --
├─BertEmbeddings: 1-1                              --
│    └─Embedding: 2-1                              23,440,896
│    └─Embedding: 2-2                              393,216
│    └─Embedding: 2-3                              1,536
│    └─LayerNorm: 2-4                              1,536
│    └─Dropout: 2-5                                --
├─BertEncoder: 1-2                                 --
│    └─ModuleList: 2-6                             --
│    │    └─BertLayer: 3-1                         7,087,872
│    │    └─BertLayer: 3-2                         7,087,872
│    │    └─BertLayer: 3-3                         7,087,872
│    │    └─BertLayer: 3-4                         7,087,872
│    │    └─BertLayer: 3-5                         7,087,872
│    │    └─BertLayer: 3-6                         7,087,872
│    │    └─BertLayer: 3-7                         7,087,872
│    │   

In [15]:
batch_tokenized = tokenizer(batch['text'], return_tensors="pt", padding='max_length', truncation=True)
test_out = bert(**batch_tokenized)

In [16]:
test_out.last_hidden_state.shape

torch.Size([4, 512, 768])

In [17]:
resnet = torch.nn.Sequential(*list(resnet50(weights=ResNet50_Weights.DEFAULT).children())[:-1])

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 175MB/s]


In [18]:
summary(resnet)

Layer (type:depth-idx)                   Param #
Sequential                               --
├─Conv2d: 1-1                            9,408
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─Bottleneck: 2-1                   --
│    │    └─Conv2d: 3-1                  4,096
│    │    └─BatchNorm2d: 3-2             128
│    │    └─Conv2d: 3-3                  36,864
│    │    └─BatchNorm2d: 3-4             128
│    │    └─Conv2d: 3-5                  16,384
│    │    └─BatchNorm2d: 3-6             512
│    │    └─ReLU: 3-7                    --
│    │    └─Sequential: 3-8              16,896
│    └─Bottleneck: 2-2                   --
│    │    └─Conv2d: 3-9                  16,384
│    │    └─BatchNorm2d: 3-10            128
│    │    └─Conv2d: 3-11                 36,864
│    │    └─BatchNorm2d: 3-12            128
│    │    └─Conv2d: 3-13               

In [19]:
resnet(batch['image']).shape

torch.Size([4, 2048, 1, 1])

In [20]:
class MultiModalModel(nn.Module):
    """
    Model that outputs representations for image and text
    """
    def __init__(self, cv_encoder, nlp_encoder, tokenizer):
        super(MultiModalModel, self).__init__()
        self.tokenizer = tokenizer
        self.nlp_encoder = nlp_encoder
        self.nlp_out = nn.Linear(768, 256) # bert
        self.cv_encoder = cv_encoder
        self.cv_out = nn.Linear(2048, 256) # resnet

    def forward(self, im, text):
        tokens = self.tokenizer(text, return_tensors="pt", padding='max_length', truncation=True).to(device)
        z_text = self.nlp_encoder(**tokens).last_hidden_state
        z_text = self.nlp_out(z_text).mean(axis = 1) # average sequence representation
        z_im = self.cv_encoder(im).squeeze()
        z_im = self.cv_out(z_im)
        return z_im, z_text

In [21]:
class FusionModel(nn.Module):
    """
    Model that performs intermediate fusion and outputs logits
    """
    def __init__(self, cv_encoder, nlp_encoder, tokenizer):
        super(FusionModel, self).__init__()
        multimodal_encoder = MultiModalModel(cv_encoder, nlp_encoder, tokenizer)
        self.multimodal_endover = multimodal_encoder
        self.out = nn.Linear(512, 2)

    def forward(self, im, text):
        z_im, z_text = self.multimodal_endover(im, text)
        z = torch.cat([z_im, z_text], axis = 1)
        z = self.out(z)
        return z

In [22]:
model = FusionModel(resnet, bert, tokenizer)
summary(model)

Layer (type:depth-idx)                                       Param #
FusionModel                                                  --
├─MultiModalModel: 1-1                                       --
│    └─BertModel: 2-1                                        --
│    │    └─BertEmbeddings: 3-1                              23,837,184
│    │    └─BertEncoder: 3-2                                 85,054,464
│    │    └─BertPooler: 3-3                                  590,592
│    └─Linear: 2-2                                           196,864
│    └─Sequential: 2-3                                       --
│    │    └─Conv2d: 3-4                                      9,408
│    │    └─BatchNorm2d: 3-5                                 128
│    │    └─ReLU: 3-6                                        --
│    │    └─MaxPool2d: 3-7                                   --
│    │    └─Sequential: 3-8                                  215,808
│    │    └─Sequential: 3-9                                  1,2

## **Causal De-Biasing**

In [45]:
# w2v lookup matrix
tokenized_text = [word_tokenize(text) for text in list(train_df['reports'].values)]
w2v_model = models.Word2Vec(sentences=tokenized_text, vector_size=100, window=5, min_count=1, workers=4)
w2v_vectors = w2v_model.wv

In [46]:
w2v_vectors.vectors.shape

(4180, 100)

In [47]:
w2v_vectors["<UNK>"] = np.random.rand(100) # deal with unknowns

In [48]:
w2v_vectors['the', 'test'].shape

(2, 100)

In [49]:
# get mean vector representations for each observation in df (only needed in training/val)
def get_w2v_rep(text):
    word_list = word_tokenize(text)
    for i in range(len(word_list)):

        # set unknowns
        if word_list[i] not in w2v_vectors:
            word_list[i] = "<UNK>"

    return w2v_vectors[word_list].mean(axis = 0)

train_df['w2v_rep'] = train_df['reports'].apply(get_w2v_rep)
val_df['w2v_rep'] = val_df['reports'].apply(get_w2v_rep)
test_df['w2v_rep'] = test_df['reports'].apply(get_w2v_rep)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_df['w2v_rep'] = train_df['reports'].apply(get_w2v_rep)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  val_df['w2v_rep'] = val_df['reports'].apply(get_w2v_rep)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_df['w2v_rep'] = test_df['reports'].apply(get_w2v_rep)


In [27]:
train_df['w2v_rep']

1860    [-0.033205464, 0.2066226, 0.2607411, 0.1040353...
353     [-0.09625494, 0.30258718, 0.2565856, 0.0369223...
1333    [-0.03160494, 0.21052597, 0.15867855, 0.187000...
905     [0.06557725, 0.23562005, 0.14597449, 0.1480740...
1289    [-0.14926514, 0.30984622, 0.27443218, -0.02234...
                              ...                        
1917    [-0.05960874, 0.31634578, 0.1497153, 0.0519511...
753     [-0.13595995, 0.24625152, 0.20256452, 0.035814...
498     [-0.033588264, 0.26611727, 0.24642119, 0.07896...
1276    [0.018732648, 0.23649655, 0.20593771, 0.138350...
1325    [-0.09468327, 0.22715266, 0.1459681, 0.1139262...
Name: w2v_rep, Length: 1400, dtype: object

In [28]:
val_df['w2v_rep']

11      [-0.1611167, 0.23175792, 0.1479323, 0.11222999...
396     [0.0019928045, 0.25537673, 0.27234286, 0.09330...
284     [-0.19908889, 0.45935753, 0.26423538, -0.04694...
1066    [-0.087428145, 0.24361074, 0.21137945, 0.09308...
1191    [0.06580996, 0.17966212, 0.1610907, 0.13195044...
                              ...                        
1372    [-0.11031494, 0.3017788, 0.39083362, -0.001839...
463     [-0.023597201, 0.34299082, 0.3060334, 0.033079...
1349    [-0.1931523, 0.34236702, 0.16111599, 0.0683558...
186     [0.07413233, 0.2063625, 0.19092868, 0.12665959...
1147    [-0.15483484, 0.26143703, 0.22782367, 0.076561...
Name: w2v_rep, Length: 300, dtype: object

In [50]:
## create columns for counterfactual within each demographic group
## each column will contain the subject id of the counterfactual that is associated with the observed individual in each group
## if the counterfactual group is the same as the observed group, then fill in the subject id of the observed individual
def find_counterfactual(cols, group, index, data, id_name):
    new_df = data
    new_cols = [col for col in cols]
    # pivot values in w2v_rep
    if 'w2v_rep' in new_cols:
        new_df = pd.concat([new_df, pd.DataFrame(new_df['w2v_rep'].to_list()).set_index(new_df.index)], axis = 1).drop('w2v_rep', axis = 1)
        new_cols.remove('w2v_rep')
        for i in range(100): # w2v context size
            new_cols.append(i)

    # get specific group
    cft_group = new_df[new_df['general_race_group'] == group]
    # ids = cft_group[id_name].values
    # df indices for now - makes indexing faster in training phase (otw would have to switch df indices to subject_id - which could also work)
    ids = list(new_df.index)

    # nearest neighbour using L2 distance - could get scaling issues due to w2v, but will see if empirically it is ok...
    distances = norm(cft_group[new_cols].to_numpy().astype(float) - new_df.loc[index].to_frame().T[new_cols].to_numpy().astype(float), axis=1)

    min_idx = distances.argmin()

    return ids[min_idx]

def create_counterfactuals(in_df, cols, cols_to_labenc):

    # encode labels
    label = LabelEncoder()
    label_cols = cols_to_labenc
    in_df[label_cols] = in_df[label_cols].apply(label.fit_transform)

    in_df[['Black/African Descent Counterfactual', 'Asian Counterfactual', 'White/Caucasian Counterfactual',
        'Hispanic/Latino Counterfactual', 'Other Races Counterfactual']] = 0

    # for all subjects
    for j in tqdm(list(in_df.index), position=0, leave=True):
        j_id = in_df.loc[j, 'subject_id']
        obs_group = in_df.loc[j, 'general_race_group']
        complement = list(set(in_df['general_race_group'].unique()) - set(obs_group))
        in_df.loc[j, obs_group + ' Counterfactual'] = j_id

        for grp in complement:
            in_df.loc[j, grp + ' Counterfactual'] = find_counterfactual(cols, grp, j, in_df, 'subject_id')

    return in_df

## cols to include
cols = ['anchor_age', 'gender', 'marital_status', 'w2v_rep']
cols_to_labenc = ['gender', 'marital_status']

train_cf_df = create_counterfactuals(train_df, cols, cols_to_labenc)
val_cf_df = create_counterfactuals(val_df, cols, cols_to_labenc)
test_cf_df = create_counterfactuals(test_df, cols, cols_to_labenc)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[label_cols] = in_df[label_cols].apply(label.fit_transform)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[['Black/African Descent Counterfactual', 'Asian Counterfactual', 'White/Caucasian Counterfactual',


  0%|          | 0/1400 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[label_cols] = in_df[label_cols].apply(label.fit_transform)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[['Black/African Descent Counterfactual', 'Asian Counterfactual', 'White/Caucasian Counterfactual',
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[['Black/African Descent Count

  0%|          | 0/300 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[label_cols] = in_df[label_cols].apply(label.fit_transform)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[['Black/African Descent Counterfactual', 'Asian Counterfactual', 'White/Caucasian Counterfactual',
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[['Black/African Descent Count

  0%|          | 0/300 [00:00<?, ?it/s]

In [30]:
train_cf_df.head()

Unnamed: 0,dicom_id,subject_id,study_id,ViewPosition,Rows,Columns,label,img_path,reports,gender,anchor_age,marital_status,race,general_race_group,w2v_rep,Black/African Descent Counterfactual,Asian Counterfactual,White/Caucasian Counterfactual,Hispanic/Latino Counterfactual,Other Races Counterfactual
1860,fec7f65a-3b7a69b1-59eafdc1-2b4a4538-04f521cc,10440642,55577077,AP,3056,2544,0,/content/drive/MyDrive/harvard/6.8610/project/...,final report chest radiograph indication: ...,0,91,3,BLACK/AFRICAN AMERICAN,Black/African Descent,"[-0.033205464, 0.2066226, 0.2607411, 0.1040353...",1860,1078,1442,585,907
353,56ecabed-e3e052bb-180ed50b-89653be3-c5cacb88,10435691,54890599,PA,2021,2021,1,/content/drive/MyDrive/harvard/6.8610/project/...,final report examination: chest (pa and lat)...,1,72,2,WHITE - OTHER EUROPEAN,White/Caucasian,"[-0.09625494, 0.30258718, 0.2565856, 0.0369223...",120,1696,1860,1290,585
1333,4df853b2-840e4f5a-5387f41e-070334d2-ffced7e8,10290812,52962263,AP,3056,2544,0,/content/drive/MyDrive/harvard/6.8610/project/...,final report single frontal view of the chest...,1,64,1,WHITE,White/Caucasian,"[-0.03160494, 0.21052597, 0.15867855, 0.187000...",494,824,353,128,1453
905,019cdb57-e50733c5-2db84dac-c4035068-0dc2ec4e,10690270,55198267,PA,2022,1910,0,/content/drive/MyDrive/harvard/6.8610/project/...,final report chest radiograph indication: ...,1,70,2,BLACK/AFRICAN AMERICAN,Black/African Descent,"[0.06557725, 0.23562005, 0.14597449, 0.1480740...",353,1696,408,944,944
1289,dc5186e1-893dc185-2328195c-5dc2716e-5912229f,10233088,52543539,PA,2021,1803,0,/content/drive/MyDrive/harvard/6.8610/project/...,final report examination: chest (pa and lat)...,0,52,1,WHITE,White/Caucasian,"[-0.14926514, 0.30984622, 0.27443218, -0.02234...",247,628,1333,70,1333


In [31]:
val_cf_df.head()

Unnamed: 0,dicom_id,subject_id,study_id,ViewPosition,Rows,Columns,label,img_path,reports,gender,anchor_age,marital_status,race,general_race_group,w2v_rep,Black/African Descent Counterfactual,Asian Counterfactual,White/Caucasian Counterfactual,Hispanic/Latino Counterfactual,Other Races Counterfactual
11,c1b39f26-3121f632-35e3b8cf-81b31e1c-91b3c81c,10521666,51897926,AP,2539,3050,1,/content/drive/MyDrive/harvard/6.8610/project/...,final report examination: chest (portable ap...,1,85,1,WHITE,White/Caucasian,"[-0.1611167, 0.23175792, 0.1479323, 0.11222999...",1575,1575,11,851,851
396,ce815f18-c37cd2fe-ae0dfd77-17c8be60-586e3f77,10630336,54142541,AP,2544,3056,1,/content/drive/MyDrive/harvard/6.8610/project/...,final report history: history of lung cancer...,1,80,1,WHITE,White/Caucasian,"[0.0019928045, 0.25537673, 0.27234286, 0.09330...",1983,1066,396,851,851
284,96d056c9-2ea5ae18-dcc8f649-6a91b05f-e5a1193c,10778904,59093298,AP,2539,3050,1,/content/drive/MyDrive/harvard/6.8610/project/...,final report examination: chest (portable ap...,0,65,2,WHITE,White/Caucasian,"[-0.19908889, 0.45935753, 0.26423538, -0.04694...",521,257,284,1875,11
1066,ce664629-a5a2f916-11663a12-e0c1d40f-935df3e8,10290812,50687128,AP,3056,2544,1,/content/drive/MyDrive/harvard/6.8610/project/...,"final report ap chest, 12:16 a.m. on ___ h...",1,64,1,WHITE,White/Caucasian,"[-0.087428145, 0.24361074, 0.21137945, 0.09308...",895,396,1066,1875,1879
1191,c41e558d-df74e42b-ebae321c-25deab27-41dea22c,10630310,50289295,AP,3050,2539,0,/content/drive/MyDrive/harvard/6.8610/project/...,final report reason for examination: seizure...,1,47,2,WHITE,White/Caucasian,"[0.06580996, 0.17966212, 0.1610907, 0.13195044...",768,1875,1191,689,335


In [32]:
# weights based on training data
sample_size = train_df.shape[0]
weight_df = train_df.groupby('general_race_group').agg(weight = ('subject_id', lambda x: 1 / (len(x) / sample_size))).reset_index()
weight_df = weight_df.set_index('general_race_group')
weights = {
    'Black/African Descent Counterfactual': weight_df.loc['Black/African Descent'].item(),
    'Asian Counterfactual': weight_df.loc['Asian'].item(),
    'White/Caucasian Counterfactual' : weight_df.loc['White/Caucasian'].item(),
    'Hispanic/Latino Counterfactual': weight_df.loc['Hispanic/Latino'].item(),
    'Other Races Counterfactual': weight_df.loc['Other Races'].item()
}
weights

{'Black/African Descent Counterfactual': 8.383233532934131,
 'Asian Counterfactual': 20.28985507246377,
 'White/Caucasian Counterfactual': 1.347449470644851,
 'Hispanic/Latino Counterfactual': 27.450980392156865,
 'Other Races Counterfactual': 18.91891891891892}

## **Training**

In [37]:
def train_step(model, train_loader, optimizer, criterion, epoch, debias = False, counterfactual_df = None, counterfactual_weights = None):
    """
    A single training epoch
    """
    total_loss = 0
    model.train()
    for batch in tqdm(train_loader, position = 0, leave = True):
        optimizer.zero_grad()
        if not debias:
            logits = model(batch['image'].to(device), batch['text'])
            loss = criterion(logits, batch['label'].to(device))
        else:
            indices = batch['index']

            groups = ['Black/African Descent Counterfactual',
                      'Asian Counterfactual',
                      'White/Caucasian Counterfactual',
                      'Hispanic/Latino Counterfactual',
                      'Other Races Counterfactual']

            loss = 0
            for group in groups: # if group is same as origin, then will match w/ itself so for |groups| = 1, reduces to non-debiased case

                # 1. get counterfactual report for each group
                counterfactual_indices = []
                for index in indices:
                    counterfactual_indices.append(counterfactual_df['train_df'].loc[index.item(), group])

                # 2. build counterfactual batch B x 1 x 224 x 224
                images = []
                for index in counterfactual_indices:
                    image = train_transforms(
                        transforms.Resize([IMG_SIZE, IMG_SIZE])(read_image(counterfactual_df['train_df'].loc[index, 'img_path']))
                    )
                    images.append(image.repeat(1, 3, 1, 1)) # so we get 3 channels for resnet
                images = torch.stack(images).reshape(batch['image'].shape)


                # 3. get counterfactual loss
                logits = model(images.to(device), batch['text'])

                # 4. aggregate counterfactual losses using weights
                group_loss = criterion(logits, batch['label'].to(device))
                loss += group_loss * counterfactual_weights[group]
                del images, logits

        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)


def infer_step(model, eval_loader, criterion, epoch, debias = False, counterfactual_df = None, counterfactual_weights = None):
    """
    A single inference epoch
    """
    model.eval()
    total_loss = 0
    auc = 0
    with torch.no_grad():
        for batch in tqdm(eval_loader, position = 0, leave = True):
            if not debias:
                logits = model(batch['image'].to(device), batch['text'])
                loss = criterion(logits, batch['label'].to(device))
            else:
                indices = batch['index']
                groups = ['Black/African Descent Counterfactual',
                        'Asian Counterfactual',
                        'White/Caucasian Counterfactual',
                        'Hispanic/Latino Counterfactual',
                        'Other Races Counterfactual']

                loss = 0
                for group in groups: # if group is same as origin, then will match w/ itself so for |groups| = 1, reduces to non-debiased case

                    # 1. get counterfactual report for each group
                    counterfactual_indices = []
                    for index in indices:
                        counterfactual_indices.append(counterfactual_df['val_df'].loc[index.item(), group])

                    # 2. build counterfactual batch B x 1 x 224 x 224
                    images = []
                    for index in counterfactual_indices:
                        image = train_transforms(
                            transforms.Resize([IMG_SIZE, IMG_SIZE])(read_image(counterfactual_df['val_df'].loc[index, 'img_path']))
                        )
                        images.append(image.repeat(1, 3, 1, 1))
                    images = torch.stack(images).reshape(batch['image'].shape)

                    # 3. get counterfactual loss
                    logits = model(images.to(device), batch['text'])

                    # 4. aggregate counterfactual losses using weights
                    group_loss = criterion(logits, batch['label'].to(device))
                    loss += group_loss * counterfactual_weights[group]
                    del images, logits

            total_loss += loss.item()
    return total_loss / len(eval_loader)

In [34]:
def train(model, train_loader, val_loader, optimizer, criterion, epochs, debias = False, counterfactual_df = None, counterfactual_weights = None, save_path = "./model.pt"):
    best_val = float("inf")
    for epoch in tqdm(range(epochs), position = 0, leave = True):
        train_loss = train_step(model, train_loader, optimizer, criterion, epoch, debias, counterfactual_df, counterfactual_weights)
        val_loss = infer_step(model, val_loader, criterion, epoch, debias, counterfactual_df, counterfactual_weights)
        print(f"Epoch {epoch}: train loss {train_loss}, eval loss {val_loss}")

        if val_loss <= best_val:
            print(f"Updating best model at epoch {epoch}")
            best_val = val_loss
            best_model = copy.deepcopy(model)
            torch.save(model.state_dict(), save_path)
    return best_model


In [35]:
model = FusionModel(resnet, bert, tokenizer).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-5)
criterion = nn.CrossEntropyLoss()
mm_net = train(model, train_loader, val_loader, optimizer, criterion, 5, save_path = "/content/drive/MyDrive/harvard/6.8610/project/fusion.pt") # no debiasing

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 0: train loss 0.5905158065578767, eval loss 0.5378913669784864
Updating best model at epoch 0


  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 1: train loss 0.4053913444812809, eval loss 0.3831513737266262
Updating best model at epoch 1


  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 2: train loss 0.2428585591373433, eval loss 0.30242584106201925
Updating best model at epoch 2


  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 3: train loss 0.17114859639933067, eval loss 0.22823509414990742
Updating best model at epoch 3


  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 4: train loss 0.09000893695496155, eval loss 0.28707568579275783


In [38]:
model = FusionModel(resnet, bert, tokenizer).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-5)
criterion = nn.CrossEntropyLoss()
cf_df_dict = {
    'train_df': train_cf_df,
    'val_df': val_cf_df,
}
mm_net_debiased = train(model, train_loader, val_loader, optimizer, criterion, 5, True, cf_df_dict, weights, "/content/drive/MyDrive/harvard/6.8610/project/debiased_fusion.pt")

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 0: train loss 9.36624675576176, eval loss 29.57999659438928
Updating best model at epoch 0


  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 1: train loss 4.836592130144792, eval loss 28.950452362298964
Updating best model at epoch 1


  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 2: train loss 4.559580465054938, eval loss 36.596940129896005


  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 3: train loss 2.7861406703639244, eval loss 40.62518711258968


  0%|          | 0/350 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

Epoch 4: train loss 0.6470431784926248, eval loss 48.194140402295936


## **Evaluation**

Evaluation options with or without matching

In [23]:
model_ndb = FusionModel(resnet, bert, tokenizer).to(device)
model_ndb.load_state_dict(torch.load("/content/drive/MyDrive/harvard/6.8610/project/fusion.pt"))

<All keys matched successfully>

In [24]:
model_db = FusionModel(resnet, bert, tokenizer).to(device)
model_db.load_state_dict(torch.load("/content/drive/MyDrive/harvard/6.8610/project/debiased_fusion.pt"))

<All keys matched successfully>

In [70]:
def get_preds(model, test_loader):
    """
    Get predictions
    """
    indices = []
    preds = []
    labels = []
    with torch.no_grad():
        for batch in tqdm(test_loader, position = 0, leave = True):
            logits = model(batch['image'].to(device), batch['text'])
            preds.append(torch.argmax(logits, dim = 1).cpu().numpy())
            labels.append(batch['label'].numpy())
            indices.append(batch['index'].numpy())

    indices = np.concatenate(indices)
    preds = np.concatenate(preds)
    labels = np.concatenate(labels)

    return indices, preds, labels

def get_metrics(preds, labels):
    """
    Get metrics
    """

    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds)
    recall = recall_score(labels, preds)
    f1 = f1_score(labels, preds)
    aucroc = roc_auc_score(labels, preds)
    confm = confusion_matrix(labels, preds, normalize="true")

    print(f"Accuracy: {accuracy}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"F1: {f1}")
    print(f"AUC-ROC: {aucroc}")
    print(f"Confusion Matrix: {confm}")

    return accuracy, precision, recall, f1, aucroc, confm

In [68]:
# not debiased
test_indices, test_preds, test_labels = get_preds(model_ndb, test_loader)
print(sum(test_indices == test_df.index.values)) # making sure order is same
test_df['pred'] = test_preds # assigning predicted labels

  0%|          | 0/75 [00:00<?, ?it/s]

300


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_df['pred'] = test_preds # assigning predicted labels


In [71]:
for group in np.unique(test_cf_df.general_race_group.values):
    print(f"########### {group} ###########")
    get_metrics(test_df[test_df['general_race_group'] == group]['pred'].values, test_df[test_df['general_race_group'] == group]['label'].values)

########### Asian ###########
Accuracy: 0.9333333333333333
Precision: 0.9
Recall: 1.0
F1: 0.9473684210526316
AUC-ROC: 0.9166666666666667
Confusion Matrix: [[0.83333333 0.16666667]
 [0.         1.        ]]
########### Black/African Descent ###########
Accuracy: 0.7916666666666666
Precision: 0.7435897435897436
Recall: 1.0
F1: 0.8529411764705882
AUC-ROC: 0.736842105263158
Confusion Matrix: [[0.47368421 0.52631579]
 [0.         1.        ]]
########### Hispanic/Latino ###########
Accuracy: 0.8571428571428571
Precision: 0.8333333333333334
Recall: 1.0
F1: 0.9090909090909091
AUC-ROC: 0.75
Confusion Matrix: [[0.5 0.5]
 [0.  1. ]]
########### Other Races ###########
Accuracy: 1.0
Precision: 1.0
Recall: 1.0
F1: 1.0
AUC-ROC: 1.0
Confusion Matrix: [[1. 0.]
 [0. 1.]]
########### White/Caucasian ###########
Accuracy: 0.8809523809523809
Precision: 0.8588235294117647
Recall: 0.9931972789115646
F1: 0.9211356466876972
AUC-ROC: 0.8061224489795918
Confusion Matrix: [[0.61904762 0.38095238]
 [0.00680272 0

In [72]:
# debiased
test_indices, test_preds, test_labels = get_preds(model_db, test_loader)
print(sum(test_indices == test_df.index.values)) # making sure order is same
test_df['db_pred'] = test_preds # assigning predicted labels

  0%|          | 0/75 [00:00<?, ?it/s]

300


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_df['db_pred'] = test_preds # assigning predicted labels


In [73]:
for group in np.unique(test_cf_df.general_race_group.values):
    print(f"########### {group} ###########")
    get_metrics(test_df[test_df['general_race_group'] == group]['db_pred'].values, test_df[test_df['general_race_group'] == group]['label'].values)

########### Asian ###########
Accuracy: 0.9333333333333333
Precision: 1.0
Recall: 0.8888888888888888
F1: 0.9411764705882353
AUC-ROC: 0.9444444444444444
Confusion Matrix: [[1.         0.        ]
 [0.11111111 0.88888889]]
########### Black/African Descent ###########
Accuracy: 0.9375
Precision: 0.9333333333333333
Recall: 0.9655172413793104
F1: 0.9491525423728815
AUC-ROC: 0.9301270417422869
Confusion Matrix: [[0.89473684 0.10526316]
 [0.03448276 0.96551724]]
########### Hispanic/Latino ###########
Accuracy: 0.8571428571428571
Precision: 1.0
Recall: 0.8
F1: 0.888888888888889
AUC-ROC: 0.9
Confusion Matrix: [[1.  0. ]
 [0.2 0.8]]
########### Other Races ###########
Accuracy: 1.0
Precision: 1.0
Recall: 1.0
F1: 1.0
AUC-ROC: 1.0
Confusion Matrix: [[1. 0.]
 [0. 1.]]
########### White/Caucasian ###########
Accuracy: 0.9142857142857143
Precision: 0.9161290322580645
Recall: 0.9659863945578231
F1: 0.9403973509933774
AUC-ROC: 0.8798185941043084
Confusion Matrix: [[0.79365079 0.20634921]
 [0.03401361