# FathomNet

Training PROB on custom data.

In [None]:
# # run this to create the necessary folder
# import os 

# data_dir = "./data/OWOD/"
# folders = ['JPEGImages', 'Annotations', 'ImageSets']

# for folder in folders:
#     try:
#         os.makedirs(os.path.join(data_dir, folder, data_name))
#     except OSError as e:
#         print(f"Can't create folder: {str(e)}")

Let's say you want to add a dataset "DATASET_A". Then you need to:

1. In `./datasets/torchvision_datasets/open_world.py` line 120, add to the dictionary VOC_COCO_CLASS_NAMES a key-value pair: VOC_COCO_CLASS_NAMES["DATASET_A"]=["a","b","c",...]
2. Store DATASET_A's images under "data/OWOD/JPEGImages/"
3. Store DATASET_A's Annotations under "data/OWOD/Annotations/"
4. Store DATASET_A's ImageSets files under "data/OWOD/ImageSets/DATASET_A/"
5. When you train, the input --dataset should be set to DATASET_A (e.g., --dataset DATASET_A)

In [None]:
import itertools

UNK_CLASS = ["unknown"]

VOC_COCO_CLASS_NAMES = {}

T1_CLASS_NAMES = [
    'Urchin', 'Fish', 'Sea star', 'Anemone', 'Sea cucumber', 
    'Sea pen', 'Sea fan', 'Worm', 'Crab', 'Gastropod'
]

T2_CLASS_NAMES = [
    'Shrimp', 'Soft coral'
]

T3_CLASS_NAMES = [
    'Glass sponge', 'Feather star'
]

T4_CLASS_NAMES = [
    'Eel', 'Squat lobster', 'Barnacle', 'Stony coral', 'Black coral', 'Sea spider'
]

VOC_COCO_CLASS_NAMES["fathomnet"] = tuple(itertools.chain(T1_CLASS_NAMES, T2_CLASS_NAMES, T3_CLASS_NAMES, T4_CLASS_NAMES, UNK_CLASS))


In [None]:
VOC_COCO_CLASS_NAMES

Other files to change:

- `configs/M_OWOD_BENCHMARK.sh` update all paths to point to the correct ImageSet files
- `run.sh` update the number of GPUs you have in your machine

Sort out WANDB:
- change entity (aka wandb username) in lines 165 and 167 in the file `/main_open_world.py`

In [None]:
import wandb

# confirm login
wandb.login()

There are issues with the file names, they have to be numbers like the VOC dataset for some reason. Quicker to do this than change the code.
- Annotations and JPEGImages need a new file name,
- File paths need to be updated inside all annotation files, and 
- all txt files inside ImageSets need to be updated to match.

If it all goes wrong and all files need to be copied again do it in bash in the `./data/processed` folder `cp -frp VOC-backup -T VOC-test` make sure the taget folder name doesn't exist already ([stackoverflow](https://stackoverflow.com/questions/33343840/bash-duplicate-rename-folder)).

In `datasets/torchvision_datasets/open_world.py` lines 193 and 198, change .jpg to .png as our images are pngs.

## Set up

### Compiling CUDA operators

In [None]:
# %cd models
# !wget https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth
# %cd ops
# !sh ./make.sh
# %cd ../..

## Inference

Orr's [reply](https://github.com/orrzohar/PROB/issues/34).

