In [1]:
import torch
import torch.nn as nn
import sys
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
sys.path.append("/kaggle/input/read-data")
from read_data import PCamDataset, transform


In [2]:
batch_size = 32
learning_rate = 1e-3
momentum = 0.9
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_pref = "/kaggle/input/train-data/"
val_pref = "/kaggle/input/val-data/"

train_dataset = PCamDataset(train_pref+"camelyonpatch_level_2_split_train_x.h5", train_pref+"camelyonpatch_level_2_split_train_y.h5")
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers = 4,pin_memory=True, persistent_workers=True)
val_dataset = PCamDataset(val_pref+"camelyonpatch_level_2_split_valid_x.h5", val_pref+"camelyonpatch_level_2_split_valid_y.h5")
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers = 4,pin_memory=True, persistent_workers=True)


In [3]:
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=stride, groups=in_channels, bias=False)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv = DepthwiseSeparableConv(in_channels, out_channels, stride)

        self.skip = nn.Identity()
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),  
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        return self.conv(x) + self.skip(x)

class InceptionModule(nn.Module):
    def __init__(self, in_channels):
        super(InceptionModule, self).__init__()
        
        self.conv1x1 = nn.Conv2d(in_channels, 64, kernel_size=1)

        self.conv3x3 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)

        self.conv5x5 = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
            nn.Conv2d(32, 64, kernel_size=3, padding=1)
        )
    
    def forward(self, x):
        return torch.cat([self.conv1x1(x), self.conv3x3(x), self.conv5x5(x)], dim=1)

class CustomCNN(nn.Module):
    def __init__(self):
        super(CustomCNN, self).__init__()
        self.initial_conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.res1 = ResidualBlock(64, 64)
        self.res2 = ResidualBlock(64, 128, 2)
        self.res3 = ResidualBlock(128, 256, 2)

        self.inception = InceptionModule(256)
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

        self.fc = nn.Sequential(
            nn.Linear(192, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )
                
    def forward(self, x):
        x = self.initial_conv(x)
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = self.inception(x)
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


In [4]:
cnn = CustomCNN().to(device)
# cnn = cnn.to(memory_format=torch.channels_last) 
criterion = nn.BCEWithLogitsLoss()
optimizer_cnn = optim.SGD(cnn.parameters(), lr=learning_rate, momentum=momentum) 

history = []

def evaluate(model, dataloader):
    model.eval()  
    total_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad(): 
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device).float().unsqueeze(1)  # Ensure labels have shape [batch_size, 1]
            outputs = model(images)
            loss = criterion(outputs, labels)  # Ensure labels match output shape
            total_loss += loss.item()
            
            predicted = (torch.sigmoid(outputs) > 0.5).long()  # Convert logits to binary predictions
            correct += predicted.eq(labels.long()).sum().item()
            total += labels.size(0)
    
    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total
    return avg_loss, accuracy


def train_model(model, optimizer, train_loader, val_loader, epochs, model_name):
    for epoch in range(epochs):
        model.train()
        progress_bar = tqdm(train_loader, desc=f"{model_name} - Epoch {epoch+1}/{epochs}")

        for imgs, labels in progress_bar:
            imgs, labels = imgs.to(device), labels.to(device).float().unsqueeze(1)  # Ensure labels have shape [batch_size, 1]
            
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
        train_loss, train_accuracy = evaluate(model, train_loader)
        val_loss, val_accuracy = evaluate(model, val_loader)
        history.append([train_loss, train_accuracy, val_loss, val_accuracy])  
        print("VAL: ", val_accuracy)
        print("TRAIN: ", train_accuracy)


In [5]:
train_model(cnn, optimizer_cnn, train_loader, val_loader, 25, "CNN")
save_path = f"/kaggle/working/cnn.pth"
torch.save(cnn.state_dict(), save_path)

print(history)

CNN - Epoch 1/25: 100%|██████████| 8192/8192 [04:11<00:00, 32.57it/s]


VAL:  78.8543701171875
TRAIN:  84.43222045898438


CNN - Epoch 2/25: 100%|██████████| 8192/8192 [04:02<00:00, 33.77it/s]


VAL:  82.5103759765625
TRAIN:  84.38568115234375


CNN - Epoch 3/25: 100%|██████████| 8192/8192 [04:02<00:00, 33.80it/s]


VAL:  80.096435546875
TRAIN:  86.2640380859375


CNN - Epoch 4/25:  68%|██████▊   | 5536/8192 [02:43<01:18, 33.80it/s]


KeyboardInterrupt: 