# Developing a Model to Analyze Health Vials from Diablo 4 Screenshots

## Environment Setup

In [14]:
%%capture
%pip install torchvision matplotlib

In [15]:
# Imports
import json
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image

## Training
Train a model to classify health vials and estimate their fill levels.  
<br>
### Semi-Supervised Learning
Use semi-supervised learning to create a sub-varient of the model to help with the labeleling process
<br>
#### Load labeled and unlabeled data

In [None]:
# Define the path for the datasets
base_path = os.path.expanduser('~/Desktop/health-vials/')

# Load labeled dataset
with open(os.path.join(base_path, 'batch_annotations_full_size(labeled).json'), 'r') as f:
    labeled_data = json.load(f)

# Load unlabeled dataset
with open(os.path.join(base_path, 'batch_annotations_full_size(unlabeled).json'), 'r') as f:
    unlabeled_data = json.load(f)


In [None]:
# Create Custom Dataset
class VialDataset(Dataset):
    def __init__(self, data, is_labeled=True, transform=None):
        self.data = data
        self.is_labeled = is_labeled
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = os.path.join(base_path, item['image']) 
        image = Image.open(image_path).convert("RGB")

        try:
            if self.transform:
                image = self.transform(image)
        except Exception as e:
            print(f"Error processing image: {image_path}, Error: {str(e)}")
            return None, None, None


        if self.is_labeled:
            # Return labeled data
            fullness = item['attributes']['fullness']
            barrier = item['attributes']['barrier']
            return image, torch.tensor([fullness], dtype=torch.float32), torch.tensor([int(barrier)], dtype=torch.float32)
        
        else:
            # Return only the image and empty tensors for unlabeled data
            return image, torch.zeros(1, dtype=torch.float32), torch.zeros(1, dtype=torch.float32)

In [None]:
# Define the model
class VialModel(nn.Module):
    def __init__(self):
        super(VialModel, self).__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(16 * 118 * 118, 128),  # Adjust to match image size
            nn.ReLU()
        )
        self.fullness_head = nn.Linear(128, 1)  # Regression for fullness
        self.barrier_head = nn.Linear(128, 1)  # Classification for barrier

    def forward(self, x):
        x = self.backbone(x)
        fullness = self.fullness_head(x)
        barrier = torch.sigmoid(self.barrier_head(x))
        return fullness, barrier

In [None]:
# Pretrain on labeled data
labeled_dataset = VialDataset(labeled_data, is_labeled=True, transform=transforms.ToTensor())
labeled_loader = DataLoader(labeled_dataset, batch_size=16, shuffle=True)

model = VialModel()

criterion_fullness = nn.MSELoss()  # Loss for regression
criterion_barrier = nn.BCELoss()  # Loss for classification
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop for labeled data
model.train()
for epoch in range(10):  # Adjust epochs
    for images, fullness, barrier in labeled_loader:
        optimizer.zero_grad()
        pred_fullness, pred_barrier = model(images)
        loss_fullness = criterion_fullness(pred_fullness, fullness)
        loss_barrier = criterion_barrier(pred_barrier, barrier)
        loss = loss_fullness + loss_barrier
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}: Loss = {loss.item()}")

Epoch 1: Loss = 16.666851043701172
Epoch 2: Loss = 9.9879789352417
Epoch 3: Loss = 0.7389234304428101
Epoch 4: Loss = 0.4268410801887512
Epoch 5: Loss = 1.014037013053894
Epoch 6: Loss = 1.373262643814087
Epoch 7: Loss = 1.422499656677246
Epoch 8: Loss = 0.8485471606254578
Epoch 9: Loss = 0.40497589111328125
Epoch 10: Loss = 0.667794406414032


In [21]:
# Generate pseudo-labels for unlabeled data
unlabeled_dataset = VialDataset(unlabeled_data, is_labeled=False, transform=transforms.ToTensor())
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=16, shuffle=False)

model.eval()
pseudo_labels = []
with torch.no_grad():
    for images, _, _ in unlabeled_loader:
        pred_fullness, pred_barrier = model(images)
        for i in range(len(images)):
            pseudo_labels.append({
                "image": unlabeled_data[i]['image'],
                "bounding_box": unlabeled_data[i]['bounding_box'],
                "attributes": {
                    "fullness": float(pred_fullness[i].item()),
                    "barrier": bool(pred_barrier[i].item() > 0.5)
                }
            })

