In [7]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import timm
from torchvision import transforms as T
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score
from tqdm import tqdm
import gc
import time

# For multi-label stratification
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

class CFG:
    # General
    debug = False  # Set to True to run on a small subset for quick debugging
    seed = 42
    num_workers = 4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Data paths
    data_dir = './'
    train_csv_path = os.path.join(data_dir, 'train.csv')
    labels_csv_path = os.path.join(data_dir, 'labels.csv')
    sample_submission_path = os.path.join(data_dir, 'sample_submission.csv')
    train_img_dir = os.path.join(data_dir, 'train')
    test_img_dir = os.path.join(data_dir, 'test')

    # Model
    model_name = 'tf_efficientnet_b4_ns'
    img_size = 384

    # Training
    epochs = 8
    batch_size = 32
    lr = 1e-4
    weight_decay = 1e-6
    n_folds = 3
    target_fold = 2 # Train fold 2

    # Output
    output_dir = 'models'

if not os.path.exists(CFG.output_dir):
    os.makedirs(CFG.output_dir)

In [8]:
## 1.2 Data Loading and Preprocessing

def get_df():
    # Load dataframes
    train_df = pd.read_csv(CFG.train_csv_path)
    labels_df = pd.read_csv(CFG.labels_csv_path)
    
    # If in debug mode, sample the dataframe first for speed
    if CFG.debug:
        print("Running in debug mode, sampling 1000 unique images.")
        unique_ids = train_df['id'].unique()
        sampled_ids = np.random.choice(unique_ids, size=1000, replace=False)
        train_df = train_df[train_df['id'].isin(sampled_ids)].reset_index(drop=True)

    # Create a mapping from attribute_id to a continuous index
    CFG.attr_ids = labels_df['attribute_id'].values
    CFG.attr_id_to_idx = {attr_id: i for i, attr_id in enumerate(CFG.attr_ids)}
    CFG.idx_to_attr_id = {i: attr_id for i, attr_id in enumerate(CFG.attr_ids)}
    CFG.num_classes = len(labels_df)
    print(f"Number of classes: {CFG.num_classes}")

    # Process train_df to create multi-hot encoded labels
    # Group by id and aggregate attribute_ids
    train_agg = train_df.groupby('id')['attribute_ids'].apply(lambda x: ' '.join(x)).reset_index()
    
    # Create the multi-hot encoded matrix
    targets = np.zeros((len(train_agg), CFG.num_classes), dtype=np.int8)
    for i, row in tqdm(train_agg.iterrows(), total=len(train_agg), desc="Processing labels"):
        attr_ids = [int(attr_id) for attr_id in row['attribute_ids'].split()]
        for attr_id in attr_ids:
            if attr_id in CFG.attr_id_to_idx:
                targets[i, CFG.attr_id_to_idx[attr_id]] = 1
    
    train_agg['targets'] = list(targets)
    
    # Add file paths
    train_agg['filepath'] = train_agg['id'].apply(lambda x: os.path.join(CFG.train_img_dir, x + '.png'))
    
    # Create folds with MultilabelStratifiedKFold
    print("Creating folds with MultilabelStratifiedKFold...")
    y_labels = np.array(train_agg['targets'].tolist())
    mskf = MultilabelStratifiedKFold(n_splits=CFG.n_folds, shuffle=True, random_state=CFG.seed)
    train_agg['fold'] = -1
    for fold, (_, val_idx) in enumerate(mskf.split(np.zeros(len(train_agg)), y_labels)):
        train_agg.loc[val_idx, 'fold'] = fold
        
    return train_agg

df = get_df()
display(df.head())
print(f"Shape of the dataframe: {df.shape}")
print(f"Fold distribution:\n{df['fold'].value_counts()}")

Number of classes: 3474


Processing labels:   0%|          | 0/120801 [00:00<?, ?it/s]

Processing labels:   3%|▎         | 3688/120801 [00:00<00:03, 36875.73it/s]

Processing labels:   6%|▋         | 7654/120801 [00:00<00:02, 38512.99it/s]

Processing labels:  10%|▉         | 11662/120801 [00:00<00:02, 39226.89it/s]

