# **Causal De-biasing Fusion Pipeline**

In [None]:
!pip install monai torchinfo pytorch-metric-learning



In [None]:
import os
import cv2
import pandas as pd
import numpy as np
from numpy.linalg import norm
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.io import read_image
import torchvision.transforms.v2 as transforms
from torchvision.transforms import Compose
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet50, ResNet50_Weights
from pytorch_metric_learning import losses

import nltk
from nltk import word_tokenize
from gensim import models

from transformers import AutoTokenizer, DistilBertModel

import monai.transforms as mt

from torchinfo import summary

from tqdm.notebook import tqdm

from google.colab import drive
drive.mount('/content/drive')

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
nltk.download('punkt')
device = "cuda" if torch.cuda.is_available() else "cpu"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## **Data Setup**

In [None]:
df = pd.read_csv('/content/drive/MyDrive/harvard/6.8610/project/data/6.8610proj/dataset_ready.csv')

In [None]:
df.shape

(5000, 14)

In [None]:
df.head()

Unnamed: 0,dicom_id,subject_id,study_id,ViewPosition,Rows,Columns,label,img_path,reports,gender,anchor_age,marital_status,race,general_race_group
0,86a4065d-a51890c6-094034a8-c549f6b9-6660ff93,10406570,57207207,PA,3056,2500,0,/content/drive/My Drive/6.8610proj/86a4065d-a5...,FINAL REPORT\...,F,69,WIDOWED,BLACK/AFRICAN AMERICAN,Black/African Descent
1,a1072339-d1fe3a01-149f489b-9a08c49b-f2fee79e,10965697,51095949,AP,3056,2544,0,/content/drive/My Drive/6.8610proj/a1072339-d1...,FINAL REPORT\...,F,67,WIDOWED,ASIAN - CHINESE,Asian
2,becf8ed6-c5f60c71-89040c32-9d94b7c6-eb956bc1,10104732,59794138,AP,3056,2544,0,/content/drive/My Drive/6.8610proj/becf8ed6-c5...,WET READ: ___ ___ ___ 8:27 PM\n No change fr...,M,49,SINGLE,WHITE,White/Caucasian
3,db41181b-0c240a54-1a370211-3723e7a9-b6cdb316,10898945,50988324,AP,2606,2544,0,/content/drive/My Drive/6.8610proj/db41181b-0c...,FINAL REPORT\...,M,78,MARRIED,WHITE,White/Caucasian
4,1739a403-be126d84-266aab85-442b24fc-a4ebe43c,10867055,59657889,AP,3050,2539,1,/content/drive/My Drive/6.8610proj/1739a403-be...,FINAL REPORT\...,M,47,SINGLE,WHITE,White/Caucasian


In [None]:
df['img_path'] = df['img_path'].str.replace('/content/drive/My Drive/6.8610proj', '/content/drive/MyDrive/harvard/6.8610/project/data/6.8610proj')
df.head()

  df['img_path'] = df['img_path'].str.replace('/content/drive/My Drive/6.8610proj', '/content/drive/MyDrive/harvard/6.8610/project/data/6.8610proj')


Unnamed: 0,dicom_id,subject_id,study_id,ViewPosition,Rows,Columns,label,img_path,reports,gender,anchor_age,marital_status,race,general_race_group
0,86a4065d-a51890c6-094034a8-c549f6b9-6660ff93,10406570,57207207,PA,3056,2500,0,/content/drive/MyDrive/harvard/6.8610/project/...,FINAL REPORT\...,F,69,WIDOWED,BLACK/AFRICAN AMERICAN,Black/African Descent
1,a1072339-d1fe3a01-149f489b-9a08c49b-f2fee79e,10965697,51095949,AP,3056,2544,0,/content/drive/MyDrive/harvard/6.8610/project/...,FINAL REPORT\...,F,67,WIDOWED,ASIAN - CHINESE,Asian
2,becf8ed6-c5f60c71-89040c32-9d94b7c6-eb956bc1,10104732,59794138,AP,3056,2544,0,/content/drive/MyDrive/harvard/6.8610/project/...,WET READ: ___ ___ ___ 8:27 PM\n No change fr...,M,49,SINGLE,WHITE,White/Caucasian
3,db41181b-0c240a54-1a370211-3723e7a9-b6cdb316,10898945,50988324,AP,2606,2544,0,/content/drive/MyDrive/harvard/6.8610/project/...,FINAL REPORT\...,M,78,MARRIED,WHITE,White/Caucasian
4,1739a403-be126d84-266aab85-442b24fc-a4ebe43c,10867055,59657889,AP,3050,2539,1,/content/drive/MyDrive/harvard/6.8610/project/...,FINAL REPORT\...,M,47,SINGLE,WHITE,White/Caucasian


In [None]:
df['reports'] = df['reports'].str.lower().str.replace("\n", " ").str.strip()

In [None]:
df.shape

(5000, 14)

In [None]:
N = 4000
df = df.sample(n=N, random_state=SEED)
df.shape

(4000, 14)

