In [1]:
DATASET_ON_GOOGLE_DRIVE = True
DATASET_ON_DISK = False
if DATASET_ON_GOOGLE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    %cd /content/drive/MyDrive/PML&DL/Project

Mounted at /content/drive
/content/drive/MyDrive/PML&DL/Project


In [2]:
!pip install transformers --quiet
!pip install datasets --quiet
from datasets import load_dataset, load_from_disk
import torch
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader
from enum import Enum, auto
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
codebert = AutoModel.from_pretrained("microsoft/codebert-base")

if DATASET_ON_GOOGLE_DRIVE or DATASET_ON_DISK:
    dataset = load_from_disk('code_x_glue_ct_code_to_text')
else:
    dataset = load_dataset("code_x_glue_ct_code_to_text", 'python')

[K     |████████████████████████████████| 5.8 MB 6.5 MB/s 
[K     |████████████████████████████████| 7.6 MB 51.7 MB/s 
[K     |████████████████████████████████| 182 kB 61.6 MB/s 
[K     |████████████████████████████████| 451 kB 6.2 MB/s 
[K     |████████████████████████████████| 212 kB 71.7 MB/s 
[K     |████████████████████████████████| 132 kB 69.3 MB/s 
[K     |████████████████████████████████| 127 kB 72.2 MB/s 
[?25h

Downloading:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/498 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/150 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/499M [00:00<?, ?B/s]

In [3]:
train_data, valid_data, test_data = dataset['train'], dataset['validation'], dataset['test']

In [None]:
class SeqType(Enum):
  CODE = auto()
  DOC = auto()
  NEG = auto()


class TokenizeCollator(object):
    def __init__(self, tokenizer, seq_type):
        self.tokenizer = tokenizer
        self.seq_type = seq_type

    def __call__(self, batch):
        return self.create_one_batch(batch)

    def create_one_batch(self, batch):
        labels, tokens_batch = self.get_formatted_input(batch)
        batch_encoding = self.tokenizer(tokens_batch, padding=True, return_tensors='pt', return_token_type_ids=True, truncation=True)
        tokens_ids = batch_encoding.input_ids.to(device)
        token_type_ids = batch_encoding.token_type_ids.to(device)
        attention_mask = batch_encoding.attention_mask.to(device)
        labels = labels.to(device)
        return (labels, (tokens_ids, token_type_ids, attention_mask))

    def get_formatted_input(self, batch):
        if self.seq_type == SeqType.CODE:
            return torch.full((len(batch), 1), 1), list(map(lambda item: self.get_formatted_input_for_code(item), batch))
        elif self.seq_type == SeqType.DOC:
            return torch.full((len(batch), 1), 1), list(map(lambda item: self.get_formatted_input_for_doc(item), batch))
        elif self.seq_type == SeqType.NEG:
            return torch.full((len(batch), 1), 0), self.get_negative_input_for_code(batch)
        else:
            raise Exception("Incorrect sequence type")

    def get_formatted_input_for_code(self, item):
        doc_tokens = ' '.join(item['docstring_tokens'])
        code_tokens = ' '.join(item['code_tokens'])
        formatted_input = self.tokenizer.cls_token + doc_tokens + self.tokenizer.sep_token+code_tokens + self.tokenizer.sep_token
        return formatted_input

    def get_formatted_input_for_doc(self, item):
        doc_tokens = ' '.join(item['docstring_tokens'])
        code_tokens = ''
        formatted_input = self.tokenizer.cls_token + doc_tokens + self.tokenizer.sep_token+code_tokens + self.tokenizer.sep_token
        return formatted_input

    def get_negative_input_for_code(self, batch):
        tokens_batch = []
        for idx, data in enumerate(batch):
            if int(len(batch)/2) > idx and idx != len(batch) - 1:
                doc_tokens = ' '.join(batch[idx+1]['docstring_tokens'])
                code_tokens = ' '.join(data['code_tokens'])
                formatted_input = self.tokenizer.cls_token + doc_tokens + self.tokenizer.sep_token + code_tokens + self.tokenizer.sep_token
                tokens_batch.append(formatted_input)
            elif idx != len(batch) - 1:
                doc_tokens = ' '.join(data['docstring_tokens'])
                code_tokens = ' '.join(batch[idx+1]['code_tokens'])
                formatted_input = self.tokenizer.cls_token + doc_tokens + self.tokenizer.sep_token + code_tokens + self.tokenizer.sep_token
                tokens_batch.append(formatted_input)
        doc_tokens = ' '.join(batch[-1]['docstring_tokens'])
        code_tokens = ' '.join(batch[0]['code_tokens'])
        formatted_input = self.tokenizer.cls_token + doc_tokens + self.tokenizer.sep_token + code_tokens + self.tokenizer.sep_token
        tokens_batch.append(formatted_input)
        return tokens_batch


code_tokenize_collate_fn = TokenizeCollator(tokenizer, SeqType.CODE)
doc_tokenize_collate_fn = TokenizeCollator(tokenizer, SeqType.DOC)
neg_tokenize_collate_fn = TokenizeCollator(tokenizer, SeqType.NEG)

BATCH_SIZE = 4

In [None]:
# torch.cuda.empty_cache()
codebert = codebert.to(device)

In [None]:
train_code_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=code_tokenize_collate_fn, num_workers=0)
train_doc_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=doc_tokenize_collate_fn, num_workers=0)
val_code_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=code_tokenize_collate_fn, num_workers=0)
val_doc_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=doc_tokenize_collate_fn, num_workers=0)
test_code_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=code_tokenize_collate_fn, num_workers=0)
test_doc_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=doc_tokenize_collate_fn, num_workers=0)

In [None]:
neg_train_code_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=neg_tokenize_collate_fn, num_workers=0)
neg_val_code_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=neg_tokenize_collate_fn, num_workers=0)
neg_test_code_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=neg_tokenize_collate_fn, num_workers=0)

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class FineTunedCodeBert(nn.Module):
    def __init__(self, model):
        super(FineTunedCodeBert, self).__init__()
        self.model = model

        for p in self.model.parameters():
            p.requires_grad = True

        self.cls_layer = nn.Linear(768+768, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, code_tokens, doc_tokens):
        tokens_ids, token_type_ids, attention_mask = code_tokens
        code_embs = self.model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output
        tokens_ids, token_type_ids, attention_mask = doc_tokens
        doc_embs = self.model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output
        concat_embs = torch.cat([code_embs, doc_embs], dim=1)
        out = self.cls_layer(concat_embs)
        out = self.sigmoid(out)
        return out

    def predict(self, code_emb, doc_emb):
        concat_embs = torch.cat([code_emb, doc_emb], dim=0)
        out = self.cls_layer(concat_embs)
        out = self.sigmoid(out)
        return out

