In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# (1) Create Datasets

In [2]:
from src.parameters import Parameters
from src.features import Featureset
from src.data_loaders import GroupDataset, GroupDataCollator

#### View Raw Data

In [3]:
# import pandas as pd

# raw_data = pd.read_csv('datasets/ehr/data_ver2.csv')
# raw_data.head()

#### Configure Data

In [4]:
data_config = {
    'pat_id':Parameters(mode='uid', vector=None),
    'label':Parameters(mode='categorical', vector=None),
    'race':Parameters(mode='categorical', vector='embedding'),
    'ethnic':Parameters(mode='categorical', vector='embedding'),
    'sex':Parameters(mode='categorical', vector='onehot'),
    'marital_status':Parameters(mode='categorical', vector='onehot'),
    'age':Parameters(mode='numerical', vector='linear'),
    'time':Parameters(mode='datetime', vector=None)
}

#### Create and Save Train/Test Sets

In [5]:
from src.ehr_utils import EHRFeature

In [6]:
_features_ehr = Featureset('datasets/ehr', data_config)
_features_ehr.load_ehr('data_ver2.csv')
_features_ehr.create_new_dataset(train_split=0.9, tag=1)

loading dataframe
converting datetimes
updating config
creating features


#### Load and Explore Data

In [7]:
import torch

features = Featureset('datasets/ehr')
features.load_dataset(tag=1)
                          
dataset = GroupDataset(features.test_features)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn = GroupDataCollator(features.config))
data = next(iter(data_loader))

{k:v.shape for k,v in data.items()}

{'label': torch.Size([2, 1]),
 'race': torch.Size([2, 1]),
 'ethnic': torch.Size([2, 1]),
 'sex': torch.Size([2, 1, 4]),
 'marital_status': torch.Size([2, 1, 10]),
 'age': torch.Size([2, 1, 1]),
 'srel': torch.Size([2, 163, 1]),
 'wday': torch.Size([2, 163]),
 'seq_mask': torch.Size([2, 163])}

In [8]:
{k:features.config[k] for k,v in data.items() if k in features.config}