Notebooks for ref: 
- [Objective: fine-tuning DETR](https://github.com/woctezuma/finetune-detr/blob/master/finetune_detr.ipynb)
- [Object Detection with DETR - a minimal implementation](https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_demo.ipynb#scrollTo=kqe_0nc5dyAq)

Questions:
- how do i load the model from the pre-trained weights?

In [None]:
# from models.prob_deformable_detr import build
from models.deformable_detr import build
from models import build_model
import numpy as np
import torch

In [None]:
import numpy as np

class Args:
    lr = 2e-4
    lr_backbone_names = ["backbone.0"]
    lr_backbone = 2e-5
    lr_linear_proj_names = ['reference_points', 'sampling_offsets']
    lr_linear_proj_mult = 0.1
    batch_size = 5
    weight_decay = 1e-4
    epochs = 51
    lr_drop = 35
    lr_drop_epochs = None
    clip_max_norm = 0.1
    sgd = False
    with_box_refine = False
    two_stage = False
    masks = False
    backbone = 'dino_resnet50'
    frozen_weights = None
    dilation = False
    position_embedding = 'sine'
    position_embedding_scale = 2 * np.pi
    num_feature_levels = 4
    enc_layers = 6
    dec_layers = 6
    dim_feedforward = 1024
    hidden_dim = 256
    dropout = 0.1
    nheads = 8
    num_queries = 100
    dec_n_points = 4
    enc_n_points = 4
    aux_loss = True
    set_cost_class = 2
    set_cost_bbox = 5
    set_cost_giou = 2
    cls_loss_coef = 2
    bbox_loss_coef = 5
    giou_loss_coef = 2
    focal_alpha = 0.25
    coco_panoptic_path = None
    remove_difficult = False
    output_dir = ''
    device = 'cuda'
    seed = 42
    resume = './exps/MOWODB/PROB/t1/checkpoint0040.pth'
    start_epoch = 0
    eval = False
    viz = False
    eval_every = 5
    num_workers = 3
    cache_mode = False
    PREV_INTRODUCED_CLS = 0
    CUR_INTRODUCED_CLS = 10
    unmatched_boxes = False
    top_unk = 5
    featdim = 1024
    invalid_cls_logits = False
    NC_branch = False
    bbox_thresh = 0.3
    pretrain = './exps/MOWODB/PROB/t1/checkpoint0040.pth'
    nc_loss_coef = 2
    train_set = 'task1_train'
    test_set = 'all_eval'
    num_classes = 21
    nc_epoch = 0
    dataset = 'fathomnet'
    data_root = '/home/sabrina/code/PROB/data/OWOD'
    unk_conf_w = 1.0
    model_type = 'prob'
    wandb_name = ''
    wandb_project = 'fathomnet'
    obj_loss_coef = 1
    obj_temp = 1
    freeze_prob_model = False
    num_inst_per_class = 50
    exemplar_replay_selection = False
    exemplar_replay_max_length = 1e10
    exemplar_replay_dir = ''
    exemplar_replay_prev_file = ''
    exemplar_replay_cur_file = ''
    exemplar_replay_random = False

args = Args()

In [None]:
import torch
from models import build_model

model, criterion, postprocessors, exemplar_selection = build_model(args, mode=args.model_type)

In [None]:
device = torch.device(args.device)

model.to(device)

model_without_ddp = model

def match_name_keywords(n, name_keywords):
    out = False
    for b in name_keywords:
        if b in n:
            out = True
            break
    return out

param_dicts = [
    {
        "params":
            [p for n, p in model_without_ddp.named_parameters()
                if not match_name_keywords(n, args.lr_backbone_names) and not match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
        "lr": args.lr,
    },
    {
        "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad],
        "lr": args.lr_backbone,
    },
    {
        "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
        "lr": args.lr * args.lr_linear_proj_mult,
    }
]

if args.sgd:
    optimizer = torch.optim.SGD(param_dicts, lr=args.lr, momentum=0.9,
                                weight_decay=args.weight_decay)
else:
    optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
                                    weight_decay=args.weight_decay)

In [None]:
import engine

In [None]:
from datasets.torchvision_datasets.open_world import OWDetection

In [None]:
import datasets.transforms as T

def make_coco_transforms(image_set):

    normalize = T.Compose([
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
    t=[]
    
    if 'train' in image_set:
        t.append(['train'])
        t.append(T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomSelect(
                T.RandomResize(scales, max_size=1333),
                T.Compose([
                    T.RandomResize([400, 500, 600]),
                    T.RandomSizeCrop(384, 600),
                    T.RandomResize(scales, max_size=1333),
                ])
            ),
            normalize,
        ]))
        return t
    
    if 'ft' in image_set:
        t.append(['ft'])
        t.append(T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomSelect(
                T.RandomResize(scales, max_size=1333),
                T.Compose([
                    T.RandomResize([400, 500, 600]),
                    T.RandomSizeCrop(384, 600),
                    T.RandomResize(scales, max_size=1333),
                ])
            ),
            normalize,
        ]))
        return t

    if 'val' in image_set:
        t.append(['val'])
        t.append(T.Compose([
            T.RandomResize([800], max_size=1333),
            normalize,
        ]))
        return t

    if 'test' in image_set:
        t.append(['test'])
        t.append(T.Compose([
            T.RandomResize([800], max_size=1333),
            normalize,
        ]))
        return t

    raise ValueError(f'unknown {image_set}')