In [None]:
net = FineTunedCodeBert(codebert)
net = net.to(device)

In [None]:
learning_rate = 1e-5
epochs = 8
optimizer = torch.optim.Adam(codebert.parameters(), lr=learning_rate)
loss_fn = nn.BCELoss()

In [None]:
from time import time
import numpy as np

def eval(model, eval_code_dataloader, eval_doc_dataloader):
    model.eval()

    batched_val_code_embs = []
    batched_val_doc_embs = []

    for iteration, ((_, code_tokens), (_, doc_tokens)) in tqdm(enumerate(zip(eval_code_dataloader, eval_doc_dataloader)), total=len(eval_code_dataloader)):
        with torch.no_grad():
            with torch.autocast(device_type=device, dtype=torch.float16):

                tokens_ids, token_type_ids, attention_mask = code_tokens
                code_embs = model.model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output

                tokens_ids, token_type_ids, attention_mask = doc_tokens
                doc_embs = model.model(input_ids=tokens_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output

            batched_val_code_embs.append(code_embs.type(torch.float32))
            batched_val_doc_embs.append(doc_embs.type(torch.float32))

    val_code_embeddings = torch.cat(batched_val_code_embs, dim=0)
    val_doc_embeddings = torch.cat(batched_val_doc_embs, dim=0)

    mrrs = []
    k = 1000
    for doc_idx, val_doc_embedding in tqdm(enumerate(val_doc_embeddings), total=len(val_doc_embeddings)):
        preds = []
        if doc_idx + k > len(val_code_embeddings):
            codes = torch.cat([val_code_embeddings[doc_idx:], val_code_embeddings[0: (k-len(val_code_embeddings)+doc_idx)]], dim=0)
        else:
            codes = val_code_embeddings[doc_idx: doc_idx+k]
        for code_idx, val_code_embedding in enumerate(codes):
            with torch.no_grad():
                pred = model.predict(val_code_embedding, val_doc_embedding)
            preds.append(pred)

        preds = torch.tensor(preds)
        _, indices = torch.sort(preds, descending=True)
        rank = (indices == 0).nonzero().item() + 1
        mrrs.append(1/rank)

    return torch.mean(torch.tensor(mrrs))

In [None]:
def get_accuracy_from_logits(logits, labels):
    soft_probs = (logits > 0.5).long()
    acc = (soft_probs == labels).float().mean()
    return acc

In [None]:
def train(model, train_code_dataloader, train_doc_dataloader, neg_train_code_dataloader, val_code_dataloader, val_doc_dataloader, epoch):
    model.train()

    running_loss = 0
    for iteration, ((pos_labels, code_tokens), (_, doc_tokens), (neg_labels, neg_tokens)) in tqdm(enumerate(zip(train_code_dataloader, train_doc_dataloader, neg_train_code_dataloader)), total=len(train_code_dataloader)):
        optimizer.zero_grad()
        with torch.autocast(device_type=device, dtype=torch.float16):
            code_logits = model(code_tokens, doc_tokens)
            neg_code_logits = model(neg_tokens, doc_tokens)
            logits = torch.cat([code_logits, neg_code_logits])
            labels = torch.cat([pos_labels.float(), neg_labels.float()]).to(torch.float16)
        loss = loss_fn(logits, labels)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()

        if iteration % 500 == 0:
            _loss = running_loss / (iteration + 1)
            acc = get_accuracy_from_logits(logits, labels)
            print("epoch: {}\titeration: {}\tloss: {}\tthis iteration loss: {}\tthis iteration acc: {}".format(epoch, iteration, _loss, loss, acc))

    print("epoch: {}\ttrain loss: {}".format(epoch, running_loss / len(train_code_dataloader)))

    torch.save(model, 'fine_tuned_codebert_cls.pt')

    val_mrr = eval(model, val_code_dataloader, val_doc_dataloader)
    print("epoch: {}\tvalid mrr: {}".format(epoch, val_mrr))

In [None]:
for epoch in range(epochs):
    train(net, train_code_loader, train_doc_loader, neg_train_code_loader, val_code_loader, val_doc_loader, epoch)

In [None]:
val_mrr = eval(net, test_code_loader, test_doc_loader)
print(f'Mean Reciprocal rank is: {val_mrr}')