# HW4 CAML

## Overview

In this question, we will implement Convolutional Attention for Multi-Label classification (CAML) proposed by Mullenbach et al. in the paper "[Explainable Prediction of Medical Codes from Clinical Text](https://www.aclweb.org/anthology/N18-1100/)".

Clinical notes are text documents that are created by clinicians for each patient encounter. They are typically accompanied by medical codes, which describe the diagnosis and treatment. Annotating these codes is labor intensive and error prone; furthermore, the connection between the codes and the text is not annotated, obscuring the reasons and details behind specific diagnoses and treatments. Thus, let us implement CAML, an attentional convolutional network to predict medical codes from clinical text.

<img src='img/clinical notes.png'>

Image courtsey: [link](https://www.aclweb.org/anthology/2020.acl-demos.33/)

In [1]:
import os
import csv
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from torch.utils.data import DataLoader


In [2]:
# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# set up your own data path !!!
DATA_PATH = "IU-chest-Xray-dataset"

## Dataset

For this question, we will be using the Indiana University Chest X-Ray dataset. The goal is to predict diseases using chest x-ray reports.

Navigate to the data folder `DATA_PATH`, there are several files:

- `train_df.csv`, `test_df.csv`: these two files contains the data used for training and testing.
    - `Report ID` refers to a unique chest x-ray report.
    - `Text` refers to the clinical report text.
    - `Label` refers to the diseases.
- `vocab.csv`: this file contains the vocabularies used in the clinical text.

In [3]:
!ls {DATA_PATH}

test_df.csv  train_df.csv  vocab.csv


For example, the first chest x-ray report in `train_df.csv` has:
- `Report ID`: 1
- `Text`: the cardiac silhouette and mediastinum size are within normal limits . there is no pulmonary edema . there is no focal consolidation . there are no xxxx of a pleural effusion . there is no evidence of pneumothorax . normal chest xxxxx .
- `Label`: [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

where label is a multi-hot vector representing the following diseases:
```
normal
cardiomegaly
scoliosis / degenerative
fractures bone
pleural effusion
thickening
pneumothorax
hernia hiatal
calcinosis
emphysema / pulmonary emphysema
pneumonia / infiltrate / consolidation
pulmonary edema
pulmonary atelectasis
cicatrix
opacity
nodule / mass
airspace disease
hypoinflation / hyperdistention
catheters indwelling / surgical instruments / tube inserted / medical device
other
```

So this report 1 is labeled as "normal".

## 1 Prepare the Dataset

### 1.1 Helper Functions

To begin, weith, let us first implement some helper functions we will use later.

In [4]:
def to_index(sequence, token2idx):
    """
    convert the sequnce of tokens to indices. 
    If the word in unknown, then map it to '<unk>'.
    
    INPUT:
        sequence (type: list of str): a sequence of tokens
        token2idx (type: dict): a dictionary mapping token to the corresponding index
    
    OUTPUT:
        indices (type: list of int): a sequence of indicies
        
    EXAMPLE:
        >>> sequence = ['hello', 'world', 'unknown_word']
        >>> token2idx = {'hello': 0, 'world': 1, '<unk>': 2}
        >>> to_index(sequence, token2idx)
        [0, 1, 2]
    """
    return [token2idx[w] if w in token2idx else token2idx['<unk>'] for w in sequence]

### 1.2 CustomDataset [10 points]

Now, let us implement a custom dataset using PyTorch class `Dataset`, which will characterize the key features of the dataset we want to generate.

We will use the clinical text as input and medical codes as output.

In [5]:
from torch.utils.data import Dataset

NUM_WORDS = 1253
NUM_CLASSES = 20


class CustomDataset(Dataset):
    
    def __init__(self, filename):        
        # read in the data files
        self.data = pd.read_csv(filename)
        # load word lookup
        self.idx2word, self.word2idx = self.load_lookup(f'{DATA_PATH}/vocab.csv', padding=True)
        assert len(self.idx2word) == len(self.word2idx) == NUM_WORDS
        
    def load_lookup(self, filename, padding=False):
        """ load lookup for word """
        idx2token = {}
        with open(filename, 'r') as f:
            for i, line in enumerate(f):
                line = line.strip()
                idx2token[i] = line
        token2idx = {w:i for i,w in idx2token.items()}
        return idx2token, token2idx
        
    def __len__(self):
        
        """
        TODO: Return the number of samples (i.e. admissions).
        """
        
        ### BEGIN SOLUTION
        return len(self.data)
        ### END SOLUTION
    
    def __getitem__(self, index):
        
        """
        TODO: Generate one sample of data.

        Hint: convert text to indices using to_index();
        """
        data = self.data.iloc[index]
        text = data['Text'].split(' ')
        label = data['Label']
        # convert label string to list
        label = [int(l) for l in label.strip('[]').split(', ')]
        assert len(label) == NUM_CLASSES
        ### BEGIN SOLUTION
        text = to_index(text, self.word2idx)
        ### END SOLUTION
        # return text as long tensor, labels as float tensor;
        return torch.tensor(text, dtype=torch.long), torch.tensor(label, dtype=torch.float)

### 1.3 Collate Function [10 points]

The collate function `collate_fn()` will be called by `DataLoader` after fetching a list of samples using the indices from `CustomDataset` to collate the list of samples into batches.

For example, assume the `DataLoader` gets a list of two samples.

```
[ [3,  1,  2, 8, 5], 
  [12, 13, 6, 7, 12, 23, 11] ]
```

where the first sample has text `[3, 1, 2, 8, 5]` the second sample has text `[12, 13, 6, 7, 12, 23, 11]`.

The collate function `collate_fn()` is supposed to pad them into the same shape (7), where 7 is the maximum number of tokens.

``` 
[ [3,  1,  2, 8, 5, *0*, *0*], 
  [12, 13, 6, 7, 12, 23,  11 ]
```

where `*0*` indicates the padding token.

We need to pad the sequences into the same length so that we can do batch training on GPU. And we also need this mask so that when training, we can ignored the padded value as they actually do not contain any information.

In [6]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(data):
    """
    TODO: implement the collate function.
    
    STEP: 1. pad the text using pad_sequence(). Set `batch_first=True`.
          2. stack the labels using torch.stack().
          
    OUTPUT:
        text: the padded text, shape: (batch size, max length)
        labels: the stacked labels, shape: (batch size, num classes)
    """
    text, labels = zip(*data)

    ### BEGIN SOLUTION
    text = pad_sequence(text, batch_first=True)
    labels = torch.stack(labels, dim=0)
    ### END SOLUTION
    
    return text, labels

All done, now let us load the dataset and data loader.

In [7]:
train_set = CustomDataset(f'{DATA_PATH}/train_df.csv')
test_set = CustomDataset(f'{DATA_PATH}/test_df.csv')
train_loader = DataLoader(train_set, batch_size=32, collate_fn=collate_fn, shuffle=True)
test_loader = DataLoader(test_set, batch_size=32, collate_fn=collate_fn)

## 2 Model

Next, we will implement the CAML model.

<img src='img/caml.png'>

CAML is a convolutional neural network (CNN)-based model. It employs a per-label attention mechanism, which allows the model to learn distinct document representations for each label.

In [8]:
from math import floor
from torch.nn.init import xavier_uniform_


class CAML(nn.Module):

    def __init__(self, kernel_size=10, num_filter_maps=16, embed_size=100, dropout=0.5):
        super(CAML, self).__init__()
        
        # embedding layer
        self.embed = nn.Embedding(NUM_WORDS, embed_size, padding_idx=0)
        self.embed_drop = nn.Dropout(p=dropout)

        # initialize conv layer as in section 2.1
        self.conv = nn.Conv1d(embed_size, num_filter_maps, kernel_size=kernel_size, padding=int(floor(kernel_size/2)))
        xavier_uniform_(self.conv.weight)

        # context vectors for computing attention as in section 2.2
        self.U = nn.Linear(num_filter_maps, 20)
        xavier_uniform_(self.U.weight)

        # final layer: create a matrix to use for the NUM_CLASSES binary classifiers as in section 2.3
        self.final = nn.Linear(num_filter_maps, NUM_CLASSES)
        xavier_uniform_(self.final.weight)
        
    def forward_embed(self, text):
        """
        TODO: Feed text through the embedding (self.embed) and dropout layer (self.embed_drop).
        
        INPUT: 
            text: (batch size, seq_len)
            
        OURPUT:
            text: (batch size, seq_len, embed_size)
        """
        ### BEGIN SOLUTION
        text = self.embed(text)
        text = self.embed_drop(text)
        return text
        ### END SOLUTION
        
    def forward_conv(self, text):
        """
        TODO: Feed text through the convolution layer (self.conv) and tanh activation function (torch.tanh) 
        in eq (1) in the paper.
        
        INTPUT:
            text: (batch size, embed_size, seq_len)
            
        OUTPUT:
            text: (batch size, num_filter_maps, seq_len)
        """
        ### BEGIN SOLUTION
        return torch.tanh(self.conv(text))
        ### END SOLUTION
        
    def forward_calc_atten(self, text):
        """
        TODO: calculate the attention weights in eq (2) in the paper. Be sure to read the documentation for
        F.softmax()
        
        INPUT:
            text: (batch size, seq_len, num_filter_maps)

        OUTPUT:
            alpha: (batch size, num_class, seq_len), the attention weights
            
        STEP: 1. multiply `self.U.weight` with `text` using torch.matmul();
              2. apply softmax using `F.softmax()`.
        """
        # (batch size, seq_len, num_filter_maps) -> (batch size, num_filter_mapsseq_len)
        text = text.transpose(1,2)
        ### BEGIN SOLUTION
        return F.softmax(self.U.weight.matmul(text), dim=2)
        ### END SOLUTION
        
    def forward_aply_atten(self, alpha, text):
        """
        TODO: apply the attention in eq (3) in the paper.

        INPUT: 
            text: (batch size, seq_len, num_filter_maps)
            alpha: (batch size, num_class, seq_len), the attention weights
            
        OUTPUT:
            v: (batch size, num_class, num_filter_maps), vector representations for each label
            
        STEP: multiply `alpha` with `text` using torch.matmul().
        """
        ### BEGIN SOLUTION
        return alpha.matmul(text)
        ### END SOLUTION
    
    def forward_linear(self, v):
        """
        TODO: apply the final linear classification in eq (5) in the paper.
        
        INPUT: 
            v: (batch size, num_class, num_filter_maps), vector representations for each label
            
        OUTPUT:
            y_hat: (batch size, num_class), label probability
            
        STEP: 1. multiply `self.final.weight` v `text` element-wise using torch.mul();
              2. sum the result over dim 2 (i.e. num_filter_maps);
              3. add the result with `self.final.bias`;
              4. apply sigmoid with torch.sigmoid().
        """
        ### BEGIN SOLUTION
        y = self.final.weight.mul(v).sum(dim=2).add(self.final.bias)
        y_hat = torch.sigmoid(y)
        return y_hat
        ### END SOLUTION
        
    def forward(self, text):
        """ 1. get embeddings and apply dropout """
        text = self.forward_embed(text)
        # (batch size, seq_len, embed_size) -> (batch size, embed_size, seq_len);
        text = text.transpose(1, 2)

        """ 2. apply convolution and nonlinearity (tanh) """
        text = self.forward_conv(text)
        # (batch size, num_filter_maps, seq_len) -> (batch size, seq_len, num_filter_maps);
        text = text.transpose(1,2)
        
        """ 3. calculate attention """
        alpha = self.forward_calc_atten(text)
        
        """ 3. apply attention """
        v = self.forward_aply_atten(alpha, text)
           
        """ 4. final layer classification """
        y_hat = self.forward_linear(v)
        
        return y_hat
    
    
model = CAML()

## 3 Training and Inferencing

In [9]:
model = CAML()

In [10]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCELoss()

Now let us implement the `eval()` and `train()` function. Note that `train()` should call `eval()` at the end of each training epoch to see the results on the validaion dataset.

In [11]:
from sklearn.metrics import precision_recall_fscore_support


def eval(model, test_loader):
    
    """    
    INPUT:
        model: the CAML model
        test_loader: dataloader
        
    OUTPUT:
        precision: overall micro precision score
        recall: overall micro recall score
        f1: overall micro f1 score
        
    REFERENCE: checkout https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
    """

    model.eval()
    y_pred = torch.LongTensor()
    y_true = torch.LongTensor()
    model.eval()
    for sequences, labels in test_loader:
        """
        TODO: 1. preform forward pass
              2. obtain the predicted class (0, 1) by comparing forward pass output against 0.5, 
                 assign the predicted class to y_hat.
        """
        ### BEGIN SOLUTION
        y_hat = model(sequences)
        y_hat = (y_hat > 0.5).int()
        ### END SOLUTION
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, labels.detach().to('cpu')), dim=0)
    
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='micro')
    return p, r, f