Processing labels:  13%|█▎        | 15637/120801 [00:00<00:02, 39430.82it/s]

Processing labels:  16%|█▌        | 19610/120801 [00:00<00:02, 39535.96it/s]

Processing labels:  20%|█▉        | 23564/120801 [00:00<00:02, 39448.52it/s]

Processing labels:  23%|██▎       | 27521/120801 [00:00<00:02, 39485.42it/s]

Processing labels:  26%|██▌       | 31508/120801 [00:00<00:02, 39607.39it/s]

Processing labels:  29%|██▉       | 35469/120801 [00:00<00:02, 39603.38it/s]

Processing labels:  33%|███▎      | 39468/120801 [00:01<00:02, 39721.10it/s]

Processing labels:  36%|███▌      | 43444/120801 [00:01<00:01, 39732.49it/s]

Processing labels:  39%|███▉      | 47503/120801 [00:01<00:01, 39993.02it/s]

Processing labels:  43%|████▎     | 51557/120801 [00:01<00:01, 40158.08it/s]

Processing labels:  46%|████▌     | 55573/120801 [00:01<00:01, 40122.76it/s]

Processing labels:  49%|████▉     | 59586/120801 [00:01<00:01, 40083.02it/s]

Processing labels:  53%|█████▎    | 63595/120801 [00:01<00:01, 39973.16it/s]

Processing labels:  56%|█████▌    | 67593/120801 [00:01<00:01, 39918.12it/s]

Processing labels:  59%|█████▉    | 71591/120801 [00:01<00:01, 39934.05it/s]

Processing labels:  63%|██████▎   | 75594/120801 [00:01<00:01, 39961.36it/s]

Processing labels:  66%|██████▌   | 79596/120801 [00:02<00:01, 39976.44it/s]

Processing labels:  69%|██████▉   | 83594/120801 [00:02<00:00, 39955.89it/s]

Processing labels:  73%|███████▎  | 87590/120801 [00:02<00:00, 39818.57it/s]

Processing labels:  76%|███████▌  | 91572/120801 [00:02<00:00, 39762.07it/s]

Processing labels:  79%|███████▉  | 95560/120801 [00:02<00:00, 39794.66it/s]

Processing labels:  82%|████████▏ | 99540/120801 [00:02<00:00, 39742.41it/s]

Processing labels:  86%|████████▌ | 103515/120801 [00:02<00:00, 39664.38it/s]

Processing labels:  89%|████████▉ | 107499/120801 [00:02<00:00, 39715.29it/s]

Processing labels:  92%|█████████▏| 111471/120801 [00:02<00:00, 39715.98it/s]

Processing labels:  96%|█████████▌| 115484/120801 [00:02<00:00, 39839.40it/s]

Processing labels:  99%|█████████▉| 119468/120801 [00:03<00:00, 39765.50it/s]

Processing labels: 100%|██████████| 120801/120801 [00:03<00:00, 39719.29it/s]




Creating folds with MultilabelStratifiedKFold...


Unnamed: 0,id,attribute_ids,targets,filepath,fold
0,000040d66f14ced4cdd18cd95d91800f,448 2429 782,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",./train/000040d66f14ced4cdd18cd95d91800f.png,2
1,0000ef13e37ef70412166725ec034a8a,2997 3231 2730 3294 3099 2017 784,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",./train/0000ef13e37ef70412166725ec034a8a.png,2
2,0001eeb4a06e8daa7c6951bcd124c3c7,2436 1715 23,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",./train/0001eeb4a06e8daa7c6951bcd124c3c7.png,1
3,000226398d224de78b191e6db45fd94e,2997 3433 448 782,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",./train/000226398d224de78b191e6db45fd94e.png,2
4,00029c3b0171158d63b1bbf803a7d750,3465 3322 3170 1553 781,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",./train/00029c3b0171158d63b1bbf803a7d750.png,2


Shape of the dataframe: (120801, 5)
Fold distribution:
fold
0    40294
1    40256
2    40251
Name: count, dtype: int64


In [9]:
## 1.3 Dataset and Augmentations

