In [None]:
!pip install causal-conv1d>=1.2.0

!pip install mamba-ssm
!pip install pandas

Collecting mamba-ssm
  Downloading mamba_ssm-1.2.0.post1.tar.gz (34 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting einops (from mamba-ssm)
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: mamba-ssm
  Building wheel for mamba-ssm (setup.py) ... [?25l[?25hdone
  Created wheel for mamba-ssm: filename=mamba_ssm-1.2.0.post1-cp310-cp310-linux_x86_64.whl size=137750683 sha256=b264292652a34fb9dd0ce880a34a4407ba7256a3338388d056769ec29a4581c9
  Stored in directory: /root/.cache/pip/wheels/22/6e/60/ddd5c574b5793a30028f2cabdacd2a3ec2276edaaa8c00fd35
Successfully built mamba-ssm
Installing collected packages: einops, mamba-ssm
Successfully installed einops-0.7.0 mamba-ssm-1.2.0.post1


In [None]:
# Import necessary libraries
from google.colab import drive
import json
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from datetime import datetime
from mamba_ssm import Mamba
from sklearn.model_selection import train_test_split

# Mount Google Drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Load JSON data from file
dataset_path = '/content/drive/MyDrive/S24-11-785/mimic_data.json'


with open(dataset_path, 'r') as file:
    data = json.load(file)

    # Inspect the first few entries to visualize dataset structure
    # Note: each patient ID has a (widely) varying number of associated keys
    for patient_id in list(data.keys())[:15]:
        print(patient_id, data[patient_id])

16252972 {'birthdate': '2041-01-01', 'end_of_data': '2120-12-15', 'events': [{'admid': '29532002', 'admdate': '2120-12-12', 'codes': 'C250'}, {'admid': '29532002', 'admdate': '2120-12-12', 'codes': 'K831'}, {'admid': '29532002', 'admdate': '2120-12-12', 'codes': 'K8590'}, {'admid': '29532002', 'admdate': '2120-12-12', 'codes': 'I130'}, {'admid': '29532002', 'admdate': '2120-12-12', 'codes': 'E46'}, {'admid': '29532002', 'admdate': '2120-12-12', 'codes': 'Z6841'}, {'admid': '29532002', 'admdate': '2120-12-12', 'codes': 'I509'}, {'admid': '29532002', 'admdate': '2120-12-12', 'codes': 'N189'}, {'admid': '29532002', 'admdate': '2120-12-12', 'codes': 'Z66'}, {'admid': '29532002', 'admdate': '2120-12-12', 'codes': 'G3184'}, {'admid': '29532002', 'admdate': '2120-12-12', 'codes': 'I872'}, {'admid': '29532002', 'admdate': '2120-12-12', 'codes': 'E785'}, {'admid': '29532002', 'admdate': '2120-12-12', 'codes': 'E039'}, {'admid': '29532002', 'admdate': '2120-12-12', 'codes': 'R8299'}, {'admid': '

In [None]:
# Load medical code dictionaries

codes1path = '/content/drive/MyDrive/S24-11-785/icd8_disease_descriptions.tsv'
codes2path = '/content/drive/MyDrive/S24-11-785/icd9_disease_descriptions.tsv'
codes3path = '/content/drive/MyDrive/S24-11-785/icd10_disease_descriptions.tsv'

codes_dict = {}

i = 0

PAD_TOKEN = i
codes_dict['<pad>'] = i
i += 1

CLS_TOKEN = i
codes_dict['<cls>'] = i
i += 1

with open(codes1path, 'r') as f:
  for line in f.readlines():
    line = line.strip().split('\t')
    if line[0] not in codes_dict:
      codes_dict[line[0]] = i
      i += 1

with open(codes2path, 'r') as f:
  start = True
  for line in f.readlines():
    if start:
      start = False
      continue
    line = line.strip().split('\t')
    if line[0] not in codes_dict:
      codes_dict[line[0]] = i
      i += 1

with open(codes3path, 'r') as f:
  start = True
  for line in f.readlines():
    if start:
      start = False
      continue
    line = line.strip().split('\t')
    if line[0] not in codes_dict:
      codes_dict[line[0]] = i
      i += 1

In [None]:
class HealthRecordDataset(Dataset):
    # Define pancreatic cancer codes as a class variable?
    PANC_CANCER_CODES = {
        'C25', 'C253', 'C252', 'C259', 'C250', 'C251', 'C258',
        'C254', 'C257', '157', '1573', '1570', '1579', '1578',
        '1571', '1574', '1572'
    } # these are the codes that CancerRiskNet associated with a 'positive' diagnosis

    def __init__(self, data, codes_dict, exclude_past6 = False):
        self.data = []
        self.bdates = []
        self.admdates = []
        self.labels = []

        TWO_MONTHS = 30 * 2
        SIX_MONTHS = 30 * 6
        TWELVE_MONTHS = 30 * 12
        TWO_YEARS = 30 * 24
        THIRTYSIX_MONTHS = 30 * 36
        SIXTY_MONTHS = 30 * 60

        # Convert records to integer sequences & assign labels
        for _, record in data.items():
            evnts = []
            features = []
            lbls = [0] * 5

            canc_date = datetime.fromisoformat(record['end_of_data'])
            cancer = False

            for event in record['events']:
              for code in event['codes'].split(","):
                if code not in codes_dict:
                  codes_dict[code] = len(codes_dict)

                if code in self.PANC_CANCER_CODES:
                  canc_date = datetime.fromisoformat(event['admdate'])
                  cancer = True
                  break

                evnts.append(event)
              if cancer:
                break

            if len(evnts) == 0:
              continue
            last_feat_date = datetime.fromisoformat(evnts[-1]['admdate'])

            if exclude_past6:
              for event in evnts:
                if abs(canc_date - datetime.fromisoformat(event['admdate'])).days > SIX_MONTHS:
                  last_feat_date = datetime.fromisoformat(event['admdate'])
                  for code in event['codes'].split(","):
                    features.append(codes_dict[code])
            else:
              for event in evnts:
                last_feat_date = datetime.fromisoformat(event['admdate'])
                for code in event['codes'].split(","):
                    features.append(codes_dict[code])

            features.append(CLS_TOKEN)

            if cancer:
              if abs(canc_date - last_feat_date).days < TWO_MONTHS:
                lbls[0] = 1
              if abs(canc_date - last_feat_date).days < SIX_MONTHS:
                lbls[1] = 1
              if abs(canc_date - last_feat_date).days < TWELVE_MONTHS:
                lbls[2] = 1
              if abs(canc_date - last_feat_date).days < THIRTYSIX_MONTHS:
                lbls[3] = 1
              if abs(canc_date - last_feat_date).days < SIXTY_MONTHS:
                lbls[4] = 1

            admdates = [datetime.fromisoformat(event['admdate']) for event in record['events']]
            bday = datetime.fromisoformat(record['birthdate'])

            ref = admdates[-1]
            adm_deltas = np.array([abs((ref - date).days) for date in admdates])
            ref = bday
            bday_deltas = np.array([abs((ref - date).days) for date in admdates])

            self.data.append(features)
            self.admdates.append(adm_deltas)
            self.labels.append(lbls)
            self.bdates.append(bday_deltas)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        out1 = torch.tensor(self.data[idx])
        out2 = torch.tensor(self.admdates[idx])
        out3 = torch.tensor(self.bdates[idx])
        out4 = torch.tensor(self.labels[idx])
        return out1, out2, out3, out4

# Collate function for DataLoader zero-padding
def collate_fn(batch):
    features, admdates, bdates, labels = zip(*batch)
    features_padded = pad_sequence(features, batch_first=True, padding_value=PAD_TOKEN)

    admdates_padded = pad_sequence(admdates, batch_first=True, padding_value=PAD_TOKEN)
    bdates_padded = pad_sequence(bdates, batch_first=True, padding_value=PAD_TOKEN)

    ftr_lengths = [len(i) for i in features]
    return features_padded, torch.stack(labels), admdates_padded, bdates_padded, torch.tensor(ftr_lengths)


In [None]:
# Creating datasets and DataLoaders
dataset = HealthRecordDataset(data, codes_dict)
train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [None]:
# for i in range(len(test_data)):
for i in test_data:
  print(i)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
         8523,  9489,  1866,  8522,  8383,  8506, 12386,  8514,  6459,  6443,
         9402,  8513,  8505,  2224,  6557,  7325,  7442,  2755,  6331, 16271,
        16604, 16788,  6335,  9489,  8603,  8371,  8522, 10873,  8505, 16788,
         6327, 16604, 16794, 16674,  8514,  6443,  6359,  6429, 10980, 10870,
         5078, 12549,  6463,  6557,  9199,  1323, 14026, 15617, 15453,  1586,
         8559,  9489,  7290,  1511,  7291,  8371,  8505,  8522,  8514,  7204,
         6327, 16604, 12424, 12407, 16685, 16674,     1]), tensor([ 999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,
         999,  999,  999,  999,  999,  999,  999,  999,  999, 1240, 1240, 1240,
        1240, 1240, 1240, 1240, 1240, 1240, 1240, 1240, 1240, 1240, 1240, 1146,
        1146, 1146, 1146, 1146, 1146, 1146, 1146, 1146, 1146, 1146, 1146, 1082,
        1082, 1082, 1082, 1082, 1082, 1082, 1082, 1082, 1082, 1082, 1082, 1082,
        

In [None]:
for batch in train_loader:
    x_pad, y, admdates, bdates, x_len, = batch



    print(f"x_pad shape:\t\t{x_pad.shape}")
    print(f"x_len shape:\t\t{x_len.shape}\n")

    print(f"y shape:\t{y.shape}")
    print(admdates.shape)
    print(bdates.shape)
    # print(y_shifted_pad)

    break

x_pad shape:		torch.Size([32, 167])
x_len shape:		torch.Size([32])

y shape:	torch.Size([32, 5])
torch.Size([32, 211])
torch.Size([32, 211])


In [None]:
class PositionalEncoding(torch.nn.Module):
    ''' Position Encoding from Attention Is All You Need Paper '''

    MIN_TIME_EMBED_PERIOD_IN_DAYS = 10

    MAX_TIME_EMBED_PERIOD_IN_DAYS = 120 * 365

    def __init__(self, d_model):
        super().__init__()

        # Initialize a tensor to hold the positional encodings
        self.d_model = d_model



    def get_time_seq(self, deltas):
        """
            Calculates the positional embeddings depending on the time diff from the events and the reference date.
        """

        multipliers = 2*np.pi / (np.linspace(
            start=self.MIN_TIME_EMBED_PERIOD_IN_DAYS, stop=self.MAX_TIME_EMBED_PERIOD_IN_DAYS, num=self.d_model
        ))

        multipliers = torch.from_numpy(multipliers)

        # NOTE: there might be shape errors here, make sure that deltas and multipliers are compatible, I didn't test this lol
        multipliers = multipliers.reshape(1, len(multipliers))
        positional_embeddings = torch.cos(deltas*multipliers[:deltas.size(1)])
        return positional_embeddings

    def forward(self, x, admdates, bdates):

      return x + self.get_time_seq(admdates) + self.get_time_seq(bdates)

In [None]:
"""Simple, minimal implementation of Mamba in one file of PyTorch.

Suggest reading the following before/while reading the code:
    [1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)
        https://arxiv.org/abs/2312.00752
    [2] The Annotated S4 (Sasha Rush and Sidd Karamcheti)
        https://srush.github.io/annotated-s4

Glossary:
    b: batch size                       (`B` in Mamba paper [1] Algorithm 2)
    l: sequence length                  (`L` in [1] Algorithm 2)
    d or d_model: hidden dim
    n or d_state: latent state dim      (`N` in [1] Algorithm 2)
    expand: expansion factor            (`E` in [1] Section 3.4)
    d_in or d_inner: d * expand         (`D` in [1] Algorithm 2)
    A, B, C, D: state space parameters  (See any state space representation formula)
                                        (B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not)
    Δ or delta: input-dependent step size
    dt_rank: rank of Δ                  (See [1] Section 3.6 "Parameterization of ∆")

"""
from __future__ import annotations
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum

In [None]:
@dataclass
class ModelArgs:
    d_model: int
    n_layer: int
    vocab_size: int
    d_state: int = 16
    expand: int = 2
    dt_rank: Union[int, str] = 'auto'
    d_conv: int = 4
    pad_vocab_size_multiple: int = 8
    conv_bias: bool = True
    bias: bool = False

    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)

        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)

        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (self.pad_vocab_size_multiple
                                - self.vocab_size % self.pad_vocab_size_multiple)

In [None]:
class Mamba(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.embedding = nn.Embedding(args.vocab_size, args.d_model)
        # self.delta_embedding = nn.Linear(1, args.d_model)  # Assuming deltas are scalar
        self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
        self.norm_f = RMSNorm(args.d_model)
        # self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
        # self.lm_head.weight = self.embedding.weight  # Tie output projection to embedding weights

    def forward(self, input_ids, adm_deltas, bdeltas):
        x = self.embedding(input_ids)
        # adm_delta_emb = self.delta_embedding(adm_deltas.unsqueeze(-1))
        # bdelta_emb = self.delta_embedding(bdeltas.unsqueeze(-1))
        # x = x + adm_delta_emb + bdelta_emb  # Combining embeddings

        for layer in self.layers:
            x = layer(x)

        x = self.norm_f(x)
        # logits = self.lm_head(x)
        return x


    @staticmethod
    def from_pretrained(pretrained_model_name: str):
        """Load pretrained weights from HuggingFace into model.

        Args:
            pretrained_model_name: One of
                * 'state-spaces/mamba-2.8b-slimpj'
                * 'state-spaces/mamba-2.8b'
                * 'state-spaces/mamba-1.4b'
                * 'state-spaces/mamba-790m'
                * 'state-spaces/mamba-370m'
                * 'state-spaces/mamba-130m'

        Returns:
            model: Mamba model with weights loaded

        """
        from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
        from transformers.utils.hub import cached_file

        def load_config_hf(model_name):
            resolved_archive_file = cached_file(model_name, CONFIG_NAME,
                                                _raise_exceptions_for_missing_entries=False)
            return json.load(open(resolved_archive_file))


        def load_state_dict_hf(model_name, device=None, dtype=None):
            resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
                                                _raise_exceptions_for_missing_entries=False)
            return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)

        config_data = load_config_hf(pretrained_model_name)
        args = ModelArgs(
            d_model=config_data['d_model'],
            n_layer=config_data['n_layer'],
            vocab_size=config_data['vocab_size']
        )
        model = Mamba(args)

        state_dict = load_state_dict_hf(pretrained_model_name)
        new_state_dict = {}
        for key in state_dict:
            new_key = key.replace('backbone.', '')
            new_state_dict[new_key] = state_dict[key]
        model.load_state_dict(new_state_dict)

        return model

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """Simple block wrapping Mamba block with normalization and residual connection."""
        super().__init__()
        self.args = args
        self.mixer = MambaBlock(args)
        self.norm = RMSNorm(args.d_model)


    def forward(self, x):
        """
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)

        Returns:
            output: shape (b, l, d)

        Official Implementation:
            Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297

            Note: the official repo chains residual blocks that look like
                [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
            where the first Add is a no-op. This is purely for performance reasons as this
            allows them to fuse the Add->Norm.

            We instead implement our blocks as the more familiar, simpler, and numerically equivalent
                [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....

        """
        output = self.mixer(self.norm(x)) + x

        return output

In [None]:
class MambaBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
        super().__init__()
        self.args = args

        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)

        self.conv1d = nn.Conv1d(
            in_channels=args.d_inner,
            out_channels=args.d_inner,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=args.d_inner,
            padding=args.d_conv - 1,
        )

        # x_proj takes in `x` and outputs the input-specific Δ, B, C
        self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)

        # dt_proj projects Δ from dt_rank to d_in
        self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)

        A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(args.d_inner))
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)


    def forward(self, x):
        """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].

        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)

        Returns:
            output: shape (b, l, d)

        Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311

        """
        (b, l, d) = x.shape

        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)

        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :l]
        x = rearrange(x, 'b d_in l -> b l d_in')

        x = F.silu(x)

        y = self.ssm(x)

        y = y * F.silu(res)

        output = self.out_proj(y)

        return output


    def ssm(self, x):
        """Runs the SSM. See:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        Args:
            x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)

        Returns:
            output: shape (b, l, d_in)

        Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311

        """
        (d_in, n) = self.A_log.shape

        # Compute ∆ A B C D, the state space parameters.
        #     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
        #                                  and is why Mamba is called **selective** state spaces)

        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        D = self.D.float()

        x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)

        (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)
        delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)

        y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]

        return y


    def selective_scan(self, u, delta, A, B, C, D):
        """Does selective scan algorithm. See:
            - Section 2 State Space Models in the Mamba paper [1]
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        This is the classic discrete state space formula:
            x(t + 1) = Ax(t) + Bu(t)
            y(t)     = Cx(t) + Du(t)
        except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).

        Args:
            u: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
            delta: shape (b, l, d_in)
            A: shape (d_in, n)
            B: shape (b, l, n)
            C: shape (b, l, n)
            D: shape (d_in,)

        Returns:
            output: shape (b, l, d_in)

        Official Implementation:
            selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
            Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.

        """
        (b, l, d_in) = u.shape
        n = A.shape[1]

        # Discretize continuous parameters (A, B)
        # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
        # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
        #   "A is the more important term and the performance doesn't change much with the simplification on B"
        deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
        deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')

        # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
        # Note that the below is sequential, while the official implementation does a much faster parallel scan that
        # is additionally hardware-aware (like FlashAttention).
        x = torch.zeros((b, d_in, n), device=deltaA.device)
        ys = []
        for i in range(l):
            x = deltaA[:, i] * x + deltaB_u[:, i]
            y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
            ys.append(y)
        y = torch.stack(ys, dim=1)  # shape (b, l, d_in)

        y = y + u * D

        return y

