# CodeBERT: Using With Transformers + Embedding Quality Questions

> ref [Using With Transformers + Embedding Quality Questions](https://github.com/microsoft/CodeBERT/issues/21)

In [2]:
from vectorhub.encoders.text import BaseText2Vec
import torch
from transformers import RobertaTokenizer, RobertaConfig, RobertaModel

In [4]:
class Code2Vec(BaseText2Vec):
    def __init__(self):
        model_name = "microsoft/codebert-base"
        self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
        self.model = RobertaModel.from_pretrained(model_name)

    def encode(self, description, code=None, pooling_method='mean', truncation=True):
        """
        Pooling method is either pooler_output or mean.
        Notes: if it is mean, we can take the last hidden state and add it to the
        model.
        """
        if pooling_method == 'pooler_output':
            return self.model.forward(**self.tokenizer.encode_plus(
                description, 
                code, 
                return_tensors='pt', 
                truncation=truncation
            ))[pooling_method].detach().numpy().tolist()[0]
        elif pooling_method == 'mean':
            return self._vector_operation(self.model.forward(
                **self.tokenizer.encode_plus(
                    description, 
                    code, 
                    return_tensors='pt', 
                    truncation=truncation
            ))['last_hidden_state'].detach().numpy().tolist(), 'mean', axis=1)[0]

model = Code2Vec()

In [7]:
query = "hello world"
code_1 = 'print("hello")'
vec_1 = model.encode(query, code_1)
print(len(vec_1))

768


In [8]:
query = "show all cells"
code_2 = """from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
"""
vec_2 = model.encode(query, code_2)
print(len(vec_2))

768


In [9]:
query = "Install private package in Colab"
code_3 = """import os
from getpass import getpass
import urllib

user = input('User name: ')
password = getpass('Password: ')
password = urllib.parse.quote(password) # your password is converted into url format
https_repo_link = input('Https Repo Link: ') 
end_string = https_repo_link.split('@github.com/')[1]
cmd_string = 'git clone https://{0}:{1}@github.com/{2}'.format(user, password, end_string)
os.system(cmd_string)"""
vec_3 = model.encode(query, code_3)
print(len(vec_3))

768


In [10]:
query = "Download an image"
code_4 = """
def download_image(image_url, output_dir):
    import requests
    r = requests.get(image_url)
    with open(output_dir, 'wb') as f:
        f.write(r.content)
"""
vec_4 = model.encode(query, code_4)
print(len(vec_4))

768


In [13]:
from simpleneighbors import SimpleNeighbors
colors = [
    (code_1, vec_1),
    (code_2, vec_2),
    (code_3, vec_3),
    (code_4, vec_4)
]
sim = SimpleNeighbors(768)
sim.feed(colors)
sim.build()

In [14]:
# The only good result:
query_vec = model.encode("view all cells in jupyter notebook")
print(list(sim.nearest_matching(query_vec, 1))[0])
# Returns  (only for mean output)
# from IPython.core.interactiveshell import InteractiveShell
# InteractiveShell.ast_node_interactivity = "all"

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"



In [18]:
query_vec = model.encode("display the string 'hello'")
print(list(sim.nearest_matching(query_vec, 1))[0])

print("hello")
