In [1]:
import time
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision
from torchvision.models.efficientnet import MBConvConfig, FusedMBConvConfig

sys.path.append("/jet/home/azhang19/stat 214/stat-214-lab2-group6/code/modeling")
from preprocessing import to_NCHW, pad_to_384x384, standardize_images
from autoencoder import EfficientNetEncoder, EfficientNetDecoder, AutoencoderConfig, masked_mse

device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True

use_amp = True

In [2]:
# Load and preprocess data
data = np.load("/jet/home/azhang19/stat 214/stat-214-lab2-group6/data/array_data.npz")
unlabeled_images, unlabeled_masks, labeled_images, labeled_masks, labels = data["unlabeled_images"], data["unlabeled_masks"], data["labeled_images"], data["labeled_masks"], data["labels"]

unlabeled_images = pad_to_384x384(to_NCHW(unlabeled_images))
unlabeled_masks = pad_to_384x384(unlabeled_masks)

labeled_images = pad_to_384x384(to_NCHW(labeled_images))
labeled_masks = pad_to_384x384(labeled_masks)
labels = pad_to_384x384(labels)

# Convert to tensors and move to GPU
unlabeled_images = torch.tensor(unlabeled_images, dtype=torch.float32).to(device)  # [161, 8, 384, 384]
unlabeled_masks = torch.tensor(unlabeled_masks, dtype=torch.bool).to(device)    # [161, 384, 384]

labeled_images = torch.tensor(labeled_images, dtype=torch.float32).to(device)      # [3, 8, 384, 384]
labeled_masks = torch.tensor(labeled_masks, dtype=torch.bool).to(device)        # [3, 384, 384]
labels = torch.tensor(labels, dtype=torch.long).to(device)                      # [3, 384, 384]


# Standardize images
unlabeled_images, std_channel, mean_channel = standardize_images(unlabeled_images, unlabeled_masks)
labeled_images, _, _ = standardize_images(labeled_images, labeled_masks, std_channel, mean_channel)

In [3]:
config = AutoencoderConfig(num_layers_block=[1, 1, 1], augmentation_flip=True, augmentation_rotate=True)
print(config)

AutoencoderConfig([1, 1, 1], flip=True, rotate=True)


In [4]:
augmentation = []
if config.augmentation_flip:
    augmentation.append(torchvision.transforms.RandomHorizontalFlip(p=0.5))
    augmentation.append(torchvision.transforms.RandomVerticalFlip(p=0.5))
if config.augmentation_rotate:
    augmentation.append(torchvision.transforms.RandomRotation(degrees=180, expand=True,
                        interpolation=torchvision.transforms.InterpolationMode.BILINEAR))
    augmentation.append(torchvision.transforms.RandomCrop(size=384))
augmentation = torchvision.transforms.Compose(augmentation)

def apply_augment(images, masks, augmentation):
    images_masks = torch.cat([masks.unsqueeze(1).float(), images], dim=1)
    images_masks = [augmentation(image_mask) for image_mask in images_masks]
    images_masks = torch.stack(images_masks)
    return images_masks[:, 1:], images_masks[:, 0] > 0.5

augment = lambda images, masks: apply_augment(images, masks, augmentation)

In [5]:
encoder_config = [
    FusedMBConvConfig(1, 3, 1, 16, 16, config.num_layers_block[0]),  # 384x384x8 -> 384x384x16
    FusedMBConvConfig(4, 3, 2, 16, 32, config.num_layers_block[1]),  # 384x384x16 -> 192x192x32
    MBConvConfig(4, 3, 2, 32, 64, config.num_layers_block[2]),       # 192x192x32 -> 96x96x64
]

# Build encoder and decoder
encoder = EfficientNetEncoder(
    inverted_residual_setting=encoder_config,
    dropout=0.1,
    input_channels=8,
    last_channel=64,
)

decoder = EfficientNetDecoder()

autoencoder = nn.Sequential(encoder, decoder).train().to(device)
#compiled_autoencoder = torch.compile(autoencoder)

In [None]:
num_epochs = 20000
ckpt = [100, 200, 400, 800, 1600, 3200, 6400, 12800, 20000]  # Checkpoints for saving model
initial_lr = 1e-3  # Moderate starting LR for AdamW
weight_decay = 1e-2  # Regularization for small dataset

# Optimizer and scheduler
optimizer = optim.AdamW(autoencoder.parameters(), lr=initial_lr, weight_decay=weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)  # Decay to near-zero
scaler = torch.amp.GradScaler(device, enabled=use_amp)

losses = np.zeros(num_epochs)

