In [29]:
import torch

def apply_binary_threshold(tensor, threshold):
    """
    Apply a binary threshold to a tensor.
    
    Args:
    - tensor (torch.Tensor): Input tensor to apply threshold to.
    - threshold (float or list of float): Threshold value(s). If a single float, it applies the same threshold to all elements.
      If a list, each element in the list corresponds to the threshold for the corresponding element in the tensor.
    
    Returns:
    - torch.Tensor: Binary thresholded tensor.
    """
    if isinstance(threshold, float):
        return (tensor >= threshold).float()
    elif isinstance(threshold, list):
        if len(threshold) != tensor.size(0):
            raise ValueError("Number of thresholds must match the number of elements in the tensor.")
        return torch.tensor([1.0 if tensor[i] >= threshold[i] else 0.0 for i in range(tensor.size(0))])
    else:
        raise TypeError("Threshold must be either a float or a list of floats.")

# Example usage:
# Assuming predictions[2] is a tensor with 12 elements
predictions = torch.rand(12)  # Example tensor of random values
print("Original predictions:")
print(predictions)

# Applying a single threshold to all elements
threshold = 0.5
binary_predictions_single_threshold = apply_binary_threshold(predictions, threshold)
print("\nBinary predictions with single threshold (0.5):")
print(binary_predictions_single_threshold)

# Applying different thresholds to each element
thresholds = [0.44, 0.36, 0.39, 0.34, 0.46, 0.47, 0.38, 0.11, 0.1,  0.49, 0.62, 0.31]
binary_predictions_multiple_thresholds = apply_binary_threshold(predictions, thresholds)
print("\nBinary predictions with multiple thresholds:")
print(binary_predictions_multiple_thresholds)

Original predictions:
tensor([0.9163, 0.4702, 0.0240, 0.3048, 0.3818, 0.9537, 0.6381, 0.5904, 0.1319,
        0.8142, 0.7906, 0.3852])

Binary predictions with single threshold (0.5):
tensor([1., 0., 0., 0., 0., 1., 1., 1., 0., 1., 1., 0.])

Binary predictions with multiple thresholds:
tensor([1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1.])


In [2]:
import torch
from torchvision import transforms
from PIL import Image, UnidentifiedImageError
import os
from networks.DDAM_ABAW import DDAMNet
import numpy as np
from tqdm import tqdm

# Define the image preprocessing function
def preprocess_image(image_path, transform):
    try:
        image = Image.open(os.path.join("cropped_aligned_test", image_path)).convert('RGB')
        image_tensor = transform(image).unsqueeze(0)  # Add batch dimension
        return image_tensor
    except UnidentifiedImageError as e:
        print(f"Skipping image {image_path}: {e}")
        return None

# Function to read the text file and return a list of image paths
def read_image_paths(txt_file):
    with open(txt_file, 'r') as file:
        image_paths = [line.rstrip(', \n') for line in file if line.strip()]
    return image_paths

# Function to write results back to the text file
def write_results(txt_file, results):
    with open(txt_file, 'w') as file:
        file.write("image,valence,arousal,expression,aus" + "\n")
        for result in results:
            valence = round(result['valence'], 2)
            arousal = round(result['arousal'], 2)
            aus_str = ','.join(map(str, result['aus'].astype(int)))  # Ensure aus are integers
            file.write(f"{result['image']},{valence},{arousal},{result['expression']},{aus_str}\n")
# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DDAMNet(num_class=8, num_head=2, pretrained=False, train_val_arousal=True, train_emotions=True, train_actions=True)
model_path = 'best_multitask_model_all.pth'
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()

# Define transformations to apply to the image
transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Read the image paths from the txt file
txt_file = 'MTL.txt'
txt_results = "MTL_results.txt"
image_paths = read_image_paths(txt_file)

results = []

# Process each image and get predictions
with torch.no_grad():
    for image_path in tqdm(image_paths[1:]):
        image_tensor = preprocess_image(image_path, transform)
        if image_tensor is None:
            results.append({
            'image': image_path,
            'valence': 0,
            'arousal': 0,
            'expression': 0,
            'aus': np.zeros(12)
            })
        else:
        
            image_tensor = image_tensor.to(device)
            
            predictions = model(image_tensor)
            
            # Assuming the model returns valence, arousal, expression, and aus
            valence = predictions[0][0][0].item()
            arousal = predictions[0][0][1].item()
            expression = np.argmax(predictions[1].cpu()).item()  # Move tensor to CPU before using np.argmax
            
            # Apply binary threshold
            thresholds = [0.44, 0.36, 0.39, 0.34, 0.46, 0.47, 0.38, 0.11, 0.1,  0.49, 0.62, 0.31]
            aus = torch.tensor([1.0 if val >= thresholds[i] else 0.0 for i, val in enumerate(predictions[2][0])])
            
            results.append({
                'image': image_path,
                'valence': valence,
                'arousal': arousal,
                'expression': expression,
                'aus': aus.cpu().numpy()
            })

# Write the results back to the text file
write_results(txt_results, results)

print("Processing and writing results completed.")






 14%|█▍        | 7403/51159 [06:17<39:32, 18.44it/s]  