In [None]:
import cv2
import numpy as np
import random
import pandas as pd
import os
from datasets import load_dataset
import matplotlib.pyplot as plt
import json
import textwrap

from skimage import transform, img_as_ubyte
from concurrent.futures import ThreadPoolExecutor, as_completed,ProcessPoolExecutor
from skimage import transform
from skimage.util import img_as_ubyte
from PIL import Image, ImageDraw, ImageFont
from faker import Faker

import datasets
import huggingface_hub
from cr_renderer import CrelloV5Renderer
fonts_path = huggingface_hub.hf_hub_download(
    repo_id="cyberagent/crello",
    filename="resources/fonts.pickle",
    repo_type="dataset",
    revision="5.0.0",
)

from PIL import Image
from transformers import AutoTokenizer, AutoProcessor, AutoModelForImageTextToText

model_path = "nanonets/Nanonets-OCR2-3B"

model = AutoModelForImageTextToText.from_pretrained(
    model_path, 
    torch_dtype="auto", 
    device_map="auto", 
    # attn_implementation="flash_attention_2"
)
model.eval()


tokenizer = AutoTokenizer.from_pretrained(model_path)
processor = AutoProcessor.from_pretrained(model_path)



In [2]:

def mirror_distortion(text):
    
    with open("mirror_distortion.json", "r", encoding="utf-8") as f:
        mirror_data = json.load(f)
        
    number=random.randint(0,len(text)//2)
    random_index=random.sample(range(len(text)),k=number)
    text_copy=list(text)
    
    for idx in random_index:
        working_char=text[idx]
        mirror_type=['HORIZONTAL_MIRROR_MULTI','VERTICAL_MIRROR_MULTI','ROTATION_180_MULTI']

        attempt=0
        while(attempt<10):
            selected_type=random.choice(mirror_type)
            if(working_char not in mirror_data[selected_type]):
                attempt+=1
                continue

            replacement=random.choice(mirror_data[selected_type][working_char])
            print("Replacing ",working_char," with ",replacement)
            text_copy[idx]=replacement
            attempt+=1
            break

    text_copy=''.join(text_copy)
    return text_copy

def char_level_repetition_distortion(text: str, max_repeats: int = 2):

    num_positions = random.randint(1, min(4, len(text)))  # up to 4 random spots
    random_indices = random.sample(range(len(text)), k=num_positions)

    distorted = ""
    for i, ch in enumerate(text):
        distorted += ch
        if i in random_indices:
            repeat_count = random.randint(1, max_repeats)
            distorted += ch * repeat_count

    return distorted

def char_level_drop_distortion(text: str, max_drops: int = 3):
    
    num_drops = random.randint(1, min(max_drops, len(text) // 2))
    drop_indices = set(random.sample(range(len(text)), k=num_drops))

    distorted = "".join(ch for i, ch in enumerate(text) if i not in drop_indices)
    return distorted

def adjacent_char_swap_distortion(text: str, max_swaps: int = 2):
    """Swaps two adjacent alphanumeric characters at random positions."""
    text_list = list(text)
    if len(text_list) < 2:
        return text

    # Find valid indices to swap (don't want to swap a letter with a space or punctuation)
    valid_indices = [
        i for i in range(len(text_list) - 1)
        if text_list[i].isalnum() and text_list[i+1].isalnum()
    ]
    
    if not valid_indices:
        return text

    num_swaps = random.randint(1, min(max_swaps, len(valid_indices)))
    swap_indices = random.sample(valid_indices, k=num_swaps)
    
    for idx in swap_indices:
        # Check again in case a previous swap invalidated this one (unlikely but safe)
        if text_list[idx].isalnum() and text_list[idx+1].isalnum():
            text_list[idx], text_list[idx+1] = text_list[idx+1], text_list[idx]
            
    return "".join(text_list)

def same_char_distortion(text):
    # print("Input: ",text)
    with open("same_char.json", "r", encoding="utf-8") as f:
        same_char_data = json.load(f)
    
    number=random.randint(0,len(text)//2)
    random_index=random.sample(range(len(text)),k=number)
    
    text_copy=list(text)                            
    for idx in random_index:
    
        working_char=text[idx]
        
        if(working_char in same_char_data):
            replacement=random.choice(same_char_data[working_char])
            print("Replacing ",working_char," with ",replacement)
            text_copy[idx]=replacement
            break
    
    text_copy=''.join(text_copy)
    # print("Output: ",text_copy)
    return text_copy

def case_shuffle_distortion(text):
    """Randomly shuffles case of all characters"""
    distorted = ""
    for ch in text:
        if ch.isalpha():
            distorted += ch.upper() if random.random() < 0.5 else ch.lower()
        else:
            distorted += ch
    return distorted

def noise_injection_distortion(text, max_noise: int = 5):
    """Injects random noise characters at random positions"""
    noise_chars = ['·', '˙', '`', '´', '¨', '˚', '°']
    
    text_list = list(text)
    num_noise = random.randint(1, min(max_noise, len(text)))
    
    for _ in range(num_noise):
        idx = random.randint(0, len(text_list))
        text_list.insert(idx, random.choice(noise_chars))
    
    return ''.join(text_list)

def ocr_confusion_distortion(text: str, max_confusions: int = 2):
    ocr_pairs = {
        # Multi-character to single character confusions
        'rn': 'm', 'nn': 'u', 'vv': 'w', 'uu': 'w', 'ii': 'u',
        'cl': 'd', 'li': 'h', 'Il': 'H', 'ln': 'h', 'rr': 'n',
        'iii': 'm', 'ri': 'n', 'RN': 'M', 'VV': 'W', 'UU': 'W',
        'tt': 'H', 'IVI': 'M', 'AI': 'N', 'NN': 'M', 'AA': 'M',
        
        # Single character to multi-character confusions
        'm': 'rn', 'w': 'vv', 'u': 'ii', 'n': 'ri', 'h': 'li',
        'M': 'RN', 'W': 'VV', 'H': 'tt', 'N': 'AI',
        
        # Number/letter confusions (multi-char patterns)
        '0O': 'OO', 'O0': '00', 'l1': '11', '1l': 'll', 
        'I1': '11', '1I': 'II', 'S5': '55', '5S': 'SS',
        'B8': '88', '8B': 'BB', 'G6': '66', '6G': 'GG',
        
        # Common word-specific OCR errors
        'tlie': 'the', 'tbe': 'the', 'tiie': 'the', 'thc': 'the',
        'aud': 'and', 'arid': 'and', 'aiid': 'and', 'ancl': 'and',
        'of': 'ol', 'ot': 'of', 'ol': 'of', 'for': 'lor', 'tor': 'for',
        'Mr': 'Nfr', 'Mrs': 'Nfrs', 'Mr.': 'Mr,', 'Mrs.': 'Mrs,',
        'was': 'vvas', 'will': 'vvill', 'with': 'witli', 'from': 'frorn',
        'that': 'tliat', 'this': 'tliis', 'which': 'wliich', 'when': 'wlien',
        'been': 'beeu', 'have': 'liave', 'said': 'sald', 'upon': 'upou',
        
        # Long s (historical OCR)
        'fs': 'ss', 'fl': 'fi', 'fi': 'fl', 'fh': 'sh', 'ft': 'st',
        
        # Punctuation and special character confusions
        ').': ')', '.)': '.)', ',)': ').', ',.': ',', '.,': '.,',
        ';"': ';', '":': ':', "';": "'", ".'": ".'",
        
        # Common prefix/suffix errors
        'tlie': 'the', 'witli': 'with', 'wliich': 'which', 'tliey': 'they',
        'tliis': 'this', 'liere': 'here', 'wlien': 'when', 'wliat': 'what',
        'tbing': 'thing', 'tbat': 'that', 'tion': 'lion', 'sion': 'siou',
        
        # Double letter confusions
        'ff': 'fi', 'fi': 'ff', 'tt': 'H', 'il': 'II', 'oo': 'œ',
        
        # Capitalization OCR errors
        'Tlie': 'The', 'Tbe': 'The', 'Wlien': 'When', 'Witli': 'With',
        'Wliich': 'Which', 'Frorn': 'From', 'Tliis': 'This', 'Tliat': 'That'
    }
    
    # Attempt replacements
    for _ in range(max_confusions):
        for pattern, replacement in ocr_pairs.items():
            if pattern in text and random.random() < 0.3:
                positions = [i for i in range(len(text) - len(pattern) + 1) 
                           if text[i:i+len(pattern)] == pattern]
                if positions:
                    idx = random.choice(positions)
                    text = text[:idx] + replacement + text[idx+len(pattern):]
                    break
    
    return text


def subscript_superscript_distortion(text: str, max_conversions: int = 2):
    superscripts = {'0': '⁰', '1': '¹', '2': '²', '3': '³', '4': '⁴',
                   '5': '⁵', '6': '⁶', '7': '⁷', '8': '⁸', '9': '⁹',
                   'a': 'ᵃ', 'b': 'ᵇ', 'c': 'ᶜ', 'd': 'ᵈ', 'e': 'ᵉ'}
    
    subscripts = {'0': '₀', '1': '₁', '2': '₂', '3': '₃', '4': '₄',
                 '5': '₅', '6': '₆', '7': '₇', '8': '₈', '9': '₉'}
    
    conversion_map = {**superscripts, **subscripts}
    
    text_list = list(text)
    valid_indices = [i for i, ch in enumerate(text) if ch in conversion_map]
    
    if not valid_indices:
        return text
    
    num_conversions = random.randint(1, min(max_conversions, len(valid_indices)))
    conversion_indices = random.sample(valid_indices, k=num_conversions)
    
    for idx in conversion_indices:
        text_list[idx] = conversion_map[text[idx]]
    
    return ''.join(text_list)


def zalgo_distortion(text: str, max_intensity: int = 3, max_chars: int = 5):
    
    if(len(text)<=8):
        return text
    """Adds stacking 'combining' diacritic marks to random characters."""
    # A selection of combining marks
    DIACRITICS = [
        # Above
        '\u0300', '\u0301', '\u0302', '\u0303', '\u0304', '\u0305', '\u0306', '\u0307', 
        '\u0308', '\u030A', '\u030B', '\u030C', '\u030D', '\u030E', '\u030F', '\u0310', 
        '\u0311',
        # Middle (includes your strikethrough)
        '\u0334', '\u0335', '\u0336', '\u0337', '\u0338',
        # Below
        '\u0316', '\u0317', '\u0318', '\u0319', '\u031A', '\u031B', '\u031C', '\u031D',
        '\u031E', '\u031F', '\u0320', '\u0321', '\u0322', '\u0323', '\u0324', '\u0325',
        '\u0326', '\u0327', '\u0328', '\u0329', '\u032A'
    ]
    
    text_list = list(text)
    
    # Find non-space characters to distort
    valid_indices = [i for i, char in enumerate(text) if not char.isspace()]
    if not valid_indices:
        return text

    num_chars_to_distort = random.randint(1, min(max_chars, len(valid_indices)))
    distort_indices = random.sample(valid_indices, k=num_chars_to_distort)
    
    for idx in sorted(distort_indices, reverse=True):
        num_diacritics = random.randint(1, max_intensity)
        for _ in range(num_diacritics):
            text_list.insert(idx + 1, random.choice(DIACRITICS))
    
    return "".join(text_list)

def render_faker_text_on_image(image_path, num_texts=5, font_path=None):
    # img = Image.open(image_path)
    # Initialize Faker
    fake = Faker()
    img=image_path.copy()
    draw = ImageDraw.Draw(img)
    width, height = img.size

    if font_path:
        font = ImageFont.truetype(font_path, 12)
    else:
        font = ImageFont.load_default()

    for _ in range(num_texts):
        text = fake.sentence()  # generate meaningful sentence dynamically
        
        x = random.randint(0, max(0, width - 100))
        y = random.randint(0, max(0, height - 20))
        color = tuple(random.randint(0,100) for _ in range(3))
        draw.text((x, y), text, fill=color, font=font)

    
    return img


def distort_text(example,distortion_list):
    text=example['text']
    
    
    
    for ind in range(len(text)):
        working_text=text[ind]
        if(working_text==''):
            continue
        else:
            for distortion in distortion_list:
                if(distortion=='repetition'):
                    working_text=char_level_repetition_distortion(working_text)
                elif(distortion=='drop'):
                    working_text=char_level_drop_distortion(working_text)
                elif(distortion=='mirror'):
                    working_text=mirror_distortion(working_text)
                elif(distortion=='same_char'):
                    working_text=same_char_distortion(working_text)
                elif(distortion=='case_shuffle'):
                    working_text=case_shuffle_distortion(working_text)
                elif(distortion=='noise_injection'):
                    working_text=noise_injection_distortion(working_text)
                elif(distortion=='adjacent_char_swap'):
                    working_text=adjacent_char_swap_distortion(working_text)
                elif(distortion=='zalgo'):
                    working_text=zalgo_distortion(working_text)
                elif(distortion=='ocr_confusion'):
                    working_text=ocr_confusion_distortion(working_text)
                elif(distortion=='subscript_superscript'):
                    working_text=subscript_superscript_distortion(working_text)
                    
        
        text[ind]=working_text
    example['text']=text
    return example

def distort_image(img,distortion_type):
    if(distortion_type=='faker_text'):
        distorted_img=render_faker_text_on_image(img, "temp_output.jpg", num_texts=5)
    return distorted_img


In [None]:

def score(image, input_text, max_new_tokens=100):
    
    prompt = """Extract the text from the provided image. Remember dont print any extra text just return the text rendered on the image. Also try to ignore the lines or borders used for just styling"""
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": prompt},
        ]},
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt")
    inputs = inputs.to(model.device)
    
    output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
    generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
    
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    
    import Levenshtein
    val=Levenshtein.distance(output_text[0],input_text)
    return val



In [None]:

if __name__ == "__main__":
    
    # # Pulkit- Just change the start and end values
    start=0
    end=10
    type_of_process=1 
    min_val=3
    max_val=11
    total_sample=150
    batch_size=16

    output_dir = "dataset"
    final_csv_path=os.path.join(output_dir,f"final_dataset_{start}_{end}")
    final_json_path=os.path.join(output_dir,f"scores_data_{start}_{end}")
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
        os.mkdir(os.path.join(output_dir, "win"))
        for n in range(1, batch_size+1):  
            os.mkdir(os.path.join(output_dir, f"lose{n}"))


    # Load the dataset from Hugging Face
    dataset = load_dataset("data-is-better-together/open-image-preferences-v1-binarized")
    renderer = CrelloV5Renderer(dataset.features, fonts_path)

    available_distortions_text_only = [
        char_level_drop_distortion,
        char_level_repetition_distortion,
        adjacent_char_swap_distortion,
        case_shuffle_distortion,
        noise_injection_distortion,
        ocr_confusion_distortion,
        subscript_superscript_distortion,
        zalgo_distortion,
        mirror_distortion,
        same_char_distortion,
    ]
    
    available_distortions_image_only = [
        render_faker_text_on_image,
    ]
    
    available_distortion=available_distortions_text_only + available_distortions_image_only

    lose_cols = [f"lose_image{i}" for i in range(1, batch_size+1)]
    final_dataset = pd.DataFrame(columns=["prompt", "win_image"] + lose_cols)

    json_dict_for_scores=[]
    
    for i in range(start,end):

        temp_dict={}
        if i % 100 == 0:
            print(f"Processing row {i}")
            
        example = dataset['train'][i]

        try:
            
            
            win_image = renderer.render(example)
            if win_image is None:
                continue

            distorted_images = []

            # Generate 100 distorted images
            import time
            start=time.time()
            

            def generate_distorted(ind):
               
                num_ops = random.randint(3, len(available_distortion))

                funcs = random.choices(available_distortion, k=num_ops)
                flag=0
                
                for i in funcs:
                    if i in available_distortions_image_only:
                        flag=1
                        
                distorted = distort_text(funcs, example)
                
                if flag==1:
                    distorted_img=renderer.render(distorted)
                    distorted=distort_image(distorted_img,'faker_text')
                    
                if distorted is None:
                    return None
                return distorted

            with ProcessPoolExecutor(max_workers=50) as executor:
                futures = [executor.submit(generate_distorted,ind) for ind in range(total_sample)]

                for future in as_completed(futures):
                    distorted = future.result()
                    if distorted is None:
                        continue
                    time_start = time.time()
                    score = score(distorted, example['text'])   #---Need to chnage this thing
                    distorted_images.append((distorted, score))
                    # print(score)
            
            if i % 200 == 0:  # Note: Use ==, not just `if i % 200`
                save_folder_path=os.path.join(output_dir, f"ckpt_{i}")
                os.makedirs(save_folder_path, exist_ok=True)
                for idx, (img, _) in enumerate(distorted_images):
                    save_path=os.path.join(save_folder_path,f'{idx}.png')
                    cv2.imwrite(save_path, img)


            print(f"Distortion time: {time.time() - start:.2f} seconds")
            def max_variation_dp(data, k):
                from functools import lru_cache

                data = sorted(data, key=lambda x: x[1])
                values = [val for val in data]
                scores = [val[1] for val in data]
                N = len(data)

                # Use indices instead of actual score values to make caching effective
                @lru_cache(maxsize=None)
                def dp(pos, rem, last_idx):
                    if rem == 0:
                        return 0, []
                    if pos == N:
                        return float("-inf"), []

                    # Option 1: Take current element
                    take_score = abs(scores[pos] - scores[last_idx]) if last_idx != -1 else 0
                    take_sum, take_list = dp(pos + 1, rem - 1, pos)
                    take_sum += take_score

                    # Option 2: Skip current element
                    skip_sum, skip_list = dp(pos + 1, rem, last_idx)

                    if take_sum > skip_sum:
                        return take_sum, [values[pos]] + take_list
                    else:
                        return skip_sum, skip_list

                _, best_subset = dp(0, k, -1)
                return best_subset


            distorted_images = sorted(distorted_images, key=lambda x: x[1])
            distorted_images1= distorted_images[:len(distorted_images)//3]
            distorted_images2= distorted_images[len(distorted_images)//3:len(distorted_images)//2]
            distorted_images3= distorted_images[len(distorted_images)//2:]
            sample_from_each_bucket=batch_size//3
            best_subset = max_variation_dp(distorted_images1, k=sample_from_each_bucket)
            best_subset += max_variation_dp(distorted_images2, k=sample_from_each_bucket)
            best_subset+=max_variation_dp(distorted_images3, k=batch_size-2*sample_from_each_bucket)

            best_subset=sorted(best_subset, key=lambda x: x[1])
            if len(best_subset) < batch_size:
                continue

            # Save win image
            win_path = os.path.join(output_dir, "win", f"{i}.png")
            if win_image.dtype != np.uint8:
                win_image = np.clip(win_image, 0, 255).astype(np.uint8)
            cv2.imwrite(win_path, win_image)

            import matplotlib.pyplot as plt

            lose_paths = []
            for j, (img, score_val) in enumerate(best_subset):
                path = os.path.join(output_dir, f"lose{j+1}", f"{i}.png")
                cv2.imwrite(path, img)
                lose_paths.append(path)
                temp_dict[j+1]=score_val

            win_image_score= score(win_image, example['text'])
            temp_dict['win_image_score']=win_image_score

            #Pulkit- Comment this line to avoid visualization
            # visualize_generated_dataset(best_subset, win_image, prompt,win_image_score)

            # Append to final dataset
            data_row = {"prompt": example['prompt'], "win_image": win_path}
            for k in range(batch_size):
                data_row[f"lose_image{k+1}"] = lose_paths[k]

            final_dataset = pd.concat([final_dataset, pd.DataFrame([data_row])], ignore_index=True)

        except Exception as e:
            print(f"Error on row {i}: {e}")
            continue

        json_dict_for_scores.append({i: temp_dict})
        # if i%100==0:
        with open(final_json_path,'w') as f:
            json.dump(json_dict_for_scores, f, indent=4)
    # Save full dataset
    final_dataset.to_csv(final_csv_path, index=False)

   
        
