In [66]:
import torch

In [67]:
A = torch.Tensor([
    [0, 1, 0, 10, 100, 0],
    [1, 0, 0, 10, 0, 0]
])

embeddings = torch.Tensor([
    [0, 1, 1,],
    [2, 3, 3],
    [4, 5, 5],
    [6, 7, 7],
    [8 ,9, 9],
])

def add_pad_embedding(embeddings):
    return torch.cat([embeddings, torch.zeros(1, embeddings.shape[1])])

embeddings = add_pad_embedding(embeddings)
embeddings = torch.nn.Embedding.from_pretrained(embeddings, freeze=False, padding_idx=-1)
display(embeddings)
display(embeddings.weight)
display(A)

Embedding(6, 3, padding_idx=5)

Parameter containing:
tensor([[0., 1., 1.],
        [2., 3., 3.],
        [4., 5., 5.],
        [6., 7., 7.],
        [8., 9., 9.],
        [0., 0., 0.]], requires_grad=True)

tensor([[  0.,   1.,   0.,  10., 100.,   0.],
        [  1.,   0.,   0.,  10.,   0.,   0.]])

In [68]:
'''
1. turn nonzeros into indices
2. use indices as in normal transformer to get embedding,
pretty much like token indices - should I include padding tokens?
3. but also somehow use these indices to extract the output
4. it may be that I can also accomplish this via masking
'''
def pad(vec, seq_len, pad_token=-1):
    return torch.cat([
        vec, torch.Tensor([pad_token]).expand(seq_len - vec.size()[0])
    ])

def select_nonzero(A, seq_len=10, pad_token=-1):
    seqs = []
    for row in A:
        nonzeros = pad(row.nonzero().T[0], seq_len, pad_token)
        seqs.append(nonzeros)
    return torch.stack(seqs)

def select_nonzero_vec(seq, seq_len=10, pad_token=-1):
    return seq.gather(1, (seq == 0).sort(dim=1, stable=True)[1])

def index_tensor(seq, indices):
    return seq.T[indices.long()].permute(0,2,1).diagonal().T

In [70]:
display(A)
A.nonzero()

tensor([[  0.,   1.,   0.,  10., 100.,   0.],
        [  1.,   0.,   0.,  10.,   0.,   0.]])

tensor([[0, 1],
        [0, 3],
        [0, 4],
        [1, 0],
        [1, 3]])

In [25]:
nonzeros = select_nonzero(A)
nonzeros_vec = select_nonzero_vec(A)
display(A)
display(nonzeros)
display(nonzeros_vec)

tensor([[  0.,   1.,   0.,  10., 100.,   0.],
        [  1.,   0.,   0.,  10.,   0.,   0.]])

tensor([[ 1.,  3.,  4., -1., -1., -1., -1., -1., -1., -1.],
        [ 0.,  3., -1., -1., -1., -1., -1., -1., -1., -1.]])

tensor([[  1.,  10., 100.,   0.,   0.,   0.],
        [  1.,  10.,   0.,   0.,   0.,   0.]])

In [26]:
A

tensor([[  0.,   1.,   0.,  10., 100.,   0.],
        [  1.,   0.,   0.,  10.,   0.,   0.]])

In [27]:
nonzeros

tensor([[ 1.,  3.,  4., -1., -1., -1., -1., -1., -1., -1.],
        [ 0.,  3., -1., -1., -1., -1., -1., -1., -1., -1.]])

In [34]:
index_attempt = A.T[nonzeros.long()]

In [49]:
index_attempt.permute(0,2,1).diagonal().T

tensor([[  1.,  10., 100.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  1.,  10.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.]])

In [52]:
index_tensor(A, nonzeros)

tensor([[  1.,  10., 100.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  1.,  10.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.]])

In [54]:
nonzeros = select_nonzero(A, pad_token=embeddings.weight.size()[0]-1)
nonzeros

tensor([[1., 3., 4., 5., 5., 5., 5., 5., 5., 5.],
        [0., 3., 5., 5., 5., 5., 5., 5., 5., 5.]])

torch.Size([2, 10])

In [55]:
squished_embeddings = embeddings(nonzeros.long())
squished_embeddings

tensor([[[2., 3., 3.],
         [6., 7., 7.],
         [8., 9., 9.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 1., 1.],
         [6., 7., 7.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]], grad_fn=<EmbeddingBackward0>)

In [57]:
squished_embeddings.shape

torch.Size([2, 10, 3])

In [49]:
display(squished_embeddings)
display(nonzeros)

tensor([[[2., 3., 3.],
         [6., 7., 7.],
         [8., 9., 9.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 1., 1.],
         [6., 7., 7.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]], grad_fn=<EmbeddingBackward>)

tensor([[1., 3., 4., 5., 5., 5., 5., 5., 5., 5.],
        [0., 3., 5., 5., 5., 5., 5., 5., 5., 5.]])

In [59]:
squished_embeddings * index_tensor(A, nonzeros).unsqueeze(-1)

tensor([[[  2.,   3.,   3.],
         [ 60.,  70.,  70.],
         [800., 900., 900.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.]],

        [[  0.,   1.,   1.],
         [ 60.,  70.,  70.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.]]], grad_fn=<MulBackward0>)

In [62]:
squished_embeddings * nonzeros.unsqueeze(-1)

tensor([[[ 2.,  3.,  3.],
         [18., 21., 21.],
         [32., 36., 36.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.]],

        [[ 0.,  0.,  0.],
         [18., 21., 21.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.]]], grad_fn=<MulBackward0>)

In [58]:
squished_embeddings.shape

torch.Size([2, 10, 3])

In [71]:
from transformers import DistilBertModel
model = DistilBertModel.from_pretrained('distilbert-base-uncased')

Downloading config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/256M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [80]:
model.embeddings.word_embeddings = torch.nn.Embedding(115000, 768)

In [83]:
model(torch.Tensor([[1, 3]).long())
                    p0

BaseModelOutput(last_hidden_state=tensor([[[ 0.0990,  0.2195,  0.1753,  ...,  0.2309,  0.0430,  0.2853],
         [ 0.3281,  0.2556,  0.2306,  ...,  0.1400, -0.0809,  0.3876]]],
       grad_fn=<NativeLayerNormBackward0>), hidden_states=None, attentions=None)