### model_name = 'VPrashant/cypher-gen' Use for cypher query generaion from text

In [1]:
import torch
from datasets import load_dataset

from transformers import Trainer, TrainingArguments
from transformers import T5Tokenizer, T5ForConditionalGeneration

In [3]:
# Load model and tokenizer
model_name = "./results/"
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)

In [4]:
# Prepare data
data = load_dataset('csv', data_files={'train': 'train.csv'})
data = data['train'].train_test_split(test_size=0.2, seed=42)

In [5]:
def preprocess(batch):
    # Handle all examples in the batch
    input_ids = []
    attention_mask = []
    labels = []
    
    for prompt, query in zip(batch['prompt'], batch['query']):
        # Tokenize input and output
        inputs = tokenizer("translate to Cypher: " + prompt, truncation=True, padding='max_length', max_length=128)
        outputs = tokenizer(query, truncation=True, padding='max_length', max_length=128)
        
        # Collect tokenized inputs and labels
        input_ids.append(inputs['input_ids'])
        attention_mask.append(inputs['attention_mask'])
        labels.append(outputs['input_ids'])
    
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}


In [6]:
tokenized_data = data.map(preprocess, batched=True)
tokenized_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

Map:   0%|          | 0/519 [00:00<?, ? examples/s]

Map:   0%|          | 0/130 [00:00<?, ? examples/s]

In [7]:
# Fine-tune the model
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=30,
    weight_decay=0.01,
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data['train'],
    eval_dataset=tokenized_data['test'],
    tokenizer=tokenizer
)

trainer.train()


  trainer = Trainer(


  0%|          | 0/1950 [00:00<?, ?it/s]

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 0.00018672127043828368, 'eval_runtime': 0.5892, 'eval_samples_per_second': 220.644, 'eval_steps_per_second': 28.854, 'epoch': 1.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 0.00014999558334238827, 'eval_runtime': 0.5601, 'eval_samples_per_second': 232.111, 'eval_steps_per_second': 30.353, 'epoch': 2.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 0.00013283672160468996, 'eval_runtime': 0.5792, 'eval_samples_per_second': 224.459, 'eval_steps_per_second': 29.352, 'epoch': 3.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 5.6323195167351514e-05, 'eval_runtime': 0.5655, 'eval_samples_per_second': 229.898, 'eval_steps_per_second': 30.064, 'epoch': 4.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 5.9614012570818886e-05, 'eval_runtime': 0.5683, 'eval_samples_per_second': 228.737, 'eval_steps_per_second': 29.912, 'epoch': 5.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 4.901817374047823e-05, 'eval_runtime': 0.5678, 'eval_samples_per_second': 228.935, 'eval_steps_per_second': 29.938, 'epoch': 6.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 6.675178883597255e-05, 'eval_runtime': 0.5685, 'eval_samples_per_second': 228.689, 'eval_steps_per_second': 29.905, 'epoch': 7.0}
{'loss': 0.0018, 'grad_norm': 0.23049570620059967, 'learning_rate': 3.717948717948718e-05, 'epoch': 7.69}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 4.3689906306099147e-05, 'eval_runtime': 0.5601, 'eval_samples_per_second': 232.086, 'eval_steps_per_second': 30.35, 'epoch': 8.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 2.2842437829240225e-05, 'eval_runtime': 0.6045, 'eval_samples_per_second': 215.061, 'eval_steps_per_second': 28.123, 'epoch': 9.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 3.312152693979442e-05, 'eval_runtime': 0.5592, 'eval_samples_per_second': 232.457, 'eval_steps_per_second': 30.398, 'epoch': 10.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 3.030436346307397e-05, 'eval_runtime': 0.602, 'eval_samples_per_second': 215.956, 'eval_steps_per_second': 28.24, 'epoch': 11.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 4.2522278818069026e-05, 'eval_runtime': 0.5842, 'eval_samples_per_second': 222.532, 'eval_steps_per_second': 29.1, 'epoch': 12.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 3.1799620046513155e-05, 'eval_runtime': 0.5638, 'eval_samples_per_second': 230.583, 'eval_steps_per_second': 30.153, 'epoch': 13.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 2.3460257580154575e-05, 'eval_runtime': 0.5661, 'eval_samples_per_second': 229.648, 'eval_steps_per_second': 30.031, 'epoch': 14.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 2.2863800040795468e-05, 'eval_runtime': 0.5602, 'eval_samples_per_second': 232.069, 'eval_steps_per_second': 30.347, 'epoch': 15.0}
{'loss': 0.0006, 'grad_norm': 0.034679483622312546, 'learning_rate': 2.435897435897436e-05, 'epoch': 15.38}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 2.2811800590716302e-05, 'eval_runtime': 0.5606, 'eval_samples_per_second': 231.911, 'eval_steps_per_second': 30.327, 'epoch': 16.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 2.1358362573664635e-05, 'eval_runtime': 0.6002, 'eval_samples_per_second': 216.58, 'eval_steps_per_second': 28.322, 'epoch': 17.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 1.8943026589113288e-05, 'eval_runtime': 0.607, 'eval_samples_per_second': 214.166, 'eval_steps_per_second': 28.006, 'epoch': 18.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 1.862281897047069e-05, 'eval_runtime': 0.5472, 'eval_samples_per_second': 237.591, 'eval_steps_per_second': 31.07, 'epoch': 19.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 1.6805657651275396e-05, 'eval_runtime': 0.5611, 'eval_samples_per_second': 231.702, 'eval_steps_per_second': 30.299, 'epoch': 20.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 2.03176405193517e-05, 'eval_runtime': 0.554, 'eval_samples_per_second': 234.639, 'eval_steps_per_second': 30.684, 'epoch': 21.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 1.6855596186360344e-05, 'eval_runtime': 0.5951, 'eval_samples_per_second': 218.442, 'eval_steps_per_second': 28.565, 'epoch': 22.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 1.473939664720092e-05, 'eval_runtime': 0.5752, 'eval_samples_per_second': 225.99, 'eval_steps_per_second': 29.553, 'epoch': 23.0}
{'loss': 0.0004, 'grad_norm': 0.00394069030880928, 'learning_rate': 1.153846153846154e-05, 'epoch': 23.08}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 1.4615195141232107e-05, 'eval_runtime': 0.5547, 'eval_samples_per_second': 234.342, 'eval_steps_per_second': 30.645, 'epoch': 24.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 1.715413964120671e-05, 'eval_runtime': 0.5554, 'eval_samples_per_second': 234.049, 'eval_steps_per_second': 30.606, 'epoch': 25.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 1.7066902728402056e-05, 'eval_runtime': 0.5579, 'eval_samples_per_second': 233.03, 'eval_steps_per_second': 30.473, 'epoch': 26.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 1.86057932296535e-05, 'eval_runtime': 0.5581, 'eval_samples_per_second': 232.931, 'eval_steps_per_second': 30.46, 'epoch': 27.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 1.9802559108939022e-05, 'eval_runtime': 0.5575, 'eval_samples_per_second': 233.186, 'eval_steps_per_second': 30.494, 'epoch': 28.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 1.7332826246274635e-05, 'eval_runtime': 0.6075, 'eval_samples_per_second': 214.007, 'eval_steps_per_second': 27.985, 'epoch': 29.0}


  0%|          | 0/17 [00:00<?, ?it/s]

