In [1]:
import torch
import torch.nn as nn

import numpy as np
import pandas as pd
import os
import sys

In [2]:
#!conda install fastbpe -y
#!conda install clang -y
#!pip install clang
#!pip install sacrebleu=="1.2.11"
#!conda remove sacrebleu -y

In [3]:
import fastBPE

import preprocessing.src.code_tokenizer as code_tokenizer
from XLM.src.data.dictionary import Dictionary, BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD
from XLM.src.model import build_model
from XLM.src.utils import AttrDict

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [5]:
MODEL='model/model_1.pth'

model = torch.load(MODEL, map_location=device)

In [6]:
model

{'epoch': 30,
 'n_total_iter': 3255,
 'best_metrics': {},
 'best_stopping_criterion': None,
 'encoder': OrderedDict([('module.position_embeddings.weight',
               tensor([[-0.0095,  0.0503,  0.0038,  ...,  0.0156, -0.0090, -0.0427],
                       [ 0.0224,  0.0274, -0.0191,  ..., -0.0009, -0.0053, -0.0260],
                       [ 0.0108,  0.0206,  0.0008,  ...,  0.0028, -0.0196, -0.0109],
                       ...,
                       [ 0.0151, -0.0050,  0.0265,  ..., -0.0256,  0.0176, -0.0139],
                       [ 0.0186, -0.0235, -0.0098,  ..., -0.0160,  0.0303,  0.0023],
                       [-0.0230,  0.0086,  0.0015,  ...,  0.0154, -0.0029,  0.0096]],
                      dtype=torch.float16)),
              ('module.lang_embeddings.weight',
               tensor([[ 0.0027,  0.0025,  0.0109,  ..., -0.0143,  0.0063,  0.0066],
                       [ 0.0016, -0.0080,  0.0105,  ..., -0.0008,  0.0116,  0.0109],
                       [ 0.0046, -0.0050,  

In [7]:
### splitting the model params

model['encoder'] = {(k[len('module.'):] if k.startswith('module.') else k):v for k, v in model['encoder'].items()}

In [8]:
if 'decoder' in model:
    decoders_names = ['decoder']
else:
    decoders_names = ['decoder_0','decoder_1']

In [9]:
for decoder_name in decoders_names:
    model[decoder_name] = {(k[len('module.'):] if k.startswith('module.') else k):v for k, v in model[decoder_name].items()}

In [10]:
'decoder' in model or ('decoder_0' in model and 'decoder_1' in model)

True

In [11]:
model['decoder']

{'position_embeddings.weight': tensor([[ 0.0169,  0.0241, -0.0060,  ...,  0.0101, -0.0061, -0.0215],
         [-0.0008,  0.0224, -0.0236,  ..., -0.0059,  0.0085, -0.0200],
         [ 0.0017,  0.0092, -0.0068,  ..., -0.0009, -0.0055, -0.0101],
         ...,
         [ 0.0151, -0.0050,  0.0265,  ..., -0.0256,  0.0176, -0.0139],
         [ 0.0186, -0.0235, -0.0098,  ..., -0.0160,  0.0303,  0.0023],
         [-0.0230,  0.0086,  0.0015,  ...,  0.0154, -0.0029,  0.0096]],
        dtype=torch.float16),
 'lang_embeddings.weight': tensor([[ 0.0026,  0.0028,  0.0092,  ..., -0.0212,  0.0106,  0.0093],
         [ 0.0001, -0.0105,  0.0124,  ...,  0.0010,  0.0093,  0.0107],
         [ 0.0053, -0.0068,  0.0115,  ...,  0.0285, -0.0015,  0.0062]],
        dtype=torch.float16),
 'embeddings.weight': tensor([[-0.0026,  0.0328, -0.0086,  ..., -0.0876,  0.0523,  0.0307],
         [-0.0312, -0.0044,  0.0104,  ...,  0.0131, -0.0293,  0.0129],
         [-0.0401,  0.0246,  0.0135,  ..., -0.0760, -0.0206, -0.00

In [12]:
model['encoder']

{'position_embeddings.weight': tensor([[-0.0095,  0.0503,  0.0038,  ...,  0.0156, -0.0090, -0.0427],
         [ 0.0224,  0.0274, -0.0191,  ..., -0.0009, -0.0053, -0.0260],
         [ 0.0108,  0.0206,  0.0008,  ...,  0.0028, -0.0196, -0.0109],
         ...,
         [ 0.0151, -0.0050,  0.0265,  ..., -0.0256,  0.0176, -0.0139],
         [ 0.0186, -0.0235, -0.0098,  ..., -0.0160,  0.0303,  0.0023],
         [-0.0230,  0.0086,  0.0015,  ...,  0.0154, -0.0029,  0.0096]],
        dtype=torch.float16),
 'lang_embeddings.weight': tensor([[ 0.0027,  0.0025,  0.0109,  ..., -0.0143,  0.0063,  0.0066],
         [ 0.0016, -0.0080,  0.0105,  ..., -0.0008,  0.0116,  0.0109],
         [ 0.0046, -0.0050,  0.0110,  ...,  0.0222,  0.0029,  0.0070]],
        dtype=torch.float16),
 'embeddings.weight': tensor([[-0.0027,  0.0350, -0.0067,  ..., -0.0880,  0.0496,  0.0398],
         [-0.0480,  0.0035,  0.0121,  ...,  0.0016, -0.0303,  0.0152],
         [-0.0391,  0.0273,  0.0157,  ..., -0.0782, -0.0212,  0.00

In [13]:
model_params = AttrDict(model['params'])

In [14]:
model_params

{'dump_path': '/checkpoint/malachaux/dumped/bt_with_comments_sa_final_modif_test/26703283',
 'exp_name': 'bt_with_comments_sa_final_modif_test',
 'save_periodic': 1,
 'exp_id': '26703283',
 'fp16': True,
 'amp': 2,
 'encoder_only': False,
 'emb_dim': 1024,
 'emb_dim_encoder': 1024,
 'emb_dim_decoder': 1024,
 'n_layers': 6,
 'n_layers_encoder': 6,
 'n_layers_decoder': 6,
 'n_heads': 8,
 'dropout': 0.1,
 'attention_dropout': 0.0,
 'gelu_activation': False,
 'share_inout_emb': True,
 'sinusoidal_embeddings': False,
 'use_lang_emb': True,
 'context_size': 0,
 'word_pred': 0.15,
 'sample_alpha': 0.0,
 'word_mask_keep_rand': '0.8,0.1,0.1',
 'word_shuffle': 3.0,
 'word_dropout': 0.1,
 'word_blank': 0.1,
 'data_path': '/private/home/malachaux/data/XLM-cpp-java-python-with-comments-functions-sa-cl-split-test',
 'lgs': 'cpp_sa-java_sa-python_sa',
 'max_vocab': -1,
 'min_count': 0,
 'lg_sampling_factor': -1.0,
 'has_sentences_ids': True,
 'bptt': 256,
 'max_len': 512,
 'group_by_size': True,
 'ba

In [15]:
dico = Dictionary(model['dico_id2word'], model['dico_word2id'], model['dico_counts'])

dico

<XLM.src.data.dictionary.Dictionary at 0x7fc93ea26d60>

In [29]:
print(model_params.n_words, '\t', len(dico))
print(model_params.bos_index, '\t', dico.index(BOS_WORD), '\t', dico.id2word[0]) ## start
print(model_params.eos_index, '\t', dico.index(EOS_WORD), '\t', dico.id2word[1]) ## end of sentence
print(model_params.pad_index, '\t', dico.index(PAD_WORD), '\t', dico.id2word[2])
print(model_params.unk_index, '\t', dico.index(UNK_WORD), '\t', dico.id2word[3])
print(model_params.mask_index, '\t', dico.index(MASK_WORD), '\t', dico.id2word[5]) # mask word

63961 	 63961
0 	 0 	 <s>
1 	 1 	 </s>
2 	 2 	 <pad>
3 	 3 	 <unk>
5 	 5 	 <special1>


In [None]:
self.model_params['reload_model'] = ','.join([model_params] * 2)
encoder, decoder = build_model(self.reloaded_params, self.dico)

build_model()