In [4]:
from transformers import AutoModel, AutoTokenizer, T5EncoderModel
from torch import nn
import torch

In [5]:
# Model option
# roberta-large, t5-large
model_option = 't5-large' 

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_option)

if 't5' in model_option:
    model = T5EncoderModel.from_pretrained(model_option)
else:
    model = AutoModel.from_pretrained(model_option)


For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-large automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
Some weights of the model checkpoint at t5-large were not used when initializing T5EncoderModel: ['decoder.block.5.layer.2.DenseReluDense.wi.weight', 'decoder.block.14.layer.2.DenseReluDense.wi.weight', 'decoder.block.1.layer.1.layer_norm.weight', 'decoder.block.10.layer.0.SelfAttention.q.weight', 'decoder.block.4.layer.0.SelfAttention.q.weight', 'decoder.block.9.layer.0.SelfAttention.k.weight', 'decoder.block.17.layer.2.DenseReluDense.wo.weight', 'decoder.block.12.layer.2.DenseReluDense.wi.weight', 'decoder.block.18.layer.0.SelfAttention.k.weight', 'decoder.block.20.layer.1.EncDecAttention

Some weights of T5EncoderModel were not initialized from the model checkpoint at t5-large and are newly initialized: ['encoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
cities = ["Belfast", "London", "Paris", "Berlin", "Madrid", "Dublin", "Lisbon", "Rome", "Moscow", "New Delhi", "Seoul"]

In [17]:
if 't5' in model_option:
    encoded_input = tokenizer(cities, return_tensors='pt', padding=True)
else:
    encoded_input = tokenizer(cities, return_tensors='pt', padding=True).input_ids

In [18]:
output = model(**encoded_input)

In [19]:
output.last_hidden_state.shape

torch.Size([11, 3, 1024])

In [20]:
# def cosine_similarity(last_hidden_layer):
#     cls_dim = last_hidden_layer[:,0,:]
#     a_norm = cls_dim / cls_dim.norm(dim=1)[:, None]
#     b_norm = cls_dim / cls_dim.norm(dim=1)[:, None]
#     res = torch.mm(a_norm, b_norm.transpose(0,1))
    
#     return res

# def dot_similarity(last_hidden_layer):
#     cls_dim = last_hidden_layer[:,0,:]
#     a_norm = cls_dim
#     b_norm = cls_dim
#     res = torch.mm(a_norm, b_norm.transpose(0,1))
    
#     return res

def get_smilarity(last_hidden_layer, option='consine'):
    cls_dim = last_hidden_layer[:,0,:]
    if option == 'consine':
        a_norm = cls_dim / cls_dim.norm(dim=1)[:, None]
        b_norm = cls_dim / cls_dim.norm(dim=1)[:, None]
    else:
        a_norm = cls_dim
        b_norm = cls_dim
        
    return torch.mm(a_norm, b_norm.transpose(0,1))

In [26]:
# option for similarity
# 'consine', 'dot'
similarity = get_smilarity(output.last_hidden_state, option='consine')

In [27]:
from pprint import pprint

In [28]:
a = similarity.detach().numpy()

In [29]:
a.shape

(11, 11)

In [30]:
# for copy and paste
print('[')
for row in a:
    out = ', '.join([str(value) for value in row])
    out = '[' + out + '],'
    print(out)
print(']')

[
[1.0000001, 0.46194863, 0.40122962, 0.41659728, 0.2945966, 0.4964512, 0.4321043, 0.2867211, 0.38526952, 0.18646638, 0.39866138],
[0.46194863, 1.0000002, 0.72997266, 0.7026277, 0.57633257, 0.793318, 0.7777086, 0.5779804, 0.70575714, 0.3023165, 0.6565156],
[0.40122962, 0.72997266, 0.99999994, 0.73374236, 0.70930564, 0.7330529, 0.71778286, 0.67648643, 0.7167645, 0.30729184, 0.70543003],
[0.41659728, 0.7026277, 0.73374236, 0.9999999, 0.6328578, 0.6856353, 0.680172, 0.6389555, 0.7365085, 0.4158886, 0.69414234],
[0.2945966, 0.57633257, 0.70930564, 0.6328578, 1.0000001, 0.59341633, 0.6930702, 0.6433421, 0.65286577, 0.24420226, 0.62368095],
[0.4964512, 0.793318, 0.7330529, 0.6856353, 0.59341633, 1.0000002, 0.7765486, 0.5444765, 0.6986886, 0.35437268, 0.682055],
[0.4321043, 0.7777086, 0.71778286, 0.680172, 0.6930702, 0.7765486, 0.9999998, 0.615153, 0.72757876, 0.26245904, 0.66718864],
[0.2867211, 0.5779804, 0.67648643, 0.6389555, 0.6433421, 0.5444765, 0.615153, 1.0, 0.65519696, 0.2634263, 0.5

In [None]:
copy = [
[1003.59644, 1003.1952, 1003.17896, 1003.2844, 1003.25244, 1003.3141, 1003.18945, 1003.1713, 1003.128, 1003.2482, 1003.3139],
[1003.1952, 1003.54205, 1003.4163, 1003.29974, 1003.26624, 1003.2899, 1003.2155, 1003.1977, 1003.3356, 1003.32477, 1003.3151],
[1003.17896, 1003.4163, 1003.5079, 1003.3141, 1003.27203, 1003.2915, 1003.2136, 1003.2303, 1003.31757, 1003.2936, 1003.2991],
[1003.2844, 1003.29974, 1003.3141, 1003.54095, 1003.32935, 1003.416, 1003.23773, 1003.2526, 1003.23535, 1003.32446, 1003.37195],
[1003.25244, 1003.26624, 1003.27203, 1003.32935, 1003.5245, 1003.33246, 1003.27203, 1003.2493, 1003.1965, 1003.32025, 1003.3255],
[1003.3141, 1003.2899, 1003.2915, 1003.416, 1003.33246, 1003.56146, 1003.21686, 1003.24695, 1003.2071, 1003.376, 1003.38165],
[1003.18945, 1003.2155, 1003.2136, 1003.23773, 1003.27203, 1003.21686, 1003.53204, 1003.1882, 1003.0486, 1003.18634, 1003.1883],
[1003.1713, 1003.1977, 1003.2303, 1003.2526, 1003.2493, 1003.24695, 1003.1882, 1003.5011, 1003.13947, 1003.2319, 1003.2338],
[1003.128, 1003.3357, 1003.31775, 1003.2354, 1003.1963, 1003.2071, 1003.0486, 1003.1395, 1003.49097, 1003.2679, 1003.24945],
[1003.2482, 1003.325, 1003.2936, 1003.3245, 1003.32007, 1003.3758, 1003.18634, 1003.23157, 1003.2679, 1003.57056, 1003.3897],
[1003.31384, 1003.3151, 1003.2992, 1003.372, 1003.3257, 1003.3816, 1003.18823, 1003.2339, 1003.24945, 1003.3897, 1003.5754],
]

In [None]:
copy