In [None]:
class RMSNorm(nn.Module):
    def __init__(self,
                 d_model: int,
                 eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))


    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output

# Define our Pancreatic Cancer Model using Mamba modules

In [None]:
class PancCancerModel(torch.nn.Module):

    def __init__(self, args, pad_token=PAD_TOKEN, dropout=0.1):

        super(PancCancerModel, self).__init__()

        self.mamba = Mamba(args)
        self.final = torch.nn.Linear(args.d_model, 5)

        # You can experiment with different weight initialization schemes or no initialization here
        # for p in self.parameters():
        #     if p.dim() > 1:
        #         torch.nn.init.xavier_uniform_(p)

    def forward(self, padded_input, input_lengths, adm_deltas, bdeltas):

        out = self.mamba(padded_input, adm_deltas, bdeltas)
        out = self.final(out[:,-1,:])

        return out

In [None]:
import torch
import torch.nn as nn

class ModelArgs:
    def __init__(self, d_model, n_layer, vocab_size, d_inner, d_conv, expand, bias=True, conv_bias=True, dt_rank=2, d_state=16):
        self.d_model = d_model
        self.n_layer = n_layer
        self.vocab_size = vocab_size
        self.d_inner = d_inner
        self.d_conv = d_conv
        self.expand = expand
        self.bias = bias
        self.conv_bias = conv_bias
        self.dt_rank = dt_rank
        self.d_state = d_state

