In [71]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

import timm
import torch.nn as nn

import torch.nn.functional as F

import torch.optim as optim

import numpy as np

from sklearn.model_selection import train_test_split
from tqdm import tqdm

from torch.optim.lr_scheduler import ReduceLROnPlateau

In [72]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [73]:
class LumbarSpineDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.levels = ['L1_L2', 'L2_L3', 'L3_L4', 'L4_L5', 'L5_S1']
        self.conditions = ['spinal_canal_stenosis', 'left_neural_foraminal_narrowing', 
                           'right_neural_foraminal_narrowing', 'left_subarticular_stenosis', 
                           'right_subarticular_stenosis']
        self.valid_indices = self._get_valid_indices()
        
    def _get_valid_indices(self):
        valid_indices = []
        for idx, row in self.data.iterrows():
            study_id = str(row['study_id'])
            study_folder = os.path.join(self.img_dir, study_id)
            if not os.path.exists(study_folder):
                continue
            subfolders = [f for f in os.listdir(study_folder) if os.path.isdir(os.path.join(study_folder, f))]
            if not subfolders:
                continue
            subfolder = subfolders[0]
            if all(os.path.exists(os.path.join(study_folder, subfolder, f"{level}.png")) for level in self.levels):
                valid_indices.append(idx)
        return valid_indices
        
    def __len__(self):
        return len(self.valid_indices)
    
    def __getitem__(self, idx):
        real_idx = self.valid_indices[idx]
        study_id = str(self.data.iloc[real_idx]['study_id'])
        images = []
        labels = []
        
        study_folder = os.path.join(self.img_dir, study_id)
        subfolders = [f for f in os.listdir(study_folder) if os.path.isdir(os.path.join(study_folder, f))]
        subfolder = subfolders[0]  # Assume there's only one subfolder
        
        for level in self.levels:
            img_path = os.path.join(study_folder, subfolder, f"{level}.png")
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            images.append(image)
            
            level_labels = []
            for condition in self.conditions:
                col_name = f"{condition}_{level.lower()}"
                severity = self.data.iloc[real_idx][col_name]
                label = torch.zeros(3)
                if severity == 'Normal/Mild':
                    label[0] = 1
                elif severity == 'Moderate':
                    label[1] = 1
                elif severity == 'Severe':
                    label[2] = 1
                level_labels.extend(label)
            labels.extend(level_labels)
        
        return images, torch.tensor(labels)

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create dataset and data loader
train_dataset = LumbarSpineDataset('data/train.csv', 'crops224', transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)

In [74]:
class LumbarSpineModel(nn.Module):
    def __init__(self, num_classes=75):  # 5 levels * 5 conditions * 3 severities
        super().__init__()
        self.backbone = timm.create_model('maxvit_tiny_tf_224.in1k', pretrained=True, num_classes=0)

        # Get the number of output features from the backbone
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224)
            features = self.backbone(dummy_input)
            num_features = features.shape[1]
        
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(num_features * 5, num_classes)
        
    def forward(self, x):
        if isinstance(x, list):
            # If x is a list of tensors (one for each image)
            features = []
            for i, img in enumerate(x):
                feat = self.backbone(img)
                features.append(feat)
            combined_features = torch.cat(features, dim=1)
        else:
            # If x is a single tensor of shape (batch_size, 5, 3, H, W)
            batch_size, num_images, C, H, W = x.shape
            x = x.view(batch_size * num_images, C, H, W)
            features = self.backbone(x)
            combined_features = features.view(batch_size, -1)
        
        x = self.dropout(combined_features)
        return self.fc(combined_features)

model = LumbarSpineModel().to(device)

In [75]:
def weighted_bce_loss(predictions, targets, pos_weights):
    return F.binary_cross_entropy_with_logits(predictions, targets, pos_weight=pos_weights)

# Define pos_weights based on severity levels
pos_weights = torch.tensor([1.0, 2.0, 4.0]).repeat(25).to(device)  # 25 = 5 levels * 5 conditions

In [76]:
def weighted_log_loss(y_true, y_pred, weights=[1, 2, 4]):
    """
    Calculate weighted log loss
    """
    y_pred = np.clip(y_pred, 1e-7, 1 - 1e-7)
    loss = 0
    for i in range(0, y_true.shape[1], 3):
        loss += np.sum(weights * (y_true[:, i:i+3] * np.log(y_pred[:, i:i+3])))
    return -loss / y_true.shape[0]

# Split the data into train and validation sets
train_indices, val_indices = train_test_split(range(len(train_dataset)), test_size=0.2, random_state=42)
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, sampler=train_sampler, num_workers=4)
val_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, sampler=val_sampler, num_workers=4)

model = LumbarSpineModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

