In [1]:
from bopt.core.tokenizer import Tokenizer
import math

In [2]:
vocab = ["[PAD]", "h","a","t","e", "at", "hat", "ate", "hate", "b", "ab"]

In [3]:
weight = {
    "[PAD]":0.0,
    "h":0.0,
    "a":0.0,
    "t":0.0,
    "e":math.log(1.0),
    "at": math.log(2.0), 
    "hat": math.log(3.0), 
    "ate": math.log(4.0), 
    "hate": math.log(5.0),
    "b": 0.0,
    "ab": 0.0,
}

In [4]:
tokenizer = Tokenizer(vocab=vocab, weights=weight, continuing_subword_prefix=None)

In [5]:
M = 4
e,la, c = tokenizer.forward(*tokenizer.encode_batch(["hat", "hate",  "ab", "at"], M=M))

In [6]:
for i,x in enumerate(["h", "ha", "hat", "hate", "a", "at", "ate", "t", "te", "e"]):
    for j, y in enumerate(["h", "ha", "hat", "hate", "a", "at", "ate", "t", "te", "e"]):
        attention = math.exp(c[1,i,j].item())
        attention_str = str(round(attention,2)) if attention > 0 else ""
        if i == j:
            print(f"({x:>4} -> {y:>4}) {attention_str:>4}", end=" ")
        else:
            print(f"[{x:>4} -> {y:>4}] {attention_str:>4}", end=" ")
    print()

(   h ->    h)  1.0 [   h ->   ha]      [   h ->  hat]      [   h -> hate]      [   h ->    a] 0.14 [   h ->   at] 0.29 [   h ->  ate] 0.57 [   h ->    t] 0.14 [   h ->   te]      [   h ->    e] 0.43 
[  ha ->    h]      (  ha ->   ha)      [  ha ->  hat]      [  ha -> hate]      [  ha ->    a]      [  ha ->   at]      [  ha ->  ate]      [  ha ->    t]      [  ha ->   te]      [  ha ->    e]      
[ hat ->    h]      [ hat ->   ha]      ( hat ->  hat)  1.0 [ hat -> hate]      [ hat ->    a]      [ hat ->   at]      [ hat ->  ate]      [ hat ->    t]      [ hat ->   te]      [ hat ->    e]  1.0 
[hate ->    h]      [hate ->   ha]      [hate ->  hat]      (hate -> hate)  1.0 [hate ->    a]      [hate ->   at]      [hate ->  ate]      [hate ->    t]      [hate ->   te]      [hate ->    e]      
[   a ->    h]  1.0 [   a ->   ha]      [   a ->  hat]      [   a -> hate]      (   a ->    a)  1.0 [   a ->   at]      [   a ->  ate]      [   a ->    t]  1.0 [   a ->   te]      [   a ->    e]  

In [7]:
M = 3
e,la, c = tokenizer.forward(*tokenizer.encode_batch(["hat", "hate",  "ab", "at"], M=M))

In [8]:
for i,x in enumerate(["h", "ha", "hat", "a", "at", "ate", "t", "te", "e"]):
    for j, y in enumerate(["h", "ha", "hat", "a", "at", "ate", "t", "te", "e"]):
        attention = math.exp(c[1,i,j].item())
        attention_str = str(round(attention,2)) if attention > 0 else ""
        if i == j:
            print(f"({x:>4} -> {y:>4}) {attention_str:>4}", end=" ")
        else:
            print(f"[{x:>4} -> {y:>4}] {attention_str:>4}", end=" ")
    print()

(   h ->    h)  1.0 [   h ->   ha]      [   h ->  hat]      [   h ->    a] 0.14 [   h ->   at] 0.29 [   h ->  ate] 0.57 [   h ->    t] 0.14 [   h ->   te]      [   h ->    e] 0.43 
[  ha ->    h]      (  ha ->   ha)      [  ha ->  hat]      [  ha ->    a]      [  ha ->   at]      [  ha ->  ate]      [  ha ->    t]      [  ha ->   te]      [  ha ->    e]      
[ hat ->    h]      [ hat ->   ha]      ( hat ->  hat)  1.0 [ hat ->    a]      [ hat ->   at]      [ hat ->  ate]      [ hat ->    t]      [ hat ->   te]      [ hat ->    e]  1.0 
[   a ->    h]  1.0 [   a ->   ha]      [   a ->  hat]      (   a ->    a)  1.0 [   a ->   at]      [   a ->  ate]      [   a ->    t]  1.0 [   a ->   te]      [   a ->    e]  1.0 
[  at ->    h]  1.0 [  at ->   ha]      [  at ->  hat]      [  at ->    a]      (  at ->   at)  1.0 [  at ->  ate]      [  at ->    t]      [  at ->   te]      [  at ->    e]  1.0 
[ ate ->    h]  1.0 [ ate ->   ha]      [ ate ->  hat]      [ ate ->    a]      [ ate ->   at] 