{'eval_loss': 1.6929872799664736e-05, 'eval_runtime': 0.5633, 'eval_samples_per_second': 230.772, 'eval_steps_per_second': 30.178, 'epoch': 30.0}
{'train_runtime': 281.1879, 'train_samples_per_second': 55.372, 'train_steps_per_second': 6.935, 'train_loss': 0.0008108815856468983, 'epoch': 30.0}


TrainOutput(global_step=1950, training_loss=0.0008108815856468983, metrics={'train_runtime': 281.1879, 'train_samples_per_second': 55.372, 'train_steps_per_second': 6.935, 'total_flos': 526817962229760.0, 'train_loss': 0.0008108815856468983, 'epoch': 30.0})

In [8]:
model.save_pretrained("./results/")
tokenizer.save_pretrained("./results/")

('./results/tokenizer_config.json',
 './results/special_tokens_map.json',
 './results/spiece.model',
 './results/added_tokens.json')

In [13]:
def test(test_input):
    # Load fine-tuned model and tokenizer
    model_name = './results/'  # Path to your fine-tuned model
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)

    # Set the model to evaluation mode
    model.eval()

    # Check if GPU is available and move model to GPU if possible
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Tokenize the input
    test_encoding = tokenizer(
        test_input,
        return_tensors="pt",
        max_length=128,
        truncation=True,
        padding="max_length"
    )
    test_encoding = {key: val.to(device) for key, val in test_encoding.items()}  # Move input to device

    # Generate Cypher query
    with torch.no_grad():  # Disable gradient calculations for inference
        output = model.generate(
            input_ids=test_encoding['input_ids'],
            attention_mask=test_encoding['attention_mask'],
            max_length=128
        )
    
    # Decode the generated query
    generated_query = tokenizer.decode(output[0], skip_special_tokens=True)

    # Print the generated Cypher query
    print("Input:", test_input)
    print("Generated Cypher Query:", generated_query)

# Call the test function
test_input = "give me a list of 10 genes"

test(test_input)


Input: give me a list of 10 genes
Generated Cypher Query: MATCH (g:Gene) RETURN g LIMIT 10
