In [4]:
# injury_classifier.py
import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from sklearn.metrics import accuracy_score, roc_auc_score
from tqdm import tqdm

In [5]:
# "C:/Users/sanke/OneDrive/Desktop/The Stray help/dataset_1", "C:/Users/sanke/OneDrive/Desktop/The Stray help/dataset2/dataset_2","C:/Users/sanke/OneDrive/Desktop/The Stray help/dataset_3","C:/Users/sanke/OneDrive/Desktop/The Stray help/dataset_4","C:/Users/sanke/OneDrive/Desktop/The Stray help/dataset_5","C:/Users/sanke/OneDrive/Desktop/The Stray help/dataset_6"

In [6]:
CONFIG = {
    "dataset_paths": [
         "C:/Users/sanke/OneDrive/Desktop/The Stray help/dataset_1", "C:/Users/sanke/OneDrive/Desktop/The Stray help/dataset_3","C:/Users/sanke/OneDrive/Desktop/The Stray help/dataset_4","C:/Users/sanke/OneDrive/Desktop/The Stray help/dataset_5","C:/Users/sanke/OneDrive/Desktop/The Stray help/dataset_6"],
    "image_size": 224,
    "batch_size": 16,
    "num_epochs": 10,
    "num_workers": 4,
    "lr": 1e-4,
    "device": "cpu"
}

In [7]:
# Define PyTorch-native transforms
train_transform = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])


In [8]:
val_transform = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [9]:
class InjuryDataset(Dataset):
    def __init__(self, split_dirs, transform=None):
        self.samples = []
        self.transform = transform
        
        for dataset_path in split_dirs:
            img_dir = os.path.join(dataset_path, "images")
            label_dir = os.path.join(dataset_path, "labels")
            
            for img_file in os.listdir(img_dir):
                img_path = os.path.join(img_dir, img_file)
                label_path = os.path.join(label_dir, f"{os.path.splitext(img_file)[0]}.txt")
                
                if os.path.exists(label_path):
                    with open(label_path, 'r') as f:
                        label = int(f.read().strip())
                    self.samples.append((img_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        # Load image with PIL
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image, torch.tensor(label).float()

In [10]:
def create_loaders():
    train_dirs = [os.path.join(d, "train") for d in CONFIG["dataset_paths"]]
    val_dirs = [os.path.join(d, "valid") for d in CONFIG["dataset_paths"]]
    test_dirs = [os.path.join(d, "test") for d in CONFIG["dataset_paths"]]

    train_dataset = InjuryDataset(train_dirs, train_transform)
    val_dataset = InjuryDataset(val_dirs, val_transform)
    test_dataset = InjuryDataset(test_dirs, val_transform)

    # Class balancing
    class_counts = np.bincount([s[1] for s in train_dataset.samples])
    class_weights = 1. / class_counts
    sample_weights = [class_weights[int(s[1])] for s in train_dataset.samples]
    
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=CONFIG['batch_size'],
        sampler=sampler,
        num_workers=CONFIG['num_workers'],
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=CONFIG['num_workers'],
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=CONFIG['num_workers'],
        pin_memory=True
    )

    return train_loader, val_loader, test_loader

In [11]:
class InjuryClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = models.resnet50(pretrained=True)
        self.backbone.fc = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 1)
        )
        
    def forward(self, x):
        return torch.sigmoid(self.backbone(x))


In [12]:
class InjuryDataset(Dataset):
    def __init__(self, split_dirs, transform=None):
        self.samples = []
        self.transform = transform
        
        for dataset_path in split_dirs:
            img_dir = os.path.join(dataset_path, "images")
            label_dir = os.path.join(dataset_path, "labels")
            
            for img_file in os.listdir(img_dir):
                img_path = os.path.join(img_dir, img_file)
                label_path = os.path.join(label_dir, f"{os.path.splitext(img_file)[0]}.txt")
                
                if os.path.exists(label_path):
                    with open(label_path, 'r') as f:
                        # Read all values and take the first one
                        label_values = f.read().strip().split()
                        label = int(label_values[0])  # Use only the first value
                        
                    self.samples.append((img_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        # Load image with PIL
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image, torch.tensor(label).float()


In [None]:
def train():
    model = InjuryClassifier().to(CONFIG['device'])
    optimizer = optim.Adam(model.parameters(), lr=CONFIG['lr'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
    criterion = nn.BCELoss()

    train_loader, val_loader, test_loader = create_loaders()
    
    print(f"Number of batches in train_loader: {len(train_loader)}")

    best_auc = 0
    for epoch in range(CONFIG['num_epochs']):
        model.train()
        train_loss = 0
        
        print(f"Starting epoch {epoch+1}")
        
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            images = images.to(CONFIG['device'])
            labels = labels.to(CONFIG['device'])
            
            optimizer.zero_grad()
            outputs = model(images).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * images.size(0)

            # Check for NaN loss
            assert not torch.isnan(loss), "Loss is NaN!"

        # Validation
        model.eval()
        val_preds = []
        val_labels = []
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(CONFIG['device'])
                outputs = model(images).squeeze().cpu().numpy()
                
                val_preds.extend(outputs)
                val_labels.extend(labels.cpu().numpy())
        
        val_auc = roc_auc_score(val_labels, val_preds)
        val_acc = accuracy_score(val_labels, np.round(val_preds))
        
        print(f"Epoch {epoch+1}/{CONFIG['num_epochs']}")
        print(f"Train Loss: {train_loss/len(train_loader.dataset):.4f}")
        print(f"Val AUC: {val_auc:.4f} | Val Acc: {val_acc:.4f}")
        
        scheduler.step(val_auc)
        
        if val_auc > best_auc:
            best_auc = val_auc
            torch.save(model.state_dict(), "best_model.pth")

# Run the training function
if __name__ == "__main__":
    train()




Number of batches in train_loader: 212
Starting epoch 1


Epoch 1:   0%|          | 0/212 [00:00<?, ?it/s]