In [1]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import timm
from torchvision import transforms as T
from tqdm import tqdm
import gc

class CFG:
    # General
    num_workers = 4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Data paths
    data_dir = './'
    labels_csv_path = os.path.join(data_dir, 'labels.csv')
    sample_submission_path = os.path.join(data_dir, 'sample_submission.csv')
    test_img_dir = os.path.join(data_dir, 'test')
    # FINAL STRATEGY: Use the newly generated thresholds with alpha=0.20
    thresholds_path = 'thresholds_final.npy'

    # Model
    model_name = 'tf_efficientnet_b4_ns'
    img_size = 384
    model_paths = [
        'models/tf_efficientnet_b4_ns_fold0_best.pth',
        'models/tf_efficientnet_b4_ns_fold1_best.pth',
        'models/tf_efficientnet_b4_ns_fold2_best.pth'
    ]

    # Inference
    batch_size = 8

# Load label mappings
labels_df = pd.read_csv(CFG.labels_csv_path)
CFG.attr_ids = labels_df['attribute_id'].values
CFG.attr_id_to_idx = {attr_id: i for i, attr_id in enumerate(CFG.attr_ids)}
CFG.idx_to_attr_id = {i: attr_id for i, attr_id in enumerate(CFG.attr_ids)}
CFG.num_classes = len(labels_df)

# Clean up memory
torch.cuda.empty_cache()
gc.collect()

0

In [2]:
def get_test_transforms():
    # This uses aspect-ratio preserving resize, center crop, and ImageNet normalization.
    print("--- Applying CORRECTED validation transforms (Resize+CenterCrop, ImageNet Norm) ---")
    
    return T.Compose([
        T.Resize(CFG.img_size), # Preserves aspect ratio
        T.CenterCrop(CFG.img_size),
        T.ToTensor(),
        T.Normalize(
            mean=[0.485, 0.456, 0.406], # ImageNet stats
            std=[0.229, 0.224, 0.225],
        ),
    ])

class iMetTestDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.filepaths = df['filepath'].values
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        filepath = self.filepaths[idx]
        image = Image.open(filepath).convert('RGB')
        
        if self.transforms:
            image = self.transforms(image)
            
        return image

class iMetModel(nn.Module):
    def __init__(self, model_name, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=CFG.num_classes)

    def forward(self, x):
        return self.model(x)

In [3]:
## Prepare Test Data
sub_df = pd.read_csv(CFG.sample_submission_path)
sub_df['filepath'] = sub_df['id'].apply(lambda x: os.path.join(CFG.test_img_dir, x + '.png'))
display(sub_df.head())

# Create the test dataset and loader (NO TTA)
test_dataset = iMetTestDataset(sub_df, transforms=get_test_transforms())
test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)

print("Test data prepared.")

Unnamed: 0,id,attribute_ids,filepath
0,347c119163f84420f10f7a8126c1b8a2,0 1 2,./test/347c119163f84420f10f7a8126c1b8a2.png
1,98c91458324cba5415c5f5d8ead68328,0 1 2,./test/98c91458324cba5415c5f5d8ead68328.png
2,3f75d332f579af62ff88d369c0736c76,0 1 2,./test/3f75d332f579af62ff88d369c0736c76.png
3,3fa35a29218b7449c8f03e2a368a708d,0 1 2,./test/3fa35a29218b7449c8f03e2a368a708d.png
4,c848b91558e4edd8034cb7d334b4e448,0 1 2,./test/c848b91558e4edd8034cb7d334b4e448.png


--- Applying CORRECTED validation transforms (Resize+CenterCrop, ImageNet Norm) ---
Test data prepared.


In [None]:
## Inference (Ensemble, NO TTA)

# Pre-allocate array for summed predictions for memory efficiency
n_samples = len(sub_df)
total_preds = np.zeros((n_samples, CFG.num_classes), dtype=np.float32)

for i, model_path in enumerate(CFG.model_paths):
    print(f"--- Inferencing with model {i+1}/{len(CFG.model_paths)}: {model_path} ---")
    
    # FIX for OOM: Create model on CPU, load weights, THEN move to GPU.
    model = iMetModel(CFG.model_name, pretrained=False) # 1. Create on CPU
    state_dict = torch.load(model_path, map_location='cpu', weights_only=True) # 2. Load weights to CPU
    model.load_state_dict(state_dict)
    model.to(CFG.device) # 3. Move fully loaded model to GPU
    model.eval()

    pbar = tqdm(test_loader, desc=f"Predicting (Model {i+1})")
    current_pos = 0
    with torch.no_grad():
        for images in pbar:
            images = images.to(CFG.device)
            logits = model(images)
            preds = logits.sigmoid().cpu().numpy()
            
            batch_size = images.size(0)
            total_preds[current_pos : current_pos + batch_size] += preds
            current_pos += batch_size
            
    # Clean up memory after each model
    del model, state_dict
    torch.cuda.empty_cache()
    gc.collect()

# Average the predictions (3 models)
all_preds = total_preds / len(CFG.model_paths)
print("\nEnsemble predictions calculated.")

