In [3]:
# !pip install git+https://github.com/huggingface/transformers -q
# !pip install sentencepiece sacremoses -q

In [1]:
import torch
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

In [2]:
import torch.quantization
import torch.nn as nn

quantized_model = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)

In [3]:
text = ["Patrick O'Neill, the former chief creative officer at Theranos from 2014 to 2017, said he has removed much of his promotional work for the company from his professional website."]

In [4]:
# translate Chinese to English
tokenizer.src_lang = "en_XX"
encoded_en = tokenizer(text, return_tensors="pt")

# print("encoded tokens: ", encoded_en)
generated_tokens = model.generate(**encoded_en, forced_bos_token_id=tokenizer.lang_code_to_id["zh_CN"], max_new_tokens=200)
generated_tokens_q = quantized_model.generate(**encoded_en, forced_bos_token_id=tokenizer.lang_code_to_id["zh_CN"], max_new_tokens=200)

# move generated tokens back to CPU for further processing
# generated_tokens = generated_tokens.cpu()  

# print("generated tokens: ", generated_tokens)
decoded_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
decoded_text_q = tokenizer.batch_decode(generated_tokens_q, skip_special_tokens=True)

In [5]:
print(decoded_text)
print(decoded_text_q)

["帕特里克·奥尼尔(Patrick O'Neill)在2014年至2017年间担任Theranos首席创意总监,他说他已经从自己的专业网站中删除了该公司的宣传工作。"]
["帕特里克·奥尼尔(Patrick O'Neill)表示,他已从自己的专业网站中删除了该公司的宣传工作。"]


In [9]:
import requests

HF_API_URL = "https://api-inference.huggingface.co/models/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
headers = {"Authorization": "YOUR AUTHENTICATION CODE"}

def query(payload):
    retries = 0
    max_retries = 10
    while retries < max_retries:
        try:
            response = requests.post(HF_API_URL, headers=headers, json=payload)
            response.raise_for_status()  # This will raise an exception for HTTP errors
            data = response.json()
            # Check if the returned value is a list consisting of two numbers
            if isinstance(data, list) and len(data) == 2 and all(isinstance(i, (int, float)) for i in data):
                return data
            else:
                print(f'Returned value is not as expected, retrying... ({retries+1}/{max_retries})')
        except (requests.exceptions.RequestException, requests.exceptions.ConnectionError, ValueError) as e:
            print(f'Caught exception: {str(e)}, retrying... ({retries+1}/{max_retries})')
        retries += 1
        time.sleep(5)
    print("Payload is: ", payload)
    raise Exception('Max retries exceeded')

In [10]:
output = query({
    "inputs": {
        "source_sentence": text[0],
        "sentences": [decoded_text[0],decoded_text_q[0]]}
        })

print(output)
print(abs(output[1] - output[0]))

[0.8856440186500549, 0.8737610578536987]
0.011882960796356201


In [12]:
# convert the input sequence to tokens and exclude special tokens
input_tokens = [token for token in tokenizer.convert_ids_to_tokens(encoded_en['input_ids'][0]) if token not in ['<s>', '</s>']]
print("Input Tokens: ", input_tokens)

# convert the output sequence to tokens and exclude special tokens
output_tokens = [token for token in tokenizer.convert_ids_to_tokens(generated_tokens[0]) if token not in ['<s>', '</s>']]
output_tokens_q = [token for token in tokenizer.convert_ids_to_tokens(generated_tokens_q[0]) if token not in ['<s>', '</s>']]

print("\nBase Tokens: ", output_tokens)
print("\nQuantized Tokens: ", output_tokens_q)

Input Tokens:  ['en_XX', '▁Patrick', '▁O', "'", 'Ne', 'ill', ',', '▁the', '▁former', '▁chief', '▁creative', '▁officer', '▁at', '▁The', 'rano', 's', '▁from', '▁2014', '▁to', '▁2017', ',', '▁said', '▁he', '▁has', '▁removed', '▁much', '▁of', '▁his', '▁promotion', 'al', '▁work', '▁for', '▁the', '▁company', '▁from', '▁his', '▁professional', '▁website', '.']

Base Tokens:  ['zh_CN', '▁', '帕', '特', '里', '克', '·', '奥', '尼', '尔', '(', 'Patri', 'ck', '▁O', "'", 'Ne', 'ill', ')', '在', '2014', '年', '至', '2017', '年', '间', '担任', 'The', 'rano', 's', '首席', '创意', '总监', ',', '他说', '他', '已经', '从', '自己的', '专业', '网站', '中', '删除', '了', '该公司', '的', '宣传', '工作', '。']

Quantized Tokens:  ['zh_CN', '▁', '帕', '特', '里', '克', '·', '奥', '尼', '尔', '(', 'Patri', 'ck', '▁O', "'", 'Ne', 'ill', ')', '表示', ',', '他', '已', '从', '自己的', '专业', '网站', '中', '删除', '了', '该公司', '的', '宣传', '工作', '。']


