In [None]:
!pip install transformers
import os
import sys
folder_path = '/content/drive/MyDrive/HiCu-ICD-main'
sys.path.append('/content/drive/MyDrive/CS598_DLH_Project/HiCu-ICD-main')
sys.path.append('/content/drive/MyDrive/CS598_DLH_Project/HiCu-ICD-main/utils')

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m55.6 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m29.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m108.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.14.1 tokenizers-0.13.3 transformers-4.28.1


In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/HiCu-ICD-main/

Mounted at /content/drive
/content/drive/.shortcut-targets-by-id/1i33jnEExQNdrh2x4f7zRXTC8io_INa2u/HiCu-ICD-main


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_ as xavier_uniform
import numpy as np
from math import floor, sqrt
# creates word representations for a text classification task using CNN architecture.
class WordRep(nn.Module):
    def __init__(self, args, Y, dicts):
        super(WordRep, self).__init__()

        if args.embed_file:
            print("loading pretrained embeddings from {}".format(args.embed_file))
            if args.use_ext_emb:
                pretrain_word_embedding, pretrain_emb_dim = build_pretrain_embedding(args.embed_file, dicts['w2ind'],
                                                                                     True)
                W = torch.from_numpy(pretrain_word_embedding)
            else:
                W = torch.Tensor(load_embeddings(args.embed_file))

            self.embed = nn.Embedding(W.size()[0], W.size()[1], padding_idx=0)
            self.embed.weight.data = W.clone()
        else:
            # add 2 to include UNK and PAD
            self.embed = nn.Embedding(len(dicts['w2ind']) + 2, args.embed_size, padding_idx=0)
        self.feature_size = self.embed.embedding_dim

        self.embed_drop = nn.Dropout(p=args.dropout)

        self.conv_dict = {1: [self.feature_size, args.num_filter_maps],
                     2: [self.feature_size, 100, args.num_filter_maps],
                     3: [self.feature_size, 150, 100, args.num_filter_maps],
                     4: [self.feature_size, 200, 150, 100, args.num_filter_maps]
                     }


    def forward(self, x):
        features = [self.embed(x)]

        x = torch.cat(features, dim=2)

        x = self.embed_drop(x)
        return x

The ResidualBlock module provides a building block for a deeper 1D CNN architecture with residual connections, which can be used for a variety of sequence modeling tasks

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, kernel_size, stride, use_res, dropout):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv1d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=int(floor(kernel_size / 2)), bias=False),
            nn.BatchNorm1d(outchannel),
            nn.Tanh(),
            nn.Conv1d(outchannel, outchannel, kernel_size=kernel_size, stride=1, padding=int(floor(kernel_size / 2)), bias=False),
            nn.BatchNorm1d(outchannel)
        )

        self.use_res = use_res
        if self.use_res:
            self.shortcut = nn.Sequential(
                        nn.Conv1d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
                        nn.BatchNorm1d(outchannel)
                    )

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        out = self.left(x)
        if self.use_res:
            out += self.shortcut(x)
        out = torch.tanh(out)
        out = self.dropout(out)
        return out

In [None]:
class RandomlyInitializedDecoder(nn.Module):
    """
    The original per-label attention network: query matrix is randomly initialized
    """
    def __init__(self, args, Y, dicts, input_size):
        super(RandomlyInitializedDecoder, self).__init__()

        Y = Y[-1]

        self.U = nn.Linear(input_size, Y)
        xavier_uniform(self.U.weight)


        self.final = nn.Linear(input_size, Y)
        xavier_uniform(self.final.weight)

        self.loss_function = nn.BCEWithLogitsLoss()


    def forward(self, x, target, text_inputs):
        # attention
        alpha = F.softmax(self.U.weight.matmul(x.transpose(1, 2)), dim=2)

        m = alpha.matmul(x)

        y = self.final.weight.mul(m).sum(dim=2).add(self.final.bias)

        loss = self.loss_function(y, target)
        return y, loss, alpha, m
    
    def change_depth(self, depth=0):
        # placeholder
        pass

