In [None]:
import numpy as np
import os.path
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CiliaImages

class CiliaImages(Dataset):
    def __init__(self, root, flow_type, is_train = True, transform = None):
        self.root = root
        self.transform = transform
        self.flow_type = flow_type
        dirs = os.listdir(self.root)
        if is_train:
            self.data_dirs = list(filter(lambda x: x.startswith("1"), dirs))
        else:
            self.data_dirs = list(filter(lambda x: x.startswith("7"), dirs))

    def __getitem__(self, index):
        # Some prefixing.
        prefix = self.data_dirs[index]
        data_path = os.path.join(self.root, prefix)

        # Load the image.
        image = Image.open(os.path.join(data_path, "frames", "00001.png"))

        # Now the mask.
        mask_path = os.path.join(data_path, self.flow_type, "flow", f"{prefix}_z.npy")
        mask_arr = np.load(mask_path)
        mask = Image.fromarray(mask_arr)

        # Any transforms?
        if self.transform is not None:
            image = self.transform(image)
            mask = self.transform(mask) # Do we want this?
        
        return image, mask

    def __len__(self):
        return len(self.data_dirs)

# Define the FCN model with ResNet-50 backbone
class FCNResNet(nn.Module):
    def __init__(self, num_classes):
        super(FCNResNet, self).__init__()
        self.resnet = models.segmentation.fcn_resnet50()
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size = 7)
        self.resnet.fc = nn.Conv2d(2048, num_classes, kernel_size = 1)
    
    def forward(self, x):
        x = self.resnet(x)
        x = torch.nn.functional.interpolate(x, size = (480, 640), mode='bilinear', align_corners=False)
        return x

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set parameters
num_classes = 2
batch_size = 16
lr = 0.001
num_epochs = 10

# Define transforms
transform = transforms.Compose([
    transforms.Grayscale(1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load your custom dataset
train_dataset = CiliaImages(root='path_to_train_images', transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Create an instance of FCNResNet
model = FCNResNet(num_classes)
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training loop
total_steps = len(train_dataloader)
for epoch in range(num_epochs):
    for i, (images, masks) in enumerate(train_dataloader):
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Print training progress
        if (i + 1) % 10 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{total_steps}], Loss: {loss.item():.4f}")