In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import random
from tqdm import tqdm
import torch
import numpy as np
import cv2
from PIL import Image
from torch import nn
from torchvision.models import vgg19
from torch.optim import Adam
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision.models.feature_extraction import create_feature_extractor
from torch.utils.data.dataset import random_split
import pickle
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
from torchvision.transforms import ToPILImage
from sklearn.model_selection import train_test_split
from torchvision.transforms.functional import to_pil_image
from scipy.spatial.distance import cosine
from torchvision.models.resnet import ResNet18_Weights, ResNet50_Weights
from torch import nn

### MODELS

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d):
        """
        input_nc (int): Number of channels in the input images.
        ndf (int): Number of filters in the first convolutional layer.
        n_layers (int): Number of convolutional layers.
        norm_layer: Normalization layer class.
        """
        super(Discriminator, self).__init__()
        kw = 4
        padw = 1
        stride = 2

        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=stride, padding=padw), nn.LeakyReLU(0.2, True)]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=stride, padding=padw, bias=False),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=False),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        self.model = nn.Sequential(*sequence)

        self.flatten_size = self._get_flatten_size(input_nc, ndf, n_layers, kw, padw, stride)

        self.fc = nn.Linear(self.flatten_size, 1)
        self.sigmoid = nn.Sigmoid()

    def _get_flatten_size(self, input_nc, ndf, n_layers, kw, padw, stride):
        with torch.no_grad():
            dummy_input = torch.zeros(1, input_nc, 256, 256)
            output = self.model(dummy_input)
            return int(torch.prod(torch.tensor(output.shape[1:])))

    def forward(self, input):
        output = self.model(input)
        output = output.view(output.size(0), -1)  
        output = self.fc(output)  
        output = self.sigmoid(output)
        return output
    
    
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=2, dilation=2),
            nn.BatchNorm2d(in_features),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25), 
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=2, dilation=2),
            nn.BatchNorm2d(in_features)
        )

    def forward(self, x):
        return x + self.block(x)
    
    
class Generator(nn.Module):
    def __init__(self, input_channels, num_residual_blocks=8):
        super(Generator, self).__init__()
        
        self.init_conv = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
                                   nn.BatchNorm2d(128),
                                   nn.LeakyReLU(0.2, inplace=True))
        self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
                                   nn.BatchNorm2d(256),
                                   nn.LeakyReLU(0.2, inplace=True))
        
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(256) for _ in range(num_residual_blocks)]
        )
        
        self.up1 = nn.Sequential(nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
                                 nn.BatchNorm2d(128),
                                 nn.LeakyReLU(0.2, inplace=True))
        self.up2 = nn.Sequential(nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
                                 nn.BatchNorm2d(64),
                                 nn.LeakyReLU(0.2, inplace=True))
        
        self.out = nn.Sequential(
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()  
        )

    def forward(self, x):
        x_init = self.init_conv(x)
        
        x_down1 = self.down1(x_init)
        x_down2 = self.down2(x_down1)
        x_res = self.res_blocks(x_down2)
        
        x_up1 = self.up1(x_res) + x_down1  
        x_up2 = self.up2(x_up1) + x_init 
        
        x_out = self.out(x_up2)
        return x_out
    
class EdgeRefineModule(nn.Module):
    def __init__(self, input_channels=5, num_layers=3):
        super(EdgeRefineModule, self).__init__()

        layers = [
            nn.Conv2d(input_channels, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        ]

        for _ in range(1, num_layers):  
            layers.extend([
                nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True)
            ])

        self.model = nn.Sequential(*layers)
        self.final_conv = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1)  
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.model(x)
        x = self.final_conv(x)
        return self.sigmoid(x)

### KINS DATASET

In [None]:
class KINSDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images_dir = os.path.join(root_dir, 'images')
        self.gt_masks_dir = os.path.join(root_dir, 'gt_masks')
        self.occ_masks_dir = os.path.join(root_dir, 'occ_masks')
        self.image_files = [f for f in os.listdir(self.images_dir) if f.endswith('.png')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.images_dir, self.image_files[idx])
        gt_mask_name = os.path.join(self.gt_masks_dir, f'gt_mask_{self.image_files[idx].split("_")[-1]}')
        occ_mask_name = os.path.join(self.occ_masks_dir, f'occ_mask_{self.image_files[idx].split("_")[-1]}')

        image = Image.open(img_name).convert('RGB')
        gt_mask = Image.open(gt_mask_name).convert('L') 
        occ_mask = Image.open(occ_mask_name).convert('L')

        if self.transform:
            image = self.transform(image)
            gt_mask = self.transform(gt_mask)
            occ_mask = self.transform(occ_mask)

        return {'image': image, 'processed_img_gt': gt_mask, 'processed_img_occ': occ_mask, "path": gt_mask_name}
transform = transforms.Compose([
    transforms.ToTensor(),
])

### PASCAL

In [None]:
def process_occluder(ann, sz):
    occluder = ann["occluder_mask"]
    occ_mask = np.zeros(sz)
    if occluder is not None:
        if occluder.ndim != 0:
            for occ in occluder:
                occ = np.array(occ)
                occ = occ.reshape((-1,1,2))
                occ_temp = points2mask(occ, sz)
                occ_mask = np.maximum(occ_mask, occ_temp)
            return occ_mask
    return None