In [7]:
@torch.compile
def trainer(images, masks, model, augment, optimizer, scheduler, scaler, loss_fn):
    with torch.inference_mode():
        images, masks = augment(images, masks)
    images, masks = images.clone(), masks.clone()
    model.train()
    optimizer.zero_grad(set_to_none=True)

    with torch.amp.autocast(device, enabled=use_amp):
        reconstructions = model(images)
        loss = loss_fn(images, masks, reconstructions)

    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()

    scheduler.step()

    return loss

In [8]:
ckpt_path = "/jet/home/azhang19/stat 214/stat-214-lab2-group6/code/modeling/ckpt"
os.makedirs(f"{ckpt_path}/{str(config)}", exist_ok=True)

In [None]:
for epoch in range(num_epochs):
    t = time.perf_counter()
    loss = trainer(unlabeled_images, unlabeled_masks, autoencoder, augment, optimizer, scheduler, scaler, masked_mse)
    losses[epoch] = loss.item()
    if epoch + 1 in ckpt:
        torch.save(autoencoder.state_dict(), f"{ckpt_path}/{str(config)}/autoencoder_{epoch + 1}.pth")
    print(f"Epoch {epoch + 1}/{num_epochs} - Loss: {loss:.4f} - Time: {time.perf_counter() - t:.2f}s")

W0308 02:40:33.321000 17407 site-packages/torch/_dynamo/variables/tensor.py:869] [8/0] Graph break from `Tensor.item()`, consider setting:
W0308 02:40:33.321000 17407 site-packages/torch/_dynamo/variables/tensor.py:869] [8/0]     torch._dynamo.config.capture_scalar_outputs = True
W0308 02:40:33.321000 17407 site-packages/torch/_dynamo/variables/tensor.py:869] [8/0] or:
W0308 02:40:33.321000 17407 site-packages/torch/_dynamo/variables/tensor.py:869] [8/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0308 02:40:33.321000 17407 site-packages/torch/_dynamo/variables/tensor.py:869] [8/0] to include these operations in the captured graph.
W0308 02:40:33.321000 17407 site-packages/torch/_dynamo/variables/tensor.py:869] [8/0] 
W0308 02:40:33.321000 17407 site-packages/torch/_dynamo/variables/tensor.py:869] [8/0] Graph break: from user code at:
W0308 02:40:33.321000 17407 site-packages/torch/_dynamo/variables/tensor.py:869] [8/0]   File "/jet/home/azhang19/.conda/envs/env_214/lib/python3.13/si

Epoch 1/20000 - Loss: 1.9808 - Time: 56.97s


