In [1]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Dataset, Subset
from torchvision import transforms, models
import pytorch_lightning as pl
from pytorch_lightning.utilities.model_summary import ModelSummary
from pytorch_lightning.loggers import WandbLogger
from PIL import Image
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
import wandb
import torchmetrics

In [2]:
class ImageEncoder(pl.LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=2e-4):
        super().__init__()
        
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        n_sizes = self._get_conv_output((3,96,96))
        
        self.dense = nn.Sequential(
            nn.Flatten(),
            nn.Linear(n_sizes, 240),
            nn.ReLU(),
            nn.Linear(240, 120),
            nn.ReLU(),
            nn.Linear(120, 80),
            nn.ReLU(),
            nn.Linear(80, 1)
        )
        
        self.accuracy = torchmetrics.Accuracy(task='binary')
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.dense(x)
        return torch.sigmoid(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy(y_hat, y)
        pred = y_hat.round()
        return {"loss":loss, 'pred':pred, 'y':y}
    
    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        pred = torch.cat([x['pred'] for x in outputs]).to(device=torch.device('cuda'))
        y = torch.cat([x['y'] for x in outputs]).to(torch.int32)
        y = torch.squeeze(y)
        
        pred = torch.flatten(pred)
        y = torch.flatten(y)
        
        avg_acc = self.accuracy(pred, y)
        self.log("train_loss", avg_loss, prog_bar=True, on_epoch=True)
        self.log("train_acc", avg_acc, prog_bar=True, on_epoch=True)
#         wandb.log({'train_loss': avg_loss, 'train_acc':avg_acc})
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        pred = y_hat.round()
        val_loss = F.binary_cross_entropy(y_hat, y)
        return {'val_loss':val_loss,'pred':pred, 'y':y}
    
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        pred = torch.cat([x['pred'] for x in outputs]).to(device=torch.device('cuda'))
        y = torch.cat([x['y'] for x in outputs]).to(torch.int32)
        y = torch.squeeze(y)
        
        pred = torch.flatten(pred)
        y = torch.flatten(y)

        avg_acc = self.accuracy(pred, y).item()
        self.log("val_loss", avg_loss, prog_bar=True, on_epoch=True)
        self.log("val_acc", avg_acc, prog_bar=True, on_epoch=True)
#         wandb.log({'val_loss': avg_loss, 'val_acc':avg_acc})
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0.00001)
        return [optimizer], [scheduler]
    
    def optimizer_step(self,
                     epoch=None, 
                    batch_idx=None, 
                    optimizer=None, 
                    optimizer_idx=None, 
                    optimizer_closure=None, 
                    on_tpu=None, 
                    using_native_amp=None, 
                    using_lbfgs=None
                     ):
        optimizer.step(closure=optimizer_closure)
        optimizer.zero_grad()
#         self.lr_scheduler.step()
    
    def _get_conv_output(self, shape):
        batch_size = 1
        input = torch.autograd.Variable(torch.rand(batch_size, *shape))
        output_feat = self.encoder(input) 
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size
    


In [3]:
class CancerDataset(Dataset):
    def __init__(self, label_dir, img_dir, transform=None):
        self.img_labels = pd.read_csv(label_dir)
        self.img_dir = img_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.img_labels)
    
    def __getitem__(self, idx):
        # label file has two columns: id, label
        # to read each images from dir
        img_path = self.img_dir + '/'+ self.img_labels.iloc[idx,0] + '.tif'
        # read image as numpy array and normalize it
        image = Image.open(img_path)
#         image_array = self.transform(image) / 255.0
        image_array = self.transform(image)
        
        # read label
        label = self.img_labels.iloc[idx, 1].astype('float32')
        label = torch.Tensor([label])
        # return image array and label
        return image_array, label

In [4]:
def train_val_dataset(dataset, train_split=0.75, val_split=0.25):
    train_idx, val_idx = train_test_split(list(range(len(dataset))), train_size=train_split, test_size=val_split)
    datasets = {}
    datasets['train'] = Subset(dataset, train_idx)
    datasets['val'] = Subset(dataset, val_idx)
    return datasets

