In [None]:
import os
from utils import util, trainer, dataset, tools
from utils.tools import MyArgumentParser
from omegaconf import OmegaConf
from sklearn.model_selection import train_test_split
import model
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from PIL import Image

- Setting the seed

In [2]:
# Set random seed for reproducibility
util.setup_seed(42)

- Creating `train.txt`, `val.txt` and `test.txt`.

In [3]:
images_list = [x.split('.')[0] for x in os.listdir('/home/sunag/Documents/ML/PFA/dataset/train/images')]
train_list, val_list = train_test_split(images_list, test_size=0.2, random_state=42)

#Create a splits directory if it doesn't exist
os.makedirs('/home/sunag/Documents/ML/PFA/dataset/train/splits', exist_ok=True)

# Write train.txt
with open(os.path.join('/home/sunag/Documents/ML/PFA/dataset/train/splits', 'train.txt'), 'w') as f:
    for item in train_list:
        f.write(f"{item}\n")

# Write val.txt
with open(os.path.join('/home/sunag/Documents/ML/PFA/dataset/train/splits', 'val.txt'), 'w') as f:
    for item in val_list:
        f.write(f"{item}\n")

# Create test.txt
test_list = [x.split('.')[0] for x in os.listdir('/home/sunag/Documents/ML/PFA/dataset/test1/images')]
os.makedirs('/home/sunag/Documents/ML/PFA/dataset/test1/splits', exist_ok=True)
with open(os.path.join('/home/sunag/Documents/ML/PFA/dataset/test1/splits', 'test.txt'), 'w') as f:
    for item in test_list:
        f.write(f"{item}\n")

In [None]:
# obj = dataset.ScribbleClassData(data_list='train.txt',
#                                 data_root='/home/sunag/Documents/ML/PFA/dataset/train',
#                                 transform=None,
#                                 path='scribbles')
# a, b, c, d = obj[8]  # Get the first item from the dataset
# print(f"Image shape: {a.shape}, Scribble shape: {b.shape}, Class label shape: {c.shape}, Path: {d}")

# obj2 = dataset.ScribbleClassData(data_list='test.txt',
#                                  data_root='/home/sunag/Documents/ML/PFA/dataset/test1',
#                                  transform=None,
#                                  path='scribbles')
# a2, b2, c2, d2 = obj2[8]  # Get the first item from the test dataset
# print(f"Image shape: {a2.shape}, Scribble shape: {b2.shape}, Class label shape: {c2.shape}, Path: {d2}")

Image shape: (375, 500, 3), Scribble shape: (375, 500, 3), Class label shape: torch.Size([1]), Path: /home/sunag/Documents/ML/PFA/dataset/train/images/2010_000622.jpg
Image shape: (375, 500, 3), Scribble shape: (375, 500, 3), Class label shape: torch.Size([1]), Path: /home/sunag/Documents/ML/PFA/dataset/test1/images/2010_005888.jpg


In [None]:
# obj3 = util.get_loader(is_train=True, args=cfg)
# obj4 = util.get_loader(is_train=False, args=cfg)

- Training and validation code

In [None]:
cfg = OmegaConf.load("/home/sunag/Documents/ML/PFA/config/train.yaml")
#cfg2 = OmegaConf.load("/home/sunag/Documents/ML/PFA/config/test.yaml")
cfg.work_dir.ckpt_dir = os.path.join(cfg.work_dir.ckpt_dir, cfg.project_name)
os.makedirs(cfg.work_dir.ckpt_dir, exist_ok=True)
gpu_ids = ''
for gpu_id in cfg.train.gpu_ids:
    gpu_ids += str(gpu_id) + ','
gpu_ids = gpu_ids[:-1]

trainer = trainer.MixTrTrainer(args=cfg)
trainer.train_model()

- Inference code

In [None]:
mean = [0.485, 0.456, 0.406]
mean = [item * 255 for item in mean]
std = [0.229, 0.224, 0.225]
std = [item * 255 for item in std]
test_transform = tools.Compose([
               tools.Crop([320, 320],
                    crop_type='center', padding=mean, ignore_label=255),
               tools.ToTensor(),
               tools.Normalize(mean=mean, std=std)])

os.makedirs('/home/sunag/Documents/ML/PFA/dataset/test1/PFA_preds', exist_ok=True)

test_data = dataset.ScribbleClassData(data_list='test.txt',
                                 data_root='/home/sunag/Documents/ML/PFA/dataset/test1',
                                 transform=test_transform,
                                 path='scribbles')

# test_dataloader = DataLoader(test_data, num_workers=2,
#                          batch_size=12,
#                           shuffle=True, pin_memory=True)

for i in range(len(test_data)):
    image, mask = test_data[i]
    input_size = image.size()[2:4]
    image = image.unsqueeze(0)  # Add batch dimension
    mask = mask.unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
       _, output = model.model(image.to(device='cuda' if torch.cuda.is_available() else 'cpu'))
       pred = F.interpolate(output, size=input_size, mode='bilinear', align_corners=True).cpu().numpy()
       pred = Image.fromarray(pred)
       pred.save(f"/home/sunag/Documents/ML/PFA/dataset/test1/PFA_preds/output_{test_list[i]}.png")