{'label': Parameters(mode='categorical', vector=None, sub_mode=None, token_map={0: 0, 1: 1}, size=2, min=None, max=None, range=None, max_len=None, d_layer_1=None, d_layer_2=None, d_layer_3=None, d_layer_4=None),
 'race': Parameters(mode='categorical', vector='embedding', sub_mode=None, token_map={0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7}, size=8, min=None, max=None, range=None, max_len=None, d_layer_1=None, d_layer_2=None, d_layer_3=None, d_layer_4=None),
 'ethnic': Parameters(mode='categorical', vector='embedding', sub_mode=None, token_map={0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8}, size=9, min=None, max=None, range=None, max_len=None, d_layer_1=None, d_layer_2=None, d_layer_3=None, d_layer_4=None),
 'sex': Parameters(mode='categorical', vector='onehot', sub_mode=None, token_map={1: 0, 2: 1, 0: 2, 3: 3}, size=4, min=None, max=None, range=None, max_len=None, d_layer_1=None, d_layer_2=None, d_layer_3=None, d_layer_4=None),
 'marital_status': Parameters(mode='categorica

# (2) Explore Group Model

In [9]:
from src.modules import MergeLayer, FlattenLayer, GroupLayer, GroupModel
from src.modules import NonLinear, EmbeddingNonLinear
from src.modules import GRU, EmbeddingGRU

In [10]:
model = GroupModel({
    'embed':GroupLayer({
        'race':EmbeddingNonLinear(8, 4),
        'ethnic':EmbeddingNonLinear(9, 4),
        'sex':NonLinear(4, 8, 4),
        'marital_status':NonLinear(10, 8, 4),
        'age':NonLinear(1, 2, 4),
        'srel':GRU(1, 2, 4),
        'wday':EmbeddingGRU(8, 4, 2, 4)
    }), 
    
    'merge':MergeLayer({
        'bkgd':('race', 'ethnic'),
        'love':('sex', 'marital_status'),
        'age':('age', ),
        'times':('srel', 'wday')
    }),
    
    'groups_1':GroupLayer({
        'bkgd':NonLinear(8, 16, 8),
        'love':NonLinear(8, 16, 8),
        'age':NonLinear(4, 8, 4),
        'times':NonLinear(8, 6, 4)
    }),
    
    'flat':FlattenLayer('times'),
    
    'proj':MergeLayer({
        'pred':('bkgd', 'love', 'age', 'times'),
    }), 
    
    'pred':GroupLayer({
        'pred':NonLinear(24, 64, 2)
    })
})

model.eval()
pass;

#### Example Prediction

In [11]:
model(data)

tensor([[0.1116, 0.0681],
        [0.1044, 0.0642]], grad_fn=<AddmmBackward>)

#### Run data through first layer 'vars'

In [12]:
model.layers['embed'](data)

{'race': tensor([[[-0.1746, -0.2779,  0.1017,  0.2228]],
 
         [[-0.1746, -0.2779,  0.1017,  0.2228]]], grad_fn=<AddBackward0>),
 'ethnic': tensor([[[ 0.0458,  0.2116, -0.2210, -0.2679]],
 
         [[ 0.0458,  0.2116, -0.2210, -0.2679]]], grad_fn=<AddBackward0>),
 'sex': tensor([[[-0.2944,  0.4751,  0.4800, -0.1711]],
 
         [[-0.1906,  0.2703,  0.4496, -0.0269]]], grad_fn=<AddBackward0>),
 'marital_status': tensor([[[ 0.0701, -0.3316, -0.1484,  0.0023]],
 
         [[ 0.1060, -0.4011, -0.0540, -0.0918]]], grad_fn=<AddBackward0>),
 'age': tensor([[[-0.4764, -0.6471, -0.5324, -0.5595]],
 
         [[-0.4530, -0.6569, -0.5448, -0.5585]]], grad_fn=<AddBackward0>),
 'srel': tensor([[[ 0.7127, -0.4101,  0.0724, -0.2590]],
 
         [[-0.1450,  0.2946, -1.0849, -0.3482]]], grad_fn=<MeanBackward1>),
 'wday': tensor([[[ 0.4433, -0.3746,  0.1728, -0.5243]],
 
         [[ 0.2181, -0.2473,  0.2206, -0.4935]]], grad_fn=<MeanBackward1>)}

#### Run data through first two layers 'vars', 'groups_1'

In [13]:
model.layers['merge'](model.layers['embed'](data))

{'bkgd': tensor([[[-0.1746, -0.2779,  0.1017,  0.2228,  0.0458,  0.2116, -0.2210,
           -0.2679]],
 
         [[-0.1746, -0.2779,  0.1017,  0.2228,  0.0458,  0.2116, -0.2210,
           -0.2679]]], grad_fn=<CatBackward>),
 'love': tensor([[[-0.2944,  0.4751,  0.4800, -0.1711,  0.0701, -0.3316, -0.1484,
            0.0023]],
 
         [[-0.1906,  0.2703,  0.4496, -0.0269,  0.1060, -0.4011, -0.0540,
           -0.0918]]], grad_fn=<CatBackward>),
 'age': tensor([[[-0.4764, -0.6471, -0.5324, -0.5595]],
 
         [[-0.4530, -0.6569, -0.5448, -0.5585]]], grad_fn=<CatBackward>),
 'times': tensor([[[ 0.2144, -0.0826, -0.5119, -0.7036,  0.5094, -0.3008, -0.0827,
           -0.6575]],
 
         [[ 0.0427,  0.2964, -0.8879, -0.2301,  0.0534, -0.0280, -0.1910,
           -0.1676]]], grad_fn=<CatBackward>)}

#### Run 'race' through 'race' block of first layer

In [14]:
model.layers['embed'].blocks['race'](data['race'])

tensor([[[-0.1746, -0.2779,  0.1017,  0.2228]],

        [[-0.1746, -0.2779,  0.1017,  0.2228]]], grad_fn=<AddBackward0>)

#### Run 'srel' through 'srel' block of first layer

In [15]:
model.layers['embed'].blocks['srel'](data['srel'], data['seq_mask'])

tensor([[[ 0.1734,  0.1998, -1.0488, -0.2826]],

        [[ 0.2049, -0.0022, -0.4994,  0.2558]]], grad_fn=<MeanBackward1>)

# (3) Train Group Model

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import datetime

In [17]:
features = Featureset('datasets/ehr')
features.load_dataset(tag=1)

trainloader = torch.utils.data.DataLoader(
    GroupDataset(features.train_features), batch_size=100, shuffle=True, 
    collate_fn=GroupDataCollator(features.config)
)

testloader = torch.utils.data.DataLoader(
    GroupDataset(features.test_features), batch_size=100, shuffle=False, 
    collate_fn=GroupDataCollator(features.config)
)

In [18]:
print_every = 2
epochs = 10

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
criterion = nn.CrossEntropyLoss()

In [19]:
time_start = datetime.datetime.now()
print('Start Time: %s'%time_start.strftime('%H:%M:%S'))

for epoch in range(0, epochs):
    epoch_start = datetime.datetime.now()
    model.train();
    train_loss = 0.0
    test_loss = 0.0
    train_nbatches = 0
    test_nbatches = 0
    
    for inputs in trainloader:
        preds = model(inputs)

        loss = criterion(preds, inputs['label'].squeeze())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        train_loss += float(loss)
        train_nbatches += 1
    
    model.eval();
    with torch.no_grad():
        for inputs in testloader:
            preds = model(inputs)
            loss = criterion(preds, inputs['label'].squeeze())
            test_loss += float(loss)
            test_nbatches += 1

    train_loss/=train_nbatches
    test_loss/=test_nbatches
    
    if epoch%print_every == 0:
        print('Epoch {} || Train Loss: {:.3f} || Test Loss: {:.3f}'.format(
            str(epoch).zfill(3), train_loss, test_loss)
             )
time_finish = datetime.datetime.now()
print('End Time: %s'%time_finish.strftime('%H:%M:%S'))
print('Completed in %s seconds'%(time_finish-time_start).total_seconds())
pass;

Start Time: 11:10:25
Epoch 000 || Train Loss: 0.681 || Test Loss: 0.673
Epoch 002 || Train Loss: 0.646 || Test Loss: 0.655
Epoch 004 || Train Loss: 0.641 || Test Loss: 0.646
Epoch 006 || Train Loss: 0.635 || Test Loss: 0.642
Epoch 008 || Train Loss: 0.632 || Test Loss: 0.634
End Time: 11:11:13
Completed in 47.968401 seconds


In [20]:
model.eval()
pass;

In [21]:
blah

NameError: name 'blah' is not defined

# Scratch (Attention)

In [None]:
def rel_shift(x, klen=-1):
    """perform relative shift to form the relative attention score."""
    x_size = x.shape

    x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
    x = x[1:, ...]
    x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
    x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))
    return x