In [None]:
def get_datasets(args):
    print(args.dataset)

    train_set = args.train_set
    test_set = args.test_set
    dataset_train = OWDetection(args, args.data_root, image_set=args.train_set, transforms=make_coco_transforms(args.train_set), dataset = args.dataset)
    dataset_val = OWDetection(args, args.data_root, image_set=args.test_set, dataset = args.dataset, transforms=make_coco_transforms(args.test_set))

    print(args.train_set)
    print(args.test_set)
    print(dataset_train)
    print(dataset_val)

    return dataset_train, dataset_val

In [None]:
OWDetection(args, args.data_root, image_set=args.test_set, dataset = args.dataset)


In [None]:
dataset_train, dataset_val = get_datasets(args)

In [None]:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)

In [None]:
from typing import Optional, List
from torch import Tensor

from util.misc import NestedTensor

def collate_fn(batch):
    batch = list(zip(*batch))
    batch[0] = nested_tensor_from_tensor_list(batch[0])
    return tuple(batch)


def _max_by_axis(the_list):
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)
    return maxes


def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
    # TODO make this more general
    if tensor_list[0].ndim == 3:
        # TODO make it support different-sized images
        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
        for img, pad_img, m in zip(tensor_list, tensor, mask):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
            m[: img.shape[1], :img.shape[2]] = False
    else:
        raise ValueError('not supported')
    return NestedTensor(tensor, mask)

# class NestedTensor(object):
#     def __init__(self, tensors, mask: Optional[Tensor]):
#         self.tensors = tensors
#         self.mask = mask

#     def to(self, device, non_blocking=False):
#         cast_tensor = self.tensors.to(device, non_blocking=non_blocking)
#         mask = self.mask
#         if mask is not None:
#             assert mask is not None
#             cast_mask = mask.to(device, non_blocking=non_blocking)
#         else:
#             cast_mask = None
#         return NestedTensor(cast_tensor, cast_mask)

#     def record_stream(self, *args, **kwargs):
#         self.tensors.record_stream(*args, **kwargs)
#         if self.mask is not None:
#             self.mask.record_stream(*args, **kwargs)

#     def decompose(self):
#         return self.tensors, self.mask

#     def __repr__(self):
#         return str(self.tensors)

In [None]:
from torch.utils.data import DataLoader

data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
                                drop_last=False, collate_fn=collate_fn, num_workers=args.num_workers,
                                pin_memory=True)

In [None]:
base_ds = dataset_val

if args.pretrain:
    print('Initialized from the pre-training model')
    checkpoint = torch.load(args.pretrain, map_location='cpu')
    state_dict = checkpoint['model']
    msg = model_without_ddp.load_state_dict(state_dict, strict=False)
    print(msg)
    args.start_epoch = checkpoint['epoch'] + 1
    if args.eval:
        test_stats, coco_evaluator = engine.evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir, args)
        #return

In [None]:
from datasets import open_world_eval

