<a href="https://colab.research.google.com/github/sogand120/sogandsa.github.io/blob/master/sample_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import xml.etree.ElementTree as ET
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AlbertTokenizer, AlbertForSequenceClassification, AdamW
from tqdm import tqdm
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

## BDD related keywords (not yet comprehensive, many more to add)
bdd_keywords = ["body image", "mirror", "hate", "BDD", "diagnosed with BDD", "flaw", "acne", "dislike", "body hatred", "body comparison", "fat", "overweight",
                "scar", "gained weight", "hate my body", "too fat", "feel ugly" ]

## parse XML function
def parse_xml(file_path):
    data = []
    try:
        tree = ET.parse(file_path)
        root = tree.getroot()

        for element in root.findall('.//record'):  # Adjust based on your XML structure
            text = element.find('text').text
            label = element.find('label').text
            data.append({'text': text, 'label': label})
    except ET.ParseError as e:
        logger.error(f"Error parsing {file_path}: {e}")
    except Exception as e:
        logger.error(f"An unexpected error occurred with {file_path}: {e}")
    return data

# custom Dataset to handle XML files and tokenization
class BDDExtendedDataset(Dataset):
    def __init__(self, xml_files, tokenizer, max_len):
        self.data = []
        self.tokenizer = tokenizer
        self.max_len = max_len

        ## parse XML files
        for file in xml_files:
            self.data.extend(parse_xml(file))

        logger.info(f"Loaded {len(self.data)} records from XML files.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]['text']
        label = self.data[idx]['label']

        ## tokenize
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        label = torch.tensor(int(label), dtype=torch.long)

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': label,
            'text': text
        }

## keyword detection function
def detect_keywords(texts, keywords):
    keyword_set = set(keywords)
    results = []
    for text in texts:
        words = set(text.lower().split())
        if keyword_set.intersection(words):
            results.append(True)
        else:
            results.append(False)
    return results

## main function
def main():
    xml_folder = '/path/to/xml_folder'  ## adjust according to folder
    xml_files = [os.path.join(xml_folder, file) for file in os.listdir(xml_folder) if file.endswith('.xml')]

    tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
    model = AlbertForSequenceClassification.from_pretrained('albert-base-v2', num_labels=2)

    dataset = BDDExtendedDataset(xml_files, tokenizer, max_len=128)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

    optimizer = AdamW(model.parameters(), lr=2e-5)

    ## training Loop
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    for epoch in range(3):
        model.train()
        train_loss = 0
        for batch in tqdm(dataloader):
            optimizer.zero_grad()
            input_ids, attention_mask, labels = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['label'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            train_loss += loss.item()
            loss.backward()
            optimizer.step()

        logger.info(f'Epoch {epoch + 1}, Training Loss: {train_loss / len(dataloader)}')

    logger.info("Performing keyword detection...")

    texts = [item['text'] for item in dataset]
    keyword_results = detect_keywords(texts, bdd_keywords)

    ## output
    for i, result in enumerate(keyword_results):
        logger.info(f"Text: {texts[i][:50]}...")  # Print the first 50 characters of each text
        logger.info(f"BDD Keywords Found: {result}")

print(main())