In [1]:
from datasets import load_dataset

datasets = load_dataset('code_x_glue_cc_code_to_code_trans')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'java', 'cs'],
        num_rows: 10300
    })
    validation: Dataset({
        features: ['id', 'java', 'cs'],
        num_rows: 500
    })
    test: Dataset({
        features: ['id', 'java', 'cs'],
        num_rows: 1000
    })
})

In [3]:
load = 'Salesforce/codet5p-220m'

max_source_len = 100
max_target_len = 100

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(load)



In [5]:
def preprocess_function(examples):
    source = examples["java"]
    target = examples["cs"]

    model_inputs = tokenizer(source, max_length=max_source_len, padding="max_length", truncation=True, return_tensors="pt")
    labels = tokenizer(target, max_length=max_target_len, padding="max_length", truncation=True, return_tensors="pt")

    # model_inputs["labels"] = labels["input_ids"].copy()
    model_inputs["labels"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in model_inputs["input_ids"]
    ]
    return model_inputs

tokenized_datasets = datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=datasets['train'].column_names,
    # num_proc=64,
    num_proc=4,
)

Map (num_proc=4): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10300/10300 [00:04<00:00, 2110.17 examples/s]
Map (num_proc=4): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:00<00:00, 1127.03 examples/s]
Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1470.35 examples/s]


In [6]:
tokenized_datasets['train']

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 10300
})

In [24]:
from transformers import AutoConfig, AutoModelForSeq2SeqLM
import evaluate
import codebleu

config = AutoConfig.from_pretrained('google-t5/t5-base')
model = AutoModelForSeq2SeqLM.from_config(config).to('cuda')

bleu_metric = evaluate.load("bleu")

# computes bleu and codebleu metric
def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    labels = [
        [(l if l != -100 else tokenizer.pad_token_id) for l in label] for label in labels
    ]
    decoded_labels = tokenizer.batch_decode(labels)
    decoded_predictions = tokenizer.batch_decode(predictions)

    print(decoded_predictions)

    return {
        'bleu': bleu_metric.compute(predictions=decoded_predictions, references=decoded_labels),
        'codebleu': codebleu.calc_codebleu(predictions=decoded_predictions, references=decoded_labels, lang="c_sharp")
    }



In [25]:
compute_metrics((tokenized_datasets['validation']['input_ids'], tokenized_datasets['validation']['labels']))

['<s>public DVRecord(RecordInputStream in) {_option_flags = in.readInt();_promptTitle = readUnicodeString(in);_errorTitle = readUnicodeString(in);_promptText = readUnicodeString(in);_errorText = readUnicodeString(in);int field_size_first_formula = in.readUShort();_not_used_1 = in.readShort();_formula1 = Formula.read(field_size</s>', '<s>public String toString() {return pattern();}\n</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '<s>public InsertInstanceRequest() {super("Ots", "2016-06-20", "InsertInstance", "ots");setMethod(MethodType.POST);}\n</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><

{'bleu': {'bleu': 1.0,
  'precisions': [1.0, 1.0, 1.0, 1.0],
  'brevity_penalty': 1.0,
  'length_ratio': 1.0,
  'translation_length': 106243,
  'reference_length': 106243},
 'codebleu': {'codebleu': 1.0,
  'ngram_match_score': 1.0,
  'weighted_ngram_match_score': 1.0,
  'syntax_match_score': 1.0,
  'dataflow_match_score': 1.0}}

In [26]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    "our_code_trans",
    evaluation_strategy="epoch",
    predict_with_generate=True,
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=50,
    save_total_limit=3,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    compute_metrics=compute_metrics,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)


In [27]:
# trainer.train("our_code_trans/checkpoint-10000")
trainer.train()

Epoch,Training Loss,Validation Loss,Bleu,Codebleu
1,4.7477,3.730758,"{'bleu': 0.06833016092145955, 'precisions': [0.892, 0.8822372881355932, 0.8653793103448276, 0.8321754385964912], 'brevity_penalty': 0.07875343892560092, 'length_ratio': 0.28237154447822443, 'translation_length': 30000, 'reference_length': 106243}","{'codebleu': 0.12228933206011333, 'ngram_match_score': 0, 'weighted_ngram_match_score': 0, 'syntax_match_score': 0.40698800566192356, 'dataflow_match_score': 0.08216932257852974}"




['<pad><s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '<pad><s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '<pad><s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '<pad><s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '<pad><s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '<pad><s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '<pad><s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '<pad><s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '<pad><s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '<pad><s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><p

Trainer is attempting to log a value of "{'bleu': 0.06833016092145955, 'precisions': [0.892, 0.8822372881355932, 0.8653793103448276, 0.8321754385964912], 'brevity_penalty': 0.07875343892560092, 'length_ratio': 0.28237154447822443, 'translation_length': 30000, 'reference_length': 106243}" of type <class 'dict'> for key "eval/bleu" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'codebleu': 0.12228933206011333, 'ngram_match_score': 0, 'weighted_ngram_match_score': 0, 'syntax_match_score': 0.40698800566192356, 'dataflow_match_score': 0.08216932257852974}" of type <class 'dict'> for key "eval/codebleu" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.


KeyboardInterrupt: 