# Fine-Tuning CodeT5 for R Code Generation

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from datasets import load_dataset, Dataset
import json
import torch
import pandas as pd
import random

In [None]:
with open('fixed_masked_training_data.json', 'r') as f:
    data = json.load(f)

df = pd.DataFrame(data)
df = df[['input', 'output']].dropna()
df = df.rename(columns={'input': 'source', 'output': 'target'})
dataset = Dataset.from_pandas(df)

In [None]:
model_checkpoint = "Salesforce/codet5-small"
tokenizer = T5Tokenizer.from_pretrained(model_checkpoint)
model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)

In [None]:
def preprocess(example):
    input_text = example['source']
    target_text = example['target']

    input_ids = tokenizer(
        input_text,
        padding="max_length",
        truncation=True,
        max_length=256
    )

    target_ids = tokenizer(
        target_text,
        padding="max_length",
        truncation=True,
        max_length=256
    )

    input_ids['labels'] = target_ids['input_ids']
    return input_ids

encoded_dataset = dataset.map(preprocess)

In [None]:
training_args = TrainingArguments(
    output_dir="./codet5-r-generation",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_steps=50,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_steps=100,
    save_total_limit=1,
    evaluation_strategy="no",
    fp16=torch.cuda.is_available()
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset,
    tokenizer=tokenizer
)

trainer.train()

In [None]:
model.save_pretrained("./codet5-r-generation")
tokenizer.save_pretrained("./codet5-r-generation")

In [None]:
def generate_r_code(prompt):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids
    outputs = model.generate(inputs, max_length=256)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

print(generate_r_code("____\n# create a plot"))