def get_bbox(ann, img_size, enlarge_factor = 1.2):
    bbox = np.array(ann["box"][0:4]).astype(int)
    h, w = img_size
    y0, x0, y1, x1 = bbox
    width = x1 - x0
    height = y1 - y0
    delta_width = int((width * enlarge_factor - width) / 2)
    delta_height = int((height * enlarge_factor - height) / 2)
    new_y0 = max(0, y0 - delta_height)
    new_x0 = max(0, x0 - delta_width)
    new_y1 = min(h, y1 + delta_height)
    new_x1 = min(w, x1 + delta_width)

    return [new_y0, new_x0, new_y1, new_x1]

def points2mask(points, img_size):
    mask = np.zeros(img_size, dtype=np.uint8)
    points = points.astype(np.int32)
    cv2.fillPoly(mask, [points], 255)
    return mask

def draw_points(img, points):
    for point in points:
        cv2.circle(img, tuple(int(x) for x in point), 3, (255, 0, 0), -1)
    return img

def process_annotation(ann, img_size):
    flat_list = [item for sublist in ann["mask"] for item in sublist]
    obj_points = np.array(flat_list).reshape((-1, 2)).astype(int)
    try:
        obj_mask = points2mask(obj_points, img_size)
    except Exception as e:
        print(e)
        print(obj_points)
    
    amodal_bbox = get_bbox(ann, img_size)
    occluder_mask = process_occluder(ann, img_size)
    
    if occluder_mask is not None:
        final_mask = obj_mask * (255 - occluder_mask)
    else:
        final_mask = obj_mask 
    
    return final_mask, amodal_bbox

def divide_gt(gt_images_path):
    names = []
    for p in os.listdir(gt_images_path):
        names.append(p)
        
    train, val = train_test_split(names, test_size = 0.2, random_state = 42)
    return train, val

def process_directories(path):
    images = os.path.join(path, "images")
    gt_images = os.path.join(images, "carFGL0_BGL0")
    train_img_dirs = [os.path.join(images, i) for i in os.listdir(images) if os.path.join(images, i) != gt_images]

    all_train_images = []
    all_train_anns = []
    all_val_images = []
    
    _, gt_names_val = divide_gt(gt_images)
    
    for directory in train_img_dirs:
        for file in os.listdir(directory):
            if file in gt_names_val:
                all_val_images.append(os.path.join(directory, file))
            else:
                all_train_images.append(os.path.join(directory, file))
    
    return all_train_images, all_val_images


path = "/kaggle/input/occ-vehicles-final/occluded_vehicles/testing"
all_train_images, all_val_images = process_directories(path)

def get_ann_path(image_path):
    ann_path = image_path.replace("images", "anns").replace(".JPEG", ".npz")
    gt_ann_path = ann_path.replace(image_path.split('/')[-2], "carFGL0_BGL0")
    
    return ann_path, gt_ann_path
ann_path, gt_ann_path = get_ann_path("/kaggle/input/occ-vehicles-final/occluded_vehicles/testing/images/carFGL3_BGL1/n04166281_18171.JPEG")

In [None]:
class OccludedVehiclesDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):

        img_path = self.image_paths[idx]
        ann_path_occ, ann_path_gt = get_ann_path(img_path)

        image = Image.open(img_path).convert("RGB")
        img_size = image.size[::-1] 

        ann_occ = np.load(ann_path_occ, allow_pickle=True)
        processed_img_occ, bbox = process_annotation(ann_occ, img_size)

        ann_gt = np.load(ann_path_gt, allow_pickle=True)
        processed_img_gt, bbox = process_annotation(ann_gt, img_size)

        processed_img_occ = self.crop_using_bbox(processed_img_occ, bbox, img_path, img_size)
        processed_img_gt = self.crop_using_bbox(processed_img_gt, bbox,img_path, img_size)
        image = self.crop_using_bbox(np.array(image), bbox, img_path, img_size)
        image = resize_and_pad(image)
        processed_img_gt = resize_and_pad(processed_img_gt.convert("L"))
        processed_img_occ = resize_and_pad(processed_img_occ.convert("L"))
        
        if self.transform:
            image = self.transform(image)
            processed_img_occ = self.transform(processed_img_occ)
            processed_img_occ = processed_img_occ // torch.max(processed_img_occ)
            processed_img_gt = self.transform(processed_img_gt)
            processed_img_gt = processed_img_gt // torch.max(processed_img_gt)
        return {
            'image': image,
            'processed_img_occ': processed_img_occ,
            'processed_img_gt': processed_img_gt,
            'path' : ann_path_gt,
        }
    
    def crop_using_bbox(self, image_array, bbox, img_path, img_size):
        y_min, x_min, y_max, x_max = bbox

        cropped_img = Image.fromarray(image_array[y_min:y_max, x_min:x_max])
        return cropped_img

def resize_and_pad(item, target_size = (256,256), padding_value=0):
    original_width, original_height = item.size
    ratio = min(target_size[0] / original_width, target_size[1] / original_height)
    new_size = (int(original_width * ratio), int(original_height * ratio))
    
    interpolation_method = Image.NEAREST if item.mode == 'L' else  Image.Resampling.LANCZOS
    item = item.resize(new_size, interpolation_method)

    mode = 'L' if item.mode == 'L' else 'RGB'
    new_item = Image.new(mode, target_size, padding_value)

    paste_x = (target_size[0] - new_size[0]) // 2
    paste_y = (target_size[1] - new_size[1]) // 2
    new_item.paste(item, (paste_x, paste_y))

    return new_item

transform = transforms.Compose([
    transforms.ToTensor(),
])


#### GET RANDOM MASKS

In [None]:
def get_smasks_batch(batch = 16):
    path = "/kaggle/input/good3dmasks200/kaggle/working/3dmasks"
    masks = []
    for i in range(batch):
        random_mask = random.choice(list(os.listdir(path)))
        p = os.path.join(path, random_mask)
        sampled_mask = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
        masks.append(sampled_mask)
    masks = np.array(masks)
    masks = torch.from_numpy(masks)
    masks = masks / 255.0
    
    return masks