# Define model arguments
args = ModelArgs(
    d_model=16,    # Model dimension
    n_layer=4,     # Number of layers
    vocab_size=len(codes_dict),  # Vocabulary size
    d_inner=128,    # Dimension of inner layers
    d_conv=4,      # Convolution kernel size
    expand=2       # Expansion factor
)

# Initialize the model
# model = Mamba(args).to("cuda")

model = PancCancerModel(args).to("cuda")


In [None]:
print(model)

PancCancerModel(
  (mamba): Mamba(
    (embedding): Embedding(31565, 16)
    (layers): ModuleList(
      (0-3): 4 x ResidualBlock(
        (mixer): MambaBlock(
          (in_proj): Linear(in_features=16, out_features=256, bias=True)
          (conv1d): Conv1d(128, 128, kernel_size=(4,), stride=(1,), padding=(3,), groups=128)
          (x_proj): Linear(in_features=128, out_features=34, bias=False)
          (dt_proj): Linear(in_features=2, out_features=128, bias=True)
          (out_proj): Linear(in_features=128, out_features=16, bias=True)
        )
        (norm): RMSNorm()
      )
    )
    (norm_f): RMSNorm()
  )
  (final): Linear(in_features=16, out_features=5, bias=True)
)


In [None]:
# Test input setup
batch, length, vocab_size, out_size = 2, 64, 5000, 5
input_ids = torch.randint(0, vocab_size, (batch, length)).to("cuda")  # Simulate input IDs
adm_deltas = torch.randn(batch, length).to("cuda")  # Simulated admission deltas
bdeltas = torch.randn(batch, length).to("cuda")     # Simulated birth date deltas

