In [7]:
import torch
import matchbox
from torch import nn
from torch.nn import functional as F
from torchtext import data, datasets
from matchbox.data import MaskedBatchField

In [8]:
TEXT = MaskedBatchField(batch_first=True)
train, dev, test = datasets.IWSLT.splits(('.de', '.en'), (TEXT, TEXT))
TEXT.build_vocab(train)
dev_iter = data.Iterator(dev, batch_size=4, device=-1, shuffle=False, sort=True, repeat=False)

In [9]:
batch = list(dev_iter)[5]
print(batch.src)
print(batch.trg)

MaskedBatch (True,) with:
 data: Variable containing:
    60     17   1477
   437   1876   5524
 16570     24  53093
 24199    943      0
[torch.LongTensor of size 4x3]

 mask: Variable containing:
 1  1  1
 1  1  1
 1  1  1
 1  1  0
[torch.LongTensor of size 4x3]

MaskedBatch (True,) with:
 data: Variable containing:
   343    236   1038
   415    473   2138
 26785   2911      0
 82312    350    298
[torch.LongTensor of size 4x3]

 mask: Variable containing:
 1  1  1
 1  1  1
 1  1  0
 1  1  1
[torch.LongTensor of size 4x3]



In [10]:
embed = nn.Embedding(len(TEXT.vocab), 8)
src = embed(batch.src)
trg = embed(batch.trg)
print(src)
print(trg)

MaskedBatch (True, False) with:
 data: Variable containing:
(0 ,.,.) = 
  2.4444 -3.2256 -0.7363 -0.8431 -1.6497 -0.0764  0.0299  0.8066
 -0.3451  1.0012 -0.8611 -0.0799  1.1851 -0.9093  0.1243  0.1627
  0.0059  0.2839 -0.7120 -1.3900  1.3094  0.3581  0.8182  0.5630

(1 ,.,.) = 
  1.3210  1.5820  0.7448 -1.4880 -0.2247  0.7374  0.8837 -1.4849
 -1.5257 -0.5993 -0.8784  0.7393  1.0647 -1.3780  1.9431 -0.3812
 -0.4972  0.7124  0.8384 -0.2372  1.0253 -0.4328  2.1098  1.0806

(2 ,.,.) = 
  0.6512  1.4799 -1.3044 -0.7127  0.5442 -0.1782  0.8608 -0.3888
 -0.9716 -0.7143  0.1118 -0.5444 -0.9741  1.0346 -0.8547  0.7575
  0.7516 -1.2743 -0.2951 -0.4151 -1.6204  0.9836  0.5326  1.1599

(3 ,.,.) = 
 -1.8505 -1.3849  0.4641 -0.6601  0.5085  1.1711  1.7459  1.6988
  0.4792  0.4383 -0.1869 -0.4333 -0.5630 -1.8269  0.2497  0.2229
 -0.0000  0.0000 -0.0000  0.0000  0.0000 -0.0000 -0.0000 -0.0000
[torch.FloatTensor of size 4x3x8]

 mask: Variable containing:
(0 ,.,.) = 
  1
  1
  1

(1 ,.,.) = 
  1
  1
 

In [11]:
alphas = src @ trg.transpose(1, 2)
print(alphas)

MaskedBatch (True, True) with:
 data: Variable containing:
(0 ,.,.) = 
 -1.1876  2.7584 -1.2024
  1.7051 -0.9500  5.3445
 -2.7143 -0.1355  4.4059

(1 ,.,.) = 
 -2.4840 -0.5720 -3.8392
 -0.4141  0.3672 -2.9099
 -1.6834  4.2161 -2.5391

(2 ,.,.) = 
  0.6744 -1.9318  0.0000
  1.7707  0.3603  0.0000
  0.6322 -3.8415  0.0000

(3 ,.,.) = 
  4.9061 -3.5177 -8.0516
  3.4813  2.2154 -0.1299
  0.0000  0.0000  0.0000
[torch.FloatTensor of size 4x3x3]

 mask: Variable containing:
(0 ,.,.) = 
  1  1  1
  1  1  1
  1  1  1

(1 ,.,.) = 
  1  1  1
  1  1  1
  1  1  1

(2 ,.,.) = 
  1  1  0
  1  1  0
  1  1  0

(3 ,.,.) = 
  1  1  1
  1  1  1
  0  0  0
[torch.FloatTensor of size 4x3x3]



In [12]:
attns = F.softmax(alphas, -1)
print(attns)

MaskedBatch (True, False) with:
 data: Variable containing:
(0 ,.,.) = 
  0.0186  0.9630  0.0183
  0.0255  0.0018  0.9727
  0.0008  0.0105  0.9887

(1 ,.,.) = 
  0.1246  0.8432  0.0321
  0.3061  0.6686  0.0252
  0.0027  0.9961  0.0012

(2 ,.,.) = 
  0.9313  0.0687  0.0000
  0.8038  0.1962  0.0000
  0.9887  0.0113  0.0000

(3 ,.,.) = 
  0.9998  0.0002  0.0000
  0.7640  0.2154  0.0206
  0.0000  0.0000  0.0000
[torch.FloatTensor of size 4x3x3]

 mask: Variable containing:
(0 ,.,.) = 
  1
  1
  1

(1 ,.,.) = 
  1
  1
  1

(2 ,.,.) = 
  1
  1
  1

(3 ,.,.) = 
  1
  1
  0
[torch.FloatTensor of size 4x3x1]

