In [1]:
import torch
import sys 
import os 



In [2]:
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..", "src")))
from data.make_dataset import get_train_loader, get_test_loader, get_val_loader
from visualization.visualize import plot_data_distribution

In [3]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using mps device


In [4]:
train_dataset_path = "/Users/wizzy/Documents/school/vision/project-1/data/train"
val_dataset_path = "/Users/wizzy/Documents/school/vision/project-1/data/val"
test_dataset_path = "/Users/wizzy/Documents/school/vision/project-1/data/test"

train_loader = get_train_loader(train_dataset_path)
val_loader = get_val_loader(val_dataset_path)
test_loader = get_test_loader(test_dataset_path)

In [5]:
dataloaders = {}
dataloaders['train'] = train_loader
dataloaders['val'] = val_loader
dataloaders['test'] = test_loader

# Structure

In [6]:
from torchvision.models import resnet50, ResNet50_Weights
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from collections import Counter
import torch.nn as nn
import torch
from block.block import CBAMBlock, InceptionBlock, SEBlock

In [7]:
# Load pretrained model
model = resnet50(weights=ResNet50_Weights.DEFAULT)

# Freeze early layers
for name, param in model.named_parameters():
    if "layer1" in name or "layer2" in name:
        param.requires_grad = False

# Replace layer3 with CBAM
original_layer3 = model.layer3
model.layer3 = nn.Sequential(
    original_layer3,
    CBAMBlock(1024)
)

# Add Inception block after layer4
original_layer4 = model.layer4
model.layer4 = nn.Sequential(
    original_layer4,
    InceptionBlock(2048),  # output channels of layer4
    SEBlock(256)
)

# Get output from Inception
num_ftrs = 256  # output channels from InceptionBlock
class_names = train_loader.dataset.classes
# Classifier head with dropout + GELU + LayerNorm
model.fc = nn.Sequential(
    nn.Linear(num_ftrs, 256),
    nn.LayerNorm(256),
    nn.GELU(),
    nn.Dropout(0.3),
    nn.Linear(256, len(class_names))
)

# Calculate class weights for CrossEntropyLoss
all_labels = [label for _, label in dataloaders['train'].dataset.samples]
class_counts = Counter(all_labels)
num_samples = sum(class_counts.values())
num_classes = len(class_counts)
class_weights = [num_samples / class_counts[i] for i in range(num_classes)]
class_weights = torch.FloatTensor(class_weights).to(device)
class_weights = class_weights / class_weights.sum()

# Loss with class weights + label smoothing
criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.005)

# Optimizer and scheduler
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, weight_decay=1e-4)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=2)


# Training

In [8]:
from models.train_model import train_model

In [9]:
type(dataloaders['train'])

torch.utils.data.dataloader.DataLoader

In [10]:
# trained_model, history = train_model(
#     model=model,
#     criterion=criterion,
#     optimizer=optimizer,
#     scheduler=scheduler,
#     dataloaders=dataloaders,
#     num_epochs=40,
#     early_stop_patience=10,
#     save_path="resnet_cbam_incept.pt"
# )

In [11]:
from models.predict_model import predict_model

In [15]:
model.load_state_dict(torch.load("/Users/wizzy/Documents/school/vision/project-1/checkpoints/resnet_cbam_incept.pt", map_location=device))

<All keys matched successfully>

In [16]:
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=2)
trained_model, history = train_model(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    dataloaders=dataloaders,
    num_epochs=40,
    early_stop_patience=10,
    save_path="resnet_cbam_incept.pt"
)


🔁 Epoch 1/40
------------------------------


python(8182) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8183) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8184) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8185) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.3159 | Acc: 0.9307


python(8315) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8316) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8317) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8318) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 1.0182 | Acc: 0.8400
✅ New best model saved at: ../checkpoints/resnet_cbam_incept.pt

🔁 Epoch 2/40
------------------------------


python(8345) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8346) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8348) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8350) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.3367 | Acc: 0.9267


python(8618) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8619) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8620) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8621) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 0.9865 | Acc: 0.8233
⚠️ No improvement for 1 epoch(s)

🔁 Epoch 3/40
------------------------------


python(8627) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8628) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8629) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8630) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.2573 | Acc: 0.9514