--- Inferencing with model 1/3: models/tf_efficientnet_b4_ns_fold0_best.pth ---


  model = create_fn(


Predicting (Model 1):   0%|          | 0/2665 [00:00<?, ?it/s]

Predicting (Model 1):   0%|          | 1/2665 [00:00<08:47,  5.05it/s]

Predicting (Model 1):   0%|          | 4/2665 [00:00<02:59, 14.85it/s]

Predicting (Model 1):   0%|          | 7/2665 [00:00<02:15, 19.65it/s]

Predicting (Model 1):   0%|          | 10/2665 [00:00<01:58, 22.41it/s]

Predicting (Model 1):   0%|          | 13/2665 [00:00<01:49, 24.11it/s]

Predicting (Model 1):   1%|          | 16/2665 [00:00<01:45, 25.15it/s]

Predicting (Model 1):   1%|          | 19/2665 [00:00<01:42, 25.82it/s]

Predicting (Model 1):   1%|          | 22/2665 [00:00<01:40, 26.29it/s]

Predicting (Model 1):   1%|          | 25/2665 [00:01<01:39, 26.49it/s]

Predicting (Model 1):   1%|          | 28/2665 [00:01<01:38, 26.69it/s]

Predicting (Model 1):   1%|          | 31/2665 [00:01<01:38, 26.83it/s]

Predicting (Model 1):   1%|▏         | 34/2665 [00:01<01:37, 27.00it/s]

Predicting (Model 1):   1%|▏         | 37/2665 [00:01<01:36, 27.12it/s]

Predicting (Model 1):   2%|▏         | 40/2665 [00:01<01:36, 27.12it/s]

Predicting (Model 1):   2%|▏         | 43/2665 [00:01<01:36, 27.19it/s]

Predicting (Model 1):   2%|▏         | 46/2665 [00:01<01:36, 27.19it/s]

Predicting (Model 1):   2%|▏         | 49/2665 [00:01<01:36, 27.16it/s]

Predicting (Model 1):   2%|▏         | 52/2665 [00:02<01:36, 27.08it/s]

Predicting (Model 1):   2%|▏         | 55/2665 [00:02<01:36, 27.09it/s]

Predicting (Model 1):   2%|▏         | 58/2665 [00:02<01:36, 27.12it/s]

Predicting (Model 1):   2%|▏         | 61/2665 [00:02<01:36, 27.09it/s]

Predicting (Model 1):   2%|▏         | 64/2665 [00:02<01:35, 27.19it/s]

Predicting (Model 1):   3%|▎         | 67/2665 [00:02<01:35, 27.27it/s]

Predicting (Model 1):   3%|▎         | 70/2665 [00:02<01:35, 27.32it/s]

Predicting (Model 1):   3%|▎         | 73/2665 [00:02<01:34, 27.36it/s]

Predicting (Model 1):   3%|▎         | 76/2665 [00:02<01:34, 27.31it/s]

Predicting (Model 1):   3%|▎         | 79/2665 [00:03<01:34, 27.28it/s]

Predicting (Model 1):   3%|▎         | 82/2665 [00:03<01:34, 27.29it/s]

Predicting (Model 1):   3%|▎         | 85/2665 [00:03<01:34, 27.34it/s]

Predicting (Model 1):   3%|▎         | 88/2665 [00:03<01:34, 27.35it/s]

Predicting (Model 1):   3%|▎         | 91/2665 [00:03<01:34, 27.33it/s]

Predicting (Model 1):   4%|▎         | 94/2665 [00:03<01:34, 27.33it/s]

Predicting (Model 1):   4%|▎         | 97/2665 [00:03<01:33, 27.33it/s]

Predicting (Model 1):   4%|▍         | 100/2665 [00:03<01:33, 27.35it/s]

Predicting (Model 1):   4%|▍         | 103/2665 [00:03<01:33, 27.33it/s]

Predicting (Model 1):   4%|▍         | 106/2665 [00:04<01:33, 27.39it/s]

Predicting (Model 1):   4%|▍         | 109/2665 [00:04<01:33, 27.37it/s]

Predicting (Model 1):   4%|▍         | 112/2665 [00:04<01:33, 27.42it/s]

Predicting (Model 1):   4%|▍         | 115/2665 [00:04<01:33, 27.38it/s]

Predicting (Model 1):   4%|▍         | 118/2665 [00:04<01:32, 27.40it/s]

Predicting (Model 1):   5%|▍         | 121/2665 [00:04<01:32, 27.37it/s]

Predicting (Model 1):   5%|▍         | 124/2665 [00:04<01:32, 27.41it/s]

In [None]:
## Create Submission with Blended Thresholds (NO MAX-K GUARD)

# Load the blended thresholds
try:
    thresholds = np.load(CFG.thresholds_path)
    print(f"Loaded blended thresholds from: {CFG.thresholds_path}")
    print(f"Thresholds shape: {thresholds.shape}")
except FileNotFoundError:
    print(f"ERROR: Threshold file not found at {CFG.thresholds_path}. Make sure it exists.")
    # Stop execution if the file is missing
    raise

predictions = []
for pred_row in tqdm(all_preds, desc="Formatting submission"):
    # Apply blended per-class thresholds
    pred_labels = (pred_row > thresholds).astype(int)
    
    # If no labels are predicted, take the one with the highest probability as a fallback
    if pred_labels.sum() == 0:
        pred_labels[pred_row.argmax()] = 1
        
    # Convert indices to attribute_ids
    attr_ids = [CFG.idx_to_attr_id[i] for i, label in enumerate(pred_labels) if label == 1]
    predictions.append(' '.join(map(str, attr_ids)))

sub_df['attribute_ids'] = predictions
sub_df[['id', 'attribute_ids']].to_csv('submission.csv', index=False)
print("Submission file created successfully!")
display(sub_df.head())