def rel_shift_bnij(x, klen=-1):
    x_size = x.shape
    x = x.reshape(x_size[0], x_size[1], x_size[3], x_size[2])
    x = x[:, :, 1:, :]
    x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3] - 1)
    x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
    return x

In [None]:
import numpy as np

def get_n_heads(d, d_head):
    if d<d_head:
        return 1
    else:
        return int(np.ceil(d/d_head))

class GroupAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.scale = 1 / (config['d_head'] ** 0.5)
        
        
        q,z,r = self.init_queries(config['q'], config['d_head'])
        k,v = self.init_keys(config['k'], config['d_head'])
        
        self.q = nn.Parameter(q)
        self.z = nn.Parameter(z)
        self.r = nn.Parameter(r)
        self.k = nn.Parameter(k)
        self.v = nn.Parameter(v)
#         self.r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
        self.w_bias = nn.Parameter(torch.FloatTensor(1, config['d_head']))

        self.layer_norm = nn.LayerNorm(12, eps=1e-5)
        self.dropout = nn.Dropout(0.0)
        
    def init_queries(self, q_config, d_head):
        q_dims = list(q_config.values())
        n_heads = [get_n_heads(d, d_head) for d in q_dims]
        
        self.q_head_view = dict(
            zip(q_config.keys(),
                map(
                    lambda x: slice(*x), 
                    zip(
                        [0]+n_heads, 
                        [sum(n_heads[:i]) for i in range(1, len(n_heads)+1)]
                    )
                )
            )
        )
        q = torch.FloatTensor(sum(q_dims), sum(n_heads), d_head).uniform_(-0.1, 0.1)
        z = torch.FloatTensor(sum(q_dims), sum(n_heads), d_head).uniform_(-0.1, 0.1)
        r = torch.FloatTensor(sum(q_dims), sum(n_heads), d_head).uniform_(-0.1, 0.1)
