In [1]:
#!/home/rafatmatting/anaconda3/envs/ml/bin/python
import torch
from pytorch_lightning import LightningModule, Trainer
from networks.UNet import UNet
from networks.MODNet import MODNet
from networks.GFM import GFM
from networks.DFM import DFM
from datasets.MattingDataModule import MattingDataModule
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
    ModelSummary,
    LearningRateMonitor,
)
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2

Backbone HRNet Pretrained weights at: ./checkpoints/hrnetv2_32_model_best_epoch96.pth, only usable for HRNetv2-32


In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

model_type = "UNet"
dataset_name = "AMD"
dataset_config = "config/datasets.yaml"
epochs = 60
num_workers = 32
batch_size = 1
resume_from_checkpoint  = ""
learning_rate = .01
log_folder = "./logs"

In [3]:
def create_test_transform(height: int = 512, width: int = 512):
    return A.Compose(
        [
            A.RandomCrop(height=height,width=width),
            # A.Resize(width=RESIZE, height=RESIZE),
            ToTensorV2(),
        ],
        additional_targets={
            "image": "image",
            "mask": "image",
            "trimap": "image",
            "fg": "image",
            "bg": "image",
        },
    )

In [4]:
settings = {
        "learning_rate": 1,
        "monitor": "validation_loss"
    }

if model_type == "MODNet":
    network = MODNet(settings)
elif model_type == "UNet":
    network = UNet(settings)
elif model_type == "GFM":
    network = GFM(settings)
elif model_type == "DFM":
    network = DFM(settings)
else:
    raise Exception("model_type not given")

"""_dataset_
"""
data_module = MattingDataModule(dataset_name=dataset_name, num_workers=num_workers, batch_size=batch_size, transform=create_test_transform())

data_module.prepare_data()
data_module.setup()

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers import WandbLogger

experiment_name = f"{model_type}_{dataset_name}"
version_name = f"epochs:{epochs}_lr:{learning_rate}"
tensorboard_logger = TensorBoardLogger(log_folder, name=experiment_name)
wandb_logger = WandbLogger(project=experiment_name)

checkpoint_path = os.path.join(log_folder, experiment_name, version_name, "checkpoints")

callbacks = [
    ModelCheckpoint(
        dirpath=checkpoint_path,
        every_n_epochs=1,
        mode="min",
        monitor="validation_loss",
        save_last=True,
    ),
]

checkpoint_file = os.path.join(checkpoint_path,"last.ckpt")
checkpoint_file = "/home/rafatmatting/dfm/logs/UNet_AMD_cropped/epochs:60_lr:0.01/checkpoints/last.ckpt"

from pytorch_lightning.plugins import DDPPlugin

trainer = Trainer(
    logger=tensorboard_logger,
    # gpus=torch.cuda.device_count(),
    # devices=torch.cuda.device_count(),
    # accelerator="gpu",
    # strategy=DDPPlugin(find_unused_parameters=False),
    strategy=DDPPlugin(),
    callbacks=callbacks,
    max_epochs=epochs,
    # auto_lr_find=True,
    # auto_scale_batch_size=True,
    # overfit_batches=10,
    # fast_dev_run=1,
    # resume_from_checkpoint=checkpoint_file,
)

# trainer.tune(network, datamodule=data_module)
predict = trainer.predict(network, datamodule=data_module, ckpt_path=checkpoint_file)


2022-08-29 16:39:24,406:[DEBUG]:utils.download:/home/rafatmatting/dfm/data/AMD already exist!
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(
initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

Restoring states from the checkpoint path at /home/rafatmatting/dfm/logs/UNet_AMD_cropped/epochs:60_lr:0.01/checkpoints/last.ckpt
Missing logger folder: ./logs/UNet_AMD
Loaded model weights from checkpoint at /home/rafatmatting/dfm/logs/UNet_AMD_cropped/epochs:60_lr:0.01/checkpoints/last.ckpt


Predicting: 0it [00:00, ?it/s]

IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/rafatmatting/anaconda3/envs/ml/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/rafatmatting/anaconda3/envs/ml/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/rafatmatting/anaconda3/envs/ml/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/rafatmatting/dfm/datasets/MattingDataset.py", line 83, in __getitem__
    fg_path = self.annotations_df.iloc[index, 3]
  File "/home/rafatmatting/anaconda3/envs/ml/lib/python3.9/site-packages/pandas/core/indexing.py", line 925, in __getitem__
    return self._getitem_tuple(key)
  File "/home/rafatmatting/anaconda3/envs/ml/lib/python3.9/site-packages/pandas/core/indexing.py", line 1506, in _getitem_tuple
    self._has_valid_tuple(tup)
  File "/home/rafatmatting/anaconda3/envs/ml/lib/python3.9/site-packages/pandas/core/indexing.py", line 754, in _has_valid_tuple
    self._validate_key(k, i)
  File "/home/rafatmatting/anaconda3/envs/ml/lib/python3.9/site-packages/pandas/core/indexing.py", line 1409, in _validate_key
    self._validate_integer(key, axis)
  File "/home/rafatmatting/anaconda3/envs/ml/lib/python3.9/site-packages/pandas/core/indexing.py", line 1500, in _validate_integer
    raise IndexError("single positional indexer is out-of-bounds")
IndexError: single positional indexer is out-of-bounds


In [None]:
predict[0].shape

In [None]:
import matplotlib.pyplot as plt

plt.imshow(predict[0][6].permute(1,2,0).detach().numpy(), cmap="gray")

In [None]:
import matplotlib.pyplot as plt

plt.imshow(gt[0].permute(1,2,0).detach().numpy(), cmap="gray")

In [None]:
import matplotlib.pyplot as plt

plt.imshow(predict_global[0].permute(1,2,0).detach().numpy(), cmap="gray")

In [None]:
import matplotlib.pyplot as plt

plt.imshow(predict_local[0].permute(1,2,0).detach().numpy(), cmap="gray")

In [None]:
import matplotlib.pyplot as plt

plt.imshow(predict_fusion[0].permute(1,2,0).detach().numpy(), cmap="gray")