In [None]:
# The Decoder module computes a binary cross-entropy loss or an asymmetric loss between 
# the predicted logits and the target labels, depending on the specified loss function, 
# and returns the predicted logits, loss, attention weights, and context vector as output.
class Decoder(nn.Module):
    """
    Decoder: knowledge transfer initialization and hyperbolic embedding correction
    """
    def __init__(self, args, Y, dicts, input_size):
        super(Decoder, self).__init__()

        self.dicts = dicts

        self.decoder_dict = nn.ModuleDict()
        for i in range(len(Y)):
            y = Y[i]
            self.decoder_dict[str(i) + '_' + '0'] = nn.Linear(input_size, y)
            self.decoder_dict[str(i) + '_' + '1'] = nn.Linear(input_size, y)
            xavier_uniform(self.decoder_dict[str(i) + '_' + '0'].weight)
            xavier_uniform(self.decoder_dict[str(i) + '_' + '1'].weight)
        
        self.use_hyperbolic =  args.decoder.find("Hyperbolic") != -1
        if self.use_hyperbolic:
            self.cat_hyperbolic = args.cat_hyperbolic
            if not self.cat_hyperbolic:
                self.hyperbolic_fc_dict = nn.ModuleDict()
                for i in range(len(Y)):
                    self.hyperbolic_fc_dict[str(i)] = nn.Linear(args.hyperbolic_dim, input_size)
            else:
                self.query_fc_dict = nn.ModuleDict()
                for i in range(len(Y)):
                    self.query_fc_dict[str(i)] = nn.Linear(input_size + args.hyperbolic_dim, input_size)
            
            # build hyperbolic embedding matrix
            self.hyperbolic_emb_dict = {}
            for i in range(len(Y)):
                self.hyperbolic_emb_dict[i] = np.zeros((Y[i], args.hyperbolic_dim))
                for idx, code in dicts['ind2c'][i].items():
                    self.hyperbolic_emb_dict[i][idx, :] = np.copy(dicts['poincare_embeddings'].get_vector(code))
                self.register_buffer(name='hb_emb_' + str(i), tensor=torch.tensor(self.hyperbolic_emb_dict[i], dtype=torch.float32))

        self.cur_depth = 5 - args.depth
        self.is_init = False
        self.change_depth(self.cur_depth)

        if args.loss == 'BCE':
            self.loss_function = nn.BCEWithLogitsLoss()
        elif args.loss == 'ASL':
            asl_config = [float(c) for c in args.asl_config.split(',')]
            self.loss_function = AsymmetricLoss(gamma_neg=asl_config[0], gamma_pos=asl_config[1],
                                                clip=asl_config[2], reduction=args.asl_reduction)
        elif args.loss == 'ASLO':
            asl_config = [float(c) for c in args.asl_config.split(',')]
            self.loss_function = AsymmetricLossOptimized(gamma_neg=asl_config[0], gamma_pos=asl_config[1],
                                                         clip=asl_config[2], reduction=args.asl_reduction)
    
    def change_depth(self, depth=0):
        if self.is_init:
            # copy previous attention weights to current attention network based on ICD hierarchy
            ind2c = self.dicts['ind2c']
            c2ind = self.dicts['c2ind']
            hierarchy_dist = self.dicts['hierarchy_dist']
            for i, code in ind2c[depth].items():
                tree = hierarchy_dist[depth][code]
                pre_idx = c2ind[depth - 1][tree[depth - 1]]

                self.decoder_dict[str(depth) + '_' + '0'].weight.data[i, :] = self.decoder_dict[str(depth - 1) + '_' + '0'].weight.data[pre_idx, :].clone()
                self.decoder_dict[str(depth) + '_' + '1'].weight.data[i, :] = self.decoder_dict[str(depth - 1) + '_' + '1'].weight.data[pre_idx, :].clone()

        if not self.is_init:
            self.is_init = True

        self.cur_depth = depth
        
    def forward(self, x, target, text_inputs):
        # attention
        if self.use_hyperbolic:
            if not self.cat_hyperbolic:
                query = self.decoder_dict[str(self.cur_depth) + '_' + '0'].weight + self.hyperbolic_fc_dict[str(self.cur_depth)](self._buffers['hb_emb_' + str(self.cur_depth)])
            else:
                query = torch.cat([self.decoder_dict[str(self.cur_depth) + '_' + '0'].weight, self._buffers['hb_emb_' + str(self.cur_depth)]], dim=1)
                query = self.query_fc_dict[str(self.cur_depth)](query)
        else:
            query = self.decoder_dict[str(self.cur_depth) + '_' + '0'].weight

        alpha = F.softmax(query.matmul(x.transpose(1, 2)), dim=2)
        m = alpha.matmul(x)

        y = self.decoder_dict[str(self.cur_depth) + '_' + '1'].weight.mul(m).sum(dim=2).add(self.decoder_dict[str(self.cur_depth) + '_' + '1'].bias)

        loss = self.loss_function(y, target)
        
        return y, loss, alpha, m