#         print(q)
        return q, z, r
    
    def init_keys(self, k_config, d_head):
        k_dims = list(k_config.values())
        n_heads = [get_n_heads(d, d_head) for d in k_dims]
        
        self.k_head_view = dict(
            zip(k_config.keys(),
                map(
                    lambda x: slice(*x), 
                    zip(
                        [0]+n_heads, 
                        [sum(n_heads[:i]) for i in range(1, len(n_heads)+1)]
                    )
                )
            )
        )
        k = torch.FloatTensor(sum(k_dims), sum(n_heads), d_head).uniform_(-0.1, 0.1)
        v = torch.FloatTensor(sum(k_dims), sum(n_heads), d_head).uniform_(-0.1, 0.1)
        return k, v
        

    @staticmethod
    def rel_shift(x, klen=-1):
        """perform relative shift to form the relative attention score."""
        x_size = x.shape

        x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
        x = x[1:, ...]
        x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
        x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))
        return x

    @staticmethod
    def rel_shift_bnij(x, klen=-1):
        x_size = x.shape
        x = x.reshape(x_size[0], x_size[1], x_size[3], x_size[2])
        x = x[:, :, 1:, :]
        x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3] - 1)
        x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
        return x

    def rel_attn_core(
        self,
        q_head,
        k_head_h,
        v_head_h,
        k_head_r=None,
        attn_mask=None,
        head_mask=None,
        output_attentions=False,
    ):
        """Core relative positional attention operations."""

        # content based attention score
        ac = torch.einsum("ibnd,jbnd->bnij", q_head + self.w_bias, k_head_h)

        # position based attention score
#         bd = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_bias, k_head_r)
#         bd = self.rel_shift_bnij(bd, klen=ac.shape[3])
        bd = 0.

        # merge attention scores and perform masking
        attn_score = (ac + bd) * self.scale
        if attn_mask is not None:
            # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
            if attn_mask.dtype == torch.float16:
                attn_score = attn_score - 65500 * torch.einsum("ijbn->bnij", attn_mask)
            else:
                attn_score = attn_score - 1e30 * torch.einsum("ijbn->bnij", attn_mask)

        # attention probability
        attn_prob = F.softmax(attn_score, dim=3)
        attn_prob = self.dropout(attn_prob)

        # Mask heads if we want to
        if head_mask is not None:
            attn_prob = attn_prob * torch.einsum("ijbn->bnij", head_mask)

        # attention output
        attn_vec = torch.einsum("bnij,jbnd->ibnd", attn_prob, v_head_h)

        if output_attentions:
            return attn_vec, torch.einsum("bnij->ijbn", attn_prob)

        return attn_vec

    def post_attention(self, h, attn_vec, residual=True):
        """Post-attention processing."""
        # post-attention projection (back to `d_model`)
        attn_out = torch.einsum("ibnd,hnd->ibh", attn_vec, self.z)

        attn_out = self.dropout(attn_out)
        if residual:
            attn_out = attn_out + h
        output = self.layer_norm(attn_out)

        return output

    def forward(
        self,
        h,
        attn_mask_h=None,
        r=None,
        head_mask=None,
        output_attentions=False,
    ):
        # Multi-head attention with relative positional encoding
        # content heads
        h = torch.cat([data_emb[k] for k in attn_config_1['q'].keys()], dim=-1)
        g = torch.cat([data_emb[k] for k in attn_config_1['k'].keys()], dim=-1)
        q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
        k_head_h = torch.einsum("ibh,hnd->ibnd", g, self.k)
        v_head_h = torch.einsum("ibh,hnd->ibnd", g, self.v)

        # positional heads
        # type casting for fp16 support