In [None]:
model.eval()
criterion.eval()
# metric_logger = utils.MetricLogger(delimiter="  ")
header = 'Test:'
iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
coco_evaluator = open_world_eval.OWEvaluator(base_ds, iou_types, args=args)

In [None]:
from util.misc import NestedTensor, is_main_process

In [None]:
data_loader_val.batch_size

In [None]:
all_results = []

In [None]:
for samples, targets in data_loader_val:
    samples = samples.to(device)
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
    outputs = model(samples)
    orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
    results = postprocessors['bbox'](outputs, orig_target_sizes)
    all_results.append(results)

In [None]:
# # import pickle
  
# # # Open a file and use dump()
# # with open('results.pkl', 'wb') as file:
      
# #     # A new file will be created
# #     pickle.dump(results, file)

# import pickle
  
# # Open the file in binary mode
# with open('results.pkl', 'rb') as file:
      
#     # Call load method to deserialze
#     results = pickle.load(file)

In [None]:
label_map = {i: label for i, label in enumerate(data_loader_val.dataset.CLASS_NAMES, start=0)}

label_map

In [None]:
targets[1]

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

def plot_image_with_boxes(target, image_dir, label_map):
    # Get the image id and convert it to int
    image_id = int(target['image_id'].cpu().item())
    
    # Construct the image file path
    image_path = f'{image_dir}/{image_id}.png'  # adjust the file extension if needed

    # Load the image
    img = Image.open(image_path)

    # Create figure and axes
    fig, ax = plt.subplots(1)

    # Display the image
    ax.imshow(img)

    # Get the image size
    img_size = target['size'].cpu().numpy()

    # Get the boxes
    boxes = target['boxes'].cpu().numpy()

    # For each box
    for box in boxes:
        # Rescale the box
        box = box * [img_size[1], img_size[0], img_size[1], img_size[0]]

        # Create a Rectangle patch
        rect = patches.Rectangle((box[0], box[1]), box[2], box[3], linewidth=1, edgecolor='r', facecolor='none')

        # Add the patch to the Axes
        ax.add_patch(rect)

            # Get the label name
        label_name = label_map[labels[i]]

        # Add the label name
        plt.text(box[0], box[1], label_name, fontsize=10, color='white', bbox=dict(facecolor='red', alpha=0.5))

    plt.show()


In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

def plot_image_with_boxes(data, image_dir, label_map):
    # Get the image id and convert it to int
    image_id = int(data['image_id'].cpu().item())
    print(image_id)
    
    # Construct the image file path
    image_path = f'{image_dir}/{image_id}.png'  # adjust the file extension if needed

    # Load the image
    img = Image.open(image_path)

    # Create figure and axes
    fig, ax = plt.subplots(1, figsize=(15,15))

    ax.axis('off')

    # Display the image
    ax.imshow(img)

    # Get the image size
    img_size = data['size'].cpu().numpy()

    # Get the boxes
    boxes = data['boxes'].cpu().numpy()

    # Get the labels
    labels = data['labels'].cpu().numpy()

    # For each box
    for i, box in enumerate(boxes):
        # Rescale the box
        box = box * [img_size[1], img_size[0], img_size[1], img_size[0]]

        # Create a Rectangle patch
        rect = patches.Rectangle((box[0], box[1]), box[2], box[3], linewidth=1, edgecolor='r', facecolor='none')

        # Add the patch to the Axes
        ax.add_patch(rect)

        # Get the label name
        label_name = label_map[labels[i]]

        # Add the label name
        plt.text(box[0], box[1], label_name, fontsize=10, color='white', bbox=dict(facecolor='red', alpha=0.5))

    plt.show()


In [None]:
targets

In [None]:
plot_image_with_boxes(targets[1], 'data/OWOD/JPEGImages', label_map)

In [None]:
results[1].keys()

In [None]:
# import matplotlib.pyplot as plt
# import matplotlib.patches as patches
# from PIL import Image
# import numpy as np