In [None]:
# the MultiResCNN class provides a flexible and powerful architecture for multi-label text 
# classification tasks, with the ability to incorporate pre-trained embeddings, 
# hierarchical label structures, and hyperbolic embeddings.
class MultiResCNN(nn.Module):
    def __init__(self, args, Y, dicts):
        super(MultiResCNN, self).__init__()

        self.word_rep = WordRep(args, Y, dicts)

        self.conv = nn.ModuleList()
        filter_sizes = args.filter_size.split(',')

        self.filter_num = len(filter_sizes)
        for filter_size in filter_sizes:
            filter_size = int(filter_size)
            one_channel = nn.ModuleList()
            tmp = nn.Conv1d(self.word_rep.feature_size, self.word_rep.feature_size, kernel_size=filter_size,
                            padding=int(floor(filter_size / 2)))
            xavier_uniform(tmp.weight)
            one_channel.add_module('baseconv', tmp)
            one_channel.add_module('base_bn', nn.BatchNorm1d(self.word_rep.feature_size))

            conv_dimension = self.word_rep.conv_dict[args.conv_layer]
            for idx in range(args.conv_layer):
                tmp = ResidualBlock(conv_dimension[idx], conv_dimension[idx + 1], filter_size, 1, True,
                                    args.dropout)
                one_channel.add_module('resconv-{}'.format(idx), tmp)

            self.conv.add_module('channel-{}'.format(filter_size), one_channel)

        if args.decoder == "HierarchicalHyperbolic" or args.decoder == "Hierarchical":
            self.decoder = Decoder(args, Y, dicts, self.filter_num * args.num_filter_maps)
        elif args.decoder == "RandomlyInitialized":
            self.decoder = RandomlyInitializedDecoder(args, Y, dicts, self.filter_num * args.num_filter_maps)
        elif args.decoder == "CodeTitle":
            self.decoder = RACDecoder(args, Y, dicts, self.filter_num * args.num_filter_maps)
        else:
            raise RuntimeError("wrong decoder name")

        self.cur_depth = 5 - args.depth

    def forward(self, x, target, text_inputs):
        x = self.word_rep(x)

        x = x.transpose(1, 2)

        conv_result = []
        for conv in self.conv:
            tmp = x
            for idx, md in enumerate(conv):
                if idx % 2 == 0:
                    tmp = F.relu(md(tmp))
                else:
                    tmp = md(tmp)
            tmp = tmp.transpose(1, 2)
            conv_result.append(tmp)
        x = torch.cat(conv_result, dim=2)

        y, loss, alpha, m = self.decoder(x, target, text_inputs)

        return y, loss, alpha, m

    def freeze_net(self):
        for p in self.word_rep.embed.parameters():
            p.requires_grad = False

In [None]:
!python main.py \
    --MODEL_DIR /content/drive/MyDrive/HiCu-ICD-main/models \
    --DATA_DIR /content/drive/MyDrive/HiCu-ICD-main/data \
    --MIMIC_3_DIR /content/drive/MyDrive/HiCu-ICD-main/data/mimic3 \
    --data_path /content/drive/MyDrive/HiCu-ICD-main/data/mimic3/train_50.csv \
    --embed_file /content/drive/MyDrive/HiCu-ICD-main/data/mimic3/processed_full_100.embed \
    --vocab /content/drive/MyDrive/HiCu-ICD-main/data/mimic3/vocab.csv \
    --Y full \
    --model MultiResCNN \
    --decoder RandomlyInitialized \
    --criterion prec_at_8 \
    --MAX_LENGTH 1024 \
    --batch_size 2  \
    --lr 5e-5 \
    --depth 5 \
    --n_epochs '2,3,3,3,5'  \
    --num_workers 8 

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
Namespace(MODEL_DIR='/content/drive/MyDrive/HiCu-ICD-main/models', DATA_DIR='/content/drive/MyDrive/HiCu-ICD-main/data', MIMIC_3_DIR='/content/drive/MyDrive/HiCu-ICD-main/data/mimic3', MIMIC_2_DIR='./data/mimic2', data_path='/content/drive/MyDrive/HiCu-ICD-main/data/mimic3/train_50.csv', vocab='/content/drive/MyDrive/HiCu-ICD-main/data/mimic3/vocab.csv', Y='full', version='mimic3', MAX_LENGTH=1024, model='MultiResCNN', decoder='RandomlyInitialized', filter_size='3,5,9,15,19,25', num_filter_maps=50, conv_layer=1, embed_file='/content/drive/MyDrive/HiCu-ICD-main/data/mimic3/processed_full_100.embed', hyperbolic_dim=50, test_model=None, use_ext_emb=False, cat_hyperbolic=False, loss='BCE', asl_config='0,0,0', asl_reduction='sum', n_epochs='2,3,3,3,5', depth=5, dropout=0.2, patience=10, batch_size=2, lr=5e-05, weight_decay=0, criterion='prec_at_8', gpu='0', num_workers=8, tune_wordemb=T