In [None]:
import sys
import os
sys.path.append(os.path.abspath('..')) 

import torch
from torch import nn
from torch.utils.data import DataLoader
import yaml
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

from models.basic_cnn import BasicCNN
from models.pretrained_cnn import PretrainedCNN
from utils.dataset import create_data_loaders


## Loss function

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=0.25):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, inputs, targets):
        CE_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-CE_loss)
        focal_loss = self.alpha * (1 - pt)**self.gamma * CE_loss
        return focal_loss.mean()

class BalancedCrossEntropyLoss(nn.Module):
    def __init__(self, weight=None):
        super(BalancedCrossEntropyLoss, self).__init__()
        self.weight = weight

    def forward(self, inputs, targets):
        return nn.CrossEntropyLoss(weight=self.weight)(inputs, targets)

## Trainer

In [None]:
class ClassifierTrainer:
    def __init__(self, model, train_loader, val_loader, loss="ce", lr=0.001):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # self.device='cpu'
        print(f"Device: {self.device}")
        self.losses = {
            'focal': {'gamma': 2.0, 'alpha': 0.25},
            'wce': {'weight': [1.0, 2.0, 1.5]},  # Example weights
            'ce': {'weight': [1.0, 1.0, 1.0]}  # Example weights
        }

        # Initialize loss function
        if loss == 'focal':
            self.criterion = FocalLoss(
                gamma= self.losses['focal']['gamma'],
                alpha= self.losses['focal']['alpha']
            )
        elif loss == 'wce':
            self.criterion = BalancedCrossEntropyLoss(
                weight=torch.tensor(self.losses['wce']['weight'])
            )
        else:  # CE loss
            self.criterion = nn.CrossEntropyLoss(
                weight=torch.tensor(self.losses['ce']['weight'])
            )
            
        self.model.to(self.device)
        self.optimizer = torch.optim.Adam(
            self.model.parameters(), 
            lr=lr
        )

    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        
        for inputs, labels in tqdm(self.train_loader):
            inputs = inputs.to(self.device).float()
            labels = labels.to(self.device).long()
            
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            
            loss = self.criterion(outputs.to(self.device), labels)
            loss.backward()
            self.optimizer.step()
            
            running_loss += loss.item()
            
        return running_loss / len(self.train_loader)

    def validate(self):
        self.model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in self.val_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                val_loss += loss.item()
                
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
        return val_loss / len(self.val_loader), correct / total

    def train(self, epochs):
        for epoch in range(epochs):
            train_loss = self.train_epoch()
            val_loss, val_acc = self.validate()
            print(f"Epoch {epoch+1}/{epochs}")
            print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")



In [None]:
data_dir = "../datasets/"
batch_size = 4
epochs = 50
loss = 'ce'
# Create data loaders
train_loader, val_loader = create_data_loaders(data_dir, batch_size)

# Initialize models
basic_model = BasicCNN()
# pretrained_model = PretrainedCNN(model_name='resnet18', num_classes=3)

# Initialize trainers
basic_trainer = ClassifierTrainer(basic_model, train_loader, val_loader)
# pretrained_trainer = ClassifierTrainer(pretrained_model, train_loader, val_loader)

# Train models
basic_trainer.train(epochs=50)
# pretrained_trainer.train(epochs=50)