# Training loop
num_epochs = 50
best_val_loss = float('inf')
patience = 10
patience_counter = 0

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    train_weighted_log_loss = 0
    train_batches = 0

    # Training loop with tqdm
    train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]', leave=False)
    for data, target in train_pbar:
        try:
            data = [img.to(device) for img in data]
            target = target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = weighted_bce_loss(output, target, pos_weights)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item()
            train_weighted_log_loss += weighted_log_loss(target.cpu().numpy(), torch.sigmoid(output).detach().cpu().numpy())
            train_batches += 1

            # Update progress bar
            train_pbar.set_postfix({'Loss': f'{train_loss/train_batches:.4f}', 'WLogLoss': f'{train_weighted_log_loss/train_batches:.4f}'})
        except Exception as e:
            print(f"Error occurred: {str(e)}")
            print(f"Data shapes: {[img.shape for img in data]}")
            print(f"Target shape: {target.shape}")
            print(f"Model structure: {model}")
            raise e
        
    # Validation step
    model.eval()
    val_loss = 0
    val_weighted_log_loss = 0
    val_batches = 0

    # Validation loop with tqdm
    val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]', leave=False)
    with torch.no_grad():
        for data, target in val_pbar:
            data = [img.to(device) for img in data]
            target = target.to(device)
            
            output = model(data)
            loss = weighted_bce_loss(output, target, pos_weights)
            
            val_loss += loss.item()
            val_weighted_log_loss += weighted_log_loss(target.cpu().numpy(), torch.sigmoid(output).cpu().numpy())
            val_batches += 1

            # Update progress bar
            val_pbar.set_postfix({'Loss': f'{val_loss/val_batches:.4f}', 'WLogLoss': f'{val_weighted_log_loss/val_batches:.4f}'})

    # Print epoch results
    print(f'Epoch {epoch+1}/{num_epochs}:')
    print(f'Train Loss: {train_loss/train_batches:.4f}')
    print(f'Train Weighted Log Loss: {train_weighted_log_loss/train_batches:.4f}')
    print(f'Validation Loss: {val_loss/val_batches:.4f}')
    print(f'Validation Weighted Log Loss: {val_weighted_log_loss/val_batches:.4f}')
    print('-' * 50)

    scheduler.step(val_weighted_log_loss/val_batches)

    # Early stopping
    if val_weighted_log_loss/val_batches < best_val_loss:
        best_val_loss = val_weighted_log_loss/val_batches
        patience_counter = 0
        # Save the best model
        torch.save(model.state_dict(), 'best_lumbar_spine_model.pth')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping triggered after epoch {epoch+1}")
            break

# Save the final model
torch.save(model.state_dict(), 'lumbar_spine_model.pth')

                                                                                                    

Epoch 1/50:
Train Loss: 0.4831
Train Weighted Log Loss: 18.7591
Validation Loss: 0.4447
Validation Weighted Log Loss: 16.0321
--------------------------------------------------


                                                                                                    

Epoch 2/50:
Train Loss: 0.4052
Train Weighted Log Loss: 15.6060
Validation Loss: 0.4241
Validation Weighted Log Loss: 16.1417
--------------------------------------------------


                                                                                                    

Epoch 3/50:
Train Loss: 0.3733
Train Weighted Log Loss: 14.2439
Validation Loss: 0.4249
Validation Weighted Log Loss: 16.8188
--------------------------------------------------


                                                                                                    

Epoch 4/50:
Train Loss: 0.3400
Train Weighted Log Loss: 12.8922
Validation Loss: 0.4166
Validation Weighted Log Loss: 16.9980
--------------------------------------------------


                                                                                                    

Epoch 5/50:
Train Loss: 0.2999
Train Weighted Log Loss: 11.1046
Validation Loss: 0.4256
Validation Weighted Log Loss: 17.7534
--------------------------------------------------


                                                                                                   

Epoch 6/50:
Train Loss: 0.2572
Train Weighted Log Loss: 9.3945
Validation Loss: 0.4373
Validation Weighted Log Loss: 19.2279
--------------------------------------------------


                                                                                                   

Epoch 7/50:
Train Loss: 0.2154
Train Weighted Log Loss: 7.7843
Validation Loss: 0.4510
Validation Weighted Log Loss: 20.4508
--------------------------------------------------


                                                                                                   

Epoch 8/50:
Train Loss: 0.1642
Train Weighted Log Loss: 5.8250
Validation Loss: 0.4714
Validation Weighted Log Loss: 22.2791
--------------------------------------------------


                                                                                                   

Epoch 9/50:
Train Loss: 0.1303
Train Weighted Log Loss: 4.5781
Validation Loss: 0.4710
Validation Weighted Log Loss: 21.8398
--------------------------------------------------


                                                                                                    

Epoch 10/50:
Train Loss: 0.1048
Train Weighted Log Loss: 3.6522
Validation Loss: 0.5050
Validation Weighted Log Loss: 24.5458
--------------------------------------------------


                                                                                                   

KeyboardInterrupt: 

In [64]:
def generate_submission(model, test_loader, submission_file):
    model.eval()
    predictions = []
    
    with torch.no_grad():
        for data, _ in test_loader:
            data = [img.to(device) for img in data]
            output = model(data)
            probs = torch.sigmoid(output).cpu().numpy()
            predictions.append(probs)
    
    predictions = np.concatenate(predictions, axis=0)
    
    # Create submission DataFrame
    submission = pd.read_csv('sample_submission.csv')
    submission.iloc[:, 1:] = predictions
    
    # Ensure probabilities sum to 1 for each condition
    for i in range(0, submission.shape[1] - 1, 3):
        submission.iloc[:, i+1:i+4] = submission.iloc[:, i+1:i+4].div(submission.iloc[:, i+1:i+4].sum(axis=1), axis=0)
    
    submission.to_csv(submission_file, index=False)

# Create test dataset and dataloader
test_dataset = LumbarSpineDataset('test.csv', 'test_images', transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=4)

# Load the trained model
model = LumbarSpineModel().to(device)
model.load_state_dict(torch.load('lumbar_spine_model.pth'))

# Generate submission
generate_submission(model, test_loader, 'submission.csv')