('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][4].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][9].grad", "L['self'].param_groups[0]['params'][10].grad", "L['self'].param_groups[0]['params'][11].grad", "L['self'].param_groups[0]['params'][12].grad", "L['self'].param_groups[0]['params'][13].grad", "L['self'].param_groups[0]['params'][14].grad", "L['self'].param_groups[0]['params'][15].grad", "L['self'].param_groups[0]['params'][16].grad", "L['self'].param_groups[0]['params'][17].grad", "L['self'].param_groups[0]['params'][18].grad", "L['self'].param_groups[0]['params'][19].grad", "L['self'].param_groups[0]['params

Epoch 2/20000 - Loss: 1.6613 - Time: 7.43s


('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][4].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][9].grad", "L['self'].param_groups[0]['params'][10].grad", "L['self'].param_groups[0]['params'][11].grad", "L['self'].param_groups[0]['params'][12].grad", "L['self'].param_groups[0]['params'][13].grad", "L['self'].param_groups[0]['params'][14].grad", "L['self'].param_groups[0]['params'][15].grad", "L['self'].param_groups[0]['params'][16].grad", "L['self'].param_groups[0]['params'][17].grad", "L['self'].param_groups[0]['params'][18].grad", "L['self'].param_groups[0]['params'][19].grad", "L['self'].param_groups[0]['params

Epoch 3/20000 - Loss: 1.4127 - Time: 4.89s


('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][4].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][9].grad", "L['self'].param_groups[0]['params'][10].grad", "L['self'].param_groups[0]['params'][11].grad", "L['self'].param_groups[0]['params'][12].grad", "L['self'].param_groups[0]['params'][13].grad", "L['self'].param_groups[0]['params'][14].grad", "L['self'].param_groups[0]['params'][15].grad", "L['self'].param_groups[0]['params'][16].grad", "L['self'].param_groups[0]['params'][17].grad", "L['self'].param_groups[0]['params'][18].grad", "L['self'].param_groups[0]['params'][19].grad", "L['self'].param_groups[0]['params

Epoch 4/20000 - Loss: 1.2305 - Time: 5.81s


('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][4].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][9].grad", "L['self'].param_groups[0]['params'][10].grad", "L['self'].param_groups[0]['params'][11].grad", "L['self'].param_groups[0]['params'][12].grad", "L['self'].param_groups[0]['params'][13].grad", "L['self'].param_groups[0]['params'][14].grad", "L['self'].param_groups[0]['params'][15].grad", "L['self'].param_groups[0]['params'][16].grad", "L['self'].param_groups[0]['params'][17].grad", "L['self'].param_groups[0]['params'][18].grad", "L['self'].param_groups[0]['params'][19].grad", "L['self'].param_groups[0]['params

Epoch 5/20000 - Loss: 1.0779 - Time: 5.06s


('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][4].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][9].grad", "L['self'].param_groups[0]['params'][10].grad", "L['self'].param_groups[0]['params'][11].grad", "L['self'].param_groups[0]['params'][12].grad", "L['self'].param_groups[0]['params'][13].grad", "L['self'].param_groups[0]['params'][14].grad", "L['self'].param_groups[0]['params'][15].grad", "L['self'].param_groups[0]['params'][16].grad", "L['self'].param_groups[0]['params'][17].grad", "L['self'].param_groups[0]['params'][18].grad", "L['self'].param_groups[0]['params'][19].grad", "L['self'].param_groups[0]['params

Epoch 6/20000 - Loss: 0.9597 - Time: 5.82s


('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][4].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][9].grad", "L['self'].param_groups[0]['params'][10].grad", "L['self'].param_groups[0]['params'][11].grad", "L['self'].param_groups[0]['params'][12].grad", "L['self'].param_groups[0]['params'][13].grad", "L['self'].param_groups[0]['params'][14].grad", "L['self'].param_groups[0]['params'][15].grad", "L['self'].param_groups[0]['params'][16].grad", "L['self'].param_groups[0]['params'][17].grad", "L['self'].param_groups[0]['params'][18].grad", "L['self'].param_groups[0]['params'][19].grad", "L['self'].param_groups[0]['params

Epoch 7/20000 - Loss: 0.8502 - Time: 4.88s


('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][4].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][9].grad", "L['self'].param_groups[0]['params'][10].grad", "L['self'].param_groups[0]['params'][11].grad", "L['self'].param_groups[0]['params'][12].grad", "L['self'].param_groups[0]['params'][13].grad", "L['self'].param_groups[0]['params'][14].grad", "L['self'].param_groups[0]['params'][15].grad", "L['self'].param_groups[0]['params'][16].grad", "L['self'].param_groups[0]['params'][17].grad", "L['self'].param_groups[0]['params'][18].grad", "L['self'].param_groups[0]['params'][19].grad", "L['self'].param_groups[0]['params

Epoch 8/20000 - Loss: 0.7796 - Time: 5.86s
Epoch 9/20000 - Loss: 0.7119 - Time: 0.22s
Epoch 10/20000 - Loss: 0.6626 - Time: 0.19s
Epoch 11/20000 - Loss: 0.6097 - Time: 0.20s
Epoch 12/20000 - Loss: 0.5769 - Time: 0.20s
Epoch 13/20000 - Loss: 0.5449 - Time: 0.20s
Epoch 14/20000 - Loss: 0.5047 - Time: 0.19s
Epoch 15/20000 - Loss: 0.4744 - Time: 0.20s
Epoch 16/20000 - Loss: 0.4493 - Time: 0.19s
Epoch 17/20000 - Loss: 0.4252 - Time: 0.19s
Epoch 18/20000 - Loss: 0.4010 - Time: 0.19s
Epoch 19/20000 - Loss: 0.3806 - Time: 0.20s
Epoch 20/20000 - Loss: 0.3656 - Time: 0.19s
Epoch 21/20000 - Loss: 0.3368 - Time: 0.20s
Epoch 22/20000 - Loss: 0.3290 - Time: 0.19s
Epoch 23/20000 - Loss: 0.3071 - Time: 0.20s
Epoch 24/20000 - Loss: 0.2931 - Time: 0.19s
Epoch 25/20000 - Loss: 0.2820 - Time: 0.20s
Epoch 26/20000 - Loss: 0.2740 - Time: 0.19s
Epoch 27/20000 - Loss: 0.2557 - Time: 0.20s
Epoch 28/20000 - Loss: 0.2500 - Time: 0.20s
Epoch 29/20000 - Loss: 0.2388 - Time: 0.20s
Epoch 30/20000 - Loss: 0.2308 - Ti