In [1]:
import torch 
import torch.nn as nn 


In [2]:
print_order = ['the', 'cat', 'is', 'on', 'a', 'chair'] 
sequence = [{print_order[i]} for i in range(len(print_order))]
sequence # list[set[str]]

[{'the'}, {'cat'}, {'is'}, {'on'}, {'a'}, {'chair'}]

In [3]:
sliding_window_size = 3 

def sliding_win_attention(seq:list[set[str]], w:int=3):
    seq_len = len(seq)
    attn_score : list[list[set]] = [[None for _ in range(seq_len)] for _ in range(seq_len)] #TxT


    for i, q_tokens_set in enumerate(seq):

        for j, k_tokens_set in enumerate(seq):

            if j>i:
                continue #upper triangle 
            if i-j >= w:
                continue 

            attention = set()
            attention.update(q_tokens_set)
            attention.update(k_tokens_set)

            attn_score[i][j] = attention # list of set 

    return attn_score


def multiple_by_v(attention_scores: list[list[set]], v_sequence: list[set[str]]) -> list[set[str]]:
    seq_len = len(v_sequence)
    result = [set() for _ in range(seq_len)]
    for i in range(seq_len):
        for j in range(seq_len):
            attention = attention_scores[i][j]
            v = v_sequence[j]
            r = result[i]
            # Add all the tokens in the attention (if not None) to r
            if attention is not None:
                # Add all the tokens in v to r
                r.update(v)
                r.update(attention)
    return result 
    


In [4]:
def print_sequence(seq: list[set[str]]):
    for i, tokens_set in enumerate(seq):
        print(f'{i}: {sorted(tokens_set, key=lambda x: print_order.index(x))}')

print_sequence(sequence)

0: ['the']
1: ['cat']
2: ['is']
3: ['on']
4: ['a']
5: ['chair']


In [5]:
attn = sliding_win_attention(sequence, 3)
attn
#print(multiple_by_v(attn, sequence))


[[{'the'}, None, None, None, None, None],
 [{'cat', 'the'}, {'cat'}, None, None, None, None],
 [{'is', 'the'}, {'cat', 'is'}, {'is'}, None, None, None],
 [None, {'cat', 'on'}, {'is', 'on'}, {'on'}, None, None],
 [None, None, {'a', 'is'}, {'a', 'on'}, {'a'}, None],
 [None, None, None, {'chair', 'on'}, {'a', 'chair'}, {'chair'}]]

In [7]:
multiple_by_v(attn, sequence)


[{'the'},
 {'cat', 'the'},
 {'cat', 'is', 'the'},
 {'cat', 'is', 'on'},
 {'a', 'is', 'on'},
 {'a', 'chair', 'on'}]

In [10]:
def print_attention(attn_score:list[list[set[str]]]):
    for i, row in enumerate(attn_score):
        for j, attn in enumerate(row):
            if attn is None:
                print("None", end='\t')
            else:
                print(f'{sorted(attn, key=lambda x: print_order.index(x))}', end='\t')
        print()

def print_layer(input:list[set[str]], layer_num:int) -> list[set[str]]:
    print(f'Layer {layer_num} input: ')
    print_sequence(input)
    attn_scores = sliding_win_attention(input)
    print()
    print(f'Layer {layer_num} atention scores: ')
    print_attention(attn_scores)
    output = multiple_by_v(attn_scores, sequence)
    print()
    print(f'Layer {layer_num} output: ')
    print_sequence(output)
    return output 


In [9]:
print_attention(attn)

['the']	None	None	None	None	None	
['the', 'cat']	['cat']	None	None	None	None	
['the', 'is']	['cat', 'is']	['is']	None	None	None	
None	['cat', 'on']	['is', 'on']	['on']	None	None	
None	None	['is', 'a']	['on', 'a']	['a']	None	
None	None	None	['on', 'chair']	['a', 'chair']	['chair']	


In [11]:
output_layer_1 = print_layer(sequence, 1)

Layer 1 input: 
0: ['the']
1: ['cat']
2: ['is']
3: ['on']
4: ['a']
5: ['chair']

Layer 1 atention scores: 
['the']	None	None	None	None	None	
['the', 'cat']	['cat']	None	None	None	None	
['the', 'is']	['cat', 'is']	['is']	None	None	None	
None	['cat', 'on']	['is', 'on']	['on']	None	None	
None	None	['is', 'a']	['on', 'a']	['a']	None	
None	None	None	['on', 'chair']	['a', 'chair']	['chair']	

Layer 1 output: 
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['cat', 'is', 'on']
4: ['is', 'on', 'a']
5: ['on', 'a', 'chair']


In [12]:
output_layer_2 = print_layer(output_layer_1, 2)

Layer 2 input: 
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['cat', 'is', 'on']
4: ['is', 'on', 'a']
5: ['on', 'a', 'chair']

Layer 2 atention scores: 
['the']	None	None	None	None	None	
['the', 'cat']	['the', 'cat']	None	None	None	None	
['the', 'cat', 'is']	['the', 'cat', 'is']	['the', 'cat', 'is']	None	None	None	
None	['the', 'cat', 'is', 'on']	['the', 'cat', 'is', 'on']	['cat', 'is', 'on']	None	None	
None	None	['the', 'cat', 'is', 'on', 'a']	['cat', 'is', 'on', 'a']	['is', 'on', 'a']	None	
None	None	None	['cat', 'is', 'on', 'a', 'chair']	['is', 'on', 'a', 'chair']	['on', 'a', 'chair']	

Layer 2 output: 
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['the', 'cat', 'is', 'on']
4: ['the', 'cat', 'is', 'on', 'a']
5: ['cat', 'is', 'on', 'a', 'chair']


In [13]:
output_layer_3 = print_layer(output_layer_2, 3)

Layer 3 input: 
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['the', 'cat', 'is', 'on']
4: ['the', 'cat', 'is', 'on', 'a']
5: ['cat', 'is', 'on', 'a', 'chair']

Layer 3 atention scores: 
['the']	None	None	None	None	None	
['the', 'cat']	['the', 'cat']	None	None	None	None	
['the', 'cat', 'is']	['the', 'cat', 'is']	['the', 'cat', 'is']	None	None	None	
None	['the', 'cat', 'is', 'on']	['the', 'cat', 'is', 'on']	['the', 'cat', 'is', 'on']	None	None	
None	None	['the', 'cat', 'is', 'on', 'a']	['the', 'cat', 'is', 'on', 'a']	['the', 'cat', 'is', 'on', 'a']	None	
None	None	None	['the', 'cat', 'is', 'on', 'a', 'chair']	['the', 'cat', 'is', 'on', 'a', 'chair']	['cat', 'is', 'on', 'a', 'chair']	

Layer 3 output: 
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['the', 'cat', 'is', 'on']
4: ['the', 'cat', 'is', 'on', 'a']
5: ['the', 'cat', 'is', 'on', 'a', 'chair']
