In [4]:

import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm  # Import tqdm for the progress bar

# Import configuration settings
import config

# Define a PyTorch Dataset
class SmokeAlarmDataset(Dataset):
    def __init__(self, df, image_dir, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.transform = transform
        self.label_map = {"Approved": 1, "Declined": 0}
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        job_no = str(self.df.iloc[idx]["job_no"])  # Convert job_no to string
        image_name = self.df.iloc[idx]["image_name"]
        label = self.df.iloc[idx]["image_status"]
        label = self.label_map.get(label, 0)  # Default to 0 if label is missing
        
        image_path = os.path.join(self.image_dir, job_no, image_name)
        
        if not os.path.exists(image_path):
            print(f"Warning: Image not found at {image_path}. Skipping this file.")
            return None, None
        
        image = Image.open(image_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Custom collate function to handle missing files
def collate_fn(batch):
    batch = [item for item in batch if item[0] is not None]
    return torch.utils.data.dataloader.default_collate(batch)

# Define transformations for data preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),          
    transforms.Normalize(mean=config.NORMALIZE_MEAN, std=config.NORMALIZE_STD)
])

# Load the CSV file
df = pd.read_csv(config.CSV_FILE_PATH)

# Initialize the dataset and dataloader with the custom collate function
dataset = SmokeAlarmDataset(df, config.IMAGE_DIR, transform=transform)
dataloader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

# Define the model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.model = getattr(models, config.MODEL_NAME)(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, 1)
    
    def forward(self, x):
        return torch.sigmoid(self.model(x))

model = SimpleCNN()

# Define loss and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)

# Training loop with tqdm progress bar
for epoch in range(config.NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    
    # Use tqdm to add a progress bar
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{config.NUM_EPOCHS}", leave=False)
    for images, labels in progress_bar:
        labels = labels.unsqueeze(1).float()  # Reshape labels for binary output
        
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # Update progress bar description with current loss
        progress_bar.set_postfix(loss=loss.item())
    
    print(f"Epoch [{epoch + 1}/{config.NUM_EPOCHS}], Loss: {running_loss / len(dataloader)}")

print("Training complete.")


Epoch 1/5:   2%|▏         | 8/385 [00:20<16:13,  2.58s/it, loss=0.435]



Epoch 1/5:   9%|▉         | 35/385 [01:29<14:34,  2.50s/it, loss=0.212]



Epoch 1/5:  12%|█▏        | 48/385 [02:02<14:12,  2.53s/it, loss=0.334]



Epoch 1/5:  20%|██        | 78/385 [03:18<14:51,  2.90s/it, loss=0.22] 



Epoch 1/5:  69%|██████▉   | 267/385 [11:33<04:53,  2.48s/it, loss=0.286] 



Epoch 1/5:  78%|███████▊  | 300/385 [12:55<03:30,  2.47s/it, loss=0.416]



                                                                        

Epoch [1/5], Loss: 0.3561022961488018


Epoch 2/5:   7%|▋         | 27/385 [01:05<14:17,  2.40s/it, loss=0.229] 



Epoch 2/5:  20%|█▉        | 76/385 [03:14<12:39,  2.46s/it, loss=0.312] 



Epoch 2/5:  57%|█████▋    | 220/385 [09:03<06:40,  2.43s/it, loss=0.384]



Epoch 2/5:  64%|██████▎   | 245/385 [10:01<05:22,  2.30s/it, loss=0.2]  



                                                                        

Epoch [2/5], Loss: 0.31672789707973403


Epoch 3/5:   9%|▉         | 34/385 [01:25<14:44,  2.52s/it, loss=0.357]



Epoch 3/5:  21%|██▏       | 82/385 [03:31<13:41,  2.71s/it, loss=0.356]



Epoch 3/5:  65%|██████▍   | 250/385 [11:05<05:42,  2.54s/it, loss=0.33] 



Epoch 3/5:  71%|███████   | 272/385 [12:02<04:51,  2.58s/it, loss=0.219]



Epoch 3/5:  72%|███████▏  | 277/385 [12:14<04:35,  2.55s/it, loss=0.266]



Epoch 3/5:  79%|███████▉  | 305/385 [13:26<03:29,  2.61s/it, loss=0.387]



                                                                        

Epoch [3/5], Loss: 0.3028861344828234


Epoch 4/5:  17%|█▋        | 66/385 [02:49<13:20,  2.51s/it, loss=0.142]



Epoch 4/5:  38%|███▊      | 146/385 [06:13<10:05,  2.53s/it, loss=0.983]



Epoch 4/5:  42%|████▏     | 163/385 [06:56<09:22,  2.53s/it, loss=0.188]



Epoch 4/5:  70%|██████▉   | 269/385 [11:28<04:57,  2.56s/it, loss=0.256]



Epoch 4/5:  86%|████████▌ | 331/385 [14:07<02:18,  2.56s/it, loss=0.229]



Epoch 4/5:  95%|█████████▍| 365/385 [15:34<00:51,  2.57s/it, loss=0.26] 



                                                                        

Epoch [4/5], Loss: 0.2895987347542466


Epoch 5/5:  10%|█         | 40/385 [01:43<14:37,  2.54s/it, loss=0.34] 



Epoch 5/5:  13%|█▎        | 50/385 [02:09<14:28,  2.59s/it, loss=0.382]



Epoch 5/5:  36%|███▌      | 139/385 [05:59<10:53,  2.66s/it, loss=0.227]



Epoch 5/5:  55%|█████▌    | 213/385 [09:10<07:12,  2.52s/it, loss=0.246]



Epoch 5/5:  61%|██████    | 234/385 [10:04<06:27,  2.57s/it, loss=0.1]  



Epoch 5/5:  82%|████████▏ | 314/385 [13:28<03:01,  2.56s/it, loss=0.129] 



                                                                         

Epoch [5/5], Loss: 0.27730527924639836
Training complete.