### CLOSEST MASKS

In [None]:
def load_dict(path_to_dict):
    with open(path_to_dict, 'rb') as f:
        data = pickle.load(f)
    return data

# class MaskFeatureExtractor(torch.nn.Module):
#     def __init__(self, base_model):
#         super().__init__()
#         self.base_model = base_model
#         self.squeeze_layer = torch.nn.Linear(1000, 400)

#     def forward(self, x):
#         x = self.base_model(x)
#         x = torch.flatten(x, 1)
#         x = self.squeeze_layer(x)
#         return x

# def init_feature_extractor():
#     device = "cuda" if torch.cuda.is_available() else "cpu"
#     resnet = models.resnet50(weights=ResNet50_Weights.DEFAULT)
#     resnet.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#     model = MaskFeatureExtractor(resnet).to(device)
#     model.eval()
#     return model

# def get_features_batch(masks, resnet_model):
#     with torch.no_grad():
#         mask_features = resnet_model(masks)  
#     return mask_features


# def find_closest_masks(batch_target_features, mask_features_dict):
#     mask_names = list(mask_features_dict.keys())
#     mask_features_list = [mask_features_dict[name] for name in mask_names]
#     mask_features = torch.stack([torch.tensor(features, dtype=batch_target_features.dtype, device=batch_target_features.device)
#                                  if not isinstance(features, torch.Tensor) else features
#                                  for features in mask_features_list])
#     cosine_sim = F.cosine_similarity(batch_target_features.unsqueeze(1), mask_features.unsqueeze(0), dim=2)
#     cosine_distance = 1 - cosine_sim
#     min_distance_indices = torch.argmin(cosine_distance, dim=1)
#     closest_masks = [mask_names[idx] for idx in min_distance_indices]
#     return closest_masks

# def find_closest_mask(target_features, mask_features_dict):
#     closest_mask = None
#     min_distance = float('inf')  # Start with the highest possible distance

#     for mask_name, features in mask_features_dict.items():

#         if not isinstance(features, torch.Tensor):
#             features = torch.tensor(features, dtype=target_features.dtype, device=target_features.device)
#         features = features.unsqueeze(0)  # Add batch dimension
#         cosine_sim = F.cosine_similarity(target_features, features, dim=1).item()
#         cosine_distance = 1 - cosine_sim  # Calculate cosine distance
#         if cosine_distance < min_distance:
#             min_distance = cosine_distance
#             closest_mask = mask_name

#     return closest_mask


# def find_closest_masks_batch(current_features, vocab):
#     res = []
#     for feature in current_features:
#         closest_mask = find_closest_mask(feature, vocab)
#         res.append(closest_mask)
#     return r

# def find_closest_masks_batch(target_features, mask_features_dict):
#     for mask_name, features in mask_features_dict.items():
#         if not isinstance(features, torch.Tensor):
#             mask_features_dict[mask_name] = torch.tensor(features, dtype=target_features.dtype, device=target_features.device)

#     all_mask_features = torch.stack(list(mask_features_dict.values()))
#     all_mask_names = list(mask_features_dict.keys())

#     cosine_sim = F.cosine_similarity(target_features.unsqueeze(1), all_mask_features.unsqueeze(0), dim=2)
#     cosine_distances = 1 - cosine_sim  # Convert similarity to distance

#     min_indices = torch.argmin(cosine_distances, dim=1)

#     closest_masks = [all_mask_names[idx] for idx in min_indices]
#     return closest_masks


# def get_closest_masks(loader, resnet_model, vocab):
#     mapping = {}
#     for data in tqdm(loader):
#         _, _, amodal_mask, image_name = data["image"], data["processed_img_occ"],\
#                                             data["processed_img_gt"], data["path"]
#         current_features = get_features_batch(amodal_mask.to("cuda"), resnet_model)
#         closest_masks = find_closest_masks_batch(current_features, vocab)
#         for n, m in zip(image_name, closest_masks):
#             mapping[n] = m
#     return mapping
        

### Contours

In [None]:
def smooth_contour(contour, epsilon_factor=0.0001):
    epsilon = epsilon_factor * cv2.arcLength(contour, True)
    smoothed_contour = cv2.approxPolyDP(contour, epsilon, True)
    return smoothed_contour

def get_contour(mask):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cont = max(contours, key=cv2.contourArea)
    cont = smooth_contour(cont)
    return cont

def find_bounding_box(mask, increase_by=0.5):
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
    ymin, ymax = np.where(rows)[0][[0, -1]]
    xmin, xmax = np.where(cols)[0][[0, -1]]
    x_center = (xmin + xmax) / 2
    y_center = (ymin + ymax) / 2
    new_width = (xmax - xmin) * (1 + increase_by)
    new_height = (ymax - ymin) * (1 + increase_by)
    xmin = max(0, int(x_center - new_width / 2))
    xmax = int(x_center + new_width / 2)
    ymin = max(0, int(y_center - new_height / 2))
    ymax = int(y_center + new_height / 2)
    
    return xmin, ymin, xmax, ymax


