# FCN for Dendrites Segmentation

### Set Proxy Environment Variables

In [None]:
import os

# Set proxy environment variables
os.environ['http_proxy'] = 'http://proxy:80'
os.environ['https_proxy'] = 'http://proxy:80'

### Import Required Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from torchvision.models.segmentation import fcn_resnet50
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score
from torchvision.models.segmentation import FCN_ResNet50_Weights

### Define Custom Dendrite Dataset Class

In [None]:
class DendritesDataset(Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        
        self.images = sorted([f for f in os.listdir(os.path.join(root, "input_images")) if os.path.isfile(os.path.join(root, "input_images", f))])
        self.masks = sorted([f for f in os.listdir(os.path.join(root, "dendrite_images")) if os.path.isfile(os.path.join(root, "dendrite_images", f))])


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "input_images", self.images[idx])
        mask_path = os.path.join(self.root, "dendrite_images", self.masks[idx])

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.transforms:
            image = self.transforms(image)
            mask = self.transforms(mask)
            mask = (mask > 0).float()
        return image, mask

### Define Transformations for Dataset

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

### Create Dataset and DataLoader

In [None]:
root_train = 'Dataset/DeepD3_Training'
root_val = 'Dataset/DeepD3_Validation'

train_loader = DataLoader(
    DendritesDataset(root_train, transforms=train_transform), 
    batch_size=4, 
    shuffle=True, 
    num_workers=2
)

val_loader = DataLoader(
    DendritesDataset(root_val, transforms=val_transform), 
    batch_size=4, 
    shuffle=False, 
    num_workers=2
)

### Initialize Model

In [None]:
# Load the model with the weights
weights = FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1
model = fcn_resnet50(weights=weights, progress=True)
model.classifier[4] = nn.Conv2d(512, 1, kernel_size=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

### Metrics for Evaluation

In [None]:
def calculate_metrics(outputs, masks):
    outputs = torch.sigmoid(outputs)
    outputs = (outputs > 0.5).float()

    outputs = outputs.cpu().detach().numpy().astype(np.uint8).flatten()
    masks = masks.cpu().detach().numpy().astype(np.uint8).flatten()

    accuracy = accuracy_score(masks, outputs)
    precision = precision_score(masks, outputs, zero_division=1)
    recall = recall_score(masks, outputs, zero_division=1)
    
    intersection = np.sum((masks * outputs) > 0)
    union = np.sum((masks + outputs) > 0)
    iou = intersection / union if union > 0 else 0.0
    return accuracy, precision, recall, iou

### Training Function

In [None]:
def train_fn(data_loader, model, criterion, optimizer, epoch, num_epochs):
    model.train()
    running_loss = 0.0
    total_accuracy = 0.0
    total_precision = 0.0
    total_recall = 0.0
    total_iou = 0.0
    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training")
    
    for images, masks in progress_bar:
        images = images.to(device)
        masks = masks.to(device).float()

        optimizer.zero_grad()
        outputs = model(images)['out']
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        accuracy, precision, recall, iou = calculate_metrics(outputs, masks)
        total_accuracy += accuracy * images.size(0)
        total_precision += precision * images.size(0)
        total_recall += recall * images.size(0)
        total_iou += iou * images.size(0)

        progress_bar.set_postfix(loss=loss.item())

    epoch_loss = running_loss / len(data_loader.dataset)
    epoch_accuracy = total_accuracy / len(data_loader.dataset)
    epoch_precision = total_precision / len(data_loader.dataset)
    epoch_recall = total_recall / len(data_loader.dataset)
    epoch_iou = total_iou / len(data_loader.dataset)
    
    return epoch_loss, epoch_accuracy, epoch_precision, epoch_recall, epoch_iou

### Evaluation Function

In [None]:
def eval_fn(data_loader, model, criterion, epoch, num_epochs):
    model.eval()
    running_loss = 0.0
    total_accuracy = 0.0
    total_precision = 0.0
    total_recall = 0.0
    total_iou = 0.0
    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation")
    with torch.no_grad():
        for images, masks in progress_bar:
            images = images.to(device)
            masks = masks.to(device).float()

            outputs = model(images)['out']
            loss = criterion(outputs, masks)

            running_loss += loss.item() * images.size(0)
            accuracy, precision, recall, iou = calculate_metrics(outputs, masks)
            total_accuracy += accuracy * images.size(0)
            total_precision += precision * images.size(0)
            total_recall += recall * images.size(0)
            total_iou += iou * images.size(0)
            progress_bar.set_postfix(loss=loss.item())

    epoch_loss = running_loss / len(data_loader.dataset)
    epoch_accuracy = total_accuracy / len(data_loader.dataset)
    epoch_precision = total_precision / len(data_loader.dataset)
    epoch_recall = total_recall / len(data_loader.dataset)
    epoch_iou = total_iou / len(data_loader.dataset)
    
    return epoch_loss, epoch_accuracy, epoch_precision, epoch_recall, epoch_iou

### Setting Parameters

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

In [None]:
# Define constants
EPOCHS = 100
CHECKPOINT_DIR = 'checkpoints_dendrites'
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, 'model_checkpoint.pth')
METRICS_PATH = os.path.join(CHECKPOINT_DIR, 'metrics.npz')
LOG_PATH = os.path.join(CHECKPOINT_DIR, 'metrics_log.log')

