In [1]:
import torch
import torchvision
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

import os



In [2]:
batch_size_train = 1
batch_size_test = 1

In [3]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [4]:
torchvision.models.mobilenet_v2(torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2)



MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

In [27]:
class AlgaeLightning(pl.LightningModule):
    def __init__(self, model_kwargs):
        super().__init__()
        self.model = torchvision.models.mobilenet_v2(torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2)
        self.model.classifier[1] = torch.nn.Linear(self.model.classifier[1].in_features, 10)
        self.criterion = model_kwargs["criterion"]
        self.optimizer = model_kwargs["optimizer"]
        self.lr = model_kwargs["lr"]
        
        self.transform=transforms.Compose([
        torchvision.transforms.Grayscale(num_output_channels=3),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,))
        ])

    def forward(self, x):
        
        x = self.transform(x)
        
        return self.model(x)

    def training_step(self, batch, batch_idx):
        # 1       
        loss = self._loss_n_metrics(batch, mode="train")
        print("training_step")
        #loss = torch.tensor(1)
        return loss
    
    def validation_step(self):
        # 2
        print("validation_step")
        #self._loss_n_metrics(batch, mode="val")
    
    def on_validation_epoch_end(self):
        print("on_validation_epoch_end")
        # 3
        pass
    
    def on_train_epoch_end(self):
        print("on_validation_epoch_end")
        # 4
        pass

    def configure_optimizers(self):
        print("configure_optimizers")
        
        if self.optimizer == "adamw":
            optimizer = optim.AdamW(self.parameters(), lr=self.lr)
            lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,150], gamma=0.1)
            return [optimizer], [lr_scheduler]
        else:
            return optim.SGD(self.parameters(), lr=self.lr)
        
    def _loss_n_metrics(self, batch, mode="train"):
        
        img, ground_truth = batch
        model_output = self(img.to("cpu")) # cause of the forward function
        
        ground_truth = ground_truth.unsqueeze(1)
        
        ground_truth_multi_hot = torch.zeros(ground_truth.size(0), 10).scatter_(1, ground_truth.to("cpu"), 1.)
        
        # this needs fixing
        # ground_truth_multi_hot = torch.zeros(ground_truth.size(0), 10).to("cuda").scatter_(torch.tensor(1).to("cuda"), ground_truth.to("cuda"), torch.tensor(1.).to("cuda")).to("cuda")
        
        loss = self.criterion(model_output, ground_truth_multi_hot.to("cpu"))
        acc = (model_output.argmax(dim=-1) == ground_truth).float().mean()

        self.log(f'{mode}_loss', loss, on_step=False, on_epoch=True)
        self.log(f'{mode}_acc', acc, on_step=False, on_epoch=True)
        
        return loss

In [29]:
transform=transforms.Compose([
    torchvision.transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

dataset1 = datasets.MNIST('example_data', train=True, download=True,
                       transform=transform)
dataset1

Dataset MNIST
    Number of datapoints: 60000
    Root location: example_data
    Split: Train
    StandardTransform
Transform: Compose(
               Grayscale(num_output_channels=3)
               ToTensor()
               Normalize(mean=(0.1307,), std=(0.3081,))
           )

In [30]:
import medmnist
from medmnist import INFO
info = INFO['retinamnist']
DataClass = getattr(medmnist, info['python_class'])
train_dataset = DataClass(split='train', download=True)       

Using downloaded and verified file: C:\Users\Prinzessin\.medmnist\retinamnist.npz


In [32]:
dir(train_dataset)

['__add__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_is_protocol',
 'as_rgb',
 'download',
 'flag',
 'imgs',
 'info',
 'labels',
 'montage',
 'root',
 'save',
 'split',
 'target_transform',
 'transform']

In [23]:
dir(dataset1)

['__add__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_check_exists',
 '_check_legacy_exist',
 '_format_transform_repr',
 '_is_protocol',
 '_load_data',
 '_load_legacy_data',
 '_repr_indent',
 'class_to_idx',
 'classes',
 'data',
 'download',
 'extra_repr',
 'mirrors',
 'processed_folder',
 'raw_folder',
 'resources',
 'root',
 'target_transform',
 'targets',
 'test_data',
 'test_file',
 'test_labels',
 'train',
 'train_data',
 'train_labels',
 'training_file',
 'transform',
 'transforms']

In [6]:
def dev_routine(**kwargs):
    
    train_kwargs = kwargs['train_kwargs']
    
    
    transform=transforms.Compose([
    torchvision.transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

    if False:
        dataset1 = datasets.MNIST('example_data', train=True, download=True,
                           transform=transform)
    else:
        info = INFO['retinamnist']
        DataClass = getattr(medmnist, info['python_class'])
        dataset1 = DataClass(split='train', download=True)
        
        
    #dataset2 = datasets.MNIST('example_data', train=False,
    #                   transform=transform)
    train_dataloader = torch.utils.data.DataLoader(dataset1, batch_size=train_kwargs["batch_size"])
    #inference_loader = torch.utils.data.DataLoader(dataset2, **train_kwargs)
    
    print(train_kwargs)
    trainer = pl.Trainer(default_root_dir=os.path.join(train_kwargs["ckpt_path"], "example_results"),
                         accelerator="gpu" if str(train_kwargs["device"]).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=train_kwargs["epochs"],
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
                                    LearningRateMonitor("epoch")])
    
    trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(train_kwargs["ckpt_path"], "DecentNet.ckpt")
    if os.path.isfile(pretrained_filename):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        model = AlgaeLightning.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameters
    else:
        pl.seed_everything(19) # To be reproducable
        
        # Initialize the LightningModule and LightningDataModule
        model = AlgaeLightning(kwargs['model_kwargs'])

        # Train the model using a Trainer
        trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=None)
        
        model = AlgaeLightning.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    # Test best model on validation and test set
    val_result = trainer.test(model, val_loader, verbose=False)
    test_result = trainer.test(model, test_loader, verbose=False)
    result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}

    return model, result
    

In [10]:
model, results = dev_routine(model_kwargs={
                                'embed_dim': 256,
                                'hidden_dim': 512,
                                'num_heads': 8,
                                'num_layers': 6,
                                'patch_size': 4,
                                'num_channels': 3,
                                'num_patches': 64,
                                'num_classes': 10,
                                'criterion': torch.nn.CrossEntropyLoss(),
                                'optimizer': "adamw",
                                'lr': 0.001,
                            },
                            train_kwargs={
                                'epochs': 100,
                                'batch_size': 16,
                                # 'test_batch_size': 1,
                                'ckpt_path': "",
                                'device': "cpu"
                            }
                            )

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Global seed set to 19


Using downloaded and verified file: C:\Users\Prinzessin\.medmnist\retinamnist.npz
{'epochs': 100, 'batch_size': 16, 'ckpt_path': '', 'device': 'cpu'}


  rank_zero_warn(

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | MobileNetV2      | 2.2 M 
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
2.2 M     Trainable params
0         Non-trainable params
2.2 M     Total params
8.947     Total estimated model params size (MB)
  rank_zero_warn(


configure_optimizers


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>