In [None]:
import timm
import torchinfo

### 1. Import MNIST data

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 1. Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)), 
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # make the image 3-channels
])

# 2. Load the training dataset
train_dataset = torchvision.datasets.MNIST(
    root='./data',        
    train=True,           
    download=True,        
    transform=transform   
)

# 3. Load the test dataset
test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,          
    download=True,
    transform=transform
)

# 4. Define DataLoaders
BATCH_SIZE = 64

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True       
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False      
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# note: matplotlib expect image H x W x C
img = train_dataset[0][0].numpy()
img = np.transpose(img, (1, 2, 0))

plt.imshow(img)

### 2. Define Architecture

In [None]:
model = timm.create_model('efficientnet_b0', pretrained=True, features_only=False)

In [None]:
torchinfo.summary(
    model,
    input_size = (32, 3, 224, 224),
    col_names=["input_size", "output_size", "num_params", "trainable"]
)

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
NUM_CLASSES = 10

model.classifier = torch.nn.Sequential(
    torch.nn.Linear(in_features=1280, out_features=NUM_CLASSES, bias=True)
)

### 3. Train 

In [None]:
import mlflow 
import os
from dotenv import load_dotenv

load_dotenv()
MLFLOW_TRACKING_URI = os.getenv("MLFLOW_TRACKING_URI")
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
mlflow.set_experiment("DummyMNIST")

In [None]:
"""
File which contains PyTorch Trainer
"""

import logging
import torch
from tqdm.auto import tqdm
from schemas.trainers.torch_trainer import TorchTrainerConfig


class TorchTrainerModule:

    def __init__(self, cfg: TorchTrainerConfig):
        self.num_epochs = cfg["num_epochs"]
        self.device = cfg["device"]
        self.log_every_n_steps = cfg["log_every_n_steps"]
        self.logger = None # TODO
    
    def _train_step(self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, 
                    loss_fn: torch.nn.Module):
        model.train()

        metrics = {
            'loss': 0
        }
        
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(self.device), y.to(self.device)
            
            # compute prediction
            y_pred = model(X)

            # compute loss
            loss = loss_fn(y_pred, y)
            metrics['loss'] += loss.item()

            # back prop
            loss.backward()

            # optimizer
            optimizer.step()
            optimizer.zero_grad()
        
        metrics['loss'] = metrics['loss'] / len(dataloader)

        return metrics


    def _val_step(self, model: torch.nn.Module, dataloader, loss_fn: torch.nn.Module): # same as train_step
        model.eval()

        metrics = {
            'loss': 0
        }
        
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(self.device), y.to(self.device)
            
            # compute prediction
            y_pred = model(X)

            # compute loss
            loss = loss_fn(y_pred, y)
            metrics['loss'] += loss.item()

            # back prop
            loss.backward()

            # optimizer
            optimizer.step()
            optimizer.zero_grad()
        
        metrics['loss'] = metrics['loss'] / len(dataloader)

        return metrics
        
    
    def train(self, model: torch.nn.Module, train_dataloaders: torch.utils.data.DataLoader, 
              val_dataloaders: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, 
              loss_fn: torch.nn.Module):

        model.to(self.device)

        for epoch in tqdm(range(self.num_epochs)):
            logging.info(f"Epoch {epoch} - Training Step")
            train_metrics = self._train_step(model=model, dataloader=train_dataloaders, optimizer=optimizer, 
                                                         loss_fn=loss_fn)

            logging.info(f"Epoch {epoch} - Validation Step")
            val_metrics = self._val_step(model=model, dataloader=val_dataloaders, loss_fn=loss_fn)

            # save metrics
            for metric, val in train_metrics.items():
                mlflow.log_metric(f"train_{metric}", val)
            for metric, val in val_metrics.items():
                mlflow.log_metric(f"val_{metric}", val)

            if epoch % self.log_every_n_steps == 0 and self.logger: # TODO: log metrics, model, ...
                pass
        

    def test(self, model: torch.nn.Module, dataloaders: torch.utils.data.DataLoader):
        pass

    def predict(self, model: torch.nn.Module, dataloaders: torch.utils.data.DataLoader):
        pass


In [None]:
mlflow.pytorch.autolog()

cfg = {
    'num_epochs': 2, 
    'device': 'mps', 
    'log_every_n_steps': 1
}

trainer = TorchTrainerModule(cfg)

In [None]:
from mlflow.models import infer_signature

X_dummy = torch.randn((2, 3, 28, 28)).to('mps')
y_dummy = model(X_dummy)
signature = infer_signature(X_dummy, y_dummy)

In [None]:
from datetime import datetime

# datetime.now().strftime('%Y-%m-%d %H:%m')

In [None]:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = torch.nn.CrossEntropyLoss()

run_name = f"test_{datetime.now().strftime('%Y-%m-%d-%H:%m')}"

with mlflow.start_run(run_name=run_name):
    # TODO: log hyperparams
    
    trainer.train(model, train_loader, test_loader, optimizer, loss_fn)

    # option 1 - not working
    # mlflow.pytorch.log_model(
    # pytorch_model=model,
    # artifact_path="model",
    # registered_model_name=None,
    # # signature=signature,
    # # input_example=X_dummy
    # )

    # option 2 - works!
    mlflow.pytorch.save_model(
    pytorch_model=model,
    path=f"models/{run_name}"
    )
    mlflow.log_artifacts("model")

In [None]:
# -- doesn't work, need with mlflow.start_run()
mlflow.pytorch.autolog()
trainer.train(model, train_loader, test_loader, optimizer, loss_fn)