In [23]:
#!g1.1
%pip install transformers
%pip install datasets
# !apt install libomp-dev --yes


import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader
from enum import Enum, auto
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    %pip install faiss-gpu -q
else:
    %pip install faiss-cpu -q
    
import faiss
import faiss.contrib.torch_utils

Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.[0m
Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.[0m


In [24]:
#!g1.1
FLAGS = {
    'batch_size': 16,
    'num_epochs': 1,
    'seed': 1234
}

In [25]:
#!g1.1
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")

Downloading:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/498 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/150 [00:00<?, ?B/s]

In [26]:
#!g1.1
dataset = load_dataset("code_x_glue_ct_code_to_text", 'python')
train_data, valid_data, test_data = dataset['train'], dataset['validation'], dataset['test']

Downloading builder script:   0%|          | 0.00/5.91k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/17.9k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/25.7k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/2.35k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.74k [00:00<?, ?B/s]

Downloading and preparing dataset code_x_glue_ct_code_to_text/python to /tmp/xdg_cache/huggingface/datasets/code_x_glue_ct_code_to_text/python/0.0.0/f8b7e9d51f609a87e7ec7c7431706d4ee0b402e3398560410313d4acc67060a0...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/941M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.4M [00:00<?, ?B/s]

  

Extracting data files #0:   0%|          | 0/1 [00:00<?, ?obj/s]

Extracting data files #1:   0%|          | 0/1 [00:00<?, ?obj/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

  

Extracting data files #0:   0%|          | 0/1 [00:00<?, ?obj/s]

Extracting data files #1:   0%|          | 0/1 [00:00<?, ?obj/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

  

Extracting data files #0:   0%|          | 0/1 [00:00<?, ?obj/s]

Extracting data files #1:   0%|          | 0/1 [00:00<?, ?obj/s]

Generating train split:   0%|          | 0/251820 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13914 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/14918 [00:00<?, ? examples/s]

Dataset code_x_glue_ct_code_to_text downloaded and prepared to /tmp/xdg_cache/huggingface/datasets/code_x_glue_ct_code_to_text/python/0.0.0/f8b7e9d51f609a87e7ec7c7431706d4ee0b402e3398560410313d4acc67060a0. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [27]:
#!g1.1
class SeqType(Enum):
  CODE = auto()
  DOC = auto()

    
class TokenizeTransform(object):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        return self.create_one_batch(batch)

    def create_one_batch(self, batch):
        tokens_batch = self.get_formatted_input(batch, SeqType.CODE)
        batch_encoding = self.tokenizer(tokens_batch, padding=True, return_token_type_ids=True, truncation=True)
        code_tokens_ids = batch_encoding.input_ids
        code_token_type_ids = batch_encoding.token_type_ids
        code_attention_mask = batch_encoding.attention_mask
        
        tokens_batch = self.get_formatted_input(batch, SeqType.DOC)
        batch_encoding = self.tokenizer(tokens_batch, padding=True, return_token_type_ids=True, truncation=True)
        doc_tokens_ids = batch_encoding.input_ids
        doc_token_type_ids = batch_encoding.token_type_ids
        doc_attention_mask = batch_encoding.attention_mask
        
        return {'code_tokens_ids': code_tokens_ids,
                'code_token_type_ids': code_token_type_ids,
                'code_attention_mask': code_attention_mask,
                
                'doc_tokens_ids': doc_tokens_ids,
                'doc_token_type_ids': doc_token_type_ids,
                'doc_attention_mask': doc_attention_mask
               }

    def get_formatted_input(self, batch, seq_type):
        if seq_type == SeqType.CODE:
            return self.get_formatted_input_with_f(batch, TokenizeTransform.get_batched_tokens4code)
        elif seq_type == SeqType.DOC:
            return self.get_formatted_input_with_f(batch, TokenizeTransform.get_batched_tokens4docstring)
        else:
            raise Exception("Incorrect sequence type")
            
    def get_formatted_input_with_f(self, batch, get_tokens):
        return [self.tokenizer.cls_token + doc_tokens + self.tokenizer.sep_token + code_tokens + self.tokenizer.sep_token \
                   for doc_tokens, code_tokens in zip(*get_tokens(batch))]

    def get_batched_tokens4code(batch):
        doc_tokens = TokenizeTransform.get_docstring_tokens(batch)
        code_tokens = TokenizeTransform.get_code_tokens(batch)
        return doc_tokens, code_tokens

    def get_batched_tokens4docstring(batch):
        doc_tokens = TokenizeTransform.get_docstring_tokens(batch)
        code_tokens = [''] * len(doc_tokens) 
        return doc_tokens, code_tokens

    def get_docstring_tokens(batch):
        return [' '.join(tokens) for tokens in batch['docstring_tokens']]
    
    def get_code_tokens(batch):
        return [' '.join(tokens) for tokens in batch['code_tokens']]
    
    def rotate_list(lst, shift=1):
        return lst[shift:] + lst[:shift]
    
    
tokenize_transform = TokenizeTransform(tokenizer)

In [28]:
#!g1.1
columns = train_data.column_names

train_dataset = train_data.map(tokenize_transform, batched=True, remove_columns=columns, batch_size=FLAGS['batch_size']).with_format("torch")
valid_dataset = valid_data.map(tokenize_transform, batched=True, remove_columns=columns, batch_size=FLAGS['batch_size']).with_format("torch")
test_dataset = test_data.map(tokenize_transform, batched=True, remove_columns=columns, batch_size=FLAGS['batch_size']).with_format("torch")



  0%|          | 0/15739 [00:00<?, ?ba/s]

  0%|          | 0/870 [00:00<?, ?ba/s]



  0%|          | 0/933 [00:00<?, ?ba/s]

In [29]:
#!g1.1
class FaissKNeighbors:
    def __init__(self, is_cuda):
        self.index = None
        self.is_cuda = is_cuda

    def fit(self, X):
        self.index = faiss.IndexFlatL2(X.shape[1])
        if self.is_cuda:
            res = faiss.StandardGpuResources()
            self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
        self.index.add(X)

    def predict(self, X, k):
        distances, indices = self.index.search(X, k=k)
        return indices

In [53]:
#!g1.1
def eval(model, dataloader, device):
    model.eval()

    batched_code_embs = []
    batched_doc_embs = []

    running_loss = 0
    for iteration, batch in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                tokens_ids, token_type_ids, attention_mask = batch['code_tokens_ids'], batch['code_token_type_ids'], batch['code_attention_mask']
                tokens_ids, token_type_ids, attention_mask = tokens_ids.to(device), token_type_ids.to(device), attention_mask.to(device)
                code_embs = model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output

                tokens_ids, token_type_ids, attention_mask = batch['doc_tokens_ids'], batch['doc_token_type_ids'], batch['doc_attention_mask']
                tokens_ids, token_type_ids, attention_mask = tokens_ids.to(device), token_type_ids.to(device), attention_mask.to(device)
                doc_embs = model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output

                neg_code_embs = code_embs.roll(1, 0)

            loss = loss_fn(doc_embs, code_embs, neg_code_embs)
            running_loss += loss.item()

            batched_code_embs.append(code_embs.type(torch.float32))
            batched_doc_embs.append(doc_embs.type(torch.float32))

    concat_code_embs = torch.cat(batched_code_embs, dim=0)
    concat_doc_embs = torch.cat(batched_doc_embs, dim=0)

    loss = running_loss / len(dataloader)

    faiss = FaissKNeighbors(is_cuda=device == torch.device('cuda'))

    k = 1000
    mrrs = []
    for beg_idx in range(0, len(concat_code_embs), k):
        if beg_idx + k > len(concat_code_embs):
            k = len(concat_code_embs) - beg_idx
        doc_embs_subset = concat_doc_embs[beg_idx:beg_idx + k]
        code_embs_subset = concat_code_embs[beg_idx:beg_idx + k]
        faiss.fit(code_embs_subset)  
        preds = faiss.predict(doc_embs_subset, k=k)
        targets = torch.unsqueeze(torch.tensor(range(k)), 1).expand(-1, k).to(device)

        reciprocal_ranks = 1 / (np.argwhere(torch.eq(preds, targets).cpu().numpy())[:,1] + 1)
        mrr_ = np.mean(reciprocal_ranks)
        mrrs.append(mrr_)

    mrrs = torch.tensor(mrrs, device=device)
    return torch.mean(mrrs), loss

In [31]:
#!g1.1
def train(model, dataloader, val_loader, device, epoch, epoch_start):
    model.train()
    running_loss = 0
    for iteration, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            tokens_ids, token_type_ids, attention_mask = batch['code_tokens_ids'], batch['code_token_type_ids'], batch['code_attention_mask']
            tokens_ids, token_type_ids, attention_mask = tokens_ids.to(device), token_type_ids.to(device), attention_mask.to(device)
            code_embs = model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output
            
            tokens_ids, token_type_ids, attention_mask = batch['doc_tokens_ids'], batch['doc_token_type_ids'], batch['doc_attention_mask']
            tokens_ids, token_type_ids, attention_mask = tokens_ids.to(device), token_type_ids.to(device), attention_mask.to(device)
            doc_embs = model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output

            neg_code_embs = code_embs.roll(1, 0)
            
        loss = loss_fn(doc_embs, code_embs, neg_code_embs)
        running_loss += loss.item()
        loss.backward()
        # xm.optimizer_step(optimizer)
        optimizer.step()

        if iteration % 500 == 0:
            _loss = running_loss / (iteration + 1)
            elapsed_train_time = time.time() - epoch_start
            print("epoch: {}\titeration: {}\tloss: {}\tthis iteration loss: {}\telapsed time: {}\expected time: {}".\
                  format(epoch, iteration, _loss, loss, elapsed_train_time, len(dataloader) / (iteration + 1) * elapsed_train_time))

    elapsed_train_time = time.time() - epoch_start
    print("\tepoch: {}\ttrain loss: {}\telapsed time: {}".format(epoch, running_loss / len(dataloader), elapsed_train_time))

    torch.save(model, 'fine_tuned_codebert.pt')

    val_mrr, val_loss = eval(model, val_loader, device)
    print("epoch: {}\tvalid loss: {}\tvalid mrr: {}".format(epoch, val_loss, val_mrr))

In [32]:
#!g1.1
  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=FLAGS['batch_size'],
    drop_last=True)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=FLAGS['batch_size'],
    drop_last=True)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=FLAGS['batch_size'],
    drop_last=True)