def get_transforms(*, data):
    if data == 'train':
        return T.Compose([
            T.RandomResizedCrop(CFG.img_size, scale=(0.8, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            T.ToTensor(),
            T.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ])
    elif data == 'valid':
        return T.Compose([
            T.Resize(CFG.img_size),
            T.CenterCrop(CFG.img_size),
            T.ToTensor(),
            T.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ])

class iMetDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.filepaths = df['filepath'].values
        self.labels = df['targets'].values
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        filepath = self.filepaths[idx]
        image = Image.open(filepath).convert('RGB')
        
        if self.transforms:
            image = self.transforms(image)
            
        label = torch.tensor(self.labels[idx], dtype=torch.float)
        return image, label

In [10]:
## 2.1 Model, Loss, and Optimizer

class iMetModel(nn.Module):
    def __init__(self, model_name, pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=CFG.num_classes)

    def forward(self, x):
        return self.model(x)

# Calculate pos_weight for BCEWithLogitsLoss
# This is important for handling class imbalance
targets_matrix = np.stack(df['targets'].values)
pos_counts = targets_matrix.sum(axis=0)
neg_counts = len(df) - pos_counts
pos_weight = neg_counts / (pos_counts + 1e-6) # Add epsilon to avoid division by zero
pos_weight = torch.tensor(pos_weight, dtype=torch.float).to(CFG.device)

print(f"pos_weight tensor shape: {pos_weight.shape}")
print(f"Device: {CFG.device}")

pos_weight tensor shape: torch.Size([3474])
Device: cuda


In [11]:
## 2.2 Training and Validation Functions

def train_fn(loader, model, criterion, optimizer, scaler, device):
    model.train()
    running_loss = 0.0
    pbar = tqdm(loader, desc="Training")
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        with autocast():
            logits = model(images)
            loss = criterion(logits, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item()
        pbar.set_postfix(loss=loss.item())
        
    avg_loss = running_loss / len(loader)
    return avg_loss

def valid_fn(loader, model, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    pbar = tqdm(loader, desc="Validating")
    with torch.no_grad():
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)
            
            logits = model(images)
            loss = criterion(logits, labels)
            
            running_loss += loss.item()
            
            all_preds.append(logits.sigmoid().cpu().numpy())
            all_labels.append(labels.cpu().numpy())
            
    avg_loss = running_loss / len(loader)
    
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    
    return avg_loss, all_preds, all_labels

def get_best_f1_score(preds, labels):
    best_f1 = 0
    best_thresh = 0
    for thresh in np.arange(0.05, 0.5, 0.01):
        binary_preds = (preds > thresh).astype(int)
        # Handle 'at-least-one' fallback
        for i in range(len(binary_preds)):
            if binary_preds[i].sum() == 0:
                binary_preds[i, preds[i].argmax()] = 1
        
        f1 = f1_score(labels, binary_preds, average='micro')
        if f1 > best_f1:
            best_f1 = f1
            best_thresh = thresh
    return best_f1, best_thresh

In [None]:
## 2.3 Main Training Loop

def run_training(fold):
    print(f"========== Fold: {fold} ==========")
    
    # Create datasets
    train_df = df[df['fold'] != fold].reset_index(drop=True)
    valid_df = df[df['fold'] == fold].reset_index(drop=True)
    
    train_dataset = iMetDataset(train_df, transforms=get_transforms(data='train'))
    valid_dataset = iMetDataset(valid_df, transforms=get_transforms(data='valid'))
    
    train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size * 2, shuffle=False, num_workers=CFG.num_workers, pin_memory=True)
    
    # Init model, criterion, optimizer
    model = iMetModel(CFG.model_name, pretrained=True).to(CFG.device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    scaler = GradScaler()
    
    best_f1 = 0
    best_epoch = -1
    
    for epoch in range(CFG.epochs):
        start_time = time.time()
        
        train_loss = train_fn(train_loader, model, criterion, optimizer, scaler, CFG.device)
        valid_loss, preds, labels = valid_fn(valid_loader, model, criterion, CFG.device)
        
        f1, thresh = get_best_f1_score(preds, labels)
        
        elapsed = time.time() - start_time
        
        print(f"Epoch {epoch+1}/{CFG.epochs} - Train Loss: {train_loss:.4f}, Val Loss: {valid_loss:.4f}, F1: {f1:.4f}, Best Thresh: {thresh:.2f}, Time: {elapsed:.0f}s")
        
        if f1 > best_f1:
            best_f1 = f1
            best_epoch = epoch
            torch.save(model.state_dict(), os.path.join(CFG.output_dir, f'{CFG.model_name}_fold{fold}_best.pth'))
            print(f"  -> New best F1 score: {best_f1:.4f}. Model saved.")
            
    print(f"\nBest F1 score for fold {fold} was {best_f1:.4f} at epoch {best_epoch+1}")
    return best_f1

# Start training for the target fold
run_training(CFG.target_fold)



  model = create_fn(


  scaler = GradScaler()


Training:   0%|          | 0/2518 [00:00<?, ?it/s]

  with autocast():


Training:   0%|          | 0/2518 [00:00<?, ?it/s, loss=1.14]

Training:   0%|          | 1/2518 [00:00<38:25,  1.09it/s, loss=1.14]

Training:   0%|          | 1/2518 [00:01<38:25,  1.09it/s, loss=1.38]

Training:   0%|          | 2/2518 [00:01<21:43,  1.93it/s, loss=1.38]

Training:   0%|          | 2/2518 [00:01<21:43,  1.93it/s, loss=1.23]

Training:   0%|          | 3/2518 [00:01<16:22,  2.56it/s, loss=1.23]

Training:   0%|          | 3/2518 [00:01<16:22,  2.56it/s, loss=2.36]

Training:   0%|          | 4/2518 [00:01<13:55,  3.01it/s, loss=2.36]

Training:   0%|          | 4/2518 [00:01<13:55,  3.01it/s, loss=1.37]

Training:   0%|          | 5/2518 [00:01<12:32,  3.34it/s, loss=1.37]

Training:   0%|          | 5/2518 [00:02<12:32,  3.34it/s, loss=1.27]

Training:   0%|          | 6/2518 [00:02<11:41,  3.58it/s, loss=1.27]

Training:   0%|          | 6/2518 [00:02<11:41,  3.58it/s, loss=1.22]

Training:   0%|          | 7/2518 [00:02<11:09,  3.75it/s, loss=1.22]

Training:   0%|          | 7/2518 [00:02<11:09,  3.75it/s, loss=1.3] 

Training:   0%|          | 8/2518 [00:02<10:49,  3.87it/s, loss=1.3]

Training:   0%|          | 8/2518 [00:02<10:49,  3.87it/s, loss=1.24]

Training:   0%|          | 9/2518 [00:02<10:35,  3.95it/s, loss=1.24]

Training:   0%|          | 9/2518 [00:03<10:35,  3.95it/s, loss=1.05]

Training:   0%|          | 10/2518 [00:03<10:25,  4.01it/s, loss=1.05]

Training:   0%|          | 10/2518 [00:03<10:25,  4.01it/s, loss=1]   

Training:   0%|          | 11/2518 [00:03<10:18,  4.05it/s, loss=1]

Training:   0%|          | 11/2518 [00:03<10:18,  4.05it/s, loss=1.02]

Training:   0%|          | 12/2518 [00:03<10:14,  4.08it/s, loss=1.02]

Training:   0%|          | 12/2518 [00:03<10:14,  4.08it/s, loss=1.49]

Training:   1%|          | 13/2518 [00:03<10:10,  4.10it/s, loss=1.49]

Training:   1%|          | 13/2518 [00:04<10:10,  4.10it/s, loss=1.27]

Training:   1%|          | 14/2518 [00:04<10:07,  4.12it/s, loss=1.27]

Training:   1%|          | 14/2518 [00:04<10:07,  4.12it/s, loss=1.01]

Training:   1%|          | 15/2518 [00:04<10:05,  4.13it/s, loss=1.01]

Training:   1%|          | 15/2518 [00:04<10:05,  4.13it/s, loss=1.31]