In [9]:
tokenizer.zero_grad()
e[0].backward(retain_graph=True)
tokenizer.weights.weight.grad

tensor([[ 0.0000],
        [ 0.1591],
        [ 0.1301],
        [ 0.1301],
        [ 0.0000],
        [ 0.0145],
        [-0.0530],
        [ 0.0000],
        [ 0.0000],
        [ 0.0000],
        [ 0.0000]])

In [10]:
tokenizer.zero_grad()
e[1].backward(retain_graph=True)
tokenizer.weights.weight.grad

tensor([[ 0.0000],
        [ 0.0228],
        [ 0.1023],
        [ 0.1023],
        [ 0.1454],
        [ 0.0330],
        [-0.0076],
        [-0.0364],
        [ 0.0000],
        [ 0.0000],
        [ 0.0000]])

In [11]:
tokenizer.zero_grad()
c.exp()[1][0][4].backward(retain_graph=True)
tokenizer.weights.weight.grad

tensor([[ 0.0000],
        [ 0.0000],
        [-0.0408],
        [-0.0408],
        [ 0.1633],
        [ 0.1020],
        [ 0.0000],
        [-0.0408],
        [ 0.0000],
        [ 0.0000],
        [ 0.0000]])

In [12]:
M = 2
e,la, c =  tokenizer.forward(*tokenizer.encode_batch(["hat", "hate",  "ab", "at"], M=M))

In [13]:
for i,x in enumerate(["h", "ha", "a", "at", "t", "te", "e"]):
    for j, y in enumerate(["h", "ha", "a", "at", "t", "te", "e"]):
        attention = math.exp(c[1,i,j].item())
        attention_str = str(round(attention,2)) if attention > 0 else ""
        if i == j:
            print(f"({x:>4} -> {y:>4}) {attention_str:>4}", end=" ")
        else:
            print(f"[{x:>4} -> {y:>4}] {attention_str:>4}", end=" ")
    print()

(   h ->    h)  1.0 [   h ->   ha]      [   h ->    a] 0.33 [   h ->   at] 0.67 [   h ->    t] 0.33 [   h ->   te]      [   h ->    e]  1.0 
[  ha ->    h]      (  ha ->   ha)      [  ha ->    a]      [  ha ->   at]      [  ha ->    t]      [  ha ->   te]      [  ha ->    e]      
[   a ->    h]  1.0 [   a ->   ha]      (   a ->    a)  1.0 [   a ->   at]      [   a ->    t]  1.0 [   a ->   te]      [   a ->    e]  1.0 
[  at ->    h]  1.0 [  at ->   ha]      [  at ->    a]      (  at ->   at)  1.0 [  at ->    t]      [  at ->   te]      [  at ->    e]  1.0 
[   t ->    h]  1.0 [   t ->   ha]      [   t ->    a]  1.0 [   t ->   at]      (   t ->    t)  1.0 [   t ->   te]      [   t ->    e]  1.0 
[  te ->    h]      [  te ->   ha]      [  te ->    a]      [  te ->   at]      [  te ->    t]      (  te ->   te)      [  te ->    e]      
[   e ->    h]  1.0 [   e ->   ha]      [   e ->    a] 0.33 [   e ->   at] 0.67 [   e ->    t] 0.33 [   e ->   te]      (   e ->    e)  1.0 


In [14]:
M = 1
e,la, c = tokenizer.forward(*tokenizer.encode_batch(["hat", "hate",  "ab", "at"], M=M))

In [15]:
for i,x in enumerate(["h", "a", "t", "e"]):
    for j, y in enumerate(["h", "a","t", "e"]):
        attention = math.exp(c[1,i,j].item())
        attention_str = str(round(attention,2)) if attention > 0 else ""
        if i == j:
            print(f"({x:>4} -> {y:>4}) {attention_str:>4}", end=" ")
        else:
            print(f"[{x:>4} -> {y:>4}] {attention_str:>4}", end=" ")
    print()

(   h ->    h)  1.0 [   h ->    a]  1.0 [   h ->    t]  1.0 [   h ->    e]  1.0 
[   a ->    h]  1.0 (   a ->    a)  1.0 [   a ->    t]  1.0 [   a ->    e]  1.0 
[   t ->    h]  1.0 [   t ->    a]  1.0 (   t ->    t)  1.0 [   t ->    e]  1.0 
[   e ->    h]  1.0 [   e ->    a]  1.0 [   e ->    t]  1.0 (   e ->    e)  1.0 


In [16]:
e

tensor([0., 0., 0., 0.], grad_fn=<SqueezeBackward1>)

In [17]:
s = sum(- x/15 * math.log(x/15) for x in [1,2,3,4,5])
print(s)

1.4897503188505912