#         k_head_r = torch.einsum("ibh,hnd->ibnd", r.type(self.r.dtype), self.r)

        # core attention ops
        attn_vec = self.rel_attn_core(
            q_head_h,
            k_head_h,
            v_head_h,
            k_head_r=None,
            attn_mask=attn_mask_h,
            head_mask=head_mask,
            output_attentions=output_attentions,
        )

        if output_attentions:
            attn_vec, attn_prob = attn_vec

        # post processing
        output_h = self.post_attention(h, attn_vec)

        outputs = (output_h,)
        if output_attentions:
            outputs = outputs + (attn_prob,)
        return outputs

In [None]:
train_features, test_features, config = load_dataset(1)
dataset = GroupDataset(train_features)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, 
                                         collate_fn=GroupDataCollator(config))
data = next(iter(dataloader))

In [None]:
testemb = GroupLayer({
        'race':nn.Embedding(8, 4, padding_idx=0),
        'ethnic':nn.Embedding(9, 8, padding_idx=0),
        'sex':nn.Linear(5, 4),
        'marital_status':nn.Linear(11, 8),
        'age':nn.Linear(1, 4),
        'srel':nn.Linear(1, 4),
        'wday':nn.Embedding(9, 8)
    })

data_emb = testemb(data)

In [None]:
{k:v.shape for k,v in data_emb.items()}

In [None]:
attn_config_1 = {
    'q':{'srel':4, 'wday':8},
    'k':{'srel':4, 'wday':8},
    'd_head':6
}

testattn = GroupAttention(attn_config_1)

In [None]:
testattn(data_emb, output_attentions=True)[0].shape
testattn(data_emb, output_attentions=True)[1].shape
testattn.q_head_view

In [None]:
class XLNetLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.rel_attn = XLNetRelativeAttention(config)
        self.ff = XLNetFeedForward(config)
        self.dropout = nn.Dropout(config.dropout)

    def forward(
        self,
        output_h,
        output_g,
        attn_mask_h,
        attn_mask_g,
        r,
        seg_mat,
        mems=None,
        target_mapping=None,
        head_mask=None,
        output_attentions=False,
    ):
        outputs = self.rel_attn(
            output_h,
            attn_mask_h,
            r,
            seg_mat,
            mems=mems,
            target_mapping=target_mapping,
            head_mask=head_mask,
            output_attentions=output_attentions,
        )
        output_h = outputs[:1]

        output_h = self.ff(output_h)

        outputs = (output_h) + outputs[2:]  # Add again attentions if there are there
        return outputs

