In [None]:
!pip install transformers torch accelerate bitsandbytes jiwer datasets peft loralib tqdm pytesseract
!apt-get install tesseract-ocr
!apt-get install tesseract-ocr-eng

In [25]:
import torch
from transformers import ( LlamaForCausalLM, LlamaTokenizer, AutoProcessor, AutoModelForVision2Seq, TrainingArguments, Trainer, BitsAndBytesConfig )
from datasets import Dataset, load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import jiwer
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import json
from glob import glob
import random
import pytesseract

In [26]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
HF_TOKEN = user_secrets.get_secret("HF_TOKEN")

from huggingface_hub import login

login(token=HF_TOKEN)

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda




# Load Images from the folder

In [28]:
def load_images_from_folder(folder):
    images = []
    image_names = []
    for filename in os.listdir(folder):
        if filename.endswith(".jpg"):
            img = Image.open(os.path.join(folder, filename)).convert("RGB")
            images.append(img)
            image_names.append(filename)
    return images, image_names

image_folder = "/kaggle/input/dataset/images"
dataset, image_names = load_images_from_folder(image_folder)
print("Dataset loaded")

Dataset loaded


# Generate True text from the images using OCR
takes about 12-15 min to generate text

In [29]:
def generate_ground_truth(images, image_names):
    ground_truth = {}
    for img, name in zip(images, image_names):
        text = pytesseract.image_to_string(img).strip()
        if not text:  # If OCR fails to extract text, use a placeholder
            text = "N/A"
        ground_truth[name] = text
    return ground_truth

ground_truth_data = generate_ground_truth(dataset, image_names)
true_texts = [ground_truth_data[img_name] for img_name in image_names]
print("true_texts generated")



true_texts generated


In [None]:
ground_truth_data['india_news_p000060.jpg']

# Creating and Loading the model "Llama-3.2-11B-Vision"
Loaded a 4-bit quantized llama-3.2-11B-vision model. *Note: Loading the model can take upto 10 minutes.*

In [6]:
def load_model():
    model_name = "meta-llama/Llama-3.2-11B-Vision"

    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",  # normalized float 4
        bnb_4bit_use_double_quant=True
    )

    model = AutoModelForVision2Seq.from_pretrained(
        model_name,
        quantization_config=quantization_config,
        device_map="auto",
        torch_dtype=torch.float16
    )

    processor = AutoProcessor.from_pretrained(model_name)

    return model, processor

In [7]:
model, processor = load_model()

config.json:   0%|          | 0.00/5.03k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/89.4k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/5 [00:00<?, ?it/s]

model-00001-of-00005.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00005.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00003-of-00005.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00005.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00005-of-00005.safetensors:   0%|          | 0.00/1.47G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/152 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/437 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.7k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

# Extract baseline texts using the model
Took about 5 hours to extract text using the model.

In [None]:
def extract_text(images,model,processor):
    texts = []
    for img in tqdm(images):
        inputs = processor(images=img, return_tensors="pt").to(device)
        outputs = model.generate(**inputs,max_length=512,num_beams=5,early_stopping=True)
        text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
        texts.append(text)
    return texts

baseline_texts = extract_text(dataset,model,processor)

# Text Organization
Cleaning the texts using regular expression for removing '\n' and some unnecessary characters

In [None]:
import re
import nltk
from nltk.tokenize import sent_tokenize

nltk.download('punkt_tab')

def clean_text(text):
    text = text.strip()  # Remove leading and trailing spaces
    text = re.sub(r'\s+', ' ', text)  # Replace multiple spaces with a single space
    text = re.sub(r'[^\x00-\x7F]+', ' ', text)  # Remove non-ASCII characters
    text = re.sub(r'\n+', ' ', text)  # Remove excessive newlines
    text = text.replace("  ", " ")  # Remove double spaces
    return text

def structure_text(text):
    sentences = sent_tokenize(text)  # Tokenize into sentences
    structured_text = "\n".join(sentences)  # Join sentences with newline
    return structured_text

def process_extracted_text(extracted_texts):
    organized_texts = []
    for text in extracted_texts:
        cleaned = clean_text(text)
        structured = structure_text(cleaned)
        organized_texts.append(structured)
    
    return organized_texts

organized_texts = process_extracted_text(baseline_texts)

# Evaluating texts i.e. calculating words error and character error

In [None]:
from jiwer import wer, cer
def evaluate_texts(true_texts, predicted_texts):
    word_error = wer(true_texts, predicted_texts)
    char_error = cer(true_texts, predicted_texts)
    return word_error, char_error

word_error, char_error = evaluate_texts(true_texts, organized_texts)
print(f"Word Error Rate: {word_error}")
print(f"Character Error Rate: {char_error}")

# Fine tune the model using LoRA (Low Rank Adaptation)

In [52]:
from torch.utils.data import Dataset

class ImageTextDataset(Dataset):
    def __init__(self, images, texts, processor):
        self.images = images
        self.texts = texts
        self.processor = processor

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        text = self.texts[idx]
        
        # Format prompt with the correct image token format for LLaMA Vision
        prompt = "<image>\nExtract and describe the text in this image.\n</image>\nText in the image:"
        
        # First process the image
        image_inputs = self.processor.image_processor(images=image, return_tensors="pt")
        
        # Then process the text with the correct token
        text_inputs = self.processor.tokenizer(
            prompt,
            return_tensors="pt",
            add_special_tokens=True,
            padding="max_length",
            max_length=512,
            truncation=True
        )
        
        # Combine inputs
        inputs = {
            "pixel_values": image_inputs.pixel_values[0],
            "input_ids": text_inputs.input_ids[0],
            "attention_mask": text_inputs.attention_mask[0]
        }
        
        # Add labels for training
        label_inputs = self.processor.tokenizer(
            text,
            return_tensors="pt",
            padding="max_length",
            max_length=512,
            truncation=True
        )
        
        inputs["labels"] = label_inputs.input_ids[0]
        
        return inputs


In [53]:
from sklearn.model_selection import train_test_split

train_images, val_images, train_texts, val_texts = train_test_split(
    dataset, true_texts, test_size=0.2, random_state=42
)

In [55]:
train_dataset = ImageTextDataset(train_images, train_texts, processor)
val_dataset = ImageTextDataset(val_images, val_texts, processor)

In [56]:
def collate_fn(examples):
    batch = {
        "pixel_values": torch.stack([example["pixel_values"] for example in examples]),
        "input_ids": torch.stack([example["input_ids"] for example in examples]),
        "attention_mask": torch.stack([example["attention_mask"] for example in examples]),
        "labels": torch.stack([example["labels"] for example in examples])
    }
    return batch

In [57]:
lora_config = LoraConfig(
    r=16,  # rank
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = prepare_model_for_kbit_training(model)

# Add LoRA adapters
model = get_peft_model(model, lora_config)

In [58]:
training_args = TrainingArguments(
    output_dir="./llama-vision-finetuned",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    learning_rate=2e-4,
    weight_decay=0.01,
    warmup_steps=100,
    logging_steps=10,
    fp16=True,
    report_to="none"
)

In [59]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn
)

In [None]:
trainer.train()