In [18]:
-(1 + math.log(2/3))/9 * (1*3-1*2)  -(1 + math.log(1/3)) / 9 * (0*3 - 1 * 1)

-0.07701635339554949

In [19]:
-(1 + math.log(1/3))/9 * (1*3-1*1)-(1 + math.log(2/3))/9 * (0*3-1*2)

0.15403270679109898

In [20]:
6 * 15 * 2 * 1

180

In [21]:
# h ha hat hath a at ath atha t th tha that h ha hat hate a at ate t te e
# 

In [22]:
M = 4
e,la, c = tokenizer.forward(*tokenizer.encode_packed_batch([["hat", "hate", "ab", "at"]], M=M))

In [23]:
for i,x in enumerate(["h", "ha", "hat", "hate", "a", "at", "ate", "atea", "t", "te", "tea", "teab", "e"]):
    for j, y in enumerate(["h", "ha", "hat", "hate", "a", "at", "ate", "atea", "t", "te", "tea", "teab", "e"]):
        attention = math.exp(c[0,12+i,12+j].item())
        attention_str = str(round(attention,2)) if attention > 0 else ""
        if i == j:
            print(f"({x:>4} -> {y:>4}) {attention_str:>4}", end=" ")
        else:
            print(f"[{x:>4} -> {y:>4}] {attention_str:>4}", end=" ")
    print()

(   h ->    h)  1.0 [   h ->   ha]      [   h ->  hat]      [   h -> hate]      [   h ->    a] 0.14 [   h ->   at] 0.29 [   h ->  ate] 0.57 [   h -> atea]      [   h ->    t] 0.14 [   h ->   te]      [   h ->  tea]      [   h -> teab]      [   h ->    e] 0.43 
[  ha ->    h]      (  ha ->   ha)      [  ha ->  hat]      [  ha -> hate]      [  ha ->    a]      [  ha ->   at]      [  ha ->  ate]      [  ha -> atea]      [  ha ->    t]      [  ha ->   te]      [  ha ->  tea]      [  ha -> teab]      [  ha ->    e]      
[ hat ->    h]      [ hat ->   ha]      ( hat ->  hat)  1.0 [ hat -> hate]      [ hat ->    a]      [ hat ->   at]      [ hat ->  ate]      [ hat -> atea]      [ hat ->    t]      [ hat ->   te]      [ hat ->  tea]      [ hat -> teab]      [ hat ->    e]  1.0 
[hate ->    h]      [hate ->   ha]      [hate ->  hat]      (hate -> hate)  1.0 [hate ->    a]      [hate ->   at]      [hate ->  ate]      [hate -> atea]      [hate ->    t]      [hate ->   te]      [hate ->  tea]   

In [24]:
M = 4
e,la, c = tokenizer.forward(*tokenizer.encode_batch(["hat", "hate",  "ab", "at"], M=M))

In [25]:
for i,x in enumerate(["h", "ha", "hat", "hate", "a", "at", "ate", "t", "te", "e"]):
    for j, y in enumerate(["h", "ha", "hat", "hate", "a", "at", "ate", "t", "te", "e"]):
        attention = math.exp(c[1,i,j].item())
        attention_str = str(round(attention,2)) if attention > 0 else ""
        if i == j:
            print(f"({x:>4} -> {y:>4}) {attention_str:>4}", end=" ")
        else:
            print(f"[{x:>4} -> {y:>4}] {attention_str:>4}", end=" ")
    print()

(   h ->    h)  1.0 [   h ->   ha]      [   h ->  hat]      [   h -> hate]      [   h ->    a] 0.14 [   h ->   at] 0.29 [   h ->  ate] 0.57 [   h ->    t] 0.14 [   h ->   te]      [   h ->    e] 0.43 
[  ha ->    h]      (  ha ->   ha)      [  ha ->  hat]      [  ha -> hate]      [  ha ->    a]      [  ha ->   at]      [  ha ->  ate]      [  ha ->    t]      [  ha ->   te]      [  ha ->    e]      
[ hat ->    h]      [ hat ->   ha]      ( hat ->  hat)  1.0 [ hat -> hate]      [ hat ->    a]      [ hat ->   at]      [ hat ->  ate]      [ hat ->    t]      [ hat ->   te]      [ hat ->    e]  1.0 
[hate ->    h]      [hate ->   ha]      [hate ->  hat]      (hate -> hate)  1.0 [hate ->    a]      [hate ->   at]      [hate ->  ate]      [hate ->    t]      [hate ->   te]      [hate ->    e]      
[   a ->    h]  1.0 [   a ->   ha]      [   a ->  hat]      [   a -> hate]      (   a ->    a)  1.0 [   a ->   at]      [   a ->  ate]      [   a ->    t]  1.0 [   a ->   te]      [   a ->    e]  

In [26]:
4 * 8 + 3 + 2 + 1

38