# Forward pass
y = model(input_ids, input_ids, adm_deltas, bdeltas)

# Assert output shape
assert y.shape == (batch, out_size), "Output shape is incorrect"

print("Test passed successfully with output shape:", y.shape)

Test passed successfully with output shape: torch.Size([2, 5])


In [None]:
# Test model on Panc Cancer Data

x_pad, labels, admdates, bdates, x_len = x_pad.to("cuda"), y.to("cuda"), admdates.to("cuda"), bdates.to("cuda"), x_len.to("cuda")

y = model(x_pad, x_len, admdates, bdates)

print("Test passed successfully with output shape:", y.shape)

Test passed successfully with output shape: torch.Size([32, 5])


# Small sanity check 1:
see if model can overfit on small section of the dataset (so we know it is "learning" how we wish)

In [None]:
# import torch
# import torch.optim as optim
# import torch.nn as nn
# import matplotlib.pyplot as plt

# # Setup for a smallish dataset to overfit
# input_ids = torch.randint(0, 5000, (10, 64)).to("cuda")
# adm_deltas = torch.randn(10, 64).to("cuda")
# bdeltas = torch.randn(10, 64).to("cuda")
# # labels = torch.randint(0, 2, (10, 64, 5000)).to("cuda")
# labels = torch.randint(0, 1, (10, 5)).to("cuda")