# Create checkpoint directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Initialize variables
best_valid_loss = np.Inf
metrics = {
    'train_losses': [],
    'valid_losses': [],
    'train_accuracies': [],
    'valid_accuracies': [],
    'train_precisions': [],
    'valid_precisions': [],
    'train_recalls': [],
    'valid_recalls': [],
    'train_ious': [],
    'valid_ious': []
}

def load_checkpoint():
    global best_valid_loss
    if os.path.exists(CHECKPOINT_PATH):
        checkpoint = torch.load(CHECKPOINT_PATH)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        best_valid_loss = checkpoint['best_valid_loss']
        start_epoch = checkpoint['epoch'] + 1
        if os.path.exists(METRICS_PATH):
            loaded_metrics = np.load(METRICS_PATH)
            for key in metrics:
                metrics[key] = loaded_metrics[key].tolist()
    else:
        start_epoch = 0
    return start_epoch

def save_checkpoint(epoch):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_valid_loss': best_valid_loss
    }, CHECKPOINT_PATH)
    np.savez(METRICS_PATH, **metrics)
    
    # Log metrics to log file
    with open(LOG_PATH, 'a') as logfile:
        logfile.write(f"Epoch {epoch + 1}/{EPOCHS}\n")
        logfile.write(f"Train Loss: {metrics['train_losses'][-1]:.4f} | Acc: {metrics['train_accuracies'][-1]:.4f} | Precision: {metrics['train_precisions'][-1]:.4f} | Recall: {metrics['train_recalls'][-1]:.4f} | IoU: {metrics['train_ious'][-1]:.4f}\n")
        logfile.write(f"Valid Loss: {metrics['valid_losses'][-1]:.4f} | Acc: {metrics['valid_accuracies'][-1]:.4f} | Precision: {metrics['valid_precisions'][-1]:.4f} | Recall: {metrics['valid_recalls'][-1]:.4f} | IoU: {metrics['valid_ious'][-1]:.4f}\n\n")

### Training Loop

In [None]:
start_epoch = load_checkpoint()
for epoch in range(start_epoch, EPOCHS):
    train_loss, train_accuracy, train_precision, train_recall, train_iou = train_fn(train_loader, model, criterion, optimizer, epoch, EPOCHS)
    valid_loss, valid_accuracy, valid_precision, valid_recall, valid_iou = eval_fn(val_loader, model, criterion, epoch, EPOCHS)

    metrics['train_losses'].append(train_loss)
    metrics['valid_losses'].append(valid_loss)
    metrics['train_accuracies'].append(train_accuracy)
    metrics['valid_accuracies'].append(valid_accuracy)
    metrics['train_precisions'].append(train_precision)
    metrics['valid_precisions'].append(valid_precision)
    metrics['train_recalls'].append(train_recall)
    metrics['valid_recalls'].append(valid_recall)
    metrics['train_ious'].append(train_iou)
    metrics['valid_ious'].append(valid_iou)

    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"Train Loss: {train_loss:.4f} | Acc: {train_accuracy:.4f} | Precision: {train_precision:.4f} | Recall: {train_recall:.4f} | IoU: {train_iou:.4f}")
    print(f"Valid Loss: {valid_loss:.4f} | Acc: {valid_accuracy:.4f} | Precision: {valid_precision:.4f} | Recall: {valid_recall:.4f} | IoU: {valid_iou:.4f}")

    if valid_loss < best_valid_loss:
        torch.save(model.state_dict(), 'dendrite_model.pt')
        best_valid_loss = valid_loss

    save_checkpoint(epoch)
    lr_scheduler.step(valid_loss)

### End of Code