In [None]:
class XLNetModel(XLNetPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.mem_len = config.mem_len
        self.reuse_len = config.reuse_len
        self.d_model = config.d_model
        self.same_length = config.same_length
        self.attn_type = config.attn_type
        self.bi_data = config.bi_data
        self.clamp_len = config.clamp_len
        self.n_layer = config.n_layer

        self.word_embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.mask_emb = nn.Parameter(torch.FloatTensor(1, 1, config.d_model))
        self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)])
        self.dropout = nn.Dropout(config.dropout)

        self.init_weights()

    def get_input_embeddings(self):
        return self.word_embedding

    def set_input_embeddings(self, new_embeddings):
        self.word_embedding = new_embeddings

    def _prune_heads(self, heads_to_prune):
        raise NotImplementedError

    def create_mask(self, qlen, mlen):
        """
        Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.

        Args:
            qlen: Sequence length
            mlen: Mask length

        ::

                  same_length=False:      same_length=True:
                  <mlen > <  qlen >       <mlen > <  qlen >
               ^ [0 0 0 0 0 1 1 1 1]     [0 0 0 0 0 1 1 1 1]
                 [0 0 0 0 0 0 1 1 1]     [1 0 0 0 0 0 1 1 1]
            qlen [0 0 0 0 0 0 0 1 1]     [1 1 0 0 0 0 0 1 1]
                 [0 0 0 0 0 0 0 0 1]     [1 1 1 0 0 0 0 0 1]
               v [0 0 0 0 0 0 0 0 0]     [1 1 1 1 0 0 0 0 0]

        """
        attn_mask = torch.ones([qlen, qlen])
        mask_up = torch.triu(attn_mask, diagonal=1)
        attn_mask_pad = torch.zeros([qlen, mlen])
        ret = torch.cat([attn_mask_pad, mask_up], dim=1)
        if self.same_length:
            mask_lo = torch.tril(attn_mask, diagonal=-1)
            ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)

        ret = ret.to(self.device)
        return ret

    def cache_mem(self, curr_out, prev_mem):
        # cache hidden states into memory.
        if self.reuse_len is not None and self.reuse_len > 0:
            curr_out = curr_out[: self.reuse_len]

        if self.mem_len is None or self.mem_len == 0:
            # If :obj:`use_cache` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time
            # and returns all of the past and current hidden states.
            cutoff = 0
        else:
            # If :obj:`use_cache` is active and `mem_len` is defined, the model returns the last `mem_len` hidden
            # states. This is the preferred setting for training and long-form generation.
            cutoff = -self.mem_len
        if prev_mem is None:
            # if :obj:`use_cache` is active and `mem_len` is defined, the model
            new_mem = curr_out[cutoff:]
        else:
            new_mem = torch.cat([prev_mem, curr_out], dim=0)[cutoff:]

        return new_mem.detach()

    @staticmethod
    def positional_embedding(pos_seq, inv_freq, bsz=None):
        sinusoid_inp = torch.einsum("i,d->id", pos_seq, inv_freq)
        pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
        pos_emb = pos_emb[:, None, :]

        if bsz is not None:
            pos_emb = pos_emb.expand(-1, bsz, -1)

        return pos_emb

    def relative_positional_encoding(self, qlen, klen, bsz=None):
        # create relative positional encoding.
        freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float)
        inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))

        if self.attn_type == "bi":
            # beg, end = klen - 1, -qlen
            beg, end = klen, -qlen
        elif self.attn_type == "uni":
            # beg, end = klen - 1, -1
            beg, end = klen, -1
        else:
            raise ValueError("Unknown `attn_type` {}.".format(self.attn_type))

        if self.bi_data:
            fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.float)
            bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.float)

            if self.clamp_len > 0:
                fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
                bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)

            if bsz is not None:
                fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
                bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
            else:
                fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
                bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)

            pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)
        else:
            fwd_pos_seq = torch.arange(beg, end, -1.0)
            if self.clamp_len > 0:
                fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
            pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)

        pos_emb = pos_emb.to(self.device)
        return pos_emb

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        mems=None,
        perm_mask=None,
        target_mapping=None,
        token_type_ids=None,
        input_mask=None,
        head_mask=None,
        inputs_embeds=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)

        # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
        # but we want a unified interface in the library with the batch size on the first dimension
        # so we move here the first dimension (batch) to the end
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_ids = input_ids.transpose(0, 1).contiguous()
            qlen, bsz = input_ids.shape[0], input_ids.shape[1]
        elif inputs_embeds is not None:
            inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
            qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
        input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
        attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
        perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
        target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None

        mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
        klen = mlen + qlen

        dtype_float = self.dtype
        device = self.device

        # Attention mask
        # causal attention mask
        if self.attn_type == "uni":
            attn_mask = self.create_mask(qlen, mlen)
            attn_mask = attn_mask[:, :, None, None]
        elif self.attn_type == "bi":
            attn_mask = None
        else:
            raise ValueError("Unsupported attention type: {}".format(self.attn_type))

        # data mask: input mask & perm mask
        assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) "
        "or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
        if input_mask is None and attention_mask is not None:
            input_mask = 1.0 - attention_mask
        if input_mask is not None and perm_mask is not None:
            data_mask = input_mask[None] + perm_mask
        elif input_mask is not None and perm_mask is None:
            data_mask = input_mask[None]
        elif input_mask is None and perm_mask is not None:
            data_mask = perm_mask
        else:
            data_mask = None

        if data_mask is not None:
            # all mems can be attended to
            if mlen > 0:
                mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
                data_mask = torch.cat([mems_mask, data_mask], dim=1)
            if attn_mask is None:
                attn_mask = data_mask[:, :, :, None]
            else:
                attn_mask += data_mask[:, :, :, None]

        if attn_mask is not None:
            attn_mask = (attn_mask > 0).to(dtype_float)

        if attn_mask is not None:
            non_tgt_mask = -torch.eye(qlen).to(attn_mask)
            if mlen > 0:
                non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
            non_tgt_mask = ((attn_mask + non_tgt_mask[:, :, None, None]) > 0).to(attn_mask)
        else:
            non_tgt_mask = None

        # Word embeddings and prepare h & g hidden states
        if inputs_embeds is not None:
            word_emb_k = inputs_embeds
        else:
            word_emb_k = self.word_embedding(input_ids)
        output_h = self.dropout(word_emb_k)
        if target_mapping is not None:
            word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
            # else:  # We removed the inp_q input which was same as target mapping
            #     inp_q_ext = inp_q[:, :, None]
            #     word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
            output_g = self.dropout(word_emb_q)
        else:
            output_g = None

        # Segment embedding
        if token_type_ids is not None:
            # Convert `token_type_ids` to one-hot `seg_mat`
            if mlen > 0:
                mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
                cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)
            else:
                cat_ids = token_type_ids

            # `1` indicates not in the same segment [qlen x klen x bsz]
            seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()
            seg_mat = F.one_hot(seg_mat, num_classes=2).to(dtype_float)
        else:
            seg_mat = None

        # Positional encoding
        pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
        pos_emb = self.dropout(pos_emb)

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
        # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
                head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
            head_mask = head_mask.to(
                dtype=next(self.parameters()).dtype
            )  # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.n_layer

        new_mems = ()
        if mems is None:
            mems = [None] * len(self.layer)

        attentions = [] if output_attentions else None
        hidden_states = [] if output_hidden_states else None
        for i, layer_module in enumerate(self.layer):
            if use_cache:
                # cache new mems
                new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
            if output_hidden_states:
                hidden_states.append((output_h, output_g) if output_g is not None else output_h)

            outputs = layer_module(
                output_h,
                output_g,
                attn_mask_h=non_tgt_mask,
                attn_mask_g=attn_mask,
                r=pos_emb,
                seg_mat=seg_mat,
                mems=mems[i],
                target_mapping=target_mapping,
                head_mask=head_mask[i],
                output_attentions=output_attentions,
            )
            output_h, output_g = outputs[:2]
            if output_attentions:
                attentions.append(outputs[2])

        # Add last hidden state
        if output_hidden_states:
            hidden_states.append((output_h, output_g) if output_g is not None else output_h)

        output = self.dropout(output_g if output_g is not None else output_h)

        # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
        output = output.permute(1, 0, 2).contiguous()

        if not use_cache:
            new_mems = None

        if output_hidden_states:
            if output_g is not None:
                hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
            else:
                hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)

        if output_attentions:
            if target_mapping is not None:
                # when target_mapping is provided, there are 2-tuple of attentions
                attentions = tuple(
                    tuple(att_stream.permute(2, 3, 0, 1).contiguous() for att_stream in t) for t in attentions
                )
            else:
                attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)

        if not return_dict:
            return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None)

        return XLNetModelOutput(
            last_hidden_state=output, mems=new_mems, hidden_states=hidden_states, attentions=attentions
        )