In [1]:
#!/home/rafatmatting/anaconda3/envs/ml/bin/python
import torch
from pytorch_lightning import LightningModule, Trainer
from networks.UNet import UNet
from networks.MODNet import MODNet
from networks.GFM import GFM
from networks.DFM import DFM
from datasets.MattingDataModule import MattingDataModule
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
    ModelSummary,
    LearningRateMonitor,
)
import torchvision
import matplotlib.pyplot as plt
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

model_type = "GFM"
dataset_name = "test"
dataset_config = "config/datasets.yaml"
epochs = 60
num_workers = 32
batch_size = 1
resume_from_checkpoint  = ""
learning_rate = .01
log_folder = "./logs"

In [None]:
def create_test_transform(height: int = 224, width: int = 224):
    return A.Compose(
        [
            # A.RandomCrop(height=height,width=width),
            A.Resize(width=height, height=width),
            ToTensorV2(),
        ],
        additional_targets={
            "image": "image",
            "mask": "image",
            "trimap": "image",
            "fg": "image",
            "bg": "image",
        },
    )

In [None]:
settings = {
        "learning_rate": 1,
        "monitor": "validation_loss"
    }

if model_type == "MODNet":
    network = MODNet(settings)
elif model_type == "UNet":
    network = UNet(settings)
elif model_type == "GFM":
    network = GFM(settings)
elif model_type == "DFM":
    network = DFM(settings)
else:
    raise Exception("model_type not given")

"""_dataset_
"""
data_module = MattingDataModule(dataset_name=dataset_name, num_workers=num_workers, batch_size=batch_size, transform=create_test_transform())

data_module.prepare_data()
data_module.setup()

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers import WandbLogger

experiment_name = f"{model_type}_{dataset_name}"
version_name = f"epochs:{epochs}_lr:{learning_rate}"
tensorboard_logger = TensorBoardLogger(log_folder, name=experiment_name)
wandb_logger = WandbLogger(project=experiment_name)

checkpoint_path = os.path.join(log_folder, experiment_name, version_name, "checkpoints")

callbacks = [
    ModelCheckpoint(
        dirpath=checkpoint_path,
        every_n_epochs=1,
        mode="min",
        monitor="validation_loss",
        save_last=True,
    ),
]

checkpoint_file = os.path.join(checkpoint_path,"last.ckpt")
checkpoint_file = "/home/rafatmatting/dfm/logs/GFM_AMD/resize/checkpoints/epoch=100-step=403.ckpt"

from pytorch_lightning.plugins import DDPPlugin
trainer = Trainer(
    logger=tensorboard_logger,
    gpus=1,#torch.cuda.device_count(),
    # devices=torch.cuda.device_count(),
    # accelerator="gpu",
    # strategy=DDPPlugin(find_unused_parameters=False),
    # strategy=DDPPlugin(),
    callbacks=callbacks,
    max_epochs=epochs,
    # auto_lr_find=True,
    # auto_scale_batch_size=True,
    # overfit_batches=10,
    # fast_dev_run=1,
    # resume_from_checkpoint=checkpoint_file,
)

# trainer.tune(network, datamodule=data_module)
predict = trainer.predict(network, dataloaders=data_module.test_dataloader(), ckpt_path=checkpoint_file)

In [None]:
# mse = torch.Tensor([loss[1] for loss in predict]).mean()
# mae = torch.Tensor([loss[0] for loss in predict]).mean()
# inference = torch.Tensor([loss[2] for loss in predict]).mean()
# images = torch.Tensor([loss[3][0] for loss in predict])
# preds = torch.Tensor([loss[4][0] for loss in predict])
# masks = torch.Tensor([loss[5][0] for loss in predict])


In [None]:
plt.imshow(predict[9][3].squeeze().permute(1,2,0)/255)
plt.imshow(torch.sigmoid(predict[9][4]).squeeze()/255, cmap="gray")
plt.imshow(predict[22][4].squeeze()/255, cmap="gray")

In [None]:
for i, item in enumerate(predict):
    torchvision.utils.save_image(predict[i][3].squeeze()/255, f"{model_type}_output/{i}_image.jpg")
    torchvision.utils.save_image(predict[i][4].squeeze(), f"{model_type}_output/{i}_predict.jpg")
    torchvision.utils.save_image(predict[i][5].squeeze()/255, f"{model_type}_output/{i}_mask.jpg")

In [None]:
mse

In [None]:
mae

In [None]:
inference

In [None]:
grid = torchvision.utils.make_grid(predict[0][3].squeeze())