# # Initialize model, optimizer, and loss function
# model = PancCancerModel(args).to("cuda")
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# loss_function = torch.nn.BCEWithLogitsLoss()

# # Stores losses for plotting
# training_losses = []

# # Train the model
# for epoch in range(100):
#     optimizer.zero_grad()
#     outputs = model(input_ids, input_ids, adm_deltas, bdeltas)
#     loss = loss_function(outputs, labels.float())
#     loss.backward()
#     optimizer.step()
#     training_losses.append(loss.item())
#     print(f"Epoch {epoch+1}, Loss: {loss.item()}")

#     if loss.item() < 1e-4:
#         print("Model successfully overfitted a small sample.")
#         break

# # Plotting the training loss
# plt.figure(figsize=(10, 5))
# plt.plot(training_losses, label='Training Loss')
# plt.title('Training Loss per Epoch')
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.legend()
# plt.grid(True)
# plt.show()

# Sanity Check 2: Slightly larger data subset

In [None]:
# import torch
# from torch.utils.data import DataLoader, TensorDataset
# import torch.optim as optim
# import torch.nn as nn
# import matplotlib.pyplot as plt

# # Generate a larger dataset
# input_ids_train_2 = torch.randint(0, 5000, (200, 64)).to("cuda") #200 trainins samples
# adm_deltas_train_2 = torch.randn(200, 64).to("cuda")
# bdeltas_train_2 = torch.randn(200, 64).to("cuda")
# # labels_train_2 = torch.randint(0, 2, (200, 64, 5000)).to("cuda")  # Binary multi-label setup
# labels_train_2 = torch.randint(0, 1, (200, 5)).to("cuda")  # Binary multi-label setup

