In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from CustomDataset import CustomImageDataset
from Models import ResNet101
from utils.Plots import plot_confusion, plot_roc, reliability_diagram,
    expected_calibration_error, brier_score, entropy_hist

In [None]:
# Config
mode = 'Binary'
model_name = 'resnet101'
epochs=10
batch_size=32
seed=777
save_dir='results/resnet101'

In [None]:
# Dataset sanity check
ds_train=CustomImageDataset(mode=mode,build_div='train')
ds_val=CustomImageDataset(mode=mode,build_div='val')
ds_test=CustomImageDataset(mode=mode,build_div='test')
print(len(ds_train),len(ds_val),len(ds_test))

In [None]:
# Build model
model=ResNet101(input_channel=3,label_num=1)
model

In [None]:
# Training loop ... (same as baseline but resnet101)
import torch.nn as nn, torch.optim as optim, os
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
train_loader=DataLoader(ds_train,batch_size=batch_size,shuffle=True)
val_loader=DataLoader(ds_val,batch_size=batch_size)
criterion=nn.BCEWithLogitsLoss()
opt=optim.Adam(model.parameters(),lr=1e-4)
for ep in range(1,epochs+1):
    model.train();losses=0;correct=0
    for x,y in train_loader:
        x,y=x.to(device),y.to(device).float().unsqueeze(1)
        out=model(x);loss=criterion(out,y)
        opt.zero_grad();loss.backward();opt.step()
        losses+=loss.item();correct+=((torch.sigmoid(out)>=0.5).int()==y.int()).sum().item()
    print(ep,losses/len(train_loader),correct/len(ds_train))

In [None]:
# Save and Eval
torch.save(model.state_dict(),f"{save_dir}/best_resnet101.pt")