# def draw_boxes(image_path, boxes, labels, label_map, scores, threshold=0.5):
#     """
#     Draw bounding boxes on an image.
    
#     Args:
#     image_path (str): Path to the image.
#     boxes (tensor): Bounding boxes tensor.
#     labels (tensor): Labels tensor.
#     scores (tensor): Scores tensor.
#     threshold (float): Score threshold for displaying bounding boxes.
#     """
#     # Move tensors to CPU and convert to numpy
#     boxes = boxes.cpu().numpy()
#     labels = labels.cpu().numpy()
#     scores = scores.cpu().numpy()

#     # Open the image
#     im = np.array(Image.open(image_path), dtype=np.uint8)

#     # Create figure and axes
#     fig, ax = plt.subplots(1, figsize=(15,15))

#     # Display the image
#     ax.imshow(im)

#     # Iterate through the boxes
#     for box, label, score in zip(boxes, labels, scores):
#         if score > threshold:
#             # Create a Rectangle patch
#             rect = patches.Rectangle((box[0],box[1]),box[2]-box[0],box[3]-box[1],
#                                      linewidth=1,edgecolor='r',facecolor='none')

#             # Add the patch to the Axes
#             ax.add_patch(rect)

#             # Get the label name
#             label_name = label_map[labels[label]]


#             # Add label and score text
#             plt.text(box[0], box[1], f'{label_name}: {score:.2f}', 
#                      color='white', fontsize=10,
#                      bbox=dict(facecolor='red', alpha=0.5))

#     plt.show()


In [None]:
def draw_boxes(image_path, boxes, labels, label_map, scores, threshold=0.5, background_color='#080e26'):
    """
    Draw bounding boxes on an image.
    
    Args:
    image_path (str): Path to the image.
    boxes (tensor): Bounding boxes tensor.
    labels (tensor): Labels tensor.
    scores (tensor): Scores tensor.
    threshold (float): Score threshold for displaying bounding boxes.
    background_color (str): Background color of the plot.
    """
    # Move tensors to CPU and convert to numpy
    boxes = boxes.cpu().numpy()
    labels = labels.cpu().numpy()
    scores = scores.cpu().numpy()

    # Open the image
    im = np.array(Image.open(image_path), dtype=np.uint8)

    # Create figure and axes
    fig, ax = plt.subplots(1, figsize=(15,15))

    # Set background color and remove axes
    ax.set_facecolor(background_color)
    ax.axis('off')

    # Display the image
    ax.imshow(im)

    # Iterate through the boxes
    for box, label, score in zip(boxes, labels, scores):
        if score > threshold:
            # Create a Rectangle patch
            rect = patches.Rectangle((box[0],box[1]),box[2]-box[0],box[3]-box[1],
                                     linewidth=1,edgecolor='#9413C1',facecolor='none')

            # Add the patch to the Axes
            ax.add_patch(rect)

            # Get the label name
            label_name = label_map[labels[label]]

            # Add label and score text
            plt.text(box[0], box[1], f'{label_name}: {score:.2f}', 
                     color='white', fontsize=10,
                     bbox=dict(facecolor='#9413C1', alpha=0.5))

    plt.show()


In [None]:
target_index = 1
image_dir = 'data/OWOD/JPEGImages'
image_ext = '.png'
image_path = f'{image_dir}/{4623}{image_ext}'
boxes = results[target_index]['boxes']
labels = results[target_index]['labels']
scores = results[target_index]['scores']

In [None]:
labels

In [None]:
draw_boxes(image_path, boxes, labels, label_map, scores, threshold=0.7)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

