In [1]:
import os
import xml.etree.ElementTree as ET
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTFeatureExtractor, ViTForImageClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AdamW
from tqdm import tqdm

In [2]:
vit_model_name = "google/vit-base-patch16-224-in21k"
vit_feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model_name)
vit_model = ViTForImageClassification.from_pretrained(vit_model_name)
bert_summary_model_name = "facebook/bart-large-cnn"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_summary_model_name)
bert_summary_model = AutoModelForSeq2SeqLM.from_pretrained(bert_summary_model_name)



model.safetensors:  39%|###9      | 136M/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

In [3]:
damage_categories = {
    "D00": "Longitudinal Crack",
    "D10": "Transverse Crack",
    "D20": "Alligator Crack",
    "D40": "Pothole",
    "D50": "Surface Cracks",
    "D60": "Erosion/wear"
}

In [11]:
class RoadDamageDataset(Dataset):
    def __init__(self, image_folder, annotation_folder, feature_extractor, tokenizer, max_length=256): #added max_length
        self.image_folder = image_folder
        self.annotation_folder = annotation_folder
        self.image_files = [f for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.max_length = max_length #Added max_length

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        image_path = os.path.join(self.image_folder, image_file)
        annotation_path = os.path.join(self.annotation_folder, os.path.splitext(image_file)[0] + ".xml")

        image = Image.open(image_path).convert("RGB")
        inputs = self.feature_extractor(images=image, return_tensors="pt")

        tree = ET.parse(annotation_path)
        root = tree.getroot()
        damage_details = []
        for obj in root.findall("object"):
            label = obj.find("name").text
            xmin = int(obj.find("bndbox/xmin").text)
            ymin = int(obj.find("bndbox/ymin").text)
            xmax = int(obj.find("bndbox/xmax").text)
            ymax = int(obj.find("bndbox/ymax").text)
            if label in damage_categories:
                damage_type = damage_categories[label]
            else:
                damage_type = "Unspecified Damage"

            area = (xmax - xmin) * (ymax - ymin)
            if area > 10000:
                severity = "High"
                priority = "Urgent"
                action = "Immediate repair suggested"
            elif area > 5000:
                severity = "Medium"
                priority = "High"
                action = "Repair recommended"
            else:
                severity = "Low"
                priority = "Moderate"
                action = "Monitor and schedule repair"

            damage_details.append({
                "type": damage_type,
                "severity": severity,
                "priority": priority,
                "action": action
            })

        summary_input_text = f"The image shows a road. "
        for detail in damage_details:
            summary_input_text += f"Type: {detail['type']}, Severity: {detail['severity']}, Priority: {detail['priority']}, Action: {detail['action']}. "

        summary_inputs = self.tokenizer([summary_input_text], return_tensors="pt", max_length=self.max_length, padding='max_length', truncation=True) #added padding and max_length
        return inputs['pixel_values'].squeeze(0), summary_inputs['input_ids'].squeeze(0), summary_inputs['attention_mask'].squeeze(0)

In [12]:
def fine_tune(vit_model, bert_summary_model, dataset, epochs=3, batch_size=4, learning_rate=1e-5, save_path_vit = "vit_finetuned.pth", save_path_bart = "bart_finetuned.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vit_model.to(device)
    bert_summary_model.to(device)
    vit_model.train()
    bert_summary_model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    vit_optimizer = AdamW(vit_model.parameters(), lr=learning_rate)
    bert_optimizer = AdamW(bert_summary_model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        for pixel_values, summary_input_ids, summary_attention_mask in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
            pixel_values = pixel_values.to(device)
            summary_input_ids = summary_input_ids.to(device)
            summary_attention_mask = summary_attention_mask.to(device)

            vit_outputs = vit_model(pixel_values, labels=vit_model(pixel_values).logits.argmax(dim=-1))
            vit_loss = vit_outputs.loss
            vit_optimizer.zero_grad()
            vit_loss.backward()
            vit_optimizer.step()

            bert_outputs = bert_summary_model(input_ids=summary_input_ids, attention_mask=summary_attention_mask, labels=summary_input_ids) # corrected line.
            bert_loss = bert_outputs.loss
            bert_optimizer.zero_grad()
            bert_loss.backward()
            bert_optimizer.step()

        print(f"Epoch {epoch + 1}/{epochs}, ViT Loss: {vit_loss.item()}, BART Loss: {bert_loss.item()}")

    torch.save(vit_model.state_dict(), save_path_vit)
    torch.save(bert_summary_model.state_dict(), save_path_bart)
    print("Fine-tuned models saved.")


In [13]:
def analyze_road_damage(image_path, vit_model, bert_summary_model):
    try:
        image = Image.open(image_path).convert("RGB")
    except FileNotFoundError:
        return f"Error: Image not found at {image_path}"

    inputs = vit_feature_extractor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = vit_model(inputs['pixel_values'].to(vit_model.device))
        logits = outputs.logits
        predicted_class_idx = logits.argmax(-1).item()
        predicted_class = vit_model.config.id2label[predicted_class_idx]

    summary_input_text = f"The image shows a road with {predicted_class}. "
    inputs = bert_tokenizer([summary_input_text], return_tensors="pt", max_length=1024, truncation=True)
    with torch.no_grad():
        summary_ids = bert_summary_model.generate(inputs["input_ids"].to(bert_summary_model.device), num_beams=4, max_length=256, early_stopping=True)
    summary = bert_tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    return summary

In [None]:
image_folder = "/Users/abhinayb/Downloads/India/train/images"  # Replace with your image folder path
annotation_folder = "/Users/abhinayb/Downloads/India/train/annotations/xmls"  # Replace with your annotation folder path

dataset = RoadDamageDataset(image_folder, annotation_folder, vit_feature_extractor, bert_tokenizer)
fine_tune(vit_model, bert_summary_model, dataset)

#Load the finetuned models.
vit_model.load_state_dict(torch.load("vit_finetuned.pth", map_location=torch.device('cpu')))
bert_summary_model.load_state_dict(torch.load("bart_finetuned.pth", map_location=torch.device('cpu')))

vit_model.eval()
bert_summary_model.eval()

image_path = "/content/India_000071.jpg" #Replace with your test image path.

analysis_result = analyze_road_damage(image_path, vit_model, bert_summary_model)
print(analysis_result)

Epoch 1/3:   0%|                                       | 0/1927 [00:00<?, ?it/s]