python(8838) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8839) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8840) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8841) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 0.9401 | Acc: 0.8567
✅ New best model saved at: ../checkpoints/resnet_cbam_incept.pt

🔁 Epoch 4/40
------------------------------


python(8866) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8867) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8869) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8870) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.2921 | Acc: 0.9406


python(9091) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9092) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9093) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9094) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 1.0652 | Acc: 0.8133
⚠️ No improvement for 1 epoch(s)

🔁 Epoch 5/40
------------------------------


python(9101) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9102) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9103) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9104) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.2762 | Acc: 0.9466


python(9260) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9261) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9262) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9263) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 1.0079 | Acc: 0.8433
⚠️ No improvement for 2 epoch(s)

🔁 Epoch 6/40
------------------------------


python(9268) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9269) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9270) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9271) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.2306 | Acc: 0.9580


python(9558) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9559) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9560) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9561) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 1.0039 | Acc: 0.8567
⚠️ No improvement for 3 epoch(s)

🔁 Epoch 7/40
------------------------------


python(9566) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9567) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9568) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9569) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.2038 | Acc: 0.9674


python(9791) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9792) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9793) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9794) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 0.9349 | Acc: 0.8667
✅ New best model saved at: ../checkpoints/resnet_cbam_incept.pt

🔁 Epoch 8/40
------------------------------


python(9799) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9800) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9801) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(9802) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.2679 | Acc: 0.9459


python(10051) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10052) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10053) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10054) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 0.9463 | Acc: 0.8333
⚠️ No improvement for 1 epoch(s)

🔁 Epoch 9/40
------------------------------


python(10061) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10062) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10063) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10064) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.2889 | Acc: 0.9420


python(10268) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10269) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10270) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10271) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 1.0755 | Acc: 0.8467
⚠️ No improvement for 2 epoch(s)

🔁 Epoch 10/40
------------------------------


python(10290) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10291) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10292) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10293) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.2645 | Acc: 0.9498


python(10512) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10513) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10514) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10515) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 0.9909 | Acc: 0.8400
⚠️ No improvement for 3 epoch(s)

🔁 Epoch 11/40
------------------------------


python(10520) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10521) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10522) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10523) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.2515 | Acc: 0.9543


python(10775) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10776) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10777) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10778) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 1.0145 | Acc: 0.8400
⚠️ No improvement for 4 epoch(s)

🔁 Epoch 12/40
------------------------------


python(10784) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10785) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10786) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10787) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.2158 | Acc: 0.9647


python(10931) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10932) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10933) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10934) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 1.0303 | Acc: 0.8467
⚠️ No improvement for 5 epoch(s)

🔁 Epoch 13/40
------------------------------


python(10959) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10960) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10961) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(10962) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.1946 | Acc: 0.9703


python(11174) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11175) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11176) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11177) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 0.9678 | Acc: 0.8533
⚠️ No improvement for 6 epoch(s)

🔁 Epoch 14/40
------------------------------


python(11182) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11183) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11184) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11185) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.1855 | Acc: 0.9734


python(11383) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11384) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11385) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11386) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 0.9527 | Acc: 0.8567
⚠️ No improvement for 7 epoch(s)

🔁 Epoch 15/40
------------------------------


python(11391) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11392) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11393) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11394) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.1757 | Acc: 0.9764


python(11632) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11633) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11634) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11635) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 0.9860 | Acc: 0.8567
⚠️ No improvement for 8 epoch(s)

🔁 Epoch 16/40
------------------------------


python(11640) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11641) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11642) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11643) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.2671 | Acc: 0.9483


python(11809) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11810) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11811) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11812) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 1.1089 | Acc: 0.8233
⚠️ No improvement for 9 epoch(s)

🔁 Epoch 17/40
------------------------------


python(11817) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11818) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11819) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(11820) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🟢 Train    | Loss: 0.2838 | Acc: 0.9452


python(12028) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(12029) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(12030) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(12031) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


🔵 Val      | Loss: 1.0929 | Acc: 0.8167
⚠️ No improvement for 10 epoch(s)
⛔ Early stopping triggered.

🏁 Training complete in 150m 46s
🏆 Best Validation Accuracy: 0.8667


In [13]:
predict_model(model, dataloaders['test'], save_path="c5-pred.csv")