In [5]:
class DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, label_dir, img_dir):
        super().__init__()
        self.batch_size = batch_size
        self.dims = (3, 96, 96)
        self.num_classes = 1
        self.label_dir = label_dir
        self.img_dir = img_dir
        self.transform = transforms.Compose([
           transforms.RandomHorizontalFlip(),
           transforms.ToTensor(),
           transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
       ])
    
    def setup(self, stage=None):
        data = CancerDataset(label_dir=self.label_dir,
                          img_dir=self.img_dir,
                          transform=self.transform
                    )
        train_n = int(len(data) * wandb.config['data_size'] * 0.7)
        val_n = int(len(data) * wandb.config['data_size'] * 0.3)
        dataset = train_val_dataset(data, train_split=train_n, val_split=val_n)
        self.train_dataset = dataset['train']
        self.val_dataset = dataset['val']
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)       

In [6]:
wandb.login()
wandb.init(
#       mode='disabled',
      # Set the project where this run will be logged
      project="histopathologic-cancer-classifier", 
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name="Test1", 
      # Track hyperparameters and run metadata
      config={
      "learning_rate": 0.0005,
      "data_size": 1.0,
      "batch_size":32,
      })

[34m[1mwandb[0m: Currently logged in as: [33mrespwill[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
wandb_logger = WandbLogger()

  rank_zero_warn(


In [8]:
dm = DataModule(batch_size=wandb.config['batch_size'], 
                label_dir="./train_labels_balance.csv", 
                img_dir="../histopathologic-cancer-detection_data/train/")

  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


In [9]:
im_encoder = ImageEncoder(input_shape=dm.dims, num_classes=dm.num_classes, learning_rate=wandb.config['learning_rate'])

  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


In [10]:
ModelSummary(im_encoder, max_depth=-1)

   | Name      | Type       | Params
------------------------------------------
0  | encoder   | Sequential | 15.7 K
1  | encoder.0 | Conv2d     | 456   
2  | encoder.1 | ReLU       | 0     
3  | encoder.2 | MaxPool2d  | 0     
4  | encoder.3 | Conv2d     | 2.4 K 
5  | encoder.4 | ReLU       | 0     
6  | encoder.5 | MaxPool2d  | 0     
7  | encoder.6 | Conv2d     | 12.8 K
8  | encoder.7 | ReLU       | 0     
9  | encoder.8 | MaxPool2d  | 0     
10 | dense     | Sequential | 530 K 
11 | dense.0   | Flatten    | 0     
12 | dense.1   | Linear     | 491 K 
13 | dense.2   | ReLU       | 0     
14 | dense.3   | Linear     | 28.9 K
15 | dense.4   | ReLU       | 0     
16 | dense.5   | Linear     | 9.7 K 
17 | dense.6   | ReLU       | 0     
18 | dense.7   | Linear     | 81    
19 | accuracy  | Accuracy   | 0     
------------------------------------------
546 K     Trainable params
0         Non-trainable params
546 K     Total params
2.185     Total estimated model params size (MB)

In [11]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath='./check_point/', 
    filename='{epoch}-{train_loss:.4f}-{val_loss:.4f}', 
    monitor="val_loss", 
    mode="min", 
    save_top_k=5
)

In [12]:
trainer = pl.Trainer(accelerator='gpu',
                     devices=1,
                    max_epochs=200,
                     logger=wandb_logger,
                    callbacks=[checkpoint_callback])

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
trainer.fit(im_encoder, dm)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name     | Type       | Params
----------------------------------------
0 | encoder  | Sequential | 15.7 K
1 | dense    | Sequential | 530 K 
2 | accuracy | Accuracy   | 0     
----------------------------------------
546 K     Trainable params
0         Non-trainable params
546 K     Total params
2.185     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