codebert = AutoModel.from_pretrained("microsoft/codebert-base").to(device).train()
for p in codebert.parameters():
    p.requires_grad = True

learning_rate = 1e-5
optimizer = torch.optim.RMSprop(codebert.parameters(), lr = learning_rate)
loss_fn = torch.nn.TripletMarginLoss(margin=200, p=2)

Downloading:   0%|          | 0.00/476M [00:00<?, ?B/s]

In [51]:
#!g1.1
for epoch in range(FLAGS['num_epochs']):
    epoch_start = time.time()
    train(codebert, train_loader, valid_loader, device, epoch, epoch_start)

  0%|          | 1/15738 [00:00<1:23:15,  3.15it/s]

epoch: 0	iteration: 0	loss: 194.375	this iteration loss: 194.375	elapsed time: 0.321727991104126\expected time: 5063.355123996735


  3%|▎         | 501/15738 [02:18<1:12:55,  3.48it/s]

epoch: 0	iteration: 500	loss: 196.9436127744511	this iteration loss: 195.5	elapsed time: 138.01833772659302\expected time: 4335.594010261718


  6%|▋         | 1001/15738 [04:27<1:07:58,  3.61it/s]

epoch: 0	iteration: 1000	loss: 197.02397602397602	this iteration loss: 200.0	elapsed time: 267.7575364112854\expected time: 4209.758349691118


 10%|▉         | 1501/15738 [06:38<1:02:54,  3.77it/s]