In [None]:
IMG_CLASSES = ['Normal','PleuralEffusion']
NUM_CLASSES = len(IMG_CLASSES)
IMG_SIZE = 224
class MultiModalDataset(Dataset):
    def __init__(self, indices, image_dir, texts, labels, transform = None):
        # load dataset
        self.images = []
        for image_file in tqdm(image_dir, position = 0, leave = True):
            image = transforms.Resize([IMG_SIZE, IMG_SIZE])(read_image(image_file))
            self.images.append(image)


        self.indices = indices
        self.texts = texts
        self.labels = labels

        # transforms
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.transform(self.images[idx]) if self.transform is not None else self.images[idx]
        image = image.repeat(3, 1, 1)
        # image = (image - image.min())/(image.max() - image.min()) # map to [0, 1]
        index = self.indices[idx]
        label = self.labels[idx]
        text = self.texts[idx]
        sample = {'index': index, 'image': image, 'text': text, 'label': label}#transforms.ToDtype(torch.uint8, scale=True)(mask)}
        return sample

In [None]:
# splits
TRAIN_SIZE = 0.7
VAL_SIZE = 0.15
TEST_SIZE = 0.15
samp_df = df.sample(frac=1, random_state=SEED)
train_df, val_df, test_df = samp_df[0:int(TRAIN_SIZE*len(samp_df))], samp_df[int(TRAIN_SIZE*len(samp_df)):int((TRAIN_SIZE+VAL_SIZE)*len(samp_df))], samp_df[int((TRAIN_SIZE+VAL_SIZE)*len(samp_df)):]
train_df.shape, val_df.shape, test_df.shape

((2800, 14), (600, 14), (600, 14))

In [None]:
train_transforms = transforms.Compose([
    mt.NormalizeIntensity(),
])

val_transforms = transforms.Compose([
    mt.NormalizeIntensity(),
])

train_ds = MultiModalDataset(list(train_df.index),
                             train_df['img_path'].values,
                             train_df['reports'].values,
                             train_df['label'].values,
                             transform = train_transforms)
val_ds = MultiModalDataset(list(val_df.index),
                           val_df['img_path'].values,
                           val_df['reports'].values,
                           val_df['label'].values,
                           transform = val_transforms)
test_ds = MultiModalDataset(list(test_df.index),
                            test_df['img_path'].values,
                            test_df['reports'].values,
                            test_df['label'].values,
                            transform = val_transforms)

  0%|          | 0/2800 [00:00<?, ?it/s]



  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/600 [00:00<?, ?it/s]

In [None]:
BATCH_SIZE = 6
train_kwargs = {'batch_size': BATCH_SIZE, 'shuffle': True}
val_kwargs = {'batch_size': BATCH_SIZE, 'shuffle': False}

train_loader = DataLoader(train_ds, **train_kwargs)
val_loader = DataLoader(val_ds, **val_kwargs)
test_loader = DataLoader(test_ds, **val_kwargs)

In [None]:
for batch in test_loader:
    print(batch['index'].shape)
    print(batch['image'].shape)
    print(len(batch['text'])) # will tokenize in training loop...
    print(batch['label'].shape)
    break

torch.Size([6])
torch.Size([6, 3, 224, 224])
6
torch.Size([6])


## **Model**

In [None]:
bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [None]:
summary(bert)

Layer (type:depth-idx)                             Param #
DistilBertModel                                    --
├─Embeddings: 1-1                                  --
│    └─Embedding: 2-1                              23,440,896
│    └─Embedding: 2-2                              393,216
│    └─LayerNorm: 2-3                              1,536
│    └─Dropout: 2-4                                --
├─Transformer: 1-2                                 --
│    └─ModuleList: 2-5                             --
│    │    └─TransformerBlock: 3-1                  7,087,872
│    │    └─TransformerBlock: 3-2                  7,087,872
│    │    └─TransformerBlock: 3-3                  7,087,872
│    │    └─TransformerBlock: 3-4                  7,087,872
│    │    └─TransformerBlock: 3-5                  7,087,872
│    │    └─TransformerBlock: 3-6                  7,087,872
Total params: 66,362,880
Trainable params: 66,362,880
Non-trainable params: 0

In [None]:
batch_tokenized = tokenizer(batch['text'], return_tensors="pt", padding='max_length', truncation=True)
test_out = bert(**batch_tokenized)

In [None]:
test_out.last_hidden_state.shape

torch.Size([6, 512, 768])

In [None]:
resnet = torch.nn.Sequential(*list(resnet50(weights=ResNet50_Weights.DEFAULT).children())[:-1])

In [None]:
summary(resnet)

Layer (type:depth-idx)                   Param #
Sequential                               --
├─Conv2d: 1-1                            9,408
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─Bottleneck: 2-1                   --
│    │    └─Conv2d: 3-1                  4,096
│    │    └─BatchNorm2d: 3-2             128
│    │    └─Conv2d: 3-3                  36,864
│    │    └─BatchNorm2d: 3-4             128
│    │    └─Conv2d: 3-5                  16,384
│    │    └─BatchNorm2d: 3-6             512
│    │    └─ReLU: 3-7                    --
│    │    └─Sequential: 3-8              16,896
│    └─Bottleneck: 2-2                   --
│    │    └─Conv2d: 3-9                  16,384
│    │    └─BatchNorm2d: 3-10            128
│    │    └─Conv2d: 3-11                 36,864
│    │    └─BatchNorm2d: 3-12            128
│    │    └─Conv2d: 3-13               

