In [None]:
import torch
import pandas as pd
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import multiprocessing
from itertools import repeat
from tqdm import tqdm


In [None]:
def process_data(df_chunk, gpu_id, corrected_sentences, model_name, example_incorrect, example_correct):
    # Set the current device to the specific GPU
    torch.cuda.set_device(gpu_id)

    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    model = GPT2LMHeadModel.from_pretrained(model_name)
    model.to(f'cuda:{gpu_id}')
    model.eval()

    for _, row in tqdm(df_chunk.iterrows(), total=df_chunk.shape[0], desc=f'GPU {gpu_id}'):
        sentence_to_correct = row['Sentence']
        prompt = f"Incorrect: {example_incorrect} Correct: {example_correct} Incorrect: {sentence_to_correct} Correct:"
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'cuda:{gpu_id}')
        out = model.generate(
            input_ids,
            max_length=200,
            eos_token_id=tokenizer.eos_token_id,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            no_repeat_ngram_size=2
        )
        generated_text = tokenizer.decode(out[0], skip_special_tokens=True)
        print(generated_text)
        corrected_sentences.append(generated_text)

In [None]:
def main():
    model_name = "sberbank-ai/mGPT"
    example_incorrect = "In the UBC, which is a type of gene/protein, there is a noted ppi of the gene/protein DCAF1"
    example_correct = "In the UBC, a type of gene/protein, there is a noted protein-protein interaction (ppi) with the gene/protein DCAF1."

    df = pd.read_csv("/home/ubuntu/Project_Files/Finetune/Data/sentences.csv", low_memory=False)

    num_gpus = 4
    chunk_size = len(df) // num_gpus
    df_chunks = [df.iloc[i:i + chunk_size] for i in range(0, len(df), chunk_size)]

    manager = multiprocessing.Manager()
    corrected_sentences = manager.list()

    processes = []
    for i in range(num_gpus):
        p = multiprocessing.Process(target=process_data, args=(df_chunks[i], i, corrected_sentences, model_name, example_incorrect, example_correct))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

    corrected_sentences = list(corrected_sentences)

    corrected_df = pd.DataFrame(corrected_sentences, columns=['Corrected Sentence'])
    corrected_df.to_csv("corrected_sentences.csv", index=False)

if __name__ == '__main__':
    main()