epoch: 0	iteration: 1500	loss: 197.13357761492338	this iteration loss: 196.625	elapsed time: 398.18438696861267\expected time: 4174.967276556979


 13%|█▎        | 2001/15738 [08:51<1:07:29,  3.39it/s]

epoch: 0	iteration: 2000	loss: 197.1907796101949	this iteration loss: 195.5	elapsed time: 531.310530424118\expected time: 4178.7931673237235


 16%|█▌        | 2501/15738 [11:07<59:10,  3.73it/s]  

epoch: 0	iteration: 2500	loss: 197.20771691323472	this iteration loss: 197.75	elapsed time: 667.3424751758575\expected time: 4199.374599887104


 19%|█▉        | 3001/15738 [13:26<55:16,  3.84it/s]  

epoch: 0	iteration: 3000	loss: 197.12874875041652	this iteration loss: 197.75	elapsed time: 806.4802806377411\expected time: 4229.385756973265


 22%|██▏       | 3501/15738 [15:44<1:02:22,  3.27it/s]

epoch: 0	iteration: 3500	loss: 197.11075407026564	this iteration loss: 196.625	elapsed time: 944.4784622192383\expected time: 4245.701810455977


 25%|██▌       | 4001/15738 [17:59<45:48,  4.27it/s]  

