In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load tokenizer and model from Hugging Face
tokenizer = AutoTokenizer.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-XLM-R")
model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-XLM-R")

In [10]:
def classify_text(text, candidate_labels):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device).eval()

    # Prepare the text input for classification
    encoded_input = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
    encoded_input = {key: val.to(device) for key, val in encoded_input.items()}

    # Perform inference
    with torch.no_grad():
        outputs = model(**encoded_input)
        scores = torch.nn.functional.softmax(outputs.logits, dim=-1)
    
    results = {label: float(score) for label, score in zip(candidate_labels, scores[0])}
    return results

def filter_files(file_paths, relevant_label='relevant football article'):
    relevant_files = []
    candidate_labels = ["relevant football article", "irrelevant content"]

    for file_path in file_paths:
        with open(file_path, 'r', encoding='utf-8') as file:
            text = file.read()

        results = classify_text(text, candidate_labels)
        print(results[relevant_label])
        if results[relevant_label] > 0.1:  # Threshold can be adjusted
            relevant_files.append(file_path)

    return relevant_files


In [None]:
import os

def get_all_file_paths(directory):
    file_paths = []  # List to store file paths
    for root, dirs, files in os.walk(directory):
        for file in files:
            file_path = os.path.join(root, file)
            file_paths.append(file_path)
    return file_paths

In [7]:
# Example usage:
directory_path = '../../data/raw'
all_file_paths = get_all_file_paths(directory_path)

['../../data/raw/gazzetta_it/turismo-sportivo_eventi.txt', '../../data/raw/gazzetta_it/Calcio_Europei_26-03-2024_europei-galles-polonia-ucraina-islanda-georgia-grecia-per-gli-ultimi-3-posti.shtml.txt', '../../data/raw/gazzetta_it/Calcio_calcio-femminile.txt']


In [12]:
# Example usage:
files_to_check = all_file_paths[:100]
filtered_files = filter_files(files_to_check)
print("Filtered files:", filtered_files)