In [None]:
# import helper class
from utils import SPARKDataset, PyTorchSPARKDataset
from matplotlib import pyplot as plt
from random import randint

In [None]:
dataset_root_dir = "/Users/umut/Documents/computer-vision-zeta-25/data/"
split = "val"

class_map= {'VenusExpress':0, 'Cheops':1, 'LisaPathfinder':2, 'ObservationSat1':3, 'Proba2':4, 'Proba3':5,
                           'Proba3ocs' :6, 'Smart1':7, 'Soho':8, 'XMM Newton':9} # Class map

dataset = SPARKDataset(class_map, root_dir=dataset_root_dir,split=split) 

In [None]:
rows = 3
cols = 4

fig, axes = plt.subplots(rows, cols, figsize=(15, 15))

for i in range(rows):
    for j in range(cols):
        dataset.visualize(randint(0, 6000),size = (10,10),ax=axes[i][j])
        axes[i][j].axis('off')
fig.tight_layout() 

In [None]:
rows = 3
cols = 4

fig, axes = plt.subplots(rows, cols, figsize=(15, 15))

for i in range(rows):
    for j in range(cols):
        dataset.visualize(randint(0, 6000),size = (10,10),ax=axes[i][j],mask_visualize=True)
        axes[i][j].axis('off')
fig.tight_layout() 

In [None]:
train_dataset = PyTorchSPARKDataset(class_map, root_dir=dataset_root_dir,split="train")
val_dataset = PyTorchSPARKDataset(class_map, root_dir=dataset_root_dir,split="val")

In [None]:
sample_idx = 0  # Index of the sample you want to retrieve
image, sample = val_dataset[sample_idx]

# Now you can access the image, mask, bbox, and class from the sample
mask = sample['masks']
bbox = sample['boxes']
class_label = sample['class']

# If you want to display the image, you can use matplotlib, but remember to convert it back to a PIL image or a NumPy array
import matplotlib.pyplot as plt
import numpy as np

# Convert the tensor image to a NumPy array and display it
# Note: PyTorch tensors are in CxHxW format and need to be converted to HxWxC for matplotlib
reverse_class_map = {v: k for k, v in class_map.items()}

# Use the class label to get the corresponding class name
class_name = reverse_class_map[class_label.item()]


image_np = image.numpy().transpose((1, 2, 0))
plt.imshow(image_np)
plt.title(f'Class: {class_name}')
plt.show()


In [None]:
import torch
import torchvision
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, collate_fn=collate_fn, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=2, collate_fn=collate_fn, shuffle=False, drop_last=True)

In [None]:
# increment for background
num_classes = len(class_map) + 1

# pretrained mask r-cnn
model_det = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")

# fix amount of classes (coco pretrained model has 80 outputs)
in_features = model_det.roi_heads.box_predictor.cls_score.in_features
model_det.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

in_features_mask = model_det.roi_heads.mask_predictor.conv5_mask.in_channels
model_det.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, 256, num_classes)

In [None]:
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_det.to(device)

optimizer = torch.optim.Adam(model_det.parameters(), lr=1e-4)

num_epochs = 20

for epoch in range(num_epochs):
    model_det.train()
    total_loss = 0

    for imgs, samples in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True):
        imgs = [img.to(device) for img in imgs]
        samples = [{k: v.to(device) for k, v in t.items()} for t in samples]

        loss_dict = model_det(imgs, samples)
        loss = sum(loss_dict.values())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}  Loss: {total_loss:.4f}")


In [None]:
model_det.eval()
with torch.no_grad():
    for imgs, targets in val_loader:
        imgs = [img.to(device) for img in imgs]
        predictions = model_det(imgs)
        break

In [None]:
predictions[0].keys()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

img = imgs[0].permute(1,2,0).cpu().numpy()
mask = predictions[0]["masks"][0,0].cpu().numpy()
box = predictions[0]["boxes"][0].cpu().numpy()

plt.imshow(img)
plt.imshow(mask, alpha=0.5)
x1,y1,x2,y2 = box
plt.gca().add_patch(plt.Rectangle((x1,y1), x2-x1, y2-y1,
                                  fill=False, edgecolor='red'))
plt.show()

# PART 2
our dataset has colorcoded segments embedded in images, so ill crop out the bounding boxes and use a pre-implemented u-net model.

In `utils.py` we have:
- hashmap for mapping colors to labels
- utility function for converting the mask colors from the data to labels
- new dataset class

In [None]:
from utils import PART_COLOR_MAP
import segmentation_models_pytorch as smp
import torch.nn as nn

num_classes = len(PART_COLOR_MAP) + 1 # new class amount (2)

model_unet = smp.Unet(
    encoder_name="resnet34", # architecture
    encoder_weights="imagenet", # pretrained weights
    in_channels=3,
    classes=num_classes
).to(device)

loss function: CrossEntropy because after some research i found that it is the standard

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model_unet.parameters(), lr=1e-4)

training

In [None]:
num_epochs = 20

for epoch in range(num_epochs):
    model_unet.train()
    total_loss = 0

    for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True):
        imgs = imgs.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()

        preds = model_unet(imgs)
        loss = criterion(preds, masks)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}: loss: {total_loss:.4f}")