epoch: 0	iteration: 4000	loss: 197.09666333416646	this iteration loss: 196.5	elapsed time: 1079.2290239334106\expected time: 4245.165303340169


 29%|██▊       | 4501/15738 [20:13<49:16,  3.80it/s]  

epoch: 0	iteration: 4500	loss: 197.08292601644078	this iteration loss: 198.5	elapsed time: 1213.9996111392975\expected time: 4244.818013799214


 32%|███▏      | 5001/15738 [22:28<47:58,  3.73it/s]

epoch: 0	iteration: 5000	loss: nan	this iteration loss: nan	elapsed time: 1348.0425970554352\expected time: 4242.250428406006


 35%|███▍      | 5501/15738 [24:40<48:52,  3.49it/s]

epoch: 0	iteration: 5500	loss: nan	this iteration loss: nan	elapsed time: 1480.0579988956451\expected time: 4234.348806875053


 38%|███▊      | 6001/15738 [26:52<40:53,  3.97it/s]

epoch: 0	iteration: 6000	loss: nan	this iteration loss: nan	elapsed time: 1612.03231883049\expected time: 4227.656162931887


 41%|████▏     | 6501/15738 [29:01<43:25,  3.55it/s]

epoch: 0	iteration: 6500	loss: nan	this iteration loss: nan	elapsed time: 1741.5308396816254\expected time: 4215.999439303095


 44%|████▍     | 7001/15738 [31:15<42:23,  3.44it/s]

epoch: 0	iteration: 7000	loss: nan	this iteration loss: nan	elapsed time: 1875.2753491401672\expected time: 4215.552556030275


 48%|████▊     | 7501/15738 [33:28<37:26,  3.67it/s]

epoch: 0	iteration: 7500	loss: nan	this iteration loss: nan	elapsed time: 2008.9621241092682\expected time: 4215.044115348842


 51%|█████     | 8001/15738 [35:41<36:49,  3.50it/s]

epoch: 0	iteration: 8000	loss: nan	this iteration loss: nan	elapsed time: 2141.75319647789\expected time: 4212.837371099741


 54%|█████▍    | 8501/15738 [37:51<35:26,  3.40it/s]