# input_ids_val_2 = torch.randint(0, 5000, (40, 64)).to("cuda") # 40 validation samples
# adm_deltas_val_2 = torch.randn(40, 64).to("cuda")
# bdeltas_val_2 = torch.randn(40, 64).to("cuda")
# # labels_val_2 = torch.randint(0, 2, (40, 64, 5000)).to("cuda") # Same binary multi-label setup
# labels_val_2 = torch.randint(0, 1, (40, 5)).to("cuda") # Same binary multi-label setup

# # Create DataLoader
# train_dataset_2 = TensorDataset(input_ids_train_2, adm_deltas_train_2, bdeltas_train_2, labels_train_2)
# train_loader_2 = DataLoader(train_dataset_2, batch_size=16, shuffle=True)

# val_dataset_2 = TensorDataset(input_ids_val_2, adm_deltas_val_2, bdeltas_val_2, labels_val_2)
# val_loader_2 = DataLoader(val_dataset_2, batch_size=8)

# # Initialize the model
# model = PancCancerModel(args).to("cuda")
# optimizer = optim.Adam(model.parameters(), lr=0.001)
# loss_function = nn.BCEWithLogitsLoss()

# # Store metrics for plotting
# train_losses = []
# val_losses = []

# # Training and Validation Loop
# for epoch in range(40):  # Fewer epochs than last time
#     model.train()
#     total_loss = 0
#     for inputs, adm_d, bd, labels in train_loader_2:
#         optimizer.zero_grad()
#         outputs = model(inputs, inputs, adm_d, bd)
#         loss = loss_function(outputs, labels.float())
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()
#     train_losses.append(total_loss / len(train_loader_2))