# Save pseudo-labeled data
with open('batch_annotations_pseudo_labeled.json', 'w') as f:
    json.dump(pseudo_labels, f, indent=4)

In [22]:
# Combine datasets
combined_data = labeled_data + pseudo_labels
combined_dataset = VialDataset(combined_data, is_labeled=True, transform=transforms.ToTensor())
combined_loader = DataLoader(combined_dataset, batch_size=16, shuffle=True)

In [23]:
# Fine-tune the model
for epoch in range(10):  # Adjust epochs
    for images, fullness, barrier in combined_loader:
        optimizer.zero_grad()
        pred_fullness, pred_barrier = model(images)
        loss_fullness = criterion_fullness(pred_fullness, fullness)
        loss_barrier = criterion_barrier(pred_barrier, barrier)
        loss = loss_fullness + loss_barrier
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}: Loss = {loss.item()}")

Epoch 1: Loss = 0.3077651858329773
Epoch 2: Loss = 0.43559712171554565
Epoch 3: Loss = 0.3910799026489258
Epoch 4: Loss = 0.1623603254556656
Epoch 5: Loss = 0.16949298977851868
Epoch 6: Loss = 0.20837630331516266
Epoch 7: Loss = 0.1786826252937317
Epoch 8: Loss = 0.16844432055950165
Epoch 9: Loss = 0.16195252537727356
Epoch 10: Loss = 0.13993768393993378


In [24]:
# Save the fine-tuned model
model_save_path = os.path.join(base_path, 'vial_model.pth')
torch.save(model.state_dict(), model_save_path)
print(f'Model saved at: {model_save_path}')

Model saved at: /Users/jpswaynos/Desktop/health-vials/vial_model.pth


## Inference

In [17]:
# Load the trained model for inference
def load_model(model_path):
    model = VialModel()  # Create a new instance of the model
    model.load_state_dict(torch.load(model_path))
    model.eval()  # Set the model to evaluation mode
    return model

def infer_health_vial(image_path, model):
    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")
    image = transforms.ToTensor()(image).unsqueeze(0)  # Add batch dimension
    
    with torch.no_grad():
        pred_fullness, pred_barrier = model(image)
        fullness = pred_fullness.item()
        barrier = bool(pred_barrier.item() > 0.5)  # Apply sigmoid and threshold
        return fullness, barrier

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import json

def display_image(image_path, fullness_value, barrier_present_value):
    """Display the image in the Jupyter output along with the threshold and barrier values."""
    img = mpimg.imread(image_path)
    plt.imshow(img)
    plt.axis('off')  # Hide axes

    # Display threshold and barrier values below the image
    plt.title(f'Fullness: {fullness_value}, Barrier Present: {barrier_present_value}')
    plt.show()