epoch: 0	iteration: 8500	loss: nan	this iteration loss: nan	elapsed time: 2271.5372116565704\expected time: 4205.323213392672


 57%|█████▋    | 9001/15738 [40:01<31:58,  3.51it/s]

epoch: 0	iteration: 9000	loss: nan	this iteration loss: nan	elapsed time: 2401.6979863643646\expected time: 4199.302622975489


 60%|██████    | 9501/15738 [42:15<28:04,  3.70it/s]

epoch: 0	iteration: 9500	loss: nan	this iteration loss: nan	elapsed time: 2535.22300696373\expected time: 4199.488441595114


 64%|██████▎   | 10001/15738 [44:22<26:15,  3.64it/s]

epoch: 0	iteration: 10000	loss: nan	this iteration loss: nan	elapsed time: 2662.8012280464172\expected time: 4190.297542945157


 67%|██████▋   | 10501/15738 [46:34<17:59,  4.85it/s]

epoch: 0	iteration: 10500	loss: nan	this iteration loss: nan	elapsed time: 2794.6674082279205\expected time: 4188.408310702886


 70%|██████▉   | 11001/15738 [48:48<22:25,  3.52it/s]

epoch: 0	iteration: 11000	loss: nan	this iteration loss: nan	elapsed time: 2928.113991498947\expected time: 4188.951731498085


 73%|███████▎  | 11501/15738 [50:57<20:19,  3.47it/s]

epoch: 0	iteration: 11500	loss: nan	this iteration loss: nan	elapsed time: 3057.666138648987\expected time: 4184.118745331515


 76%|███████▋  | 12001/15738 [53:08<16:17,  3.82it/s]

epoch: 0	iteration: 12000	loss: nan	this iteration loss: nan	elapsed time: 3188.164508342743\expected time: 4180.9293419130145


 79%|███████▉  | 12501/15738 [55:23<16:07,  3.35it/s]

epoch: 0	iteration: 12500	loss: nan	this iteration loss: nan	elapsed time: 3323.235972881317\expected time: 4183.752319110965


 83%|████████▎ | 13001/15738 [57:37<13:04,  3.49it/s]

epoch: 0	iteration: 13000	loss: nan	this iteration loss: nan	elapsed time: 3457.367675304413\expected time: 4185.220557952531


 86%|████████▌ | 13501/15738 [59:57<11:05,  3.36it/s]

epoch: 0	iteration: 13500	loss: nan	this iteration loss: nan	elapsed time: 3597.2784028053284\expected time: 4193.316606425469


 89%|████████▉ | 14001/15738 [1:02:10<07:33,  3.83it/s]

epoch: 0	iteration: 14000	loss: nan	this iteration loss: nan	elapsed time: 3730.6723136901855\expected time: 4193.509097411337


 92%|█████████▏| 14501/15738 [1:04:18<05:56,  3.47it/s]

epoch: 0	iteration: 14500	loss: nan	this iteration loss: nan	elapsed time: 3858.973160505295\expected time: 4188.160788913339


 95%|█████████▌| 15001/15738 [1:06:29<03:18,  3.72it/s]

epoch: 0	iteration: 15000	loss: nan	this iteration loss: nan	elapsed time: 3989.075571537018\expected time: 4185.059085717591


 98%|█████████▊| 15501/15738 [1:08:43<01:11,  3.33it/s]

epoch: 0	iteration: 15500	loss: nan	this iteration loss: nan	elapsed time: 4123.380038738251\expected time: 4186.423782314856


100%|██████████| 15738/15738 [1:09:45<00:00,  3.76it/s]

	epoch: 0	train loss: nan	elapsed time: 4185.117289066315



869it [01:24, 10.30it/s]


AttributeError: module 'torch' has no attribute 'argwhere'

In [54]:
#!g1.1
val_mrr, val_loss = eval(codebert, valid_loader, device)
print("loss: {}\tmrr: {}".format(val_loss, val_mrr))

869it [01:24, 10.30it/s]
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


loss: nan	mrr: nan