def crop_and_center_mask(mask, output_size=(256,256)):
    xmin, ymin, xmax, ymax = find_bounding_box(mask, increase_by=0.5)
    ymax = min(ymax, mask.shape[0])
    xmax = min(xmax, mask.shape[1])
    cropped_mask = mask[ymin:ymax, xmin:xmax]
    
    if output_size is None:
        return cropped_mask
    centered_mask = np.zeros((output_size[1], output_size[0]), dtype=np.uint8)
    x_offset = (output_size[0] - cropped_mask.shape[1]) // 2
    y_offset = (output_size[1] - cropped_mask.shape[0]) // 2
    centered_mask[y_offset:y_offset+cropped_mask.shape[0], x_offset:x_offset+cropped_mask.shape[1]] = cropped_mask
    
    return centered_mask
    
def compare_masks(mask1, mask2):
    mask1_processed = preprocess_mask(mask1)
    mask2_processed = preprocess_mask(mask2)

    contours1, _ = cv2.findContours(mask1_processed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours2, _ = cv2.findContours(mask2_processed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    cnt1 = max(contours1, key=cv2.contourArea)
    cnt2 = max(contours2, key=cv2.contourArea)

    match = cv2.matchShapes(cnt1, cnt2, 1, 0.0)
    return match

def get_best(current_c, vocab):
    best_m = float("inf")
    best_p = ""
    for key in vocab.keys():
        match = cv2.matchShapes(current_c, vocab[key], 1, 0.0)
        if match < best_m:
            best_m = match
            best_p = key
    return best_p

def get_best_batch(amodal_masks, vocab):
    best_masks = []
    for mask in amodal_masks:
        mask = mask.cpu().numpy()
        mask = mask[0].astype(np.uint8)
        mask = crop_and_center_mask(mask)

        mask_c = get_contour(mask)
        best_mask = get_best(mask_c, vocab)
        best_masks.append(best_mask)
    return best_masks


def draw_contours(contours, canvas_size, contour_color=(0, 255, 0), thickness=2):
    canvas = np.zeros((canvas_size[1], canvas_size[0], 3), dtype=np.uint8)
    
    cv2.drawContours(canvas, contours, -1, contour_color, thickness)
    plt.imshow(canvas)
    plt.show()

    return canvas


def get_closest_by_c(names):
    res = []
    for name in names:
        img = cv2.imread(name, cv2.IMREAD_GRAYSCALE)
        res.append(img)
    return res

def create_mapping(loader, vocab):
    mapping = {}
    for data in tqdm(loader, desc = "Create mapping"):
        input_image, inmodal_mask, amodal_mask, names = data["image"],\
                            data["processed_img_occ"], data["processed_img_gt"], data["path"]
        best_names = get_best_batch(amodal_mask, vocab)
        for i in range(len(names)):
            mapping[names[i]] = cv2.imread(best_names[i], cv2.IMREAD_GRAYSCALE)
    return mapping

def get_sampled_batch(names, mapping, pascal=False):
    res = torch.Tensor()
    for name in names:
        if pascal:
            parts = name.split("/")
            last = parts[-1].split(".")[0]
            last = ".".join([last, "jpeg"])  
            parts[-1] = last
            name = "/".join(parts)
            name = name.replace("anns", "images")
        mapped = mapping[name]  
        res = torch.cat((res, torch.from_numpy(mapped).unsqueeze(0)))
    return res
        
        

### LOSSES

In [None]:
class PerceptualLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(PerceptualLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        inputs = torch.clamp(inputs, min=1e-7, max=1 - 1e-7)
        bce_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.where(targets == 1, inputs, 1 - inputs)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

    

class GradLayer(nn.Module):
    def __init__(self):
        super(GradLayer, self).__init__()
        kernel_v = torch.tensor([[0, -1, 0], [0, 0, 0], [0, 1, 0]], dtype=torch.float32).view(1, 1, 3, 3)
        kernel_h = torch.tensor([[0, 0, 0], [-1, 0, 1], [0, 0, 0]], dtype=torch.float32).view(1, 1, 3, 3)
        self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False)
        self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False)

    def forward(self, x):
        x_v = F.conv2d(x, self.weight_v, padding=1)
        x_h = F.conv2d(x, self.weight_h, padding=1)
        gradients = torch.sqrt(x_v.pow(2) + x_h.pow(2) + 1e-6)
        return gradients

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=0.25, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss


class EdgeRefineLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=10.0, gamma=2, focal_alpha=0.25):
        super(EdgeRefineLoss, self).__init__()
        self.grad_layer = GradLayer()
        self.grad_loss = nn.L1Loss()
        self.focal_loss = FocalLoss(gamma=gamma, alpha=focal_alpha)
        self.alpha = alpha  
        self.beta = beta    

    def forward(self, predicted, gt_full_mask, gt_visible_mask):
        
        predicted_edges = self.grad_layer(predicted)
        target_full_edges = self.grad_layer(gt_full_mask)
        seg_loss_full = self.focal_loss(predicted, gt_full_mask)
        gt_invisible_mask = gt_full_mask - gt_visible_mask
        predicted_invisible = predicted * (gt_invisible_mask > 0).float()
        seg_loss_invisible = self.focal_loss(predicted_invisible, gt_invisible_mask)
        edge_loss = self.grad_loss(predicted_edges, target_full_edges)
        total_loss = self.alpha * edge_loss + self.beta * (0.5 * seg_loss_full + 0.5 * seg_loss_invisible)
        return total_loss




In [None]:
def smooth_positive_labels(y):
    y = torch.clamp(y - 0.1, min=0.0)
    random_part = torch.rand(y.shape, device=y.device) * 0.1
    y = torch.clamp(y + random_part, max=1.0)
    return y

def flip_labels(labels, flip_prob=0.05):
    flip_mask = torch.rand(labels.shape, device=labels.device) < flip_prob
    flipped_labels = torch.where(flip_mask, 1 - labels, labels)
    flipped_labels = torch.clamp(flipped_labels, min=0.0, max=1.0)
    return flipped_labels