In [None]:
resnet(batch['image']).shape

torch.Size([6, 2048, 1, 1])

In [None]:
class MultiModalModel(nn.Module):
    """
    Model that outputs representations for image and text
    """
    def __init__(self, cv_encoder, nlp_encoder, tokenizer):
        super(MultiModalModel, self).__init__()
        self.tokenizer = tokenizer
        self.nlp_encoder = nlp_encoder
        self.nlp_out = nn.Linear(768, 256) # bert
        self.cv_encoder = cv_encoder
        self.cv_out = nn.Linear(2048, 256) # resnet

    def forward(self, im, text):
        tokens = self.tokenizer(text, return_tensors="pt", padding='max_length', truncation=True).to(device)
        z_text = self.nlp_encoder(**tokens).last_hidden_state
        z_text = self.nlp_out(z_text).mean(axis = 1) # average sequence representation
        z_im = self.cv_encoder(im).squeeze()
        z_im = self.cv_out(z_im)
        return z_im, z_text

In [None]:
class FusionModel(nn.Module):
    """
    Model that performs intermediate fusion and outputs logits
    """
    def __init__(self, cv_encoder, nlp_encoder, tokenizer):
        super(FusionModel, self).__init__()
        multimodal_encoder = MultiModalModel(cv_encoder, nlp_encoder, tokenizer)
        self.multimodal_endover = multimodal_encoder
        self.out = nn.Linear(512, 2)

    def forward(self, im, text):
        z_im, z_text = self.multimodal_endover(im, text)
        z = torch.cat([z_im, z_text], axis = 1)
        z = self.out(z)
        return z

In [None]:
model = FusionModel(resnet, bert, tokenizer)
summary(model)

Layer (type:depth-idx)                                       Param #
FusionModel                                                  --
├─MultiModalModel: 1-1                                       --
│    └─DistilBertModel: 2-1                                  --
│    │    └─Embeddings: 3-1                                  23,835,648
│    │    └─Transformer: 3-2                                 42,527,232
│    └─Linear: 2-2                                           196,864
│    └─Sequential: 2-3                                       --
│    │    └─Conv2d: 3-3                                      9,408
│    │    └─BatchNorm2d: 3-4                                 128
│    │    └─ReLU: 3-5                                        --
│    │    └─MaxPool2d: 3-6                                   --
│    │    └─Sequential: 3-7                                  215,808
│    │    └─Sequential: 3-8                                  1,219,584
│    │    └─Sequential: 3-9                                  7

## **Causal De-Biasing**

In [None]:
# w2v lookup matrix
tokenized_text = [word_tokenize(text) for text in list(train_df['reports'].values)]
w2v_model = models.Word2Vec(sentences=tokenized_text, vector_size=100, window=5, min_count=1, workers=4)
w2v_vectors = w2v_model.wv

In [None]:
w2v_vectors.vectors.shape

(5651, 100)

In [None]:
w2v_vectors["<UNK>"] = np.random.rand(100) # deal with unknowns

In [None]:
w2v_vectors['the'].shape

(100,)

In [None]:
# get mean vector representations for each observation in df (only needed in training/val)
def get_w2v_rep(text):
    word_list = word_tokenize(text)
    for i in range(len(word_list)):

        # set unknowns
        if word_list[i] not in w2v_vectors:
            word_list[i] = "<UNK>"

    return w2v_vectors[word_list].mean(axis = 0)

train_df['w2v_rep'] = train_df['reports'].apply(get_w2v_rep)
val_df['w2v_rep'] = val_df['reports'].apply(get_w2v_rep)
test_df['w2v_rep'] = test_df['reports'].apply(get_w2v_rep)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_df['w2v_rep'] = train_df['reports'].apply(get_w2v_rep)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  val_df['w2v_rep'] = val_df['reports'].apply(get_w2v_rep)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_df['w2v_rep'] = test_df['reports'].apply(get_w2v_rep)


In [None]:
train_df['w2v_rep']