In [13]:
def process_tokens(token_list):
    processed_list = []
    for token in token_list:
        if token.startswith('▁'):  # remove the leading underscore
            token = token[1:]
        processed_list.append(token)
    return processed_list

In [14]:
input_tokens = process_tokens(input_tokens[1:])
output_tokens = process_tokens(output_tokens[1:])
output_tokens_q = process_tokens(output_tokens_q[1:])
print(input_tokens)
print(output_tokens)
print(output_tokens_q)

['Patrick', 'O', "'", 'Ne', 'ill', ',', 'the', 'former', 'chief', 'creative', 'officer', 'at', 'The', 'rano', 's', 'from', '2014', 'to', '2017', ',', 'said', 'he', 'has', 'removed', 'much', 'of', 'his', 'promotion', 'al', 'work', 'for', 'the', 'company', 'from', 'his', 'professional', 'website', '.']
['', '帕', '特', '里', '克', '·', '奥', '尼', '尔', '(', 'Patri', 'ck', 'O', "'", 'Ne', 'ill', ')', '在', '2014', '年', '至', '2017', '年', '间', '担任', 'The', 'rano', 's', '首席', '创意', '总监', ',', '他说', '他', '已经', '从', '自己的', '专业', '网站', '中', '删除', '了', '该公司', '的', '宣传', '工作', '。']
['', '帕', '特', '里', '克', '·', '奥', '尼', '尔', '(', 'Patri', 'ck', 'O', "'", 'Ne', 'ill', ')', '表示', ',', '他', '已', '从', '自己的', '专业', '网站', '中', '删除', '了', '该公司', '的', '宣传', '工作', '。']


In [15]:
def compare_tokens(original, quantized):
    original_indices = {token: [i for i, x in enumerate(original) if x == token] for token in original}
    quantized_indices = {token: [i for i, x in enumerate(quantized) if x == token] for token in quantized}
    
    additional_tokens = {token: quantized_indices[token] for token in quantized_indices if token not in original_indices}
    missing_tokens = {token: original_indices[token] for token in original_indices if token not in quantized_indices}
    
    return additional_tokens, missing_tokens

In [16]:
addit, miss = compare_tokens(output_tokens, output_tokens_q)
print(addit)
print(miss)

{'表示': [17], '已': [20]}
{'在': [17], '2014': [18], '年': [19, 22], '至': [20], '2017': [21], '间': [23], '担任': [24], 'The': [25], 'rano': [26], 's': [27], '首席': [28], '创意': [29], '总监': [30], '他说': [32], '已经': [34]}


In [17]:
from simalign import SentenceAligner

# making an instance of our model.
# You can specify the embedding model and all alignment settings in the constructor.
myaligner = SentenceAligner(model="bert", token_type="bpe", matching_methods="mai")

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
2023-07-16 00:55:16,231 - simalign.simalign - INFO - Initialized the EmbeddingLoader with model: bert-base-multilingual-cased


In [18]:
# The source and target sentences should be tokenized to words.
src_sentence = input_tokens
trg_sentence = output_tokens
trg_sentence_q = output_tokens_q


# The output is a dictionary with different matching methods.
# Each method has a list of pairs indicating the indexes of aligned words (The alignments are zero-indexed).
alignments = myaligner.get_word_aligns(src_sentence, trg_sentence)
alignments_q = myaligner.get_word_aligns(src_sentence, trg_sentence_q)

src_sent_len = len(src_sentence)
base_align = alignments['mwmf']
quant_align = alignments_q['mwmf']
print("Baseline: (src_index, trg_index) = ", base_align)
print("Quantized: (src_index, trg_index) = ", quant_align)

# print(addit)
# print(output_tokens_q)
# print(miss)
# print(output_tokens)
# print(src_sentence)
# for matching_method in alignments:
#     print(matching_method, ":", alignments[matching_method])

# Expected output:
# mwmf (Match): [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
# inter (ArgMax): [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
# itermax (IterMax): [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]

Baseline: (src_index, trg_index) =  [(0, 4), (1, 12), (2, 13), (3, 14), (4, 15), (5, 16), (6, 24), (7, 28), (8, 28), (9, 29), (10, 30), (11, 24), (12, 25), (13, 26), (14, 27), (15, 17), (16, 18), (17, 20), (18, 21), (19, 31), (20, 32), (21, 32), (22, 34), (23, 40), (24, 41), (25, 43), (26, 33), (27, 44), (28, 40), (29, 45), (30, 45), (31, 42), (32, 42), (33, 35), (34, 36), (35, 37), (36, 38), (37, 46)]
Quantized: (src_index, trg_index) =  [(0, 4), (1, 12), (2, 13), (3, 14), (4, 15), (5, 9), (6, 1), (7, 25), (8, 23), (9, 30), (10, 24), (11, 22), (12, 28), (13, 10), (13, 11), (14, 11), (15, 17), (16, 8), (17, 5), (18, 18), (19, 16), (20, 17), (21, 19), (22, 20), (23, 26), (24, 27), (25, 29), (26, 22), (27, 30), (28, 26), (29, 31), (30, 31), (31, 28), (32, 28), (33, 21), (34, 22), (35, 23), (36, 24), (37, 32)]