class OBJ_D_loss(nn.Module):
    def __init__(self, flip_prob=0.1):
        super(OBJ_D_loss, self).__init__()
        self.bce_loss = torch.nn.BCELoss()
        self.flip_prob = flip_prob

    def forward(self, pred_real_gt, pred_fake , pred_real_s):
        real_labels = smooth_positive_labels(torch.full_like(pred_real_gt, 1.0))
        fake_labels = torch.full_like(pred_fake, 0.0)

        real_labels = flip_labels(real_labels, self.flip_prob)
        fake_labels = flip_labels(fake_labels, self.flip_prob)

        loss_real_gt = self.bce_loss(pred_real_gt, real_labels)
        loss_real_s = self.bce_loss(pred_real_s, real_labels)
        loss_fake = self.bce_loss(pred_fake, fake_labels)

        loss_real = (loss_real_gt + loss_real_s) / 2
        return loss_real + loss_fake

class INS_D_loss(nn.Module):
    def __init__(self, flip_prob=0.1):
        super(INS_D_loss, self).__init__()
        self.bce_loss = torch.nn.BCELoss()
        self.flip_prob = flip_prob

    def forward(self, pred_real_gt, pred_fake_G_output, pred_fake_M_s):
        real_labels = smooth_positive_labels(torch.full_like(pred_real_gt, 1.0))
        fake_labels = torch.full_like(pred_fake_G_output, 0.0)

        real_labels = flip_labels(real_labels, self.flip_prob)
        fake_labels = flip_labels(fake_labels, self.flip_prob)

        loss_real = self.bce_loss(pred_real_gt, real_labels)
        loss_fake_G_output = self.bce_loss(pred_fake_G_output, fake_labels)
        loss_fake_M_s = self.bce_loss(pred_fake_M_s, fake_labels)

        loss_fake = (loss_fake_G_output + loss_fake_M_s) / 2
        return loss_real + loss_fake

class SegmentationLoss(nn.Module):
    def __init__(self, lambda_val, beta_val, perceptual_loss_fn):
        super(SegmentationLoss, self).__init__()
        self.lambda_val = lambda_val
        self.beta_val = beta_val
        self.l1_loss_fn = nn.L1Loss()

    def forward(self, G_output, M_gt, obj_loss, ins_loss, perceptual_loss):
        l1_loss = self.l1_loss_fn(G_output, M_gt)
        return self.lambda_val * l1_loss + self.beta_val * perceptual_loss + obj_loss + ins_loss



### TRAIN EPOCH

In [None]:
def train_epoch(train_loader, generator, d_obj, d_ins, egde_refine_module,
                device, optim_g, optim_d_obj, optim_d_ins, optim_edge_refine,
                segmentation_loss, obj_d_loss_fn, ins_d_loss_fn,
                perceptual_loss_fn, edge_refine_loss_fn, epoch, save_folder, mapping):
    total_d_obj_loss = 0
    total_d_ins_loss = 0
    total_seg_loss = 0
    total_perceptual = 0
    total_edge_refine_loss = 0
    
    num_batches = len(train_loader)
    print("len map in train ", len(mapping))

    for i, data in enumerate(tqdm(train_loader)):
        input_image, inmodal_mask, amodal_mask, names = data["image"], data["processed_img_occ"], data["processed_img_gt"], data["path"]
        input_image, inmodal_mask, amodal_mask = input_image.to(device), inmodal_mask.to(device), amodal_mask.to(device)
        input_tensor = torch.cat((input_image, inmodal_mask), dim=1).to(device)

        sampled_masks = get_smasks_batch().unsqueeze(dim=1).to(device)
        sampled_masks_obj = get_sampled_batch(names, mapping).unsqueeze(1).to(device)

        # Generator forward pass
        fake_masks = generator(input_tensor)
        
        fake_masks_to_refine = fake_masks.clone().detach()
        fake_masks_to_refine = torch.where(fake_masks_to_refine < 0.8, torch.zeros_like(fake_masks_to_refine), torch.ones_like(fake_masks_to_refine))
        refine_input = torch.cat((inmodal_mask, input_image, fake_masks_to_refine), dim = 1).to(device)
#         print(refine_input.shape)
        refined_masks = egde_refine_module(refine_input)

        # Discriminator predictions
        d_obj_pred_real = d_obj(amodal_mask)
        d_obj_pred_fake = d_obj(fake_masks.detach())
        d_obj_pred_sampled = d_obj(sampled_masks_obj)
        
        d_ins_pred_real = d_ins(torch.cat((amodal_mask, input_image, inmodal_mask), dim=1))
        d_ins_pred_fake = d_ins(torch.cat((fake_masks.detach(), input_image, inmodal_mask), dim=1))
        d_ins_pred_sampled = d_ins(torch.cat((sampled_masks, input_image, inmodal_mask), dim=1))
        
        # Discriminator OBJ loss and update
        optim_d_obj.zero_grad()
        d_obj_loss = obj_d_loss_fn(d_obj_pred_real, d_obj_pred_fake, d_obj_pred_sampled)
