In [1]:
import os
import tempfile
from typing import List, Tuple

import mlflow
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from dotenv import load_dotenv
from git import Diff, Repo
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from pytorch_lightning.loggers.mlflow import MLFlowLogger
from torch import nn
from torch.utils.data import DataLoader, random_split

load_dotenv("./.env")

if os.environ.get("MLFLOW_TRACKING_TOKEN"):
    print("Token set!")

  from .autonotebook import tqdm as notebook_tqdm


Token set!


In [2]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_set = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
train_size = int(len(train_set) * 0.8)
val_size = len(train_set) - train_size
train_set, val_set = torch.utils.data.random_split(
    train_set, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)
test_set = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)

classes = (
    "plane",
    "car",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
class LitResNet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        original_model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18')
        self.model = nn.Sequential(*list(original_model.children())[:-1])
        self.fc = nn.Linear(in_features=512, out_features=10, bias=True)
        self.softmax = nn.Softmax(1)
    
    def forward(self, x) -> torch.Tensor:
        out = self.model(x)
        out = torch.flatten(out, start_dim=1, end_dim=-1)
        out = self.fc(out)
        out = self.softmax(out)
        return out
    
    def __get_dirty_files(self, repo: Repo) -> List[str]:
        dirty_files = []
        dirty_files += repo.untracked_files
        dirty_files += [diff.b_path for diff in repo.index.diff("HEAD", create_patch=True, R=True)]
        dirty_files += [diff.b_path for diff in repo.index.diff(None)]
        return dirty_files
    
    def __get_file_content(self, file_path: str) -> str:
        with open(file_path, "r", encoding="utf-8") as f:
            text = f.read()
        return text
    
    def __log_dirty_files(self, repo: Repo):
        if not isinstance(self.logger, MLFlowLogger):
            return
        
        dirty_files = self.__get_dirty_files(repo)

        for dirty_file in dirty_files:
            self.logger._mlflow_client.log_artifact(self.logger.run_id, dirty_file, os.path.join("uncommit-files", dirty_file))
        
        comment = "# Uncommit files\n"
        for dirty_file in dirty_files:
            comment += "- {}\n".format(dirty_file)
        if len(dirty_files) == 0:
            comment += "(empty)\n"
        comment += "# Git info\n"
        comment += "- commit hash: {}\n".format(repo.head.commit.hexsha)
        comment += "- branch: {}\n".format(repo.active_branch.name)
        comment += "- repository: {}\n".format(repo.remotes[0].url)
        
        self.logger._mlflow_client.set_tag(self.logger.run_id, "mlflow.note.content", comment)
    
    def on_train_start(self):
        if not isinstance(self.logger, MLFlowLogger):
            return
        run_id = self.logger.run_id
        self.logger.log_hyperparams(self.hparams)
        repo = Repo()
        self.logger._mlflow_client.set_tag(run_id, "mlflow.source.git.commit", repo.head.commit.hexsha)
        self.logger._mlflow_client.set_tag(run_id, "mlflow.source.git.branch", repo.active_branch.name)
        remotes = repo.remotes
        if len(remotes) >= 1:
            self.logger._mlflow_client.set_tag(run_id, "mlflow.source.git.repoURL", repo.remotes[0].url)
        
        self.__log_dirty_files(repo)

    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        self.log_dict({"loss": loss})
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        self.log_dict({"val_loss": loss})
    
    def _common_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
        X_common, y_common = batch[0], batch[1]
        y_hat = self.forward(X_common)
        y_common = F.one_hot(y_common, num_classes=10)
        loss = F.mse_loss(y_hat.type(torch.FloatTensor), y_common.type(torch.FloatTensor))
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-4)
        return optimizer

In [4]:
torch.set_float32_matmul_precision('medium')

train_loader = DataLoader(train_set, shuffle=True, batch_size=128, num_workers=8)
val_loader = DataLoader(val_set, shuffle=False, batch_size=32, num_workers=8)
test_loader = DataLoader(test_set, shuffle=False, batch_size=128, num_workers=8)

mlf_logger = MLFlowLogger(experiment_name="cifar10-practice/resnet-18", tracking_uri=os.environ.get("MLFLOW_TRACKING_URI"), log_model=False)
model = LitResNet()
trainer = pl.Trainer(
    max_epochs=500,
    accelerator="gpu",
    logger=mlf_logger,
    callbacks=[
            ModelCheckpoint(
                dirpath=os.path.join(
                    "lightning_logs/", mlf_logger._experiment_name, mlf_logger.run_id
                ),
                save_top_k=2,
                monitor="val_loss",
                mode="min",
                filename="checkpoint-{epoch:02d}-{val_loss:.5f}",
            ),
            EarlyStopping("val_loss"),
            LearningRateMonitor('step', log_momentum=True)
        ],
    check_val_every_n_epoch=1,
)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

Using cache found in /home/clx/.cache/torch/hub/pytorch_vision_v0.10.0
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | model   | Sequential | 11.2 M
1 | fc      | Linear     | 5.1 K 
2 | softmax | Softmax    | 0     
---------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Epoch 1:  42%|████▏     | 261/626 [00:44<01:01,  5.90it/s, loss=0.0586, v_num=466b]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