In [19]:
# additional is quantized output token index map; missing is baseline output token index map.
def build_hotmap(bmap, qmap, add, miss, src_len):
    hotmap = {i: 0 for i in range(src_len)}
    for token in add:
        for src_index, trg_index in qmap:
            if trg_index in add[token]:
                hotmap[src_index] += 1
                
    for token in miss:
        for src_index, trg_index in bmap:
            if trg_index in miss[token]:
                hotmap[src_index] += 1
    
    return hotmap

In [20]:
hot_map = build_hotmap(base_align, quant_align, addit, miss, src_sent_len)
print(hot_map)
# for key in hot_map:
#     if hot_map[key] !=0:
#         print(src_sentence[key])

{0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 1, 7: 1, 8: 1, 9: 1, 10: 1, 11: 1, 12: 1, 13: 1, 14: 1, 15: 2, 16: 1, 17: 1, 18: 1, 19: 0, 20: 2, 21: 1, 22: 2, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0, 36: 0, 37: 0}


In [23]:
import random
import numpy as np

stop_words = [
    'a', 'an', 'the', 'and', 'or', 'but', 'if', 'while', 'as', 'that', 'this',
    'these', 'those', 'to', 'for', 'with', 'at', 'from', 'by', 'on', 'off', 'of',
    'into', 'over', 'under', 'above', 'below', 'is', 'be', 'am', 'are', 'was',
    'were', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'can',
    'could', 'shall', 'should', 'will', 'would', 'might', 'must', 'it', 'its',
    'it\'s', 'he', 'his', 'she', 'her', 'hers', 'they', 'their', 'theirs', 'you',
    'your', 'yours', 'we', 'our', 'ours', 'in', 'out', 'through', 'because',
    'while', 'during', 'before', 'after', 'about', 'against', 'between', 'among',
    'into', 'through', 'during', 'before', 'after', 'above', 'below', 'to',
    'from', 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again',
    'further', 'then', 'once', 'here', 'there', 'when', 'where', 'why', 'how',
    'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such',
    'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's'
]

# helper functions
def mask_tokens(hit_map, tokens):
    non_punctuation_indices = [i for i, token in enumerate(tokens) if token.isalnum()]
    weights = [8 if hit_map[i] == 0 else 1/(hit_map[i]+1) for i in non_punctuation_indices]
    normalized_weights = [w / sum(weights) for w in weights]

    num_unref = sum(1 for i in non_punctuation_indices if hit_map[i] == 0)
    num_mask = np.random.randint(max(1, num_unref // 10), 6)
    mask_indices = np.random.choice(non_punctuation_indices, size=num_mask, p=normalized_weights, replace=False)

    # Continue generating mask_indices until we find one that isn't a stop word
    while all(tokens[i] in stop_words for i in mask_indices):
        num_mask = np.random.randint(max(1, num_unref // 10), 6)
        mask_indices = np.random.choice(non_punctuation_indices, size=num_mask, p=normalized_weights, replace=False)

    masked_tokens = tokens.copy()
    for i in mask_indices:
        if masked_tokens[i] not in stop_words:
            masked_tokens[i] = '<fill>'
  
    return ' '.join(masked_tokens)

In [24]:
masked = mask_tokens(hot_map, src_sentence)
masked

"<fill> O ' <fill> ill , the former chief creative officer at The rano s from 2014 to 2017 , said he has removed <fill> of his promotion al work for the company from his professional website ."

In [25]:
from revChatGPT.V1 import Chatbot

chatbot = Chatbot(config={
  "access_token": "YOUR ACCESS TOKEN"})

prefix = "Complete the sentence by filling up the missing information denoted as <fill> in the original sentence: \n"
prompt = prefix + masked

response = ""

for data in chatbot.ask(prompt):
    response = data["message"]

print(response)



John O'Neill, the former chief creative officer at The Rano's from 2014 to 2017, said he has removed all of his promotional work for the company from his professional website.


In [25]:
# import os
# def print_size_of_model(model):
#     torch.save(model.state_dict(), "temp.p")
#     print('Size (MB):', os.path.getsize("temp.p")/1e6)
#     os.remove('temp.p')

# print_size_of_model(model)
# print_size_of_model(quantized_model)

In [14]:
# import time
# torch.set_num_threads(1)

# def time_model_evaluation(model, encoded_data):
#     s = time.time()
#     model.generate(**encoded_data)
#     elapsed = time.time() - s
#     print("elapsed time", elapsed)

# time_model_evaluation(model, encoded_zh)
# time_model_evaluation(quantized_model, encoded_zh)

elapsed time 6.116252660751343
elapsed time 3.0593185424804688