#         print("d_obj_loss = ", d_obj_loss)
        d_obj_loss.backward()
        optim_d_obj.step()

        # Discriminator INS loss and update
        optim_d_ins.zero_grad()
        d_ins_loss = ins_d_loss_fn(d_ins_pred_real, d_ins_pred_fake, d_ins_pred_sampled)
        d_ins_loss.backward()
        optim_d_ins.step()

        # Generator loss and update
        optim_g.zero_grad()
        perceptual = perceptual_loss_fn(fake_masks, amodal_mask)
        
        seg_loss = segmentation_loss(fake_masks, amodal_mask, d_obj_loss.item(), d_ins_loss.item(), perceptual.item())
        seg_loss.backward()
        optim_g.step()
        
        optim_edge_refine.zero_grad()
        edge_refine_loss = edge_refine_loss_fn(refined_masks, amodal_mask, inmodal_mask)
        edge_refine_loss.backward()
        optim_edge_refine.step()

        total_d_obj_loss += d_obj_loss.item()
        total_d_ins_loss += d_ins_loss.item()
        total_seg_loss += seg_loss.item()
        total_perceptual += perceptual.item()
        total_edge_refine_loss += edge_refine_loss.item()

        if i == num_batches - 1 and epoch % 5 == 0:
            save_images_with_overlay(input_image, fake_masks,refined_masks,amodal_mask, save_folder, epoch)

            

    avg_d_obj_loss = total_d_obj_loss / num_batches
    avg_d_ins_loss = total_d_ins_loss / num_batches
    avg_seg_loss = total_seg_loss / num_batches
    avg_perceptual = total_perceptual / num_batches
    avg_edge_refine_loss = total_edge_refine_loss / num_batches
    return avg_d_obj_loss, avg_d_ins_loss, avg_seg_loss, avg_perceptual, avg_edge_refine_loss


### VALIDATION EPOCH

In [None]:
def validate_epoch(val_loader, generator, d_obj, d_ins, edge_refine_module,
                   device, segmentation_loss, perceptual_loss_fn,
                   obj_d_loss_fn, ins_d_loss_fn, edge_refine_loss_fn, epoch, save_folder,
                mapping):
    
    total_d_obj_loss = 0
    total_d_ins_loss = 0
    total_seg_loss = 0
    total_iou = 0
    total_iou_r = 0
    total_precision = 0
    total_precision_r = 0
    total_recall = 0
    total_recall_r = 0
    total_f1_r = 0
    total_f1 = 0
    total_l1_error = 0
    total_l2_error = 0
    total_edge_refine_loss = 0

    generator.eval()
    d_obj.eval()
    d_ins.eval()
    edge_refine_module.eval()

    with torch.no_grad():
        num_batches = len(val_loader)
        for i, data in enumerate(tqdm(val_loader, desc="Validation")):
            input_image, inmodal_mask, amodal_mask, names = data["image"], data["processed_img_occ"], data["processed_img_gt"], data["path"]
            input_image, inmodal_mask, amodal_mask = input_image.to(device), inmodal_mask.to(device), amodal_mask.to(device)
            input_tensor = torch.cat((input_image, inmodal_mask), dim=1).to(device)
            
            fake_masks = generator(input_tensor)
            
            fake_masks_to_refine = fake_masks.clone().detach()
            fake_masks_to_refine = torch.where(fake_masks_to_refine < 0.8, torch.zeros_like(fake_masks_to_refine), torch.ones_like(fake_masks_to_refine))
            refine_input = torch.cat((inmodal_mask, input_image, fake_masks_to_refine), dim = 1).to(device)
            refined_masks = edge_refine_module(refine_input)
            
            sampled_masks = get_smasks_batch().unsqueeze(dim=1).to(device)
#             sampled_masks_obj = get_closest_mask_obj(amodal_mask, masks_dict, resnet_model, mask_images).to(device)
#             sampled_masks_obj = get_closest_masks_mapping(mapping, names)
            sampled_masks_obj = get_sampled_batch(names, mapping).unsqueeze(1).to(device)


            precision, recall, f1, iou = compute_segmentation_metrics(fake_masks, amodal_mask)
            precision_r, recall_r, f1_r, iou_r = compute_segmentation_metrics(refined_masks, amodal_mask)
            
            l1_error = torch.nn.functional.l1_loss(fake_masks, amodal_mask).item()
            l2_error = torch.nn.functional.mse_loss(fake_masks, amodal_mask).item()

            d_obj_pred_real = d_obj(amodal_mask)
            d_obj_pred_fake = d_obj(fake_masks)
            d_obj_pred_sampled = d_obj(sampled_masks_obj)
            
            d_ins_pred_real = d_ins(torch.cat((amodal_mask, input_image, inmodal_mask), dim=1))
            d_ins_pred_fake = d_ins(torch.cat((fake_masks, input_image, inmodal_mask), dim=1))
            d_ins_pred_sampled = d_ins(torch.cat((sampled_masks, input_image, inmodal_mask), dim=1))

            d_obj_loss = obj_d_loss_fn(d_obj_pred_real, d_obj_pred_fake, d_obj_pred_sampled)
            d_ins_loss = ins_d_loss_fn(d_ins_pred_real, d_ins_pred_fake, d_ins_pred_sampled)
            perceptual_loss = perceptual_loss_fn(fake_masks, amodal_mask)
            