def process_images(image_dir, model, output_file, fullness_threshold=None, barrier_filter=None):
    """Process images for inference and apply filters based on thresholds."""
    # List all image files in the directory
    image_files = [f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
    
    # To hold results
    results = []
    
    # Define the bounding box (fixed coordinate values)
    bounding_box = {
        "x_min": 0,
        "y_min": 0,
        "x_max": 236,
        "y_max": 236
    }

    # Loop through each image file and perform inference
    for image_file in image_files:
        image_path = os.path.join(image_dir, image_file)

        # Perform inference (Assume you have a pre-defined function to call)
        fullness, barrier = infer_health_vial(image_path, model)  # Replace this with your model's inference method
        
        # Construct result for current image
        result = {
            "image": image_file,
            "bounding_box": bounding_box,
            "attributes": {
                "fullness": fullness,
                "barrier": barrier
            }
        }
        
        # Append result to results list
        results.append(result)

        # Apply filters
        if (fullness_threshold is None or fullness > fullness_threshold) and \
           (barrier_filter is None or barrier == barrier_filter):
            # Print the image in the Jupyter output along with threshold and barrier values
            display_image(image_path, fullness, barrier)

    # Write results to a JSON file
    with open(output_file, 'w') as json_file:
        json.dump(results, json_file, indent=4)

    print(f"Inference results saved to {output_file}")

#base_path = os.path.join(os.getcwd(), '..') # one directory above the current working directory
base_path = os.path.expanduser('~/Desktop/health-vials')

# Load the trained model for inference
model_path = os.path.join(base_path, 'vial_model.pth')
model = load_model(model_path)

# Directory containing the images for inference
image_dir = os.path.expanduser('~/Desktop/health-vials-eval')
output_file = os.path.join(base_path, 'vial_inference_results.json')

# Specify filters:
# fullness_threshold: Change this to set the minimum fullness to show
# barrier_filter:
#   None = don't filter on barrier
#   True = show images with barrier
#   False = show images without barrier
fullness_threshold = 0.98 
barrier_filter = None

# Process images with specified filters
process_images(image_dir, model, output_file, fullness_threshold, barrier_filter)

## Data Preparation
Utilities for preparing the data for training.

### Extract Health Vials from Screenshots
Using a sample image ```../images/health-vial-mask.1440.png``` this notebook will apply the mask to 1440p screenshots from the game and extract the health vial image.  
<br>
Additionally, the notebook will resize the image to a size just larger than the extracted health vial, and use a white background.

In [None]:
%%capture
%pip install Pillow

In [None]:
def extract_masked_content(input_image_path, output_image_path, mask_image_path='../images/health-vial-mask.1440.png'):
    # Open the input image
    img = Image.open(input_image_path).convert("RGBA")
    
    # Open the mask image and convert to greyscale
    mask = Image.open(mask_image_path).convert("L")  # Convert the mask to greyscale

    # Create a new image by applying the mask
    masked_data = []
    data = img.getdata()

    for index, item in enumerate(data):
        # Get mask pixel value (0: black, 255: white)
        mask_value = mask.getpixel((index % img.width, index // img.width))
        
        # If the mask pixel is white (255), keep the corresponding input image pixel
        if mask_value < 255:  # If not white, keep the pixel
            masked_data.append(item)  # Keep the original pixel
        else:
            masked_data.append((0, 0, 0, 0))  # Fully transparent pixel

    # Create a new image with the masked content
    masked_image = Image.new("RGBA", img.size)
    masked_image.putdata(masked_data)

    # Find the bounding box of the non-transparent pixels
    bbox = masked_image.getbbox()

    # Crop the image to the bounding box
    if bbox:
        cropped_image = masked_image.crop(bbox)

        # Create a new transparent image for resizing
        max_size = max(cropped_image.size)
        new_image = Image.new("RGBA", (max_size, max_size), (0, 0, 0, 0))  # Transparent background

        # Paste cropped image onto the center of the new transparent image
        new_image.paste(cropped_image, ((max_size - cropped_image.width) // 2, (max_size - cropped_image.height) // 2))

        # Save the result
        new_image.save(output_image_path, format="PNG")
    else:
        print("No visible content to extract.")

In [None]:
# Define the base directories
input_base_dir = os.path.expanduser('~/Desktop/Diablo 4 Captures')
output_base_dir = os.path.expanduser('~/Desktop/health-vials-eval')

# Create the output directory if it doesn't exist
os.makedirs(output_base_dir, exist_ok=True)

# Get a list of all files in the input directory
all_files = os.listdir(input_base_dir)

# Filter to get only PNG files
png_image_paths = [os.path.join(input_base_dir, f) for f in all_files if f.endswith('.png')]

# Process each PNG image
for input_image_path in png_image_paths:
    # Construct the output image path
    output_image_path = os.path.join(output_base_dir, f'extracted_{os.path.basename(input_image_path)}')
    
    extract_masked_content(input_image_path, output_image_path)
    print(f'Processing {input_image_path} => {output_image_path}')

### Crop and Resize Screenshots
Remove top and bottom 105px and resize the screenshot to 2560x1440.

In [None]:
# Define the path to your desktop
images_path = os.path.expanduser("~/Desktop/Diablo 4 Captures")

# Loop through all files in the directory
for file_name in os.listdir(images_path):
    if file_name.endswith(".png"):  # Process only PNG files
        file_path = os.path.join(images_path, file_name)
        
        # Open the image
        with Image.open(file_path) as img:
            width, height = img.size
            
            # Check if the image dimensions are 2560 x 1440
            if (width, height) > (2560, 1440):
                
                # Crop the image to remove 105px from top and bottom
                cropped_img = img.crop((0, 105, width, height - 105))

                # Resize the cropped image to 2560x1440 if it isn't already
                resized_img = cropped_img.resize((2560, 1440), Image.LANCZOS)
                
                # Save the resized image (overwrite the original file)
                cropped_img.save(file_path)