In [12]:
def train(model, train_loader, test_loader, n_epochs):
    """    
    INPUT:
        model: the CAML model
        train_loader: dataloder
        val_loader: dataloader
        n_epochs: total number of epochs
    """
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for sequences, labels in train_loader:
            optimizer.zero_grad()
            """ 
            TODO: 1. perform forward pass using `model`, save the output to y_hat;
                  2. calculate the loss using `criterion`, save the output to loss.
            """
            y_hat, loss = None, None
            ### BEGIN SOLUTION
            y_hat = model(sequences)
            loss = criterion(y_hat, labels)
            ### END SOLUTION
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        p, r, f = eval(model, test_loader)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}'.format(epoch+1, p, r, f))

    
# number of epochs to train the model
n_epochs = 20

train(model, train_loader, test_loader, n_epochs)

Epoch: 1 	 Training Loss: 0.457061
Epoch: 1 	 Validation p: 0.00, r:0.00, f: 0.00
Epoch: 2 	 Training Loss: 0.262114
Epoch: 2 	 Validation p: 0.00, r:0.00, f: 0.00


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 3 	 Training Loss: 0.232308
Epoch: 3 	 Validation p: 0.00, r:0.00, f: 0.00


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 4 	 Training Loss: 0.217560
Epoch: 4 	 Validation p: 0.83, r:0.15, f: 0.26
Epoch: 5 	 Training Loss: 0.203588
Epoch: 5 	 Validation p: 0.81, r:0.22, f: 0.35
Epoch: 6 	 Training Loss: 0.193507
Epoch: 6 	 Validation p: 0.80, r:0.24, f: 0.37
Epoch: 7 	 Training Loss: 0.185501
Epoch: 7 	 Validation p: 0.83, r:0.25, f: 0.39
Epoch: 8 	 Training Loss: 0.180506
Epoch: 8 	 Validation p: 0.82, r:0.28, f: 0.42
Epoch: 9 	 Training Loss: 0.174325
Epoch: 9 	 Validation p: 0.83, r:0.28, f: 0.42
Epoch: 10 	 Training Loss: 0.166225
Epoch: 10 	 Validation p: 0.86, r:0.32, f: 0.47
Epoch: 11 	 Training Loss: 0.157288
Epoch: 11 	 Validation p: 0.89, r:0.40, f: 0.55
Epoch: 12 	 Training Loss: 0.148467
Epoch: 12 	 Validation p: 0.90, r:0.43, f: 0.58
Epoch: 13 	 Training Loss: 0.144797
Epoch: 13 	 Validation p: 0.89, r:0.47, f: 0.61
Epoch: 14 	 Training Loss: 0.136517
Epoch: 14 	 Validation p: 0.89, r:0.49, f: 0.64
Epoch: 15 	 Training Loss: 0.131915
Epoch: 15 	 Validation p: 0.90, r:0.54, f: 0.68
Epoc