#             print("contout_consistency loss = ", contour * 0.25)
            seg_loss = segmentation_loss(fake_masks, amodal_mask, d_obj_loss.item(), d_ins_loss.item(), perceptual_loss.item())
            edge_refine_loss = edge_refine_loss_fn(refined_masks, amodal_mask, inmodal_mask)

            total_d_obj_loss += d_obj_loss.item()
            total_d_ins_loss += d_ins_loss.item()
            total_seg_loss += seg_loss.item()
            total_iou += iou
            total_iou_r += iou_r
            total_precision += precision
            total_precision_r += precision_r
            total_recall += recall
            total_recall_r += recall_r
            total_f1 += f1
            total_f1_r += f1_r
            total_l1_error += l1_error
            total_l2_error += l2_error
            total_edge_refine_loss += edge_refine_loss.item()
            
            if i == num_batches - 1 and epoch % 5 == 0:
                save_images_with_overlay(input_image, fake_masks,refined_masks,amodal_mask, save_folder, epoch)

    avg_metrics = {
        "d_obj_loss": total_d_obj_loss / num_batches,
        "d_ins_loss": total_d_ins_loss / num_batches,
        "seg_loss": total_seg_loss / num_batches,
        "iou": total_iou / num_batches,
        "precision": total_precision / num_batches,
        "recall": total_recall / num_batches,
        "f1": total_f1 / num_batches,
        "l1_error": total_l1_error / num_batches,
        "l2_error": total_l2_error / num_batches,
        "refine_module_loss": total_edge_refine_loss / num_batches,
        "iou_r":  total_iou_r / num_batches,
        "precision_r": total_precision_r / num_batches,
        "recall_r": total_recall_r / num_batches,
        "f1_r": total_f1_r / num_batches,
        
    }

    return avg_metrics

def compute_segmentation_metrics(pred, target, threshold=0.8):
    pred = (pred > threshold).float()
    pred = pred.view(-1)
    target = target.view(-1)
    true_positive = (pred * target).sum()
    false_positive = (pred * (1 - target)).sum()
    false_negative = ((1 - pred) * target).sum()
    precision = true_positive / (true_positive + false_positive + 1e-8)
    recall = true_positive / (true_positive + false_negative + 1e-8)
    f1_score = 2 * (precision * recall) / (precision + recall + 1e-8)
    intersection = true_positive
    union = true_positive + false_positive + false_negative
    iou = intersection / (union + 1e-8)

    return precision.item(), recall.item(), f1_score.item(), iou.item()



def overlay_masks(input_images, masks, alpha=0.3, color=[1.0, 0.0, 0.0]):  
    masks_rgb = masks.repeat(1, 3, 1, 1) * torch.tensor(color, device=masks.device).view(1, 3, 1, 1)
    overlayed_images = input_images.clone()  
    overlayed_images = overlayed_images * (1 - alpha) + masks_rgb * alpha
    return overlayed_images

def save_images_with_overlay(input_images, fake_masks, refined_masks, amodal_masks, folder, epoch):
    overlayed_fake = overlay_masks(input_images, fake_masks, alpha=0.3, color=[1.0, 0.0, 0.0])
    overlayed_amodal = overlay_masks(input_images, amodal_masks, alpha=0.3, color=[0.0, 0.0, 1.0])
    overlayed_refined = overlay_masks(input_images, refined_masks, alpha=0.3, color=[0.0, 0.0, 1.0])
    

    os.makedirs(folder, exist_ok=True)
    for i, (img_fake, img_amodal, img_refined) in enumerate(zip(overlayed_fake, overlayed_amodal, overlayed_refined)):
        image_fake = to_pil_image(img_fake.cpu()).convert("RGB")
        image_refined = to_pil_image(img_refined.cpu()).convert("RGB")
        filename_fake = os.path.join(folder, f"fake_{i}_{epoch}.png")
        filename_refined = os.path.join(folder, f"refined_{i}_{epoch}.png")
        image_fake.save(filename_fake)
        image_refined.save(filename_refined)
        if epoch == 5:
            image_amodal = to_pil_image(img_amodal.cpu()).convert("RGB")
            filename_amodal = os.path.join(folder, f"amodal_{i}_{epoch}.png")
            image_amodal.save(filename_amodal)


