In [None]:
import os
import json
import torch
from torch import nn
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchvision import transforms
from PIL import Image
from transformers import BertTokenizer, BertModel

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("alvarobasily/road-damage")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/road-damage


In [None]:
def load_data(data_path):
  image_files = []
  text_files = []
  for filename in os.listdir(data_path):
          if filename.endswith(".jpg") or filename.endswith(".png") or filename.endswith("jpeg"):
              image_files.append(os.path.join(data_path, filename))
          elif filename.endswith(".txt"):
              text_files.append(os.path.join(data_path, filename))
  return image_files, text_files

In [None]:
#data_path = "/kaggle/input/road-damage"
#images, txt = load_data(data_path)

In [None]:
def extract_damage_info(txt_file_path):
    with open(txt_file_path, 'r') as file:
        line = file.readline().strip()
        parts = line.split()
        if len(parts) >= 5: # check if the file has the correct number of values.
            class_id = int(parts[0])
            coordinates = [float(val) for val in parts[1:]]
            return class_id, coordinates
        else:
          return None, None

#for i in range(5):
    #print(extract_damage_info(txt[i]))

In [None]:
def process_image(image_path):
    img = Image.open(image_path).convert("RGB")
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    img_tensor = preprocess(img).unsqueeze(0)
    return img_tensor

#processed_img = process_image(images[0])
#processed_img, processed_img.shape

In [None]:
class RoadDamageClassifier(nn.Module):
    def __init__(self, vit_model, num_classes, coordinate_dim=4, output_dim=256):
        super(RoadDamageClassifier, self).__init__()
        self.vit = vit_model
        self.vit.heads = nn.Identity()
        self.image_projection = nn.Linear(self.vit.hidden_dim, output_dim)
        self.coordinate_projection = nn.Linear(coordinate_dim, output_dim)
        self.combined_projection = nn.Linear(output_dim * 2, output_dim)
        self.classifier = nn.Linear(output_dim, num_classes)

    def forward(self, image, coordinates):
        image_features = self.vit(image)
        image_projected = self.image_projection(image_features)

        coordinate_projected = self.coordinate_projection(coordinates)
        combined_features = torch.cat((image_projected, coordinate_projected), dim=1)
        combined_projected = self.combined_projection(combined_features)

        class_output = self.classifier(combined_projected)
        return class_output

In [None]:
weights = ViT_B_16_Weights.DEFAULT
vit_model = vit_b_16(weights=weights)
vit_model.eval()

num_classes = 4
model = RoadDamageClassifier(vit_model, num_classes)
model.eval()

RoadDamageClassifier(
  (vit): VisionTransformer(
    (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (encoder): Encoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): Sequential(
        (encoder_layer_0): EncoderBlock(
          (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (0): Linear(in_features=768, out_features=3072, bias=True)
            (1): GELU(approximate='none')
            (2): Dropout(p=0.0, inplace=False)
            (3): Linear(in_features=3072, out_features=768, bias=True)
            (4): Dropout(p=0.0, inplace=False)
          )
        )
        (encoder_layer_1): EncoderBlock(
          (ln_

In [None]:
path = "/root/.cache/kagglehub/datasets/alvarobasily/road-damage/versions/1"
images, txt = load_data(path)

In [None]:
from torch.utils.data import Dataset

class RoadDamageDataset(Dataset):
    def __init__(self, image_files, text_files, transform=None):
        self.image_files = image_files
        self.text_files = text_files
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        txt_path = self.text_files[idx]

        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        class_id, coordinates = extract_damage_info(txt_path)
        coordinates = torch.tensor(coordinates, dtype=torch.float32)

        return image, coordinates, class_id

In [None]:
from torch.utils.data import DataLoader

dataset = RoadDamageDataset(images, txt, transform=transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]))

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

images, coordinates, labels = train_loader.__iter__().__next__()
images, coordinates, labels

(tensor([[[[-0.0458, -0.1486, -0.1314,  ...,  2.1975,  2.1975,  2.1975],
           [-0.0287, -0.1657, -0.2171,  ...,  2.1975,  2.1975,  2.1975],
           [-0.1314, -0.2342, -0.1143,  ...,  2.1975,  2.1975,  2.1975],
           ...,
           [ 0.2796,  0.2796,  0.2796,  ...,  0.5536,  0.5536,  0.5364],
           [ 0.3138,  0.3309,  0.3138,  ...,  0.5364,  0.5193,  0.5022],
           [ 0.2967,  0.2967,  0.2967,  ...,  0.5364,  0.5364,  0.5022]],
 
          [[ 0.3978,  0.4328,  0.4503,  ...,  2.4111,  2.4111,  2.4111],
           [ 0.4503,  0.4503,  0.4153,  ...,  2.4111,  2.4111,  2.4111],
           [ 0.3452,  0.3452,  0.5203,  ...,  2.4111,  2.4111,  2.4111],
           ...,
           [ 0.3803,  0.3627,  0.3627,  ...,  0.6779,  0.6604,  0.6604],
           [ 0.3978,  0.3803,  0.3452,  ...,  0.6604,  0.6604,  0.6429],
           [ 0.3803,  0.3627,  0.3627,  ...,  0.6604,  0.6604,  0.6429]],
 
          [[ 0.1825, -0.0441, -0.0964,  ...,  2.5703,  2.5703,  2.5703],
           [ 

In [None]:
import torch.optim as optim

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    for images, coordinates, labels in train_loader:
        images = images.to(torch.device("cpu"))
        coordinates = coordinates.to(torch.device("cpu"))
        labels = labels.to(torch.device("cpu"))

        outputs = model(images, coordinates)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Epoch: {epoch}, loss: {loss}")

Epoch: 0, loss: 1.3808602094650269
Epoch: 0, loss: 1.9254025220870972
Epoch: 0, loss: 1.481019377708435
Epoch: 0, loss: 1.3822617530822754
Epoch: 0, loss: 1.3485584259033203
Epoch: 0, loss: 1.2332817316055298
Epoch: 0, loss: 1.334702491760254
Epoch: 0, loss: 1.4345779418945312


In [None]:
path = "/kaggle/input/road-damage"
image_files, text_files = load_data(path)

if image_files and text_files:
    image_path = image_files[0]
    txt_path = text_files[0]

    image_tensor = process_image(image_path)
    class_id, coordinates = extract_damage_info(txt_path)

    if coordinates is not None:
        coordinates_tensor = torch.tensor(coordinates, dtype=torch.float32).unsqueeze(0)

        with torch.no_grad():
            class_output = model(image_tensor, coordinates_tensor)
            predicted_class = torch.argmax(class_output, dim=1).item()
            print(f"Predicted class: {predicted_class}")
    else:
        print(f"Error reading coordinates from {txt_path}")
else:
    print("No image or text files found in the dataset.")


Predicted class: 1
