In [1]:
import torch
from unixcoder import UniXcoder
import datasets

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UniXcoder("microsoft/unixcoder-base")
model.to(device)



UniXcoder(
  (model): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(51416, 768, padding_idx=1)
      (position_embeddings): Embedding(1026, 768, padding_idx=1)
      (token_type_embeddings): Embedding(10, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,),

In [23]:
# decoder no mask just completion
context = """
function countWords(str) {
"""

def predict_statement(input: str) -> str:
    tokens_ids = model.tokenize([input],max_length=512,mode="<decoder-only>")
    source_ids = torch.tensor(tokens_ids).to(device)
    prediction_ids = model.generate(source_ids, decoder_only=True, beam_size=3, max_length=128)
    predictions = model.decode(prediction_ids)
    prediction = predictions[0][0]
    statements = [x + ";" for x in prediction.split(";\n")]
    if len(statements) == 1:
        statements = prediction.split("\n")
    return statements[0]

for i in range(3):
    context = context + "\n" + predict_statement(context)
    print(i, len(context), context, '\n')

0 44 
function countWords(str) {

var words = []; 

1 111 
function countWords(str) {

var words = [];

for (var i = 0; i < str.length; i++) {
words.push(str.charAt(i)); 

2 129 
function countWords(str) {

var words = [];

for (var i = 0; i < str.length; i++) {
words.push(str.charAt(i));

}

return words; 



In [45]:
# encoder decoder mask
context = """
function countWords(str: string): number {
    var words: string[] = []
    <mask0>
    return words.length
}
"""

def predict_statement_mask(code: str) -> str:
    tokens_ids = model.tokenize([code],max_length=512,mode="<encoder-decoder>")
    source_ids = torch.tensor(tokens_ids).to(device)
    prediction_ids = model.generate(source_ids, decoder_only=False, beam_size=3, max_length=128)
    predictions = model.decode(prediction_ids)
    # top 1
    return predictions[0][0].replace("<mask0>", "").strip()

print(context.replace("<mask0>", predict_statement_mask(context)))


function countWords(str: string): number {
    var words: string[] = []
    words.push(str)
    return words.length
}