In [None]:
def train_and_validate(train_loader, val_loader, d_obj, d_ins, generator,edge_refine_module, mapping):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"\nAvailable device: {device}")
    optim_d_obj = Adam(d_obj.parameters(), lr=1e-5, betas=(0.5, 0.999))
    optim_d_ins = Adam(d_ins.parameters(), lr=1e-5, betas=(0.5, 0.999))
    optim_g = Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
    optim_refine =  Adam(edge_refine_module.parameters(), lr=1e-4, betas=(0.5, 0.999))

    scheduler_d_obj = ReduceLROnPlateau(optim_d_obj, mode='min', factor=0.5, patience=15, verbose=True, min_lr = 1e-7)
    scheduler_d_ins = ReduceLROnPlateau(optim_d_ins, mode='min', factor=0.5, patience=15, verbose=True, min_lr = 1e-7)
    scheduler_g = ReduceLROnPlateau(optim_g, mode='min', factor=0.5, patience=15, verbose=True, min_lr = 1e-7)
    scheduler_refine = ReduceLROnPlateau(optim_refine, mode='min', factor=0.5, patience=10, verbose=True, min_lr = 1e-7)

    perceptual_loss_fn = PerceptualLoss().to(device)
    segmentation_loss = SegmentationLoss(20, 1, perceptual_loss_fn).to(device)
    obj_d_loss_fn = OBJ_D_loss().to(device)
    ins_d_loss_fn = INS_D_loss().to(device)
    refine_loss_fn = EdgeRefineLoss().to(device)
    
    metrics = {
        'train_d_obj_loss': [],
        'train_d_ins_loss': [],
        'train_seg_loss': [],
        'val_d_obj_loss': [],
        'val_d_ins_loss': [],
        'val_seg_loss': [],
        'val_iou': [],
        'val_precision': [],
        'val_recall': [],
        'val_f1': []
    }

    epochs = 100
    for epoch in range(epochs):
        train_losses = train_epoch(train_loader, generator, d_obj, d_ins, edge_refine_module, device, 
                                   optim_g, optim_d_obj, optim_d_ins,optim_refine, segmentation_loss, 
                                   obj_d_loss_fn, ins_d_loss_fn, perceptual_loss_fn,refine_loss_fn, epoch,
                                   "/kaggle/working/training_output", mapping)

        validation_losses = validate_epoch(val_loader, generator,
                                           d_obj, d_ins,edge_refine_module,  device, segmentation_loss,
                                           perceptual_loss_fn, obj_d_loss_fn,
                                           ins_d_loss_fn,refine_loss_fn, epoch,
                                           "/kaggle/working/validation_output",
                                            mapping)
        if epoch % 10 == 0 and epoch != 0:
            torch.save(generator.state_dict(), 'model_weights_gen.pth')
            torch.save(edge_refine_module.state_dict(), 'edge_refine_weights.pth')
        print(f"Epoch {epoch} train losses = ", train_losses)
        print(f"Epoch {epoch} validation losses = ", validation_losses)
        scheduler_d_obj.step(validation_losses["d_obj_loss"])
        scheduler_d_ins.step(validation_losses["d_ins_loss"])
        scheduler_g.step(validation_losses["seg_loss"])
        scheduler_refine.step(validation_losses["refine_module_loss"])

        metrics['train_d_obj_loss'].append(train_losses[0])
        metrics['train_d_ins_loss'].append(train_losses[1])
        metrics['train_seg_loss'].append(train_losses[2])
        metrics['val_d_obj_loss'].append(validation_losses['d_obj_loss'])
        metrics['val_d_ins_loss'].append(validation_losses['d_ins_loss'])
        metrics['val_seg_loss'].append(validation_losses['seg_loss'])
        metrics['val_iou'].append(validation_losses['iou'])
        metrics['val_precision'].append(validation_losses['precision'])
        metrics['val_recall'].append(validation_losses['recall'])
        metrics['val_f1'].append(validation_losses['f1'])

    plt.figure(figsize=(15, 10))
    plt.subplot(3, 2, 1)
    plt.plot(metrics['train_d_obj_loss'], label='Train Discriminator Obj Loss')
    plt.plot(metrics['val_d_obj_loss'], label='Val Discriminator Obj Loss')
    plt.title('Discriminator Obj Loss')
    plt.legend()

    plt.subplot(3, 2, 2)
    plt.plot(metrics['train_d_ins_loss'], label='Train Discriminator Ins Loss')
    plt.plot(metrics['val_d_ins_loss'], label='Val Discriminator Ins Loss')
    plt.title('Discriminator Ins Loss')
    plt.legend()

    plt.subplot(3, 2, 3)
    plt.plot(metrics['train_seg_loss'], label='Train Segmentation Loss')
    plt.plot(metrics['val_seg_loss'], label='Val Segmentation Loss')
    plt.title('Segmentation Loss')
    plt.legend()

    plt.subplot(3, 2, 4)
    plt.plot(metrics['val_iou'], label='Val IOU')
    plt.title('IOU')
    plt.legend()

    plt.subplot(3, 2, 5)
    plt.plot(metrics['val_precision'], label='Val Precision')
    plt.plot(metrics['val_recall'], label='Val Recall')
    plt.plot(metrics['val_f1'], label='Val F1 Score')
    plt.title('Precision, Recall, F1 Score')
    plt.legend()

    plt.tight_layout()
    plt.show()


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# kins_dataset = KINSDataset("/kaggle/input/kins-data-correct/kaggle/working/kins_data",transform = transform)

# total_count = len(kins_dataset)
# train_count = int(int(0.4 * total_count) + 16 - (int(0.4 * total_count) % 16))
# val_count = int(int(0.2 * total_count) + 16 - (int(0.2 * total_count)  % 16))
# rest_count = total_count - train_count - val_count

# # Split the dataset
# ds_train_kins, ds_val_kins, _ = random_split(kins_dataset, [train_count, val_count, rest_count])

# print(len(ds_train_kins))

# train_loader = DataLoader(ds_train_kins, batch_size=16, shuffle=True, drop_last=False, num_workers = 4)
# val_loader = DataLoader(ds_val_kins, batch_size=16, drop_last=False, num_workers = 4)


### PASCAL LOADERS

In [None]:
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

torch.cuda.manual_seed_all(seed)

In [None]:
ds_train = OccludedVehiclesDataset(all_train_images, transform=transform)
ds_val = OccludedVehiclesDataset(all_val_images, transform=transform)

train_dataset, _ = random_split(ds_train, [0.15,0.85])
val_dataset, _ = random_split(ds_val, [0.3, 0.7])

#     remaining_size = dataset_size - train_size - val_size
#     train_dataset, val_dataset, _ = random_split(dataset, [train_size, val_size, remaining_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, drop_last=True, num_workers = 4)
val_loader = DataLoader(val_dataset, batch_size=16, drop_last=True, num_workers = 4)


In [None]:
# from IPython.display import FileLink
# FileLink(r'mapping.pkl')

In [None]:
vocab = load_dict("/kaggle/input/masks200contours/my_dict-6.pkl")
# val_mapping = create_mapping(val_loader, vocab)
# train_mapping = create_mapping(train_loader, vocab)
mapping = load_dict("/kaggle/input/pascal-mapping/pascal_mapping.pkl")

In [None]:
d_obj = Discriminator(1, n_layers = 3).to(device)
d_ins = Discriminator(5, n_layers = 3).to(device)
generator = Generator(4).to(device)
refine_module = EdgeRefineModule().to(device)

In [None]:
# !rm -rf validation_output train_output

In [None]:
# !mkdir validation_output train_output

In [None]:
train_and_validate(train_loader, val_loader, d_obj, d_ins, generator,refine_module, mapping)