def draw_boxes(image_tensor, boxes, labels, scores, threshold=0.5):
    """
    Draw bounding boxes on an image.
    
    Args:
    image_tensor (tensor): Tensor representation of the image.
    boxes (tensor): Bounding boxes tensor.
    labels (tensor): Labels tensor.
    scores (tensor): Scores tensor.
    threshold (float): Score threshold for displaying bounding boxes.
    """
    # Move tensors to CPU and convert to numpy
    image = image_tensor.permute(1, 2, 0).cpu().numpy()
    boxes = boxes.cpu().numpy()
    labels = labels.cpu().numpy()
    scores = scores.cpu().numpy()

    # Create figure and axes
    fig, ax = plt.subplots(1)

    # Display the image
    ax.imshow(image)

    # Iterate through the boxes
    for box, label, score in zip(boxes, labels, scores):
        if score > threshold:
            # Create a Rectangle patch
            rect = patches.Rectangle((box[0],box[1]),box[2]-box[0],box[3]-box[1],
                                     linewidth=1,edgecolor='r',facecolor='none')

            # Add the patch to the Axes
            ax.add_patch(rect)

            # Add label and score text
            plt.text(box[0], box[1], f'{label}: {score:.2f}', 
                     color='white', fontsize=10,
                     bbox=dict(facecolor='red', alpha=0.2))

    plt.show()


In [None]:
data_loader_val.dataset.images

In [None]:
data_loader_val.dataset.CLASS_NAMES

In [None]:
data_loader_val.dataset

In [None]:
target_index = 15
image = data_loader_val.dataset[target_index][0]
boxes = results[target_index]['boxes']
labels = results[target_index]['labels']
scores = results[target_index]['scores']

In [None]:
image.shape

In [None]:
draw_boxes(image, boxes, labels, scores, threshold=0.5)


## Update XML files
Also in script `update_xml.py` - still needs a main.

### Add missing tags to XML

In [None]:
import os
import re

def replace_in_file(file_path, pattern, replacement):
    with open(file_path, 'r+') as file:
        file_content = file.read()
        file_content = re.sub(pattern, replacement, file_content)
        file.seek(0)
        file.write(file_content)
        file.truncate()

def replace_in_all_files(directory, pattern, replacement):
    for foldername, subfolders, filenames in os.walk(directory):
        for filename in filenames:
            file_path = os.path.join(foldername, filename)
            replace_in_file(file_path, pattern, replacement)

directory = "./data/OWOD/Annotations/"
pattern = "/name>\n        <bndbox>"
replacement = "/name>\n        <truncated>0</truncated>\n        <difficult>0</difficult>\n        <bndbox>"


In [None]:
# replace_in_all_files(directory, pattern, replacement)

### Update annotation files with new path and new name.

In [None]:
import pandas as pd
import os
import xml.etree.ElementTree as ET
import logging
from tqdm import tqdm

dict_csv = "./data/OWOD/filename_map.csv"
xml_dir = "./data/OWOD/Annotations/"
new_path_prefix = "/home/sabrina/code/PROB/data/OWOD/ImageSets/"

def main(dict_csv, xml_dir, new_path_prefix):
    logging.basicConfig(filename='xml_update.log', level=logging.INFO)
    
    df = pd.read_csv(dict_csv)
    name_dict = df.set_index('old_name')['new_name'].to_dict()

    xml_files = [f for f in os.listdir(xml_dir) if f.endswith('.xml')]

    for xml_file in tqdm(xml_files, desc="Updating XML files"):
        tree = ET.parse(os.path.join(xml_dir, xml_file))
        root = tree.getroot()

        for elem in root.iter():
            try:
                if elem.tag == 'filename':
                    old_filename = elem.text.split('.')[0]
                    file_extension = elem.text.split('.')[1]

                    new_filename = f"{name_dict[old_filename]}.{file_extension}"
                    new_path = f"{new_path_prefix}{new_filename}"

                    elem.text = new_filename

                if elem.tag == 'path':
                    elem.text = new_path

            except Exception as e:
                logging.error(f"Error processing XML file {xml_file}: {e}")
                pass

        tree.write(os.path.join(xml_dir, xml_file))                    


In [None]:
# main(dict_csv, xml_dir, new_path_prefix)