#     model.eval()
#     val_loss = 0
#     with torch.no_grad():
#         for inputs, adm_d, bd, labels in val_loader_2:
#             outputs = model(inputs, inputs, adm_d, bd)
#             loss = loss_function(outputs, labels.float())
#             val_loss += loss.item()
#     val_losses.append(val_loss / len(val_loader_2))

#     print(f'Epoch {epoch+1}: Train Loss {total_loss / len(train_loader_2):.4f}, Val Loss {val_loss / len(val_loader_2):.4f}')

# # Plotting training and validation loss
# plt.figure(figsize=(10, 5))
# plt.plot(train_losses, label='Training Loss')
# plt.plot(val_losses, label='Validation Loss')
# plt.title('Training and Validation Loss per Epoch')
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.legend()
# plt.grid(True)
# plt.show()


# Full Dataset

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
import torch.nn as nn
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt

dataset = HealthRecordDataset(data, codes_dict)
train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(test_data, batch_size=32, shuffle=False, collate_fn=collate_fn)

# Initialize the model
model = PancCancerModel(args).to("cuda")
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_function = nn.BCEWithLogitsLoss()

# Store metrics for plotting
train_losses = []
val_losses = []
train_aurocs = []
val_aurocs = []

# Training and Validation Loop
for epoch in range(50):
    model.train()
    total_loss = 0
    train_preds = []
    train_labels = []

    for x_pad, labels, adm_d, bd, x_len in train_loader:
        x_pad, labels, adm_d, bd, x_len = x_pad.to("cuda"), labels.to("cuda"), adm_d.to("cuda"), bd.to("cuda"), x_len.to("cuda")
        optimizer.zero_grad()
        outputs = model(x_pad, x_len, adm_d, bd)
        loss = loss_function(outputs, labels.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        train_preds.extend(torch.sigmoid(outputs).cpu().detach().numpy())
        train_labels.extend(labels.cpu().numpy())

    train_losses.append(total_loss / len(train_loader))
    train_aurocs.append(roc_auc_score(train_labels, train_preds))

    model.eval()
    val_loss = 0
    val_preds = []
    val_labels = []

    with torch.no_grad():
        for x_pad, labels, adm_d, bd, x_len in val_loader:
            x_pad, labels, adm_d, bd, x_len = x_pad.to("cuda"), labels.to("cuda"), adm_d.to("cuda"), bd.to("cuda"), x_len.to("cuda")
            outputs = model(x_pad, x_len, adm_d, bd)
            loss = loss_function(outputs, labels.float())
            val_loss += loss.item()
            val_preds.extend(torch.sigmoid(outputs).cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    val_losses.append(val_loss / len(val_loader))
    val_aurocs.append(roc_auc_score(val_labels, val_preds))

    print(f'Epoch {epoch+1}: Train Loss {total_loss / len(train_loader):.4f}, Train AUROC {train_aurocs[-1]:.4f}, Val Loss {val_loss / len(val_loader):.4f}, Val AUROC {val_aurocs[-1]:.4f}')

# Plotting training and validation metrics
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

ax1.plot(train_losses, label='Training Loss')
ax1.plot(val_losses, label='Validation Loss')
ax1.set_title('Training and Validation Loss per Epoch')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)

ax2.plot(train_aurocs, label='Training AUROC')
ax2.plot(val_aurocs, label='Validation AUROC')
ax2.set_title('Training and Validation AUROC per Epoch')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('AUROC')
ax2.legend()
ax2.grid(True)

plt.show()

Epoch 1: Train Loss 0.0562, Train AUROC 0.9884, Val Loss 0.5176, Val AUROC 0.8386
Epoch 2: Train Loss 0.0594, Train AUROC 0.9889, Val Loss 0.5195, Val AUROC 0.8406
Epoch 3: Train Loss 0.0481, Train AUROC 0.9912, Val Loss 0.5288, Val AUROC 0.8475
Epoch 4: Train Loss 0.0487, Train AUROC 0.9903, Val Loss 0.5319, Val AUROC 0.8518
Epoch 5: Train Loss 0.0424, Train AUROC 0.9923, Val Loss 0.5333, Val AUROC 0.8519
Epoch 6: Train Loss 0.0380, Train AUROC 0.9933, Val Loss 0.5690, Val AUROC 0.8500
Epoch 7: Train Loss 0.0385, Train AUROC 0.9950, Val Loss 0.6003, Val AUROC 0.8521


KeyboardInterrupt: 

In [None]:
from sklearn.metrics import precision_score, recall_score
def get_time_acc(model, val_loader):
    model.eval()
    acc = torch.zeros(5).to("cuda")
    prec = 0
    rec = 0
    with torch.no_grad():
        for x_pad, labels, adm_d, bd, x_len in val_loader:
            x_pad, labels, adm_d, bd, x_len = x_pad.to("cuda"), labels.to("cuda"), adm_d.to("cuda"), bd.to("cuda"), x_len.to("cuda")
            outputs = model(x_pad, x_len, adm_d, bd)
            outputs = torch.sigmoid(outputs)
            predictions = torch.round(outputs)
            acc    += torch.sum(predictions == labels, axis=0) / labels.shape[0]
            labels = labels.cpu().numpy()
            predictions = predictions.cpu().numpy()
            prec += precision_score(labels, predictions, average='macro', zero_division=0)
            rec += recall_score(labels, predictions, average='macro', zero_division=0)
    return acc / len(val_loader), prec / len(val_loader), rec / len(val_loader)

get_time_acc(model, val_loader)

(tensor([0.8635, 0.8635, 0.8452, 0.8274, 0.8166], device='cuda:0'),
 0.6654128821370201,
 0.662216748768473)