<a href="https://colab.research.google.com/github/rishuatgithub/cs-autograds-code-nmt/blob/main/Load_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [33]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [34]:
cd '/content/drive/MyDrive/cs-autograds-code-nmt/'

/content/drive/MyDrive/cs-autograds-code-nmt


In [35]:
import torch
import torch.nn as nn

import numpy as np
import pandas as pd
import os
import sys

In [36]:
!pip install fastbpe clang sacrebleu=="1.2.11"



In [37]:
#!pip install git+https://github.com/llvm/llvm-project.git

In [40]:
import fastBPE

import preprocessing.src.code_tokenizer as code_tokenizer
from XLM.src.data.dictionary import Dictionary, BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD
from XLM.src.model import build_model
from XLM.src.utils import AttrDict

In [41]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [42]:
MODEL_PATH='model/model_1.pth'

model = torch.load(MODEL_PATH, map_location=device)

In [None]:
model

In [44]:
### splitting the model params

model['encoder'] = {(k[len('module.'):] if k.startswith('module.') else k):v for k, v in model['encoder'].items()}

In [45]:
if 'decoder' in model:
    decoders_names = ['decoder']
else:
    decoders_names = ['decoder_0','decoder_1']

In [46]:
for decoder_name in decoders_names:
    model[decoder_name] = {(k[len('module.'):] if k.startswith('module.') else k):v for k, v in model[decoder_name].items()}

In [47]:
'decoder' in model or ('decoder_0' in model and 'decoder_1' in model)

True

In [None]:
model['decoder']

In [None]:
model['encoder']

In [50]:
model_params = AttrDict(model['params'])

In [None]:
model_params

In [52]:
dico = Dictionary(model['dico_id2word'], model['dico_word2id'], model['dico_counts'])

dico

<XLM.src.data.dictionary.Dictionary at 0x7f867a68b908>

In [53]:
print(model_params.n_words, '\t', len(dico))
print(model_params.bos_index, '\t', dico.index(BOS_WORD), '\t', dico.id2word[0]) ## start
print(model_params.eos_index, '\t', dico.index(EOS_WORD), '\t', dico.id2word[1]) ## end of sentence
print(model_params.pad_index, '\t', dico.index(PAD_WORD), '\t', dico.id2word[2])
print(model_params.unk_index, '\t', dico.index(UNK_WORD), '\t', dico.id2word[3])
print(model_params.mask_index, '\t', dico.index(MASK_WORD), '\t', dico.id2word[5]) # mask word

63961 	 63961
0 	 0 	 <s>
1 	 1 	 </s>
2 	 2 	 <pad>
3 	 3 	 <unk>
5 	 5 	 <special1>


In [54]:
torch._C._cuda_getDevice()

0

In [55]:
model_params['reload_model'] = ','.join([MODEL_PATH] * 2)
#model_params

In [56]:
encoder1, decoder1 = build_model(model_params, dico)

In [None]:
encoder = encoder1[0]
encoder

In [60]:
encoder.load_state_dict(model['encoder'])

<All keys matched successfully>

In [None]:
decoder = decoder1[0]
decoder

In [62]:
decoder.load_state_dict(model['decoder'])

<All keys matched successfully>

In [None]:
encoder.cuda()
decoder.cuda()

In [69]:
lang1="java"
lang2="python"

In [70]:
tokenizer = getattr(code_tokenizer, f'tokenize_{lang1}')
detokenizer = getattr(code_tokenizer, f'tokenize_{lang2}')

In [71]:
lang1 += '_sa'
lang2 += '_sa'

In [72]:
lang1_id = model_params.lang2id[lang1]
lang2_id = model_params.lang2id[lang2]

In [75]:
input="int f(String target, ArrayList<String> array) { "+ \
  "  int count = 0; "+ \
  "  for (String str: array) { "+ \
  "      if (target.equals(str)) { "+ \
  "         count++; "+ \
  "     } "+ \
  "  } "+ \
  "  return count; "+ \
"}"

input

'int f(String target, ArrayList<String> array) {   int count = 0;   for (String str: array) {       if (target.equals(str)) {          count++;      }   }   return count; }'

In [80]:
BPE_PATH='/content/drive/MyDrive/cs-autograds-code-nmt/data/BPE_with_comments_codes'

bpe_model = fastBPE.fastBPE(os.path.abspath(BPE_PATH))

In [81]:
tokens = [t for t in tokenizer(input)]
tokens = bpe_model.apply(tokens)
tokens = ['</s>'] + tokens + ['</s>']
input = " ".join(tokens)

In [82]:
input

'</s> int f ( String target , ArrayList < String > array ) { int count = 0 ; for ( String str : array ) { if ( target . equals ( str ) ) { count ++ ; } } return count ; } </s>'

In [86]:
len1 = len(input.split())
len1 = torch.LongTensor(1).fill_(len1).to(device)

x1 = torch.LongTensor([dico.index(w) for w in input.split()]).to(device)[:, None]

langs1 = x1.clone().fill_(lang1_id)
enc1 = encoder('fwd', x=x1, lengths=len1, langs=langs1, causal=False)
enc1 = enc1.transpose(0, 1)

In [90]:
x2, len2 = decoder.generate(enc1,
                            len1,
                            lang2_id,
                            max_len=int(min(model_params.max_len, 3 * len1.max().item() + 10)),
                            sample_temperature=None)

In [92]:
len2

tensor([40], device='cuda:0')

In [95]:
tok = []
for i in range(x2.shape[1]):
  wid = [dico[x2[j, i].item()] for j in range(len(x2))][1:]
  wid = wid[:wid.index(EOS_WORD)] if EOS_WORD in wid else wid
  tok.append(" ".join(wid).replace("@@ ", ""))

results = []
for t in tok:
  results.append(detokenizer(t))
  #return results
  #print(detokenizer(t))

In [104]:
for out in results[0]:
  #print("=" * 20)
  if out == 'NEWLINE':
    print("\n"+out)
  else:
    print(out)

def
f
(
target
,
array
)
:
NEW_LINE
INDENT
count
=
0
NEW_LINE
for
str
in
array
:
NEW_LINE
INDENT
if
target
==
str
:
NEW_LINE
INDENT
count
+=
1
NEW_LINE
DEDENT
DEDENT
return
count
NEW_LINE
DEDENT
NEW_LINE