1391    [0.1431831, 0.8061786, 0.3222398, 0.056575578,...
1285    [0.1886897, 0.77073896, 0.6276552, -0.18130088...
3142    [-0.072228104, 0.28583893, 0.6208517, -0.12407...
81      [0.2885824, 0.80615133, 0.57626206, -0.0283374...
1092    [0.03490969, 0.40188745, 0.71527094, -0.060938...
                              ...                        
441     [0.019018777, 0.34156767, 0.5149189, -0.110178...
4120    [-0.0701188, 0.49396485, 0.6116898, 0.02462575...
4347    [0.040428985, 0.06380252, 0.6532868, 0.0406045...
1183    [-0.026629562, 0.50979793, 0.60838556, -0.0547...
3402    [-0.0007209983, 0.32048064, 0.59073085, -0.085...
Name: w2v_rep, Length: 2800, dtype: object

In [None]:
val_df['w2v_rep']

3608    [0.067898184, 0.47818714, 0.51573724, -0.04206...
867     [0.093796715, 0.3781239, 0.6234552, -0.0290211...
3932    [0.035818752, 0.20288226, 0.6414697, 0.1171203...
461     [-0.06937614, 0.48614416, 0.6236977, -0.064642...
4629    [-0.09949541, 0.6197796, 0.5834543, 0.02398801...
                              ...                        
1525    [0.2092606, 0.4226435, 0.69694525, 0.016435388...
278     [0.14337681, 0.5226113, 0.7490471, -0.1685027,...
2109    [-0.13036576, 0.8047973, 0.527713, -0.13033654...
83      [0.09926721, 0.22341128, 0.6250036, -0.0584550...
4677    [0.13975702, 0.89748096, 0.44178632, -0.020974...
Name: w2v_rep, Length: 600, dtype: object

In [None]:
## create columns for counterfactual within each demographic group
## each column will contain the subject id of the counterfactual that is associated with the observed individual in each group
## if the counterfactual group is the same as the observed group, then fill in the subject id of the observed individual
def find_counterfactual(cols, group, index, data, id_name):
    new_df = data
    new_cols = [col for col in cols]
    # pivot values in w2v_rep
    if 'w2v_rep' in new_cols:
        new_df = pd.concat([new_df, pd.DataFrame(new_df['w2v_rep'].to_list()).set_index(new_df.index)], axis = 1).drop('w2v_rep', axis = 1)
        new_cols.remove('w2v_rep')
        for i in range(100): # w2v context size
            new_cols.append(i)

    # get specific group
    cft_group = new_df[new_df['general_race_group'] == group]
    # ids = cft_group[id_name].values
    # df indices for now - makes indexing faster in training phase (otw would have to switch df indices to subject_id - which could also work)
    ids = list(new_df.index)

    # nearest neighbour using L2 distance - could get scaling issues due to w2v, but will see if empirically it is ok...
    distances = norm(cft_group[new_cols].to_numpy().astype(float) - new_df.loc[index].to_frame().T[new_cols].to_numpy().astype(float), axis=1)

    min_idx = distances.argmin()

    return ids[min_idx]

def create_counterfactuals(in_df, cols, cols_to_labenc):

    # encode labels
    label = LabelEncoder()
    label_cols = cols_to_labenc
    in_df[label_cols] = in_df[label_cols].apply(label.fit_transform)

    in_df[['Black/African Descent Counterfactual', 'Asian Counterfactual', 'White/Caucasian Counterfactual',
        'Hispanic/Latino Counterfactual', 'Other Races Counterfactual']] = 0

    # for all subjects
    for j in tqdm(list(in_df.index), position=0, leave=True):
        j_id = in_df.loc[j, 'subject_id']
        obs_group = in_df.loc[j, 'general_race_group']
        complement = list(set(in_df['general_race_group'].unique()) - set(obs_group))
        in_df.loc[j, obs_group + ' Counterfactual'] = j_id

        for grp in complement:
            in_df.loc[j, grp + ' Counterfactual'] = find_counterfactual(cols, grp, j, in_df, 'subject_id')

    return in_df

## cols to include
cols = ['anchor_age', 'gender', 'marital_status', 'w2v_rep']
cols_to_labenc = ['gender', 'marital_status']

train_cf_df = create_counterfactuals(train_df, cols, cols_to_labenc)
val_cf_df = create_counterfactuals(val_df, cols, cols_to_labenc)
test_cf_df = create_counterfactuals(test_df, cols, cols_to_labenc)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[label_cols] = in_df[label_cols].apply(label.fit_transform)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[['Black/African Descent Counterfactual', 'Asian Counterfactual', 'White/Caucasian Counterfactual',
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[['Black/African Descent Count

  0%|          | 0/2800 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[label_cols] = in_df[label_cols].apply(label.fit_transform)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[['Black/African Descent Counterfactual', 'Asian Counterfactual', 'White/Caucasian Counterfactual',
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[['Black/African Descent Count

  0%|          | 0/600 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[label_cols] = in_df[label_cols].apply(label.fit_transform)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[['Black/African Descent Counterfactual', 'Asian Counterfactual', 'White/Caucasian Counterfactual',
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  in_df[['Black/African Descent Count

  0%|          | 0/600 [00:00<?, ?it/s]

In [None]:
train_cf_df.head()

Unnamed: 0,dicom_id,subject_id,study_id,ViewPosition,Rows,Columns,label,img_path,reports,gender,anchor_age,marital_status,race,general_race_group,w2v_rep,Black/African Descent Counterfactual,Asian Counterfactual,White/Caucasian Counterfactual,Hispanic/Latino Counterfactual,Other Races Counterfactual
1391,d05a53ee-448e7f8d-e5062343-a82beabd-482399fe,10324973,55534471,PA,3056,2524,0,/content/drive/MyDrive/harvard/6.8610/project/...,final report history: decreased breath sound...,1,80,1,WHITE,White/Caucasian,"[0.1431831, 0.8061786, 0.3222398, 0.056575578,...",4866,658,1391,2980,3167
1285,82b13a50-a56eab3b-0cd7ccc6-d8f3394b-61104e92,10506015,52385680,PA,2022,2022,0,/content/drive/MyDrive/harvard/6.8610/project/...,final report history: cough. comparison: ...,1,25,2,ASIAN - ASIAN INDIAN,Asian,"[0.1886897, 0.77073896, 0.6276552, -0.18130088...",3908,1391,4497,2575,3604
3142,ca93beeb-51ff6d55-ad032e67-32b9cde8-0bf04394,10078115,51981600,AP,3043,2539,1,/content/drive/MyDrive/harvard/6.8610/project/...,"final report ap chest, 7:53 a.m. on ___. h...",1,50,1,OTHER,Other Races,"[-0.072228104, 0.28583893, 0.6208517, -0.12407...",4435,1587,4759,1559,1391
81,c7c31dca-c47dbf98-9cdfb286-f96083ae-715abd88,10279956,56898314,PA,2544,3056,0,/content/drive/MyDrive/harvard/6.8610/project/...,final report history: bilateral lower extrem...,1,46,2,WHITE,White/Caucasian,"[0.2885824, 0.80615133, 0.57626206, -0.0283374...",3568,2682,1285,1032,1704
1092,35b972db-6b959ede-d70061ea-53341c0a-5e320b15,10692735,55032282,PA,2021,1758,1,/content/drive/MyDrive/harvard/6.8610/project/...,final report chest radiograph indication: ...,0,68,3,BLACK/AFRICAN AMERICAN,Black/African Descent,"[0.03490969, 0.40188745, 0.71527094, -0.060938...",1391,2248,1862,4041,4546


In [None]:
val_cf_df.head()

Unnamed: 0,dicom_id,subject_id,study_id,ViewPosition,Rows,Columns,label,img_path,reports,gender,anchor_age,marital_status,race,general_race_group,w2v_rep,Black/African Descent Counterfactual,Asian Counterfactual,White/Caucasian Counterfactual,Hispanic/Latino Counterfactual,Other Races Counterfactual
3608,1617c7ef-e1a6c62a-bbdd1ca0-eae6d065-5dbe21f0,10190130,55005273,PA,2021,2021,1,/content/drive/MyDrive/harvard/6.8610/project/...,"final report pa and lateral chest, ___ h...",0,65,1,WHITE,White/Caucasian,"[0.067898184, 0.47818714, 0.51573724, -0.04206...",4411,1855,3608,1855,640
867,b29075d4-fd3ac67a-d6c9c3b9-1b792505-a385e688,10627650,51109572,PA,3056,2544,1,/content/drive/MyDrive/harvard/6.8610/project/...,final report exam: chest frontal and lateral...,1,41,2,BLACK/CAPE VERDEAN,Black/African Descent,"[0.093796715, 0.3781239, 0.6234552, -0.0290211...",3608,59,181,1163,4777
3932,3b77e898-49f8ee91-7b585a97-17fb06cc-8b342b39,10196757,54943509,AP,3056,2544,0,/content/drive/MyDrive/harvard/6.8610/project/...,final report chest radiograph indication: ...,1,70,1,OTHER,Other Races,"[0.035818752, 0.20288226, 0.6414697, 0.1171203...",1621,4888,1474,1855,3608
461,f4b59688-e1fa4d49-b3eb6419-2b731f85-7ffc359b,10213275,55321436,AP,2544,3056,0,/content/drive/MyDrive/harvard/6.8610/project/...,final report examination: chest (portable ap...,0,64,2,WHITE,White/Caucasian,"[-0.06937614, 0.48614416, 0.6236977, -0.064642...",2596,964,867,4831,1620
4629,874ebb9d-46f1609c-b01de1e4-89a59ab0-823abf57,10291098,57722714,AP,2544,3056,0,/content/drive/MyDrive/harvard/6.8610/project/...,final report examination: chest (portable ap...,1,29,2,WHITE,White/Caucasian,"[-0.09949541, 0.6197796, 0.5834543, 0.02398801...",2432,640,3932,1621,634


In [None]:
# weights based on training data
sample_size = train_df.shape[0]
weight_df = train_df.groupby('general_race_group').agg(weight = ('subject_id', lambda x: 1 / (len(x) / sample_size))).reset_index()
weight_df = weight_df.set_index('general_race_group')
weights = {
    'Black/African Descent Counterfactual': weight_df.loc['Black/African Descent'].item(),
    'Asian Counterfactual': weight_df.loc['Asian'].item(),
    'White/Caucasian Counterfactual' : weight_df.loc['White/Caucasian'].item(),
    'Hispanic/Latino Counterfactual': weight_df.loc['Hispanic/Latino'].item(),
    'Other Races Counterfactual': weight_df.loc['Other Races'].item()
}
weights

{'Black/African Descent Counterfactual': 7.671232876712328,
 'Asian Counterfactual': 23.52941176470588,
 'White/Caucasian Counterfactual': 1.3539651837524178,
 'Hispanic/Latino Counterfactual': 25.454545454545457,
 'Other Races Counterfactual': 20.28985507246377}

## **Training**

In [None]:
def train_step(model, train_loader, optimizer, criterion, epoch, debias = False, counterfactual_df = None, counterfactual_weights = None):
    """
    A single training epoch
    """
    total_loss = 0
    model.train()
    for batch in tqdm(train_loader, position = 0, leave = True):
        optimizer.zero_grad()
        if not debias:
            logits = model(batch['image'].to(device), batch['text'])
            loss = criterion(logits, batch['label'].to(device))
        else:
            indices = batch['index']

            groups = ['Black/African Descent Counterfactual',
                      'Asian Counterfactual',
                      'White/Caucasian Counterfactual',
                      'Hispanic/Latino Counterfactual',
                      'Other Races Counterfactual']

            loss = 0
            for group in groups: # if group is same as origin, then will match w/ itself so for |groups| = 1, reduces to non-debiased case

                # 1. get counterfactual report for each group
                counterfactual_indices = []
                for index in indices:
                    counterfactual_indices.append(counterfactual_df['train_df'].loc[index.item(), group])

                # 2. build counterfactual batch B x 1 x 224 x 224
                images = []
                for index in counterfactual_indices:
                    image = train_transforms(
                        transforms.Resize([IMG_SIZE, IMG_SIZE])(read_image(counterfactual_df['train_df'].loc[index, 'img_path']))
                    )
                    images.append(image.repeat(1, 3, 1, 1)) # so we get 3 channels for resnet
                images = torch.stack(images).reshape(batch['image'].shape)


                # 3. get counterfactual loss
                logits = model(images.to(device), batch['text'])

                # 4. aggregate counterfactual losses using weights
                group_loss = criterion(logits, batch['label'].to(device))
                loss += group_loss * counterfactual_weights[group]
                del images, logits
                torch.cuda.empty_cache()

        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)


def infer_step(model, eval_loader, criterion, epoch, debias = False, counterfactual_df = None, counterfactual_weights = None):
    """
    A single inference epoch
    """
    model.eval()
    total_loss = 0
    auc = 0
    count = 0
    with torch.no_grad():
        for batch in tqdm(eval_loader, position = 0, leave = True):
            if not debias:
                logits = model(batch['image'].to(device), batch['text'])
                loss = criterion(logits, batch['label'].to(device))

                try:
                    batch_auc = roc_auc_score(batch['label'].to('cpu'), F.softmax(logits, dim = 1)[:, 1].to('cpu'))
                    count += 1
                except:
                    batch_auc = 0
            else:
                indices = batch['index']
                groups = ['Black/African Descent Counterfactual',
                        'Asian Counterfactual',
                        'White/Caucasian Counterfactual',
                        'Hispanic/Latino Counterfactual',
                        'Other Races Counterfactual']

                loss = 0
                batch_auc = 0
                for group in groups: # if group is same as origin, then will match w/ itself so for |groups| = 1, reduces to non-debiased case

                    # 1. get counterfactual report for each group
                    counterfactual_indices = []
                    for index in indices:
                        counterfactual_indices.append(counterfactual_df['val_df'].loc[index.item(), group])

                    # 2. build counterfactual batch B x 1 x 224 x 224
                    images = []
                    for index in counterfactual_indices:
                        image = train_transforms(
                            transforms.Resize([IMG_SIZE, IMG_SIZE])(read_image(counterfactual_df['val_df'].loc[index, 'img_path']))
                        )
                        images.append(image.repeat(1, 3, 1, 1))
                    images = torch.stack(images).reshape(batch['image'].shape)

                    # 3. get counterfactual loss
                    logits = model(images.to(device), batch['text'])

                    # 4. aggregate counterfactual losses using weights
                    group_loss = criterion(logits, batch['label'].to(device))
                    loss += group_loss * counterfactual_weights[group]

                    try:
                        batch_auc += roc_auc_score(batch['label'].to('cpu'), F.softmax(logits, dim = 1)[:, 1].to('cpu'))
                        count += 1
                    except:
                        batch_auc += 0
                    del images, logits
                    torch.cuda.empty_cache()

            total_loss += loss.item()
            auc += batch_auc
    return total_loss / len(eval_loader), auc/count

In [None]:
def train(model, train_loader, val_loader, optimizer, criterion, epochs, debias = False, counterfactual_df = None, counterfactual_weights = None, save_path = "./model.pt"):
    best_val = 0
    for epoch in tqdm(range(epochs), position = 0, leave = True):
        train_loss = train_step(model, train_loader, optimizer, criterion, epoch, debias, counterfactual_df, counterfactual_weights)
        val_loss, val_auc = infer_step(model, val_loader, criterion, epoch, debias, counterfactual_df, counterfactual_weights)
        print(f"Epoch {epoch}: train loss {train_loss}, eval loss {val_loss}, eval auc {val_auc}")

        if val_auc >= best_val:
            print(f"Updating best model at epoch {epoch}")
            best_val = val_auc
            best_model = copy.deepcopy(model)
            torch.save(model.state_dict(), save_path)
    return best_model


In [None]:
model = FusionModel(resnet, bert, tokenizer).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-5)
criterion = nn.CrossEntropyLoss()
mm_net = train(model, train_loader, val_loader, optimizer, criterion, 15, save_path = f"/content/drive/MyDrive/harvard/6.8610/project/weights/fusion_n{N}.pt") # no debiasing

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/467 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 0: train loss 0.24155575685338684, eval loss 0.0924852819275111, eval auc 0.978125
Updating best model at epoch 0


  0%|          | 0/467 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1: train loss 0.06586148742809732, eval loss 0.0549378450633958, eval auc 0.9917613636363637
Updating best model at epoch 1


  0%|          | 0/467 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 2: train loss 0.043303034179574734, eval loss 0.06484663819079288, eval auc 0.9931818181818183
Updating best model at epoch 2


  0%|          | 0/467 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 3: train loss 0.029900040497955502, eval loss 0.07206128865596838, eval auc 0.9963068181818181
Updating best model at epoch 3


  0%|          | 0/467 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 4: train loss 0.019516722957640526, eval loss 0.05883740985125769, eval auc 0.9977272727272727
Updating best model at epoch 4


  0%|          | 0/467 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 5: train loss 0.0074704479884452986, eval loss 0.09181046070356387, eval auc 0.9940340909090909


  0%|          | 0/467 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 6: train loss 0.007325967392453931, eval loss 0.055510259942093396, eval auc 0.9964646464646464


  0%|          | 0/467 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 7: train loss 0.004589997847814799, eval loss 0.07693920117235394, eval auc 0.993939393939394


  0%|          | 0/467 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 8: train loss 0.0016472563573196172, eval loss 0.05728195048395719, eval auc 0.9964646464646464


  0%|          | 0/467 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 9: train loss 0.0005031963279351145, eval loss 0.06190289822592604, eval auc 0.9964646464646464


  0%|          | 0/467 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 10: train loss 0.0002016485548283799, eval loss 0.06855436816298606, eval auc 0.9964646464646464


  0%|          | 0/467 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 11: train loss 0.00013166166953943178, eval loss 0.06946638467656158, eval auc 0.9964646464646464


  0%|          | 0/467 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 12: train loss 8.261707837045756e-05, eval loss 0.07414491315597843, eval auc 0.9964646464646464


  0%|          | 0/467 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 13: train loss 6.330895405548138e-05, eval loss 0.072987577621343, eval auc 0.9964646464646464


  0%|          | 0/467 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 14: train loss 4.540612756030841e-05, eval loss 0.07799630282602266, eval auc 0.9964646464646464


In [39]:
model = FusionModel(resnet, bert, tokenizer).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-5)
criterion = nn.CrossEntropyLoss()
cf_df_dict = {
    'train_df': train_cf_df,
    'val_df': val_cf_df,
}
mm_net_debiased = train(model, train_loader, val_loader, optimizer, criterion, 15, True, cf_df_dict, weights, f"/content/drive/MyDrive/harvard/6.8610/project/weights/debiased_fusion_n{N}.pt")

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/467 [00:00<?, ?it/s]



  0%|          | 0/100 [00:00<?, ?it/s]



Epoch 0: train loss 2.096703511600959, eval loss 3.839712303876877, eval auc 0.9974747474747475
Updating best model at epoch 0


  0%|          | 0/467 [00:00<?, ?it/s]



  0%|          | 0/100 [00:00<?, ?it/s]



Epoch 1: train loss 0.10557481759014425, eval loss 5.764085946567357, eval auc 0.9962121212121213


  0%|          | 0/467 [00:00<?, ?it/s]



  0%|          | 0/100 [00:00<?, ?it/s]



Epoch 2: train loss 0.6047993024543917, eval loss 4.550201182365417, eval auc 0.9977272727272727
Updating best model at epoch 2


  0%|          | 0/467 [00:00<?, ?it/s]



  0%|          | 0/100 [00:00<?, ?it/s]



Epoch 3: train loss 0.09406592686136096, eval loss 3.9605950271338224, eval auc 1.0
Updating best model at epoch 3


  0%|          | 0/467 [00:00<?, ?it/s]



  0%|          | 0/100 [00:00<?, ?it/s]



Epoch 4: train loss 0.02049112623473689, eval loss 4.379624453852885, eval auc 1.0
Updating best model at epoch 4


  0%|          | 0/467 [00:00<?, ?it/s]



  0%|          | 0/100 [00:00<?, ?it/s]



Epoch 5: train loss 0.12766854112620846, eval loss 15.874561543026939, eval auc 0.994760101010101


  0%|          | 0/467 [00:00<?, ?it/s]



  0%|          | 0/100 [00:00<?, ?it/s]



Epoch 6: train loss 1.6578650065962968, eval loss 2.69776316517964, eval auc 0.9977272727272727


  0%|          | 0/467 [00:00<?, ?it/s]



  0%|          | 0/100 [00:00<?, ?it/s]



Epoch 7: train loss 0.06291841233161137, eval loss 2.510889150155708, eval auc 0.9977272727272727


  0%|          | 0/467 [00:00<?, ?it/s]



  0%|          | 0/100 [00:00<?, ?it/s]



Epoch 8: train loss 0.01610418662779847, eval loss 2.666053144931793, eval auc 0.9977272727272727


  0%|          | 0/467 [00:00<?, ?it/s]



  0%|          | 0/100 [00:00<?, ?it/s]



Epoch 9: train loss 0.010927840104083306, eval loss 2.7782663804036565, eval auc 0.9977272727272727


  0%|          | 0/467 [00:00<?, ?it/s]



  0%|          | 0/100 [00:00<?, ?it/s]



Epoch 10: train loss 0.007438562781504122, eval loss 2.9228328844555653, eval auc 0.9977272727272727


  0%|          | 0/467 [00:00<?, ?it/s]



  0%|          | 0/100 [00:00<?, ?it/s]



KeyboardInterrupt: ignored

## **Evaluation**

Evaluation options with or without matching

In [40]:
model_ndb = FusionModel(resnet, bert, tokenizer).to(device)
model_ndb.load_state_dict(torch.load(f"/content/drive/MyDrive/harvard/6.8610/project/weights/fusion_n{N}.pt"))

<All keys matched successfully>

In [41]:
model_db = FusionModel(resnet, bert, tokenizer).to(device)
model_db.load_state_dict(torch.load(f"/content/drive/MyDrive/harvard/6.8610/project/weights/debiased_fusion_n{N}.pt"))

<All keys matched successfully>

In [42]:
def get_preds(model, test_loader):
    """
    Get predictions
    """
    indices = []
    preds = []
    labels = []
    with torch.no_grad():
        for batch in tqdm(test_loader, position = 0, leave = True):
            logits = model(batch['image'].to(device), batch['text'])
            preds.append(torch.argmax(logits, dim = 1).cpu().numpy())
            labels.append(batch['label'].numpy())
            indices.append(batch['index'].numpy())

    indices = np.concatenate(indices)
    preds = np.concatenate(preds)
    labels = np.concatenate(labels)

    return indices, preds, labels

def get_metrics(preds, labels):
    """
    Get metrics
    """

    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds)
    recall = recall_score(labels, preds)
    f1 = f1_score(labels, preds)
    aucroc = roc_auc_score(labels, preds)
    confm = confusion_matrix(labels, preds, normalize="true")

    print(f"Accuracy: {accuracy}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"F1: {f1}")
    print(f"AUC-ROC: {aucroc}")
    print(f"Confusion Matrix: {confm}")

    return accuracy, precision, recall, f1, aucroc, confm

In [43]:
# not debiased
test_indices, test_preds, test_labels = get_preds(model_ndb, test_loader)
print(sum(test_indices == test_df.index.values)) # making sure order is same
test_df['pred'] = test_preds # assigning predicted labels

  0%|          | 0/100 [00:00<?, ?it/s]

600


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_df['pred'] = test_preds # assigning predicted labels


In [44]:
for group in np.unique(test_cf_df.general_race_group.values):
    print(f"########### {group} ###########")
    get_metrics(test_df[test_df['general_race_group'] == group]['pred'].values, test_df[test_df['general_race_group'] == group]['label'].values)

########### Asian ###########
Accuracy: 0.9655172413793104
Precision: 1.0
Recall: 0.95
F1: 0.9743589743589743
AUC-ROC: 0.975
Confusion Matrix: [[1.   0.  ]
 [0.05 0.95]]
########### Black/African Descent ###########
Accuracy: 0.9873417721518988
Precision: 1.0
Recall: 0.9814814814814815
F1: 0.9906542056074767
AUC-ROC: 0.9907407407407407
Confusion Matrix: [[1.         0.        ]
 [0.01851852 0.98148148]]
########### Hispanic/Latino ###########
Accuracy: 1.0
Precision: 1.0
Recall: 1.0
F1: 1.0
AUC-ROC: 1.0
Confusion Matrix: [[1. 0.]
 [0. 1.]]
########### Other Races ###########
Accuracy: 1.0
Precision: 1.0
Recall: 1.0
F1: 1.0
AUC-ROC: 1.0
Confusion Matrix: [[1. 0.]
 [0. 1.]]
########### White/Caucasian ###########
Accuracy: 0.9836065573770492
Precision: 0.9933554817275747
Recall: 0.9835526315789473
F1: 0.9884297520661156
AUC-ROC: 0.9836462344886607
Confusion Matrix: [[0.98373984 0.01626016]
 [0.01644737 0.98355263]]


In [45]:
# debiased
test_indices, test_preds, test_labels = get_preds(model_db, test_loader)
print(sum(test_indices == test_df.index.values)) # making sure order is same
test_df['db_pred'] = test_preds # assigning predicted labels

  0%|          | 0/100 [00:00<?, ?it/s]

600


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_df['db_pred'] = test_preds # assigning predicted labels


In [46]:
for group in np.unique(test_cf_df.general_race_group.values):
    print(f"########### {group} ###########")
    get_metrics(test_df[test_df['general_race_group'] == group]['db_pred'].values, test_df[test_df['general_race_group'] == group]['label'].values)

########### Asian ###########
Accuracy: 1.0
Precision: 1.0
Recall: 1.0
F1: 1.0
AUC-ROC: 1.0
Confusion Matrix: [[1. 0.]
 [0. 1.]]
########### Black/African Descent ###########
Accuracy: 1.0
Precision: 1.0
Recall: 1.0
F1: 1.0
AUC-ROC: 1.0
Confusion Matrix: [[1. 0.]
 [0. 1.]]
########### Hispanic/Latino ###########
Accuracy: 1.0
Precision: 1.0
Recall: 1.0
F1: 1.0
AUC-ROC: 1.0
Confusion Matrix: [[1. 0.]
 [0. 1.]]
########### Other Races ###########
Accuracy: 1.0
Precision: 1.0
Recall: 1.0
F1: 1.0
AUC-ROC: 1.0
Confusion Matrix: [[1. 0.]
 [0. 1.]]
########### White/Caucasian ###########
Accuracy: 0.990632318501171
Precision: 0.993421052631579
Recall: 0.993421052631579
F1: 0.993421052631579
AUC-ROC: 0.9885804450149765
Confusion Matrix: [[0.98373984 0.01626